diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py index 4bc56162..20afc9f8 100644 --- a/reasoning_gym/__init__.py +++ b/reasoning_gym/__init__.py @@ -3,6 +3,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasonin """ from . import algebra, algorithmic, arithmetic, cognition, data, games, graphs, logic +from .factory import create_dataset, register_dataset __version__ = "0.1.1" __all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"] diff --git a/reasoning_gym/algebra/__init__.py b/reasoning_gym/algebra/__init__.py index 84c3d302..69d4b91e 100644 --- a/reasoning_gym/algebra/__init__.py +++ b/reasoning_gym/algebra/__init__.py @@ -1,11 +1,9 @@ -from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset, polynomial_equations_dataset -from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset, simple_equations_dataset +from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset +from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset __all__ = [ "SimpleEquationsDataset", "SimpleEquationsConfig", - "simple_equations_dataset", "PolynomialEquationsConfig", "PolynomialEquationsDataset", - "polynomial_equations_dataset", ] diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index 6c47c39b..ed7e857f 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple from sympy import Eq, Symbol, expand, solve -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -147,31 +147,4 @@ class PolynomialEquationsDataset(ProceduralDataset): return polynomial_expr -def polynomial_equations_dataset( - min_terms: int = 2, - max_terms: int = 4, - min_value: int = 1, - max_value: int = 100, - min_degree: int = 1, - max_degree: int = 3, - operators: Tuple[str, ...] = ("+", "-"), - seed: Optional[int] = None, - size: int = 500, -) -> PolynomialEquationsDataset: - """ - Factory function for creating a PolynomialEquationsDataset. - Example usage: - dataset = polynomial_equations_dataset(min_degree=2, max_degree=3, ...) - """ - config = PolynomialEquationsConfig( - min_terms=min_terms, - max_terms=max_terms, - min_value=min_value, - max_value=max_value, - min_degree=min_degree, - max_degree=max_degree, - operators=operators, - seed=seed, - size=size, - ) - return PolynomialEquationsDataset(config) +register_dataset("polynomial_equations", PolynomialEquationsDataset, PolynomialEquationsConfig) diff --git a/reasoning_gym/algebra/simple_equations.py b/reasoning_gym/algebra/simple_equations.py index 40e49808..5a85fcb5 100644 --- a/reasoning_gym/algebra/simple_equations.py +++ b/reasoning_gym/algebra/simple_equations.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple import sympy from sympy import Eq, Symbol, solve -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -116,23 +116,4 @@ class SimpleEquationsDataset(ProceduralDataset): return f"{left_side} = {right_side}", solution_value -def simple_equations_dataset( - min_terms: int = 2, - max_terms: int = 5, - min_value: int = 1, - max_value: int = 100, - operators: tuple = ("+", "-", "*"), - seed: Optional[int] = None, - size: int = 500, -) -> SimpleEquationsDataset: - """Create a SimpleEquationsDataset with the given configuration""" - config = SimpleEquationsConfig( - min_terms=min_terms, - max_terms=max_terms, - min_value=min_value, - max_value=max_value, - operators=operators, - seed=seed, - size=size, - ) - return SimpleEquationsDataset(config) +register_dataset("simple_equations", SimpleEquationsDataset, SimpleEquationsConfig) diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index bf33e5f3..fe8c4a5b 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -6,31 +6,21 @@ Algorithmic tasks for training reasoning capabilities: - Pattern matching """ -from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset -from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset - -from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset -from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset -from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset -from .number_sorting import NumberSortingConfig, NumberSortingDataset, number_sorting_dataset -from .word_reversal import WordReversalConfig, WordReversalDataset, word_reversal_dataset +from .base_conversion import BaseConversionConfig, BaseConversionDataset +from .letter_counting import LetterCountingConfig, LetterCountingDataset +from .number_filtering import NumberFilteringConfig, NumberFilteringDataset +from .number_sorting import NumberSortingConfig, NumberSortingDataset +from .word_reversal import WordReversalConfig, WordReversalDataset __all__ = [ - "basic_arithmetic_dataset", "BaseConversionConfig", "BaseConversionDataset", - "base_conversion_dataset", - "chain_sum_dataset", "LetterCountingConfig", "LetterCountingDataset", - "letter_counting_dataset", "NumberFilteringConfig", "NumberFilteringDataset", - "number_filtering_dataset", "NumberSortingConfig", "NumberSortingDataset", - "number_sorting_dataset", "WordReversalConfig", "WordReversalDataset", - "word_reversal_dataset", ] diff --git a/reasoning_gym/algorithmic/base_conversion.py b/reasoning_gym/algorithmic/base_conversion.py index 3a73d243..eb0978bd 100644 --- a/reasoning_gym/algorithmic/base_conversion.py +++ b/reasoning_gym/algorithmic/base_conversion.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from random import Random from typing import Optional, Tuple -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -88,21 +88,4 @@ class BaseConversionDataset(ProceduralDataset): } -def base_conversion_dataset( - min_base: int = 2, - max_base: int = 16, - min_value: int = 0, - max_value: int = 1000, - seed: Optional[int] = None, - size: int = 500, -) -> BaseConversionDataset: - """Create a BaseConversionDataset with the given configuration.""" - config = BaseConversionConfig( - min_base=min_base, - max_base=max_base, - min_value=min_value, - max_value=max_value, - seed=seed, - size=size, - ) - return BaseConversionDataset(config) +register_dataset("base_conversion", BaseConversionDataset, BaseConversionConfig) diff --git a/reasoning_gym/algorithmic/letter_counting.py b/reasoning_gym/algorithmic/letter_counting.py index 06ae7b59..1ef33148 100644 --- a/reasoning_gym/algorithmic/letter_counting.py +++ b/reasoning_gym/algorithmic/letter_counting.py @@ -7,7 +7,7 @@ from typing import List, Optional from reasoning_gym.data import read_data_file -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -63,17 +63,4 @@ class LetterCountingDataset(ProceduralDataset): } -def letter_counting_dataset( - min_words: int = 5, - max_words: int = 15, - seed: Optional[int] = None, - size: int = 500, -) -> LetterCountingDataset: - """Create a LetterCountingDataset with the given configuration.""" - config = LetterCountingConfig( - min_words=min_words, - max_words=max_words, - seed=seed, - size=size, - ) - return LetterCountingDataset(config) +register_dataset("letter_counting", LetterCountingDataset, LetterCountingConfig) diff --git a/reasoning_gym/algorithmic/number_filtering.py b/reasoning_gym/algorithmic/number_filtering.py index 2efc8368..cc05992b 100644 --- a/reasoning_gym/algorithmic/number_filtering.py +++ b/reasoning_gym/algorithmic/number_filtering.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from random import Random from typing import List, Optional, Tuple -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -98,25 +98,4 @@ class NumberFilteringDataset(ProceduralDataset): } -def number_filtering_dataset( - min_numbers: int = 3, - max_numbers: int = 10, - min_decimals: int = 0, - max_decimals: int = 4, - min_value: float = -100.0, - max_value: float = 100.0, - seed: Optional[int] = None, - size: int = 500, -) -> NumberFilteringDataset: - """Create a NumberFilteringDataset with the given configuration.""" - config = NumberFilteringConfig( - min_numbers=min_numbers, - max_numbers=max_numbers, - min_decimals=min_decimals, - max_decimals=max_decimals, - min_value=min_value, - max_value=max_value, - seed=seed, - size=size, - ) - return NumberFilteringDataset(config) +register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig) diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index 362a4e7d..d922aa74 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from random import Random from typing import List, Optional, Tuple -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -86,25 +86,4 @@ class NumberSortingDataset(ProceduralDataset): } -def number_sorting_dataset( - min_numbers: int = 3, - max_numbers: int = 10, - min_decimals: int = 0, - max_decimals: int = 2, - min_value: float = -100.0, - max_value: float = 100.0, - seed: Optional[int] = None, - size: int = 500, -) -> NumberSortingDataset: - """Create a NumberSortingDataset with the given configuration.""" - config = NumberSortingConfig( - min_numbers=min_numbers, - max_numbers=max_numbers, - min_decimals=min_decimals, - max_decimals=max_decimals, - min_value=min_value, - max_value=max_value, - seed=seed, - size=size, - ) - return NumberSortingDataset(config) +register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig) diff --git a/reasoning_gym/algorithmic/word_reversal.py b/reasoning_gym/algorithmic/word_reversal.py index 4919b835..b08b459d 100644 --- a/reasoning_gym/algorithmic/word_reversal.py +++ b/reasoning_gym/algorithmic/word_reversal.py @@ -6,7 +6,7 @@ from random import Random from typing import List, Optional from ..data import read_data_file -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -55,17 +55,4 @@ class WordReversalDataset(ProceduralDataset): } -def word_reversal_dataset( - min_words: int = 3, - max_words: int = 8, - seed: Optional[int] = None, - size: int = 500, -) -> WordReversalDataset: - """Create a WordReversalDataset with the given configuration.""" - config = WordReversalConfig( - min_words=min_words, - max_words=max_words, - seed=seed, - size=size, - ) - return WordReversalDataset(config) +register_dataset("word_reversal", WordReversalDataset, WordReversalConfig) diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 2ac85f97..12a6ee89 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -6,17 +6,13 @@ Arithmetic tasks for training reasoning capabilities: - Leg counting """ -from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset -from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset -from .fraction_simplification import ( - FractionSimplificationConfig, - FractionSimplificationDataset, - fraction_simplification_dataset, -) -from .gcd import GCDConfig, GCDDataset, gcd_dataset -from .lcm import LCMConfig, LCMDataset, lcm_dataset -from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset -from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset, prime_factorization_dataset +from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig +from .chain_sum import ChainSum, ChainSumConfig +from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset +from .gcd import GCDConfig, GCDDataset +from .lcm import LCMConfig, LCMDataset +from .leg_counting import LegCountingConfig, LegCountingDataset +from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset __all__ = [ "BasicArithmeticDataset", @@ -24,20 +20,14 @@ __all__ = [ "basic_arithmetic_dataset", "ChainSum", "ChainSumConfig", - "chain_sum_dataset", "FractionSimplificationConfig", "FractionSimplificationDataset", - "fraction_simplification_dataset", "GCDConfig", "GCDDataset", - "gcd_dataset", "LCMConfig", "LCMDataset", - "lcm_dataset", "LegCountingConfig", "LegCountingDataset", - "leg_counting_dataset", "PrimeFactorizationConfig", "PrimeFactorizationDataset", - "prime_factorization_dataset", ] diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 5a47fe7f..9ec096ee 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Literal, Optional -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -231,47 +231,5 @@ class BasicArithmeticDataset(ProceduralDataset): return rng.choice(templates).format(expression) -def basic_arithmetic_dataset( - min_terms: int = 2, - max_terms: int = 6, - min_digits: int = 1, - max_digits: int = 4, - operators: list[str] = ("+", "-", "*", "/"), - allow_parentheses: bool = True, - allow_negation: bool = True, - seed: Optional[int] = None, - size: int = 500, - format_style: Literal["simple", "natural"] = "simple", - whitespace: Literal["no_space", "single", "random"] = "single", -) -> BasicArithmeticDataset: - """Create a BasicArithmeticDataset with the given configuration. - - Args: - min_terms: Minimum number of terms in expressions - max_terms: Maximum number of terms in expressions - min_digits: Minimum number of digits in numbers - max_digits: Maximum number of digits in numbers - operators: List of operators to use ("+", "-", "*") - allow_parentheses: Whether to allow parentheses in expressions - allow_negation: Whether to allow negative numbers - seed: Random seed for reproducibility - size: Virtual size of the dataset - format_style: Style of question formatting ("simple" or "natural") - - Returns: - BasicArithmeticDataset: Configured dataset instance - """ - config = BasicArithmeticDatasetConfig( - min_terms=min_terms, - max_terms=max_terms, - min_digits=min_digits, - max_digits=max_digits, - operators=operators, - allow_parentheses=allow_parentheses, - allow_negation=allow_negation, - seed=seed, - size=size, - format_style=format_style, - whitespace=whitespace, - ) - return BasicArithmeticDataset(config) +# Register the dataset +register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig) diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 7453b68f..18519c93 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -2,8 +2,7 @@ import random from dataclasses import dataclass from typing import Optional -from ..dataset import ProceduralDataset -from ..factory import register_dataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -109,40 +108,5 @@ class ChainSum(ProceduralDataset): return expression, result -def chain_sum_dataset( - min_terms: int = 2, - max_terms: int = 6, - min_digits: int = 1, - max_digits: int = 4, - allow_negation: bool = False, - seed: Optional[int] = None, - size: int = 500, -) -> ChainSum: - """Create a ChainSum dataset with the given configuration. - - Args: - min_terms: Minimum number of terms in expressions - max_terms: Maximum number of terms in expressions - min_digits: Minimum number of digits in numbers - max_digits: Maximum number of digits in numbers - allow_negation: Whether to allow negative numbers - seed: Random seed for reproducibility - size: Virtual size of the dataset - - Returns: - ChainSum: Configured dataset instance - """ - config = ChainSumConfig( - min_terms=min_terms, - max_terms=max_terms, - min_digits=min_digits, - max_digits=max_digits, - allow_negation=allow_negation, - seed=seed, - size=size, - ) - return ChainSum(config) - - # Register the dataset register_dataset("chain_sum", ChainSum, ChainSumConfig) diff --git a/reasoning_gym/arithmetic/fraction_simplification.py b/reasoning_gym/arithmetic/fraction_simplification.py index 7424af9f..cfefc422 100644 --- a/reasoning_gym/arithmetic/fraction_simplification.py +++ b/reasoning_gym/arithmetic/fraction_simplification.py @@ -5,7 +5,7 @@ from math import gcd from random import Random from typing import Optional, Sequence, Tuple -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -121,23 +121,4 @@ class FractionSimplificationDataset(ProceduralDataset): } -def fraction_simplification_dataset( - min_value: int = 1, - max_value: int = 100, - min_factor: int = 2, - max_factor: int = 10, - styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac"), - seed: Optional[int] = None, - size: int = 500, -) -> FractionSimplificationDataset: - """Create a FractionSimplificationDataset with the given configuration.""" - config = FractionSimplificationConfig( - min_value=min_value, - max_value=max_value, - min_factor=min_factor, - max_factor=max_factor, - styles=styles, - seed=seed, - size=size, - ) - return FractionSimplificationDataset(config) +register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig) diff --git a/reasoning_gym/arithmetic/gcd.py b/reasoning_gym/arithmetic/gcd.py index d24d86cf..ce30a127 100644 --- a/reasoning_gym/arithmetic/gcd.py +++ b/reasoning_gym/arithmetic/gcd.py @@ -6,7 +6,7 @@ from math import gcd from random import Random from typing import List, Optional, Tuple -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -63,21 +63,4 @@ class GCDDataset(ProceduralDataset): } -def gcd_dataset( - min_numbers: int = 2, - max_numbers: int = 2, - min_value: int = 1, - max_value: int = 10_000, - seed: Optional[int] = None, - size: int = 500, -) -> GCDDataset: - """Create a GCDDataset with the given configuration.""" - config = GCDConfig( - min_numbers=min_numbers, - max_numbers=max_numbers, - min_value=min_value, - max_value=max_value, - seed=seed, - size=size, - ) - return GCDDataset(config) +register_dataset("gcd", GCDDataset, GCDConfig) diff --git a/reasoning_gym/arithmetic/lcm.py b/reasoning_gym/arithmetic/lcm.py index a643f406..19402fd9 100644 --- a/reasoning_gym/arithmetic/lcm.py +++ b/reasoning_gym/arithmetic/lcm.py @@ -6,7 +6,7 @@ from math import lcm from random import Random from typing import List, Optional, Tuple -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -66,21 +66,4 @@ class LCMDataset(ProceduralDataset): } -def lcm_dataset( - min_numbers: int = 2, - max_numbers: int = 2, - min_value: int = 1, - max_value: int = 100, - seed: Optional[int] = None, - size: int = 500, -) -> LCMDataset: - """Create a LCMDataset with the given configuration.""" - config = LCMConfig( - min_numbers=min_numbers, - max_numbers=max_numbers, - min_value=min_value, - max_value=max_value, - seed=seed, - size=size, - ) - return LCMDataset(config) +register_dataset("lcm", LCMDataset, LCMConfig) diff --git a/reasoning_gym/arithmetic/leg_counting.py b/reasoning_gym/arithmetic/leg_counting.py index 54640190..de950631 100644 --- a/reasoning_gym/arithmetic/leg_counting.py +++ b/reasoning_gym/arithmetic/leg_counting.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from random import Random from typing import Dict, Optional -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset ANIMALS = { # Animals with 0 legs @@ -115,19 +115,4 @@ class LegCountingDataset(ProceduralDataset): } -def leg_counting_dataset( - min_animals: int = 2, - max_animals: int = 5, - max_instances: int = 3, - seed: Optional[int] = None, - size: int = 500, -) -> LegCountingDataset: - """Create a LegCountingDataset with the given configuration.""" - config = LegCountingConfig( - min_animals=min_animals, - max_animals=max_animals, - max_instances=max_instances, - seed=seed, - size=size, - ) - return LegCountingDataset(config) +register_dataset("leg_counting", LegCountingDataset, LegCountingConfig) diff --git a/reasoning_gym/arithmetic/prime_factorization.py b/reasoning_gym/arithmetic/prime_factorization.py index d3416ba0..c51f90ee 100644 --- a/reasoning_gym/arithmetic/prime_factorization.py +++ b/reasoning_gym/arithmetic/prime_factorization.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from random import Random from typing import List, Optional, Tuple -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -66,17 +66,4 @@ class PrimeFactorizationDataset(ProceduralDataset): } -def prime_factorization_dataset( - min_value: int = 2, - max_value: int = 1000, - seed: Optional[int] = None, - size: int = 500, -) -> PrimeFactorizationDataset: - """Create a PrimeFactorizationDataset with the given configuration.""" - config = PrimeFactorizationConfig( - min_value=min_value, - max_value=max_value, - seed=seed, - size=size, - ) - return PrimeFactorizationDataset(config) +register_dataset("prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig) diff --git a/reasoning_gym/cognition/__init__.py b/reasoning_gym/cognition/__init__.py index 2d3b87a6..f5d43196 100644 --- a/reasoning_gym/cognition/__init__.py +++ b/reasoning_gym/cognition/__init__.py @@ -6,14 +6,12 @@ Cognition tasks for training reasoning capabilities: - Working memory """ -from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset, color_cube_rotation_dataset -from .number_sequences import NumberSequenceConfig, NumberSequenceDataset, number_sequence_dataset +from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset +from .number_sequences import NumberSequenceConfig, NumberSequenceDataset __all__ = [ "NumberSequenceConfig", "NumberSequenceDataset", - "number_sequence_dataset", "ColorCubeRotationConfig", "ColorCubeRotationDataset", - "color_cube_rotation_dataset", ] diff --git a/reasoning_gym/cognition/color_cube_rotation.py b/reasoning_gym/cognition/color_cube_rotation.py index 92357756..42069423 100644 --- a/reasoning_gym/cognition/color_cube_rotation.py +++ b/reasoning_gym/cognition/color_cube_rotation.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import StrEnum from typing import Dict, List, Optional, Tuple -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset class Color(StrEnum): @@ -189,17 +189,4 @@ class ColorCubeRotationDataset(ProceduralDataset): return "\n".join(story_parts) -def color_cube_rotation_dataset( - min_rotations: int = 1, - max_rotations: int = 3, - seed: Optional[int] = None, - size: int = 500, -) -> ColorCubeRotationDataset: - """Create a ColorCubeRotationDataset with the given configuration""" - config = ColorCubeRotationConfig( - min_rotations=min_rotations, - max_rotations=max_rotations, - seed=seed, - size=size, - ) - return ColorCubeRotationDataset(config) +register_dataset("color_cube_rotation", ColorCubeRotationDataset, ColorCubeRotationConfig) diff --git a/reasoning_gym/cognition/number_sequences.py b/reasoning_gym/cognition/number_sequences.py index b09a070d..bac6a18a 100644 --- a/reasoning_gym/cognition/number_sequences.py +++ b/reasoning_gym/cognition/number_sequences.py @@ -3,7 +3,7 @@ from enum import StrEnum from random import Random from typing import List, Optional -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset class Operation(StrEnum): @@ -198,23 +198,4 @@ class NumberSequenceDataset(ProceduralDataset): } -def number_sequence_dataset( - min_terms: int = 4, - max_terms: int = 8, - min_value: int = -100, - max_value: int = 100, - max_complexity: int = 3, - seed: Optional[int] = None, - size: int = 500, -) -> NumberSequenceDataset: - """Create a NumberSequenceDataset with the given configuration.""" - config = NumberSequenceConfig( - min_terms=min_terms, - max_terms=max_terms, - min_value=min_value, - max_value=max_value, - max_complexity=max_complexity, - seed=seed, - size=size, - ) - return NumberSequenceDataset(config) +register_dataset("number_sequence", NumberSequenceDataset, NumberSequenceConfig) diff --git a/reasoning_gym/factory.py b/reasoning_gym/factory.py index b482d1b1..274d644d 100644 --- a/reasoning_gym/factory.py +++ b/reasoning_gym/factory.py @@ -1,75 +1,58 @@ from dataclasses import is_dataclass -from typing import Any, Dict, Type, TypeVar +from typing import Dict, Type, TypeVar from .dataset import ProceduralDataset # Type variables for generic type hints -ConfigT = TypeVar('ConfigT') -DatasetT = TypeVar('DatasetT', bound=ProceduralDataset) +ConfigT = TypeVar("ConfigT") +DatasetT = TypeVar("DatasetT", bound=ProceduralDataset) # Global registry of datasets _DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {} -def register_dataset( - name: str, - dataset_cls: Type[DatasetT], - config_cls: Type[ConfigT] -) -> None: + +def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None: """ Register a dataset class with its configuration class. - + Args: name: Unique identifier for the dataset dataset_cls: Class derived from ProceduralDataset config_cls: Configuration dataclass for the dataset - + Raises: ValueError: If name is already registered or invalid types provided """ if name in _DATASETS: raise ValueError(f"Dataset '{name}' is already registered") - + if not issubclass(dataset_cls, ProceduralDataset): - raise ValueError( - f"Dataset class must inherit from ProceduralDataset, got {dataset_cls}" - ) - + raise ValueError(f"Dataset class must inherit from ProceduralDataset, got {dataset_cls}") + if not is_dataclass(config_cls): - raise ValueError( - f"Config class must be a dataclass, got {config_cls}" - ) - + raise ValueError(f"Config class must be a dataclass, got {config_cls}") + _DATASETS[name] = (dataset_cls, config_cls) -def create_dataset( - name: str, - config: Any, - seed: int = None, - size: int = 500 -) -> ProceduralDataset: + +def create_dataset(name: str, **kwargs) -> ProceduralDataset: """ Create a dataset instance by name with the given configuration. - + Args: name: Registered dataset name - config: Configuration instance for the dataset - seed: Optional random seed - size: Size of the dataset (default: 500) - + Returns: Configured dataset instance - + Raises: ValueError: If dataset not found or config type mismatch """ if name not in _DATASETS: - raise ValueError(f"Dataset '{name}' not found") - + raise ValueError(f"Dataset '{name}' not registered") + dataset_cls, config_cls = _DATASETS[name] - - if not isinstance(config, config_cls): - raise ValueError( - f"Config must be instance of {config_cls.__name__}, got {type(config).__name__}" - ) - - return dataset_cls(config=config, seed=seed, size=size) + + conifg = config_cls(**kwargs) + + return dataset_cls(config=conifg) diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 507fc7d9..44a92ec6 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -5,18 +5,15 @@ Game tasks for training reasoning capabilities: - Strategy games """ -from .maze import MazeConfig, MazeDataset, maze_dataset -from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset, mini_sudoku_dataset -from .sudoku import SudokuConfig, SudokuDataset, sudoku_dataset +from .maze import MazeConfig, MazeDataset +from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset +from .sudoku import SudokuConfig, SudokuDataset __all__ = [ "MiniSudokuConfig", "MiniSudokuDataset", - "mini_sudoku_dataset", "SudokuConfig", "SudokuDataset", - "sudoku_dataset", "MazeConfig", "MazeDataset", - "maze_dataset", ] diff --git a/reasoning_gym/games/maze.py b/reasoning_gym/games/maze.py index 18a40fb8..2c8cd9bd 100644 --- a/reasoning_gym/games/maze.py +++ b/reasoning_gym/games/maze.py @@ -3,7 +3,7 @@ import string from dataclasses import dataclass from typing import Optional -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -47,7 +47,6 @@ class MazeDataset(ProceduralDataset): num_retries=1000, ): super().__init__(config=config, seed=config.seed, size=config.size) - self.config = config # Probability that a cell is a path instead of a wall self.prob_path = prob_path # Number of times to resample a grid to find a suitable maze before giving up @@ -184,21 +183,4 @@ class MazeDataset(ProceduralDataset): return "\n".join("".join(row) for row in grid) -def maze_dataset( - min_dist: int = 5, - max_dist: int = 10, - min_grid_size: int = 5, - max_grid_size: int = 10, - seed: Optional[int] = None, - size: int = 50, -) -> MazeDataset: - """Convenient function to create a MazeDataset.""" - config = MazeConfig( - min_dist=min_dist, - max_dist=max_dist, - min_grid_size=min_grid_size, - max_grid_size=max_grid_size, - seed=seed, - size=size, - ) - return MazeDataset(config) +register_dataset("maze", MazeDataset, MazeConfig) diff --git a/reasoning_gym/games/mini_sudoku.py b/reasoning_gym/games/mini_sudoku.py index a08c8123..23ddf959 100644 --- a/reasoning_gym/games/mini_sudoku.py +++ b/reasoning_gym/games/mini_sudoku.py @@ -1,9 +1,10 @@ """Mini Sudoku (4x4) puzzle generator""" -import random from dataclasses import dataclass from random import Random -from typing import List, Optional, Set, Tuple +from typing import List, Optional, Tuple + +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -21,13 +22,11 @@ class MiniSudokuConfig: assert self.min_empty <= self.max_empty <= 16, "max_empty must be between min_empty and 16" -class MiniSudokuDataset: +class MiniSudokuDataset(ProceduralDataset): """Generates 4x4 sudoku puzzles with configurable difficulty""" def __init__(self, config: MiniSudokuConfig): - self.config = config - self.config.validate() - self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) + super().__init__(config=config, seed=config.seed, size=config.size) def __len__(self) -> int: return self.config.size @@ -149,17 +148,4 @@ class MiniSudokuDataset: } -def mini_sudoku_dataset( - min_empty: int = 8, - max_empty: int = 12, - seed: Optional[int] = None, - size: int = 500, -) -> MiniSudokuDataset: - """Create a MiniSudokuDataset with the given configuration.""" - config = MiniSudokuConfig( - min_empty=min_empty, - max_empty=max_empty, - seed=seed, - size=size, - ) - return MiniSudokuDataset(config) +register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig) diff --git a/reasoning_gym/games/sudoku.py b/reasoning_gym/games/sudoku.py index a47b7fbf..9268546c 100644 --- a/reasoning_gym/games/sudoku.py +++ b/reasoning_gym/games/sudoku.py @@ -1,9 +1,10 @@ """Sudoku puzzle generator""" -import random from dataclasses import dataclass from random import Random -from typing import List, Optional, Set, Tuple +from typing import List, Optional, Tuple + +from ..factory import ProceduralDataset, register_dataset @dataclass @@ -21,13 +22,11 @@ class SudokuConfig: assert self.min_empty <= self.max_empty <= 81, "max_empty must be between min_empty and 81" -class SudokuDataset: +class SudokuDataset(ProceduralDataset): """Generates sudoku puzzles with configurable difficulty""" def __init__(self, config: SudokuConfig): - self.config = config - self.config.validate() - self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) + super().__init__(config=config, seed=config.seed, size=config.size) def __len__(self) -> int: return self.config.size @@ -139,17 +138,4 @@ class SudokuDataset: } -def sudoku_dataset( - min_empty: int = 30, - max_empty: int = 50, - seed: Optional[int] = None, - size: int = 500, -) -> SudokuDataset: - """Create a SudokuDataset with the given configuration.""" - config = SudokuConfig( - min_empty=min_empty, - max_empty=max_empty, - seed=seed, - size=size, - ) - return SudokuDataset(config) +register_dataset("sudoku", SudokuDataset, SudokuConfig) diff --git a/reasoning_gym/graphs/__init__.py b/reasoning_gym/graphs/__init__.py index 399e0101..8ede1fe9 100644 --- a/reasoning_gym/graphs/__init__.py +++ b/reasoning_gym/graphs/__init__.py @@ -1,3 +1,3 @@ -from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset, family_relationships_dataset +from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset -__all__ = ["FamilyRelationshipsDataset", "FamilyRelationshipsConfig", "family_relationships_dataset"] +__all__ = ["FamilyRelationshipsDataset", "FamilyRelationshipsConfig"] diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index 61b96b9f..e2c10911 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -4,7 +4,7 @@ from enum import StrEnum from itertools import count from typing import Dict, List, Optional, Set, Tuple -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset class Gender(StrEnum): @@ -310,21 +310,4 @@ class FamilyRelationshipsDataset(ProceduralDataset): return " ".join(story_parts) -def family_relationships_dataset( - min_family_size: int = 4, - max_family_size: int = 8, - male_names: List[str] = None, - female_names: List[str] = None, - seed: Optional[int] = None, - size: int = 500, -) -> FamilyRelationshipsDataset: - """Create a FamilyRelationshipsDataset with the given configuration""" - config = FamilyRelationshipsConfig( - min_family_size=min_family_size, - max_family_size=max_family_size, - male_names=male_names, - female_names=female_names, - seed=seed, - size=size, - ) - return FamilyRelationshipsDataset(config) +register_dataset("family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig) diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py index 9ad19305..c2c07625 100644 --- a/reasoning_gym/logic/__init__.py +++ b/reasoning_gym/logic/__init__.py @@ -6,13 +6,12 @@ Logic tasks for training reasoning capabilities: - Syllogisms """ -from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset, propositional_logic_dataset -from .syllogisms import SyllogismConfig, SyllogismDataset, Term, syllogism_dataset +from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset +from .syllogisms import SyllogismConfig, SyllogismDataset, Term __all__ = [ "PropositionalLogicConfig", "PropositionalLogicDataset", - "propositional_logic_dataset", "SyllogismConfig", "SyllogismDataset", "syllogism_dataset", diff --git a/reasoning_gym/logic/propositional_logic.py b/reasoning_gym/logic/propositional_logic.py index 565527ad..395c919f 100644 --- a/reasoning_gym/logic/propositional_logic.py +++ b/reasoning_gym/logic/propositional_logic.py @@ -5,6 +5,8 @@ from enum import StrEnum from random import Random from typing import Any, List, Optional, Set +from ..factory import ProceduralDataset, register_dataset + class Operator(StrEnum): """Basic logical operators""" @@ -70,13 +72,11 @@ class Expression: return f"({self.left} {self.operator.value} {self.right})" -class PropositionalLogicDataset: +class PropositionalLogicDataset(ProceduralDataset): """Generates propositional logic reasoning tasks""" def __init__(self, config: PropositionalLogicConfig): - self.config = config - self.config.validate() - self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) + super().__init__(config=config, seed=config.seed, size=config.size) def __len__(self) -> int: return self.config.size @@ -199,23 +199,4 @@ class PropositionalLogicDataset: return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right) -def propositional_logic_dataset( - min_vars: int = 2, - max_vars: int = 4, - min_statements: int = 2, - max_statements: int = 4, - max_complexity: int = 3, - seed: Optional[int] = None, - size: int = 500, -) -> PropositionalLogicDataset: - """Create a PropositionalLogicDataset with the given configuration.""" - config = PropositionalLogicConfig( - min_vars=min_vars, - max_vars=max_vars, - min_statements=min_statements, - max_statements=max_statements, - max_complexity=max_complexity, - seed=seed, - size=size, - ) - return PropositionalLogicDataset(config) +register_dataset("propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig) diff --git a/reasoning_gym/logic/syllogisms.py b/reasoning_gym/logic/syllogisms.py index ad25feac..4e205189 100644 --- a/reasoning_gym/logic/syllogisms.py +++ b/reasoning_gym/logic/syllogisms.py @@ -5,7 +5,7 @@ from enum import Enum from random import Random from typing import List, Optional, Tuple -from ..dataset import ProceduralDataset +from ..factory import ProceduralDataset, register_dataset class Quantifier(Enum): @@ -256,27 +256,4 @@ class SyllogismDataset(ProceduralDataset): return self._generate_syllogism(rng) -def syllogism_dataset( - terms: List[Term] = None, - allow_all: bool = True, - allow_no: bool = True, - allow_some: bool = True, - allow_some_not: bool = True, - include_invalid: bool = True, - invalid_ratio: float = 0.3, - seed: Optional[int] = None, - size: int = 500, -) -> SyllogismDataset: - """Create a SyllogismDataset with the given configuration.""" - config = SyllogismConfig( - terms=terms, - allow_all=allow_all, - allow_no=allow_no, - allow_some=allow_some, - allow_some_not=allow_some_not, - include_invalid=include_invalid, - invalid_ratio=invalid_ratio, - seed=seed, - size=size, - ) - return SyllogismDataset(config) +register_dataset("syllogism", SyllogismDataset, SyllogismConfig) diff --git a/tests/test_color_cube_rotation.py b/tests/test_color_cube_rotation.py index b0d780d7..a554afdd 100644 --- a/tests/test_color_cube_rotation.py +++ b/tests/test_color_cube_rotation.py @@ -1,10 +1,12 @@ import pytest -from reasoning_gym.cognition.color_cube_rotation import Color, Cube, Side, color_cube_rotation_dataset +from reasoning_gym import create_dataset +from reasoning_gym.cognition.color_cube_rotation import Color, ColorCubeRotationDataset, Cube, Side def test_color_cube_rotation_generation(): - dataset = color_cube_rotation_dataset(seed=42, size=10) + dataset = create_dataset("color_cube_rotation", seed=42, size=10) + assert isinstance(dataset, ColorCubeRotationDataset) for item in dataset: # Check required keys exist @@ -33,15 +35,15 @@ def test_color_cube_rotation_generation(): def test_color_cube_rotation_config(): # Test invalid config raises assertion with pytest.raises(AssertionError): - dataset = color_cube_rotation_dataset(min_rotations=0) + dataset = create_dataset("color_cube_rotation", min_rotations=0) with pytest.raises(AssertionError): - dataset = color_cube_rotation_dataset(max_rotations=1, min_rotations=2) + dataset = create_dataset("color_cube_rotation", max_rotations=1, min_rotations=2) def test_deterministic_generation(): - dataset1 = color_cube_rotation_dataset(seed=42, size=5) - dataset2 = color_cube_rotation_dataset(seed=42, size=5) + dataset1 = create_dataset("color_cube_rotation", seed=42, size=5) + dataset2 = create_dataset("color_cube_rotation", seed=42, size=5) for i in range(5): assert dataset1[i]["question"] == dataset2[i]["question"] diff --git a/tests/test_family_relationships.py b/tests/test_family_relationships.py index 80d2d9f0..98e87216 100644 --- a/tests/test_family_relationships.py +++ b/tests/test_family_relationships.py @@ -1,10 +1,12 @@ import pytest -from reasoning_gym.graphs.family_relationships import Gender, Relationship, family_relationships_dataset +from reasoning_gym import create_dataset +from reasoning_gym.graphs.family_relationships import FamilyRelationshipsDataset, Relationship def test_family_relationships_generation(): - dataset = family_relationships_dataset(seed=42, size=10) + dataset = create_dataset("family_relationships", seed=42, size=10) + assert isinstance(dataset, FamilyRelationshipsDataset) for item in dataset: # Check required keys exist @@ -32,21 +34,21 @@ def test_family_relationships_generation(): def test_family_relationships_config(): # Test invalid config raises assertion with pytest.raises(AssertionError): - dataset = family_relationships_dataset(min_family_size=2) + dataset = create_dataset("family_relationships", min_family_size=2) with pytest.raises(AssertionError): - dataset = family_relationships_dataset(max_family_size=3, min_family_size=4) + dataset = create_dataset("family_relationships", max_family_size=3, min_family_size=4) with pytest.raises(AssertionError): - dataset = family_relationships_dataset(male_names=[]) + dataset = create_dataset("family_relationships", male_names=[]) with pytest.raises(AssertionError): - dataset = family_relationships_dataset(female_names=[]) + dataset = create_dataset("family_relationships", female_names=[]) def test_deterministic_generation(): - dataset1 = family_relationships_dataset(seed=42, size=5) - dataset2 = family_relationships_dataset(seed=42, size=5) + dataset1 = create_dataset("family_relationships", seed=42, size=5) + dataset2 = create_dataset("family_relationships", seed=42, size=5) for i in range(5): assert dataset1[i]["question"] == dataset2[i]["question"] @@ -54,7 +56,7 @@ def test_deterministic_generation(): def test_relationship_consistency(): - dataset = family_relationships_dataset(seed=42, size=10) + dataset = create_dataset("family_relationships", seed=42, size=10) for item in dataset: # Check that relationship matches the gender diff --git a/tests/test_maze.py b/tests/test_maze.py index 0cfda1cc..fda8ed14 100644 --- a/tests/test_maze.py +++ b/tests/test_maze.py @@ -1,6 +1,7 @@ import pytest -from reasoning_gym.games.maze import MazeConfig, MazeDataset, maze_dataset +from reasoning_gym import create_dataset +from reasoning_gym.games.maze import MazeConfig, MazeDataset def test_maze_config_validation(): @@ -38,7 +39,8 @@ def test_maze_dataset_creation(): def test_maze_dataset_items(): - ds = maze_dataset( + ds = create_dataset( + "maze", min_dist=3, max_dist=5, min_grid_size=5, @@ -62,7 +64,8 @@ def test_maze_shortest_path_correctness(): """ min_dist = 4 max_dist = 8 - ds = maze_dataset( + ds = create_dataset( + "maze", min_dist=min_dist, max_dist=max_dist, min_grid_size=5, diff --git a/tests/test_polynomial_equations.py b/tests/test_polynomial_equations.py index 16186092..6e1bb0c0 100644 --- a/tests/test_polynomial_equations.py +++ b/tests/test_polynomial_equations.py @@ -1,11 +1,8 @@ import pytest from sympy import Symbol, sympify -from reasoning_gym.algebra.polynomial_equations import ( - PolynomialEquationsConfig, - PolynomialEquationsDataset, - polynomial_equations_dataset, -) +from reasoning_gym import create_dataset +from reasoning_gym.algebra.polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset def test_polynomial_config_validation(): @@ -47,7 +44,8 @@ def test_polynomial_equations_dataset_basic(): def test_polynomial_equations_dataset_items(): """Test that generated items have correct structure""" - ds = polynomial_equations_dataset( + ds = create_dataset( + "polynomial_equations", min_terms=2, max_terms=3, min_value=1, @@ -87,7 +85,8 @@ def test_polynomial_equations_dataset_deterministic(): def test_polynomial_solutions_evaluation(): """Test that real_solutions satisfy the polynomial equation.""" - ds = polynomial_equations_dataset( + ds = create_dataset( + "polynomial_equations", min_terms=2, max_terms=4, min_value=1,