diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 4fc62bad..0c3ee345 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -2,8 +2,6 @@ from dataclasses import dataclass from random import Random from typing import Any, Literal, Optional -from reasoning_gym import utils - from ..factory import ProceduralDataset, register_dataset @@ -234,10 +232,6 @@ class BasicArithmeticDataset(ProceduralDataset): template = rng.choice(templates) return template.format(expression) - def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - oracle_answer = entry["answer"].strip() - return utils.compute_reward(answer, oracle_answer, allow_commas=False) - # Register the dataset register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig) diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 04d79d29..2072983c 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -111,8 +111,7 @@ class ChainSumDataset(ProceduralDataset): return expression, result def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - oracle_answer = entry["answer"].strip() - return utils.compute_reward(answer, oracle_answer) + return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"]) class ChainSumCurriculum(BaseCurriculum): diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py index 606e35eb..8401be91 100644 --- a/reasoning_gym/arithmetic/products.py +++ b/reasoning_gym/arithmetic/products.py @@ -103,8 +103,7 @@ class ProductsDataset(ProceduralDataset): return expression, result def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - oracle_answer = entry["answer"].strip() - return utils.compute_reward(answer, oracle_answer) + return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"]) class ProductsCurriculum(BaseCurriculum): diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py index d90356e1..5d8e7b69 100644 --- a/reasoning_gym/utils.py +++ b/reasoning_gym/utils.py @@ -81,7 +81,7 @@ def is_integer(obj: Any) -> bool: return False -def compute_reward(answer: Optional[str], oracle_answer: str, allow_commas: bool = True) -> float: +def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_commas: bool = True) -> float: """Compute the reward for a given answer compared to the oracle answer. Args: @@ -94,13 +94,18 @@ def compute_reward(answer: Optional[str], oracle_answer: str, allow_commas: bool """ reward = 0.0 if answer is not None and len(answer) > 0: - answer = answer.strip() - answer = answer.replace(",", "") if allow_commas else answer - if answer == oracle_answer: - reward = 1.0 - elif oracle_answer in answer: + reward = 0.01 + try: + if strip_commas: + answer = answer.replace(",", "") + oracle_answer = oracle_answer.replace(",", "") + + if Decimal(answer) == Decimal(oracle_answer): + reward = 1.0 + except: + pass + + if oracle_answer in answer: reward = len(oracle_answer) / len(answer) - else: - reward = 0.01 return reward