diff --git a/reasoning_gym/games/countdown.py b/reasoning_gym/games/countdown.py index 4721844d..87d1a793 100644 --- a/reasoning_gym/games/countdown.py +++ b/reasoning_gym/games/countdown.py @@ -1,9 +1,10 @@ from dataclasses import dataclass from random import Random -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Any import sympy from sympy import Symbol, symbols +from sympy.parsing.sympy_parser import parse_expr from ..factory import ProceduralDataset, register_dataset @@ -157,6 +158,23 @@ class CountdownDataset(ProceduralDataset): continue raise ValueError(f"Failed to generate valid expression after {max_attempts} attempts") + + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + """Determine if the solution provided solves the problem""" + reward = 0.0 + if answer is not None: + try: + user_answer = int(parse_expr(answer)) + solved = user_answer == metadata["target"] + if solved: + reward = 1.0 + elif (len(answer.strip()) > 0): # encourage partial solutions + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward # Register the dataset