black formatting

This commit is contained in:
Andreas Koepf 2025-02-03 22:57:24 +01:00
parent 5a7cbe7c24
commit c8fcb6ca02
6 changed files with 4291 additions and 3733 deletions

View file

@ -1,15 +1,18 @@
"""GSM Symblic dataset generator"""
from . import generators
from dataclasses import dataclass
from random import Random
from typing import List, Optional
from reasoning_gym.factory import ProceduralDataset, register_dataset
from . import generators
@dataclass
class GSMSymbolicDatasetConfig:
"""Configuration for GSM symbolic task generation"""
seed: Optional[int] = None
size: int = 500
@ -17,9 +20,10 @@ class GSMSymbolicDatasetConfig:
"""Validate configuration parameters"""
pass
class GSMSymbolicDataset(ProceduralDataset):
def __init__(self, config, seed = None, size = 500):
def __init__(self, config, seed=None, size=500):
super().__init__(config, seed, size)
# Initialize as None
self._generators = None
@ -35,13 +39,11 @@ class GSMSymbolicDataset(ProceduralDataset):
"""
Generates mapper from task identifiers (keys) to example generator functions
"""
prefix = 'generate_'
return {
self.strip_prefix(n, prefix): getattr(generators, n) for n in dir(generators) if n.startswith(prefix)
}
prefix = "generate_"
return {self.strip_prefix(n, prefix): getattr(generators, n) for n in dir(generators) if n.startswith(prefix)}
def strip_prefix(self, s, prefix):
return s[len(prefix):]
return s[len(prefix) :]
def __getitem__(self, idx) -> dict:
"""Generate a single GSM symbolic dataset"""
@ -49,8 +51,9 @@ class GSMSymbolicDataset(ProceduralDataset):
# Stringify the random integer generated from the random number generator
generator_idx = str(rng.randint(0, len(self.generators) - 1))
generator = self.generators[generator_idx]
# Here the res is a dictionary of
# Here the res is a dictionary of
res = generator(rng)
return res
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)