diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index 60935013..6c47c39b 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -1,9 +1,8 @@ import random import string from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional, Tuple -import sympy from sympy import Eq, Symbol, expand, solve from ..dataset import ProceduralDataset @@ -28,7 +27,7 @@ class PolynomialEquationsConfig: seed: Optional[int] = None size: int = 500 - def validate(self): + def validate(self) -> None: """Validate configuration parameters.""" assert self.min_terms > 0, "min_terms must be positive." assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms." @@ -53,15 +52,13 @@ class PolynomialEquationsDataset(ProceduralDataset): """ def __init__(self, config: PolynomialEquationsConfig): - config.validate() - self.config = config self._prompt_templates = [ "Find the real value(s) of {variable} in the equation: {polynomial_expanded} = 0", "Solve for real {variable}: {polynomial_expanded} = 0", "Determine the real value(s) of {variable} tha satisfies: {polynomial_expanded} = 0", "Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0", ] - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: """ diff --git a/reasoning_gym/algebra/simple_equations.py b/reasoning_gym/algebra/simple_equations.py index e83613fe..40e49808 100644 --- a/reasoning_gym/algebra/simple_equations.py +++ b/reasoning_gym/algebra/simple_equations.py @@ -21,7 +21,7 @@ class SimpleEquationsConfig: seed: Optional[int] = None size: int = 500 - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_terms > 0, "min_terms must be positive" assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" @@ -35,14 +35,12 @@ class SimpleEquationsDataset(ProceduralDataset): """Generates simple equations with one variable to solve""" def __init__(self, config: SimpleEquationsConfig): - self.config = config - self.config.validate() self._prompt_templates = [ "Find the value of {variable} in the equation: {equation}", "Solve for {variable}: {equation}", "Determine the value of {variable} that satisfies: {equation}", ] - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: """Generate a single equation task diff --git a/reasoning_gym/algorithmic/base_conversion.py b/reasoning_gym/algorithmic/base_conversion.py index e3b54550..3a73d243 100644 --- a/reasoning_gym/algorithmic/base_conversion.py +++ b/reasoning_gym/algorithmic/base_conversion.py @@ -18,7 +18,7 @@ class BaseConversionConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert 2 <= self.min_base <= 36, "min_base must be between 2 and 36" assert self.min_base <= self.max_base <= 36, "max_base must be between min_base and 36" @@ -30,9 +30,7 @@ class BaseConversionDataset(ProceduralDataset): """Generates base conversion tasks""" def __init__(self, config: BaseConversionConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def _format_base_name(self, base: int) -> str: """Get human-readable name for common bases""" diff --git a/reasoning_gym/algorithmic/letter_counting.py b/reasoning_gym/algorithmic/letter_counting.py index 465b9b2c..06ae7b59 100644 --- a/reasoning_gym/algorithmic/letter_counting.py +++ b/reasoning_gym/algorithmic/letter_counting.py @@ -19,7 +19,7 @@ class LetterCountingConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_words > 0, "min_words must be positive" assert self.max_words >= self.min_words, "max_words must be >= min_words" @@ -29,9 +29,7 @@ class LetterCountingDataset(ProceduralDataset): """Generates letter counting tasks from text spans""" def __init__(self, config: LetterCountingConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) # Load and preprocess text text = read_data_file("in_the_year_2889.txt") diff --git a/reasoning_gym/algorithmic/number_filtering.py b/reasoning_gym/algorithmic/number_filtering.py index e6ffaea3..2efc8368 100644 --- a/reasoning_gym/algorithmic/number_filtering.py +++ b/reasoning_gym/algorithmic/number_filtering.py @@ -20,7 +20,7 @@ class NumberFilteringConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_numbers > 0, "min_numbers must be positive" assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers" @@ -33,9 +33,7 @@ class NumberFilteringDataset(ProceduralDataset): """Generates number filtering tasks""" def __init__(self, config: NumberFilteringConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def _format_number(self, num: float, decimals: int) -> str: """Format a number with specified decimal places""" diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index 85cfc854..362a4e7d 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -20,7 +20,7 @@ class NumberSortingConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_numbers > 0, "min_numbers must be positive" assert self.min_numbers <= self.max_numbers, "max_numbers must be >= min_numbers" @@ -33,9 +33,7 @@ class NumberSortingDataset(ProceduralDataset): """Generates number sorting tasks""" def __init__(self, config: NumberSortingConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def _format_number(self, num: float, decimals: int) -> str: """Format number with specified decimal places""" diff --git a/reasoning_gym/algorithmic/word_reversal.py b/reasoning_gym/algorithmic/word_reversal.py index fe0aa1f9..4919b835 100644 --- a/reasoning_gym/algorithmic/word_reversal.py +++ b/reasoning_gym/algorithmic/word_reversal.py @@ -18,7 +18,7 @@ class WordReversalConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_words > 0, "min_words must be positive" assert self.max_words >= self.min_words, "max_words must be >= min_words" @@ -28,9 +28,7 @@ class WordReversalDataset(ProceduralDataset): """Generates word reversal tasks from text spans""" def __init__(self, config: WordReversalConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) # Load and preprocess text text = read_data_file("in_the_year_2889.txt") diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index f0b54398..5a47fe7f 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -21,7 +21,7 @@ class BasicArithmeticDatasetConfig: format_style: Literal["simple", "natural"] = "simple" whitespace: Literal["no_space", "single", "random"] = "single" # Whitespace style between terms - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_terms > 0, "min_terms must be positive" assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" @@ -63,9 +63,7 @@ class BasicArithmeticDataset(ProceduralDataset): """Dataset that generates basic arithmetic tasks with configurable complexity""" def __init__(self, config: BasicArithmeticDatasetConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict[str, Any]: """Generate a single arithmetic task diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 9ea7ffc2..7453b68f 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -18,7 +18,7 @@ class ChainSumConfig: seed: Optional[int] = None size: int = 500 - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_terms > 0, "min_terms must be positive" assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms" @@ -34,9 +34,7 @@ class ChainSum(ProceduralDataset): """Generates simple arithmetic tasks using only + and - operators""" def __init__(self, config: ChainSumConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: """Generate a single chain sum task @@ -145,5 +143,6 @@ def chain_sum_dataset( ) return ChainSum(config) + # Register the dataset register_dataset("chain_sum", ChainSum, ChainSumConfig) diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index 21007fea..7424af9f 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -20,7 +20,7 @@ class FractionSimplificationConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_value > 0, "min_value must be positive" assert self.max_value > self.min_value, "max_value must be > min_value" @@ -37,9 +37,7 @@ class FractionSimplificationDataset(ProceduralDataset): """Generates fraction simplification tasks""" def __init__(self, config: FractionSimplificationConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def _generate_fraction(self, rng: Random) -> Tuple[int, int, int, int]: """Generate a random fraction and its simplified form. diff --git a/reasoning_gym/arithmetic/gcd.py b/reasoning_gym/arithmetic/gcd.py index c67e9cc8..d24d86cf 100644 --- a/reasoning_gym/arithmetic/gcd.py +++ b/reasoning_gym/arithmetic/gcd.py @@ -20,7 +20,7 @@ class GCDConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_numbers >= 2, "min_numbers must be at least 2" assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers" @@ -32,9 +32,7 @@ class GCDDataset(ProceduralDataset): """Generates Greatest Common Divisor (GCD) tasks""" def __init__(self, config: GCDConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]: """Generate a list of random positive integers and their GCD. diff --git a/reasoning_gym/arithmetic/lcm.py b/reasoning_gym/arithmetic/lcm.py index ad0983d4..a643f406 100644 --- a/reasoning_gym/arithmetic/lcm.py +++ b/reasoning_gym/arithmetic/lcm.py @@ -20,7 +20,7 @@ class LCMConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_numbers >= 2, "min_numbers must be at least 2" assert self.max_numbers >= self.min_numbers, "max_numbers must be >= min_numbers" @@ -32,9 +32,7 @@ class LCMDataset(ProceduralDataset): """Generates Least Common Multiple (LCM) tasks""" def __init__(self, config: LCMConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def _generate_numbers(self, rng: Random) -> Tuple[List[int], int]: """Generate a list of random positive integers and their LCM. diff --git a/reasoning_gym/arithmetic/leg_counting.py b/reasoning_gym/arithmetic/leg_counting.py index a1308b0c..54640190 100644 --- a/reasoning_gym/arithmetic/leg_counting.py +++ b/reasoning_gym/arithmetic/leg_counting.py @@ -65,7 +65,7 @@ class LegCountingConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_animals > 0, "min_animals must be positive" assert self.max_animals >= self.min_animals, "max_animals must be >= min_animals" @@ -76,9 +76,7 @@ class LegCountingDataset(ProceduralDataset): """Generates leg counting arithmetic tasks""" def __init__(self, config: LegCountingConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def _generate_animals(self, rng: Random) -> Dict[str, int]: """Generate a random set of animals and their counts""" diff --git a/reasoning_gym/arithmetic/prime_factorization.py b/reasoning_gym/arithmetic/prime_factorization.py index ab228d7f..d3416ba0 100644 --- a/reasoning_gym/arithmetic/prime_factorization.py +++ b/reasoning_gym/arithmetic/prime_factorization.py @@ -16,7 +16,7 @@ class PrimeFactorizationConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_value >= 2, "min_value must be >= 2" assert self.max_value >= self.min_value, "max_value must be >= min_value" @@ -26,9 +26,7 @@ class PrimeFactorizationDataset(ProceduralDataset): """Generates prime factorization tasks""" def __init__(self, config: PrimeFactorizationConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def _prime_factors(self, n: int) -> List[int]: """Compute prime factors of a number""" diff --git a/reasoning_gym/cognition/color_cube_rotation.py b/reasoning_gym/cognition/color_cube_rotation.py index b0b65fae..92357756 100644 --- a/reasoning_gym/cognition/color_cube_rotation.py +++ b/reasoning_gym/cognition/color_cube_rotation.py @@ -95,7 +95,7 @@ class ColorCubeRotationConfig: seed: Optional[int] = None size: int = 500 - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_rotations > 0, "min_rotations must be positive" assert self.max_rotations >= self.min_rotations, "max_rotations must be >= min_rotations" @@ -105,9 +105,7 @@ class ColorCubeRotationDataset(ProceduralDataset): """Generates color cube rotation reasoning tasks""" def __init__(self, config: ColorCubeRotationConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: rng = random.Random(self.seed + idx) diff --git a/reasoning_gym/cognition/number_sequences.py b/reasoning_gym/cognition/number_sequences.py index bc5770df..b09a070d 100644 --- a/reasoning_gym/cognition/number_sequences.py +++ b/reasoning_gym/cognition/number_sequences.py @@ -31,7 +31,7 @@ class NumberSequenceConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_terms >= 4, "need at least 4 terms to establish pattern" assert self.max_terms >= self.min_terms @@ -155,9 +155,7 @@ class NumberSequenceDataset(ProceduralDataset): """Generates number sequence completion tasks with dynamic pattern generation""" def __init__(self, config: NumberSequenceConfig): - self.config = config - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: """Generate a sequence task with a newly generated pattern""" diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index bbfe3895..3527e289 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -11,6 +11,9 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): def __init__(self, config: Any, seed: Optional[int] = None, size: int = 500): """Initialize the dataset with config, optional seed and size""" + if hasattr(config, "validate") and callable(config.validate): + config.validate() + self.config = config self.size = size self.seed = seed if seed is not None else Random().randint(0, 2**32) diff --git a/reasoning_gym/games/maze.py b/reasoning_gym/games/maze.py index 439c2809..18a40fb8 100644 --- a/reasoning_gym/games/maze.py +++ b/reasoning_gym/games/maze.py @@ -26,7 +26,7 @@ class MazeConfig: seed: Optional[int] = None size: int = 50 - def validate(self): + def validate(self) -> None: """Validate configuration parameters.""" assert self.min_dist >= 1, "min_dist must be >= 1" assert self.max_dist >= self.min_dist, "max_dist must be >= min_dist" @@ -46,8 +46,7 @@ class MazeDataset(ProceduralDataset): prob_path=0.7, num_retries=1000, ): - config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) self.config = config # Probability that a cell is a path instead of a wall self.prob_path = prob_path diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index db30d69f..61b96b9f 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -166,7 +166,7 @@ class FamilyRelationshipsConfig: if self.female_names is None: self.female_names = default_female_names - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert self.min_family_size >= 3, "min_family_size must be at least 3" assert self.max_family_size >= self.min_family_size, "max_family_size must be >= min_family_size" @@ -178,14 +178,12 @@ class FamilyRelationshipsDataset(ProceduralDataset): """Generates family relationship reasoning tasks""" def __init__(self, config: FamilyRelationshipsConfig): - self.config = config - self.config.validate() self._templates = [ "What is {person1} to {person2}?", "How is {person1} related to {person2}?", "What relation is {person1} to {person2}?", ] - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: rng = random.Random(self.seed + idx) diff --git a/reasoning_gym/logic/syllogisms.py b/reasoning_gym/logic/syllogisms.py index 4a072e64..ad25feac 100644 --- a/reasoning_gym/logic/syllogisms.py +++ b/reasoning_gym/logic/syllogisms.py @@ -45,7 +45,7 @@ class SyllogismConfig: seed: Optional[int] = None size: int = 500 - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert any( [self.allow_all, self.allow_no, self.allow_some, self.allow_some_not] @@ -100,11 +100,8 @@ class SyllogismDataset(ProceduralDataset): ] def __init__(self, config: SyllogismConfig): - self.config = config - if self.config.terms is None: - self.config.terms = self.DEFAULT_TERMS - self.config.validate() - super().__init__(seed=config.seed, size=config.size) + super().__init__(config=config, seed=config.seed, size=config.size) + self.terms = self.DEFAULT_TERMS if config.terms is None else config.terms def _get_allowed_quantifiers(self) -> List[Quantifier]: """Get list of allowed quantifiers based on config""" @@ -212,7 +209,7 @@ class SyllogismDataset(ProceduralDataset): def _generate_syllogism(self, rng: Random) -> dict: """Generate a single syllogism problem""" # Select three different terms - terms = rng.sample(self.config.terms, 3) + terms = rng.sample(self.terms, 3) quantifiers = self._get_allowed_quantifiers() # Generate premises and conclusion