mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
updated algorithmics dataset (#269)
* updated algorithmic datasets * added changes to symbolic and power * updated power function test
This commit is contained in:
parent
f426db90ec
commit
d9638df79c
5 changed files with 57 additions and 29 deletions
|
|
@ -13,6 +13,7 @@ The stride is equal to the kernel size, meaning there is no overlap between the
|
|||
|
||||
Your output should be a matrix in the same format as the input matrix.
|
||||
The output matrix is smaller than the input matrix when the kernel size is greater than 1, and its elements may be floating-point numbers.
|
||||
Give elements in the output matrix correct to 2 decimal places.
|
||||
|
||||
Perform {pool_type} pooling on the following matrix with a kernel size of {pool_size}:
|
||||
{matrix}
|
||||
|
|
@ -87,7 +88,7 @@ class PoolMatrixDataset(ProceduralDataset):
|
|||
try:
|
||||
oracle_answer = np.loadtxt(entry["answer"].splitlines(), dtype=np.float32)
|
||||
answer = np.loadtxt(answer.splitlines(), dtype=np.float32)
|
||||
if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer):
|
||||
if oracle_answer.shape == answer.shape and np.allclose(oracle_answer, answer, rtol=1e-2):
|
||||
reward = 1.0
|
||||
elif oracle_answer.shape == answer.shape:
|
||||
reward = 0.1
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""GSM Symblic dataset generator"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Any, Callable, Optional
|
||||
|
|
@ -149,8 +150,28 @@ class GSMSymbolicDataset(ProceduralDataset):
|
|||
generator_idx = self.task_indices[idx]
|
||||
generator = self.generators[generator_idx]
|
||||
example = generator(rng, self.config.difficulty)
|
||||
example["question"] += " Give only the result as your final answer."
|
||||
example["question"] += " Give the result as your final answer. Do not include units."
|
||||
return example
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
reward = 0.0
|
||||
if answer is None:
|
||||
return reward
|
||||
try:
|
||||
# Extract number using regex with search
|
||||
match = re.search(r"\b-?\d+(?:\.\d+)?\b", answer)
|
||||
if not match:
|
||||
return reward
|
||||
|
||||
answer_value = float(match.group(0))
|
||||
expected_answer = float(entry["answer"])
|
||||
if answer_value == expected_answer:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.01
|
||||
except Exception:
|
||||
return reward
|
||||
return reward
|
||||
|
||||
|
||||
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Computhe the power of a number."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from math import pow
|
||||
from random import Random
|
||||
from typing import Any, Optional
|
||||
|
|
@ -9,7 +10,8 @@ from ..factory import ProceduralDataset, register_dataset
|
|||
|
||||
QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number.
|
||||
|
||||
Compute {base}^{exponent}
|
||||
Compute {base}^{exponent}. Return your final answer correct to 3 significant figures.
|
||||
Provide your answer in scientific notation using 'e' notation (e.g., 1.23e+4).
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -33,19 +35,26 @@ class PowerFunctionDataset(ProceduralDataset):
|
|||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Overwrite this method in derived classes if a single oracle answer is not available."""
|
||||
"""Score the answer by checking if it matches the expected answer to 3 significant figures."""
|
||||
oracle_answer = entry["answer"]
|
||||
if answer is not None:
|
||||
try:
|
||||
answer = round(float(answer), 4)
|
||||
oracle_answer = round(float(oracle_answer), 4)
|
||||
difference = abs(float(answer) - float(oracle_answer))
|
||||
if difference < 1e-4:
|
||||
user_answer = Decimal(answer)
|
||||
oracle_value = Decimal(oracle_answer)
|
||||
|
||||
if oracle_value == 0:
|
||||
return 1.0 if user_answer == 0 else 0.01
|
||||
|
||||
user_sig_figs = f"{user_answer:.3g}"
|
||||
oracle_sig_figs = f"{oracle_value:.3g}"
|
||||
|
||||
# Check if they match to 3 significant figures
|
||||
if user_sig_figs == oracle_sig_figs:
|
||||
return 1.0
|
||||
elif difference < 1e-1:
|
||||
return 0.5
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
return 0.01
|
||||
except Exception as e:
|
||||
return 0.01
|
||||
return 0.0
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
|
|
|
|||
|
|
@ -90,3 +90,14 @@ def test_gsm_symbolic_generators():
|
|||
print(f"ok: q={len(question_set)}, a={len(answer_set)}")
|
||||
|
||||
i += 1
|
||||
|
||||
|
||||
def test_gsm_symbolic_score_answer():
|
||||
"""Test score answer function"""
|
||||
config = GSMSymbolicDatasetConfig(size=100, seed=42)
|
||||
dataset = GSMSymbolicDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
score = dataset.score_answer(item["answer"], item)
|
||||
assert score == 1.0
|
||||
|
|
|
|||
|
|
@ -59,20 +59,6 @@ def test_power_function_score_function():
|
|||
config = PowerFunctionConfig(seed=42)
|
||||
dataset = PowerFunctionDataset(config)
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
# Answer is within 1e-6 of solution
|
||||
answer = str(item["metadata"]["solution"] - 1e-7)
|
||||
assert dataset.score_answer(answer, item) == 1.0
|
||||
|
||||
# Answer is within 1e-1 of solution
|
||||
answer = str(item["metadata"]["solution"] - 1e-2)
|
||||
assert dataset.score_answer(answer, item) == 0.5
|
||||
|
||||
# Answer is far from solution
|
||||
answer = str(item["metadata"]["solution"] - 1)
|
||||
assert dataset.score_answer(answer, item) == 0.0
|
||||
|
||||
# Answer is None
|
||||
answer = None
|
||||
assert dataset.score_answer(answer, item) == 0.0
|
||||
for item in dataset:
|
||||
answer = item["answer"]
|
||||
assert dataset.score_answer(answer, item) == 1.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue