reasoning-gym/reasoning_gym/arithmetic/gsm_symbolic/gsm_symbolic.py
joesharratt1229 d9638df79c
updated algorithmics dataset (#269)
* updated algorithmic datasets
* added changes to symbolic and power
* updated power function test
2025-03-05 23:32:53 +01:00

177 lines
3.7 KiB
Python

"""GSM Symblic dataset generator"""
import re
from dataclasses import dataclass
from random import Random
from typing import Any, Callable, Optional
from reasoning_gym.factory import ProceduralDataset, register_dataset
tasks_ok = [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
33,
34,
36,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
62,
64,
66,
67,
68,
69,
70,
71,
72,
73,
75,
78,
80,
81,
82,
83,
84,
85,
88,
89,
91,
92,
93,
94,
95,
96,
99,
]
tasks_need_fix = [32, 35, 37, 61, 63, 65, 74, 76, 77, 79, 86, 87, 90, 97, 98]
@dataclass
class GSMSymbolicDatasetConfig:
"""Configuration for GSM symbolic task generation"""
difficulty: float = 1.0
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert 1.0 <= self.difficulty <= 1.0 # currently only difficulty 1.0 is supported
class GSMSymbolicDataset(ProceduralDataset):
def __init__(self, config: GSMSymbolicDatasetConfig):
super().__init__(config, config.seed, config.size)
self._generators: dict[int, Callable[[Random, float], dict[str, Any]]] = None # initially None, lazy loading
self.task_indices = Random(self.seed).choices(tasks_ok, k=self.size)
@property
def generators(self) -> dict[int, Callable[[Random, float], dict[str, Any]]]:
"""Lazy load generators only when first accessed"""
if self._generators is None:
self._generators = self._load_generators()
return self._generators
def _load_generators(self):
"""
Generates mapper from task identifiers (keys) to example generator functions
"""
from . import generators_00_49, generators_50_99
def strip_prefix(s: str, prefix: str) -> str:
return s[len(prefix) :]
prefix = "generate_"
gs = {}
for n in dir(generators_00_49):
if n.startswith(prefix):
gs[int(strip_prefix(n, prefix))] = getattr(generators_00_49, n)
for n in dir(generators_50_99):
if n.startswith(prefix):
gs[int(strip_prefix(n, prefix))] = getattr(generators_50_99, n)
return gs
def __getitem__(self, idx: int) -> dict:
"""Generate a single GSM symbolic dataset"""
rng = Random(self.seed + idx)
generator_idx = self.task_indices[idx]
generator = self.generators[generator_idx]
example = generator(rng, self.config.difficulty)
example["question"] += " Give the result as your final answer. Do not include units."
return example
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
if answer is None:
return reward
try:
# Extract number using regex with search
match = re.search(r"\b-?\d+(?:\.\d+)?\b", answer)
if not match:
return reward
answer_value = float(match.group(0))
expected_answer = float(entry["answer"])
if answer_value == expected_answer:
reward = 1.0
else:
reward = 0.01
except Exception:
return reward
return reward
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)