gsm_symbolic generator changes

This commit is contained in:
Andreas Koepf 2025-02-04 13:47:57 +01:00
parent b84e29a8b6
commit afb95508ef
10 changed files with 9007 additions and 7360 deletions

View file

@ -12,8 +12,7 @@ from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDat
from .chain_sum import ChainSum, ChainSumConfig
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset
# from .gsm_symbolic.gsm_symbolic_datasets import GSMSymbolicDataset, GSMSymbolicDatasetConfig
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
from .lcm import LCMConfig, LCMDataset
from .leg_counting import LegCountingConfig, LegCountingDataset
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
@ -39,8 +38,8 @@ __all__ = [
"LegCountingDataset",
"PrimeFactorizationConfig",
"PrimeFactorizationDataset",
# "GSMSymbolicDatasetConfig",
# "GSMSymbolicDataset",
"GSMSymbolicDatasetConfig",
"GSMSymbolicDataset",
"TimeIntervalsConfig",
"TimeIntervalsDataset",
]

View file

@ -0,0 +1,6 @@
from .gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
__all__ = [
"GSMSymbolicDatasetConfig",
"GSMSymbolicDataset",
]

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,154 @@
"""GSM Symblic dataset generator"""
from dataclasses import dataclass
from random import Random
from typing import Any, Callable, Optional
from reasoning_gym.factory import ProceduralDataset, register_dataset
tasks_ok = [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
33,
34,
36,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
62,
64,
66,
67,
68,
69,
70,
71,
72,
73,
75,
78,
80,
81,
82,
83,
84,
85,
88,
89,
91,
92,
93,
94,
95,
96,
99,
]
tasks_need_fix = [32, 35, 37, 61, 63, 65, 74, 76, 77, 79, 86, 87, 90, 97, 98]
@dataclass
class GSMSymbolicDatasetConfig:
"""Configuration for GSM symbolic task generation"""
difficulty: float = 1.0
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert 1.0 <= self.difficulty <= 1.0 # currently only difficulty 1.0 is supported
class GSMSymbolicDataset(ProceduralDataset):
def __init__(self, config: GSMSymbolicDatasetConfig):
super().__init__(config, config.seed, config.size)
self._generators: dict[int, Callable[[Random, float], dict[str, Any]]] = None # initially None, lazy loading
self.task_indices = Random(self.seed).choices(tasks_ok, k=self.size)
@property
def generators(self) -> dict[int, Callable[[Random, float], dict[str, Any]]]:
"""Lazy load generators only when first accessed"""
if self._generators is None:
self._generators = self._load_generators()
return self._generators
def _load_generators(self):
"""
Generates mapper from task identifiers (keys) to example generator functions
"""
from . import generators_00_49, generators_50_99
def strip_prefix(s: str, prefix: str) -> str:
return s[len(prefix) :]
prefix = "generate_"
gs = {}
for n in dir(generators_00_49):
if n.startswith(prefix):
gs[int(strip_prefix(n, prefix))] = getattr(generators_00_49, n)
for n in dir(generators_50_99):
if n.startswith(prefix):
gs[int(strip_prefix(n, prefix))] = getattr(generators_50_99, n)
return gs
def __getitem__(self, idx: int) -> dict:
"""Generate a single GSM symbolic dataset"""
rng = Random(self.seed + idx)
generator_idx = self.task_indices[idx]
generator = self.generators[generator_idx]
return generator(rng, self.config.difficulty)
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)

View file

@ -1,59 +0,0 @@
"""GSM Symblic dataset generator"""
from dataclasses import dataclass
from random import Random
from typing import List, Optional
from reasoning_gym.factory import ProceduralDataset, register_dataset
from . import generators
@dataclass
class GSMSymbolicDatasetConfig:
"""Configuration for GSM symbolic task generation"""
seed: Optional[int] = None
size: int = 500
def validate(self) -> None:
"""Validate configuration parameters"""
pass
class GSMSymbolicDataset(ProceduralDataset):
def __init__(self, config, seed=None, size=500):
super().__init__(config, seed, size)
# Initialize as None
self._generators = None
@property
def generators(self):
"""Lazy load generators only when first accessed"""
if self._generators is None:
self._generators = self.get_generators()
return self._generators
def get_generators(self):
"""
Generates mapper from task identifiers (keys) to example generator functions
"""
prefix = "generate_"
return {self.strip_prefix(n, prefix): getattr(generators, n) for n in dir(generators) if n.startswith(prefix)}
def strip_prefix(self, s, prefix):
return s[len(prefix) :]
def __getitem__(self, idx) -> dict:
"""Generate a single GSM symbolic dataset"""
rng = Random(self.seed + idx)
# Stringify the random integer generated from the random number generator
generator_idx = str(rng.randint(0, len(self.generators) - 1))
generator = self.generators[generator_idx]
# Here the res is a dictionary of
res = generator(rng)
return res
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)