diff --git a/reasoning_gym/algorithmic/pool_matrix.py b/reasoning_gym/algorithmic/pool_matrix.py index e22f9a92..706d5c42 100644 --- a/reasoning_gym/algorithmic/pool_matrix.py +++ b/reasoning_gym/algorithmic/pool_matrix.py @@ -13,6 +13,7 @@ The stride is equal to the kernel size, meaning there is no overlap between the Your output should be a matrix in the same format as the input matrix. The output matrix is smaller than the input matrix when the kernel size is greater than 1, and its elements may be floating-point numbers. +Give elements in the output matrix correct to 2 decimal places. Perform {pool_type} pooling on the following matrix with a kernel size of {pool_size}: {matrix} @@ -87,7 +88,7 @@ class PoolMatrixDataset(ProceduralDataset): try: oracle_answer = np.loadtxt(entry["answer"].splitlines(), dtype=np.float32) answer = np.loadtxt(answer.splitlines(), dtype=np.float32) - if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer): + if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer, rtol=1e-2): reward = 1.0 elif oracle_answer.shape == answer.shape: reward = 0.1 diff --git a/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py b/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py index 13b0b372..b99ef3a0 100644 --- a/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py +++ b/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py @@ -1,5 +1,6 @@ """GSM Symblic dataset generator""" +import re from dataclasses import dataclass from random import Random from typing import Any, Callable, Optional @@ -149,8 +150,28 @@ class GSMSymbolicDataset(ProceduralDataset): generator_idx = self.task_indices[idx] generator = self.generators[generator_idx] example = generator(rng, self.config.difficulty) - example["question"] += " Give only the result as your final answer." + example["question"] += " Give the result as your final answer. Do not include units." return example + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: + reward = 0.0 + if answer is None: + return reward + try: + # Extract number using regex with search + match = re.search(r"\b-?\d+(?:\.\d+)?\b", answer) + if not match: + return reward + + answer_value = float(match.group(0)) + expected_answer = float(entry["answer"]) + if answer_value == expected_answer: + reward = 1.0 + else: + reward = 0.01 + except Exception: + return reward + return reward + register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig) diff --git a/reasoning_gym/arithmetic/power_function.py b/reasoning_gym/arithmetic/power_function.py index 5d5848c0..321da833 100644 --- a/reasoning_gym/arithmetic/power_function.py +++ b/reasoning_gym/arithmetic/power_function.py @@ -1,6 +1,7 @@ """Computhe the power of a number.""" from dataclasses import dataclass +from decimal import Decimal from math import pow from random import Random from typing import Any, Optional @@ -9,7 +10,8 @@ from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number. -Compute {base}^{exponent} +Compute {base}^{exponent}. Return your final answer correct to 3 significant figures. +Provide your answer in scientific notation using 'e' notation (e.g., 1.23e+4). """ @@ -33,19 +35,26 @@ class PowerFunctionDataset(ProceduralDataset): super().__init__(config=config, seed=config.seed, size=config.size) def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - """Overwrite this method in derived classes if a single oracle answer is not available.""" + """Score the answer by checking if it matches the expected answer to 3 significant figures.""" oracle_answer = entry["answer"] if answer is not None: try: - answer = round(float(answer), 4) - oracle_answer = round(float(oracle_answer), 4) - difference = abs(float(answer) - float(oracle_answer)) - if difference < 1e-4: + user_answer = Decimal(answer) + oracle_value = Decimal(oracle_answer) + + if oracle_value == 0: + return 1.0 if user_answer == 0 else 0.01 + + user_sig_figs = f"{user_answer:.3g}" + oracle_sig_figs = f"{oracle_value:.3g}" + + # Check if they match to 3 significant figures + if user_sig_figs == oracle_sig_figs: return 1.0 - elif difference < 1e-1: - return 0.5 - except Exception: - pass + else: + return 0.01 + except Exception as e: + return 0.01 return 0.0 def __getitem__(self, idx: int) -> dict: diff --git a/tests/test_gsm_symbolic.py b/tests/test_gsm_symbolic.py index 4351616e..cb8590cc 100644 --- a/tests/test_gsm_symbolic.py +++ b/tests/test_gsm_symbolic.py @@ -90,3 +90,14 @@ def test_gsm_symbolic_generators(): print(f"ok: q={len(question_set)}, a={len(answer_set)}") i += 1 + + +def test_gsm_symbolic_score_answer(): + """Test score answer function""" + config = GSMSymbolicDatasetConfig(size=100, seed=42) + dataset = GSMSymbolicDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + score = dataset.score_answer(item["answer"], item) + assert score == 1.0 diff --git a/tests/test_power_function.py b/tests/test_power_function.py index df2454a4..4a7b8382 100644 --- a/tests/test_power_function.py +++ b/tests/test_power_function.py @@ -59,20 +59,6 @@ def test_power_function_score_function(): config = PowerFunctionConfig(seed=42) dataset = PowerFunctionDataset(config) - item = dataset[0] - - # Answer is within 1e-6 of solution - answer = str(item["metadata"]["solution"] - 1e-7) - assert dataset.score_answer(answer, item) == 1.0 - - # Answer is within 1e-1 of solution - answer = str(item["metadata"]["solution"] - 1e-2) - assert dataset.score_answer(answer, item) == 0.5 - - # Answer is far from solution - answer = str(item["metadata"]["solution"] - 1) - assert dataset.score_answer(answer, item) == 0.0 - - # Answer is None - answer = None - assert dataset.score_answer(answer, item) == 0.0 + for item in dataset: + answer = item["answer"] + assert dataset.score_answer(answer, item) == 1.0