mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
pass config to ProceduralDataset base
This commit is contained in:
parent
df2b8d2809
commit
e9549f2a63
20 changed files with 45 additions and 80 deletions
|
|
@ -1,9 +1,8 @@
|
|||
import random
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import sympy
|
||||
from sympy import Eq, Symbol, expand, solve
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
|
|
@ -28,7 +27,7 @@ class PolynomialEquationsConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
assert self.min_terms > 0, "min_terms must be positive."
|
||||
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms."
|
||||
|
|
@ -53,15 +52,13 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
"""
|
||||
|
||||
def __init__(self, config: PolynomialEquationsConfig):
|
||||
config.validate()
|
||||
self.config = config
|
||||
self._prompt_templates = [
|
||||
"Find the real value(s) of {variable} in the equation: {polynomial_expanded} = 0",
|
||||
"Solve for real {variable}: {polynomial_expanded} = 0",
|
||||
"Determine the real value(s) of {variable} tha satisfies: {polynomial_expanded} = 0",
|
||||
"Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0",
|
||||
]
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class SimpleEquationsConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_terms > 0, "min_terms must be positive"
|
||||
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
|
||||
|
|
@ -35,14 +35,12 @@ class SimpleEquationsDataset(ProceduralDataset):
|
|||
"""Generates simple equations with one variable to solve"""
|
||||
|
||||
def __init__(self, config: SimpleEquationsConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
self._prompt_templates = [
|
||||
"Find the value of {variable} in the equation: {equation}",
|
||||
"Solve for {variable}: {equation}",
|
||||
"Determine the value of {variable} that satisfies: {equation}",
|
||||
]
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single equation task
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class BaseConversionConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert 2 <= self.min_base <= 36, "min_base must be between 2 and 36"
|
||||
assert self.min_base <= self.max_base <= 36, "max_base must be between min_base and 36"
|
||||
|
|
@ -30,9 +30,7 @@ class BaseConversionDataset(ProceduralDataset):
|
|||
"""Generates base conversion tasks"""
|
||||
|
||||
def __init__(self, config: BaseConversionConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _format_base_name(self, base: int) -> str:
|
||||
"""Get human-readable name for common bases"""
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class LetterCountingConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_words > 0, "min_words must be positive"
|
||||
assert self.max_words >= self.min_words, "max_words must be >= min_words"
|
||||
|
|
@ -29,9 +29,7 @@ class LetterCountingDataset(ProceduralDataset):
|
|||
"""Generates letter counting tasks from text spans"""
|
||||
|
||||
def __init__(self, config: LetterCountingConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load and preprocess text
|
||||
text = read_data_file("in_the_year_2889.txt")
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class NumberFilteringConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_numbers > 0, "min_numbers must be positive"
|
||||
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
|
||||
|
|
@ -33,9 +33,7 @@ class NumberFilteringDataset(ProceduralDataset):
|
|||
"""Generates number filtering tasks"""
|
||||
|
||||
def __init__(self, config: NumberFilteringConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _format_number(self, num: float, decimals: int) -> str:
|
||||
"""Format a number with specified decimal places"""
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class NumberSortingConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_numbers > 0, "min_numbers must be positive"
|
||||
assert self.min_numbers <= self.max_numbers, "max_numbers must be >= min_numbers"
|
||||
|
|
@ -33,9 +33,7 @@ class NumberSortingDataset(ProceduralDataset):
|
|||
"""Generates number sorting tasks"""
|
||||
|
||||
def __init__(self, config: NumberSortingConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _format_number(self, num: float, decimals: int) -> str:
|
||||
"""Format number with specified decimal places"""
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class WordReversalConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_words > 0, "min_words must be positive"
|
||||
assert self.max_words >= self.min_words, "max_words must be >= min_words"
|
||||
|
|
@ -28,9 +28,7 @@ class WordReversalDataset(ProceduralDataset):
|
|||
"""Generates word reversal tasks from text spans"""
|
||||
|
||||
def __init__(self, config: WordReversalConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
# Load and preprocess text
|
||||
text = read_data_file("in_the_year_2889.txt")
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class BasicArithmeticDatasetConfig:
|
|||
format_style: Literal["simple", "natural"] = "simple"
|
||||
whitespace: Literal["no_space", "single", "random"] = "single" # Whitespace style between terms
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_terms > 0, "min_terms must be positive"
|
||||
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
|
||||
|
|
@ -63,9 +63,7 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
"""Dataset that generates basic arithmetic tasks with configurable complexity"""
|
||||
|
||||
def __init__(self, config: BasicArithmeticDatasetConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
"""Generate a single arithmetic task
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class ChainSumConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_terms > 0, "min_terms must be positive"
|
||||
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
|
||||
|
|
@ -34,9 +34,7 @@ class ChainSum(ProceduralDataset):
|
|||
"""Generates simple arithmetic tasks using only + and - operators"""
|
||||
|
||||
def __init__(self, config: ChainSumConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single chain sum task
|
||||
|
|
@ -145,5 +143,6 @@ def chain_sum_dataset(
|
|||
)
|
||||
return ChainSum(config)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("chain_sum", ChainSum, ChainSumConfig)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class FractionSimplificationConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_value > 0, "min_value must be positive"
|
||||
assert self.max_value > self.min_value, "max_value must be > min_value"
|
||||
|
|
@ -37,9 +37,7 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
"""Generates fraction simplification tasks"""
|
||||
|
||||
def __init__(self, config: FractionSimplificationConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_fraction(self, rng: Random) -> Tuple[int, int, int, int]:
|
||||
"""Generate a random fraction and its simplified form.
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class GCDConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_numbers >= 2, "min_numbers must be at least 2"
|
||||
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
|
||||
|
|
@ -32,9 +32,7 @@ class GCDDataset(ProceduralDataset):
|
|||
"""Generates Greatest Common Divisor (GCD) tasks"""
|
||||
|
||||
def __init__(self, config: GCDConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||
"""Generate a list of random positive integers and their GCD.
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class LCMConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_numbers >= 2, "min_numbers must be at least 2"
|
||||
assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers"
|
||||
|
|
@ -32,9 +32,7 @@ class LCMDataset(ProceduralDataset):
|
|||
"""Generates Least Common Multiple (LCM) tasks"""
|
||||
|
||||
def __init__(self, config: LCMConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]:
|
||||
"""Generate a list of random positive integers and their LCM.
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class LegCountingConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_animals > 0, "min_animals must be positive"
|
||||
assert self.max_animals >= self.min_animals, "max_animals must be >= min_animals"
|
||||
|
|
@ -76,9 +76,7 @@ class LegCountingDataset(ProceduralDataset):
|
|||
"""Generates leg counting arithmetic tasks"""
|
||||
|
||||
def __init__(self, config: LegCountingConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _generate_animals(self, rng: Random) -> Dict[str, int]:
|
||||
"""Generate a random set of animals and their counts"""
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class PrimeFactorizationConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_value >= 2, "min_value must be >= 2"
|
||||
assert self.max_value >= self.min_value, "max_value must be >= min_value"
|
||||
|
|
@ -26,9 +26,7 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
|||
"""Generates prime factorization tasks"""
|
||||
|
||||
def __init__(self, config: PrimeFactorizationConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def _prime_factors(self, n: int) -> List[int]:
|
||||
"""Compute prime factors of a number"""
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class ColorCubeRotationConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_rotations > 0, "min_rotations must be positive"
|
||||
assert self.max_rotations >= self.min_rotations, "max_rotations must be >= min_rotations"
|
||||
|
|
@ -105,9 +105,7 @@ class ColorCubeRotationDataset(ProceduralDataset):
|
|||
"""Generates color cube rotation reasoning tasks"""
|
||||
|
||||
def __init__(self, config: ColorCubeRotationConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class NumberSequenceConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_terms >= 4, "need at least 4 terms to establish pattern"
|
||||
assert self.max_terms >= self.min_terms
|
||||
|
|
@ -155,9 +155,7 @@ class NumberSequenceDataset(ProceduralDataset):
|
|||
"""Generates number sequence completion tasks with dynamic pattern generation"""
|
||||
|
||||
def __init__(self, config: NumberSequenceConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a sequence task with a newly generated pattern"""
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
|
|||
|
||||
def __init__(self, config: Any, seed: Optional[int] = None, size: int = 500):
|
||||
"""Initialize the dataset with config, optional seed and size"""
|
||||
if hasattr(config, "validate") and callable(config.validate):
|
||||
config.validate()
|
||||
|
||||
self.config = config
|
||||
self.size = size
|
||||
self.seed = seed if seed is not None else Random().randint(0, 2**32)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class MazeConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 50
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
assert self.min_dist >= 1, "min_dist must be >= 1"
|
||||
assert self.max_dist >= self.min_dist, "max_dist must be >= min_dist"
|
||||
|
|
@ -46,8 +46,7 @@ class MazeDataset(ProceduralDataset):
|
|||
prob_path=0.7,
|
||||
num_retries=1000,
|
||||
):
|
||||
config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.config = config
|
||||
# Probability that a cell is a path instead of a wall
|
||||
self.prob_path = prob_path
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ class FamilyRelationshipsConfig:
|
|||
if self.female_names is None:
|
||||
self.female_names = default_female_names
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.min_family_size >= 3, "min_family_size must be at least 3"
|
||||
assert self.max_family_size >= self.min_family_size, "max_family_size must be >= min_family_size"
|
||||
|
|
@ -178,14 +178,12 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
|||
"""Generates family relationship reasoning tasks"""
|
||||
|
||||
def __init__(self, config: FamilyRelationshipsConfig):
|
||||
self.config = config
|
||||
self.config.validate()
|
||||
self._templates = [
|
||||
"What is {person1} to {person2}?",
|
||||
"How is {person1} related to {person2}?",
|
||||
"What relation is {person1} to {person2}?",
|
||||
]
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
rng = random.Random(self.seed + idx)
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class SyllogismConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert any(
|
||||
[self.allow_all, self.allow_no, self.allow_some, self.allow_some_not]
|
||||
|
|
@ -100,11 +100,8 @@ class SyllogismDataset(ProceduralDataset):
|
|||
]
|
||||
|
||||
def __init__(self, config: SyllogismConfig):
|
||||
self.config = config
|
||||
if self.config.terms is None:
|
||||
self.config.terms = self.DEFAULT_TERMS
|
||||
self.config.validate()
|
||||
super().__init__(seed=config.seed, size=config.size)
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self.terms = self.DEFAULT_TERMS if config.terms is None else config.terms
|
||||
|
||||
def _get_allowed_quantifiers(self) -> List[Quantifier]:
|
||||
"""Get list of allowed quantifiers based on config"""
|
||||
|
|
@ -212,7 +209,7 @@ class SyllogismDataset(ProceduralDataset):
|
|||
def _generate_syllogism(self, rng: Random) -> dict:
|
||||
"""Generate a single syllogism problem"""
|
||||
# Select three different terms
|
||||
terms = rng.sample(self.config.terms, 3)
|
||||
terms = rng.sample(self.terms, 3)
|
||||
quantifiers = self._get_allowed_quantifiers()
|
||||
|
||||
# Generate premises and conclusion
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue