diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 0df90113..3075dd28 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from random import Random from typing import Any, Dict, Literal, Optional +from reasoning_gym import utils + from ..factory import ProceduralDataset, register_dataset @@ -88,7 +90,7 @@ class BasicArithmeticDataset(ProceduralDataset): else: expression, result = self._generate_simple_task(rng, num_terms, num_digits) - question = self._format_question(rng, expression) + "." + question = self._format_question(rng, expression) return { "question": question, @@ -233,19 +235,8 @@ class BasicArithmeticDataset(ProceduralDataset): return template.format(expression) def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: - """Overwrite this method in derived classes if a single oracle answer is not available.""" oracle_answer = entry["answer"].strip() - reward = 0.0 - if answer is not None and len(answer) > 0: - answer = answer.strip().replace(",", "") - if answer == oracle_answer: - reward = 1.0 - elif oracle_answer in answer: - reward = len(oracle_answer) / len(answer) - else: - reward = 0.01 - - return reward + return utils.compute_reward(answer, oracle_answer, allow_commas=False) # Register the dataset diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 9bf59c00..4a5d68c1 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -2,6 +2,8 @@ import random from dataclasses import dataclass from typing import Dict, Optional +from reasoning_gym import utils + from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -109,19 +111,8 @@ class ChainSumDataset(ProceduralDataset): return expression, result def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: - """Overwrite this method in derived classes if a single oracle answer is not available.""" oracle_answer = entry["answer"].strip() - reward = 0.0 - if answer is not None and len(answer) > 0: - answer = answer.strip().replace(",", "") - if answer == oracle_answer: - reward = 1.0 - elif oracle_answer in answer: - reward = len(oracle_answer) / len(answer) - else: - reward = 0.01 - - return reward + return utils.compute_reward(answer, oracle_answer) class ChainSumCurriculum(BaseCurriculum): diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py index 546e3c8c..2d77035b 100644 --- a/reasoning_gym/arithmetic/products.py +++ b/reasoning_gym/arithmetic/products.py @@ -2,6 +2,8 @@ import random from dataclasses import dataclass from typing import Dict, Optional +from reasoning_gym import utils + from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -101,19 +103,8 @@ class ProductsDataset(ProceduralDataset): return expression, result def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: - """Overwrite this method in derived classes if a single oracle answer is not available.""" oracle_answer = entry["answer"].strip() - reward = 0.0 - if answer is not None and len(answer) > 0: - answer = answer.strip().replace(",", "") - if answer == oracle_answer: - reward = 1.0 - elif oracle_answer in answer: - reward = len(oracle_answer) / len(answer) - else: - reward = 0.01 - - return reward + return utils.compute_reward(answer, oracle_answer) class ProductsCurriculum(BaseCurriculum): diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py index c59c06ca..d90356e1 100644 --- a/reasoning_gym/utils.py +++ b/reasoning_gym/utils.py @@ -79,3 +79,28 @@ def is_integer(obj: Any) -> bool: elif isinstance(obj, Fraction): return obj.denominator == 1 return False + + +def compute_reward(answer: Optional[str], oracle_answer: str, allow_commas: bool = True) -> float: + """Compute the reward for a given answer compared to the oracle answer. + + Args: + answer: Answer provided by model + oracle_answer: Correct answer to the question + allow_commas: Whether to allow commas in the answer e.g "1,000" = "1000" + + Returns: + Reward value between 0.0 and 1.0 + """ + reward = 0.0 + if answer is not None and len(answer) > 0: + answer = answer.strip() + answer = answer.replace(",", "") if allow_commas else answer + if answer == oracle_answer: + reward = 1.0 + elif oracle_answer in answer: + reward = len(oracle_answer) / len(answer) + else: + reward = 0.01 + + return reward