diff --git a/reasoning_gym/games/countdown.py b/reasoning_gym/games/countdown.py index 75552494..c8da71a5 100644 --- a/reasoning_gym/games/countdown.py +++ b/reasoning_gym/games/countdown.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +import numpy as np import sympy from sympy import Symbol, symbols from sympy.parsing.sympy_parser import parse_expr @@ -23,6 +24,16 @@ Final answer format instructions: DATASET_NAME = "countdown" +_num_re = re.compile(r"\b\d+\b") # pre-compile once, reuse + + +def _extract_ints(expr_str: str) -> list[int]: + """ + Fast path: grab the literal integers that appear in the source text. + Handles duplicates correctly (e.g. “1 + 1 + 81” ⇒ [1, 1, 81]). + """ + return [int(m) for m in _num_re.findall(expr_str)] + @dataclass class CountdownConfig: @@ -192,12 +203,14 @@ class CountdownDataset(ProceduralDataset): return reward try: - answer = answer.strip() - user_answer = int(parse_expr(answer)) - used_numbers = [int(num) for num in re.findall(r"\b\d+\b", answer)] - target_numbers = set(entry["metadata"]["numbers"]) + user_answer = float(parse_expr(answer)) + used_numbers = _extract_ints(answer) + target_numbers = entry["metadata"]["numbers"] - if (user_answer == entry["metadata"]["target"]) and (set(used_numbers) == target_numbers): + if sorted(used_numbers) != sorted(target_numbers): + return 0.05 + + if np.isclose(user_answer, entry["metadata"]["target"], atol=1e-6): return 1.0 return 0.05 if answer else 0.01 diff --git a/tests/test_countdown.py b/tests/test_countdown.py index 2070a6dc..6e3a1e0a 100644 --- a/tests/test_countdown.py +++ b/tests/test_countdown.py @@ -100,6 +100,20 @@ def test_answer_without_all_numbers(): assert dataset.score_answer(answer=answer, entry=item) == 0.05 +def test_edge_cases_1(): + dataset = CountdownDataset(CountdownConfig(size=10, seed=42)) + answer = "1*81" + item = {"metadata": {"numbers": [1, 1, 1, 81], "target": 81}} + assert dataset.score_answer(answer=answer, entry=item) != 1.0 + + +def test_edge_cases_2(): + dataset = CountdownDataset(CountdownConfig(size=10, seed=42)) + answer = "6*34/11-1" + item = {"metadata": {"numbers": [6, 34, 1, 11], "target": 17}} + assert dataset.score_answer(answer=answer, entry=item) != 1.0 + + def test_countdown_game_randomization(): """Test number randomization configuration""" config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing