mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-27 17:23:19 +00:00
black formatting
This commit is contained in:
parent
5a7cbe7c24
commit
c8fcb6ca02
6 changed files with 4291 additions and 3733 deletions
|
|
@ -1,15 +1,18 @@
|
|||
"""GSM Symblic dataset generator"""
|
||||
|
||||
from . import generators
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import List, Optional
|
||||
|
||||
from reasoning_gym.factory import ProceduralDataset, register_dataset
|
||||
|
||||
from . import generators
|
||||
|
||||
|
||||
@dataclass
|
||||
class GSMSymbolicDatasetConfig:
|
||||
"""Configuration for GSM symbolic task generation"""
|
||||
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
|
|
@ -17,9 +20,10 @@ class GSMSymbolicDatasetConfig:
|
|||
"""Validate configuration parameters"""
|
||||
pass
|
||||
|
||||
|
||||
class GSMSymbolicDataset(ProceduralDataset):
|
||||
|
||||
def __init__(self, config, seed = None, size = 500):
|
||||
def __init__(self, config, seed=None, size=500):
|
||||
super().__init__(config, seed, size)
|
||||
# Initialize as None
|
||||
self._generators = None
|
||||
|
|
@ -35,13 +39,11 @@ class GSMSymbolicDataset(ProceduralDataset):
|
|||
"""
|
||||
Generates mapper from task identifiers (keys) to example generator functions
|
||||
"""
|
||||
prefix = 'generate_'
|
||||
return {
|
||||
self.strip_prefix(n, prefix): getattr(generators, n) for n in dir(generators) if n.startswith(prefix)
|
||||
}
|
||||
prefix = "generate_"
|
||||
return {self.strip_prefix(n, prefix): getattr(generators, n) for n in dir(generators) if n.startswith(prefix)}
|
||||
|
||||
def strip_prefix(self, s, prefix):
|
||||
return s[len(prefix):]
|
||||
return s[len(prefix) :]
|
||||
|
||||
def __getitem__(self, idx) -> dict:
|
||||
"""Generate a single GSM symbolic dataset"""
|
||||
|
|
@ -49,8 +51,9 @@ class GSMSymbolicDataset(ProceduralDataset):
|
|||
# Stringify the random integer generated from the random number generator
|
||||
generator_idx = str(rng.randint(0, len(self.generators) - 1))
|
||||
generator = self.generators[generator_idx]
|
||||
# Here the res is a dictionary of
|
||||
# Here the res is a dictionary of
|
||||
res = generator(rng)
|
||||
return res
|
||||
|
||||
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
|
||||
|
||||
|
||||
register_dataset("gsm_symbolic", GSMSymbolicDataset, GSMSymbolicDatasetConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue