add reasoning_gym.create_dataset({name}, ...) global factory function

This commit is contained in:
Andreas Koepf 2025-01-25 00:58:34 +01:00
parent 0d2d8ba6a0
commit 519e411fa5
35 changed files with 133 additions and 598 deletions

View file

@ -3,6 +3,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasonin
""" """
from . import algebra, algorithmic, arithmetic, cognition, data, games, graphs, logic from . import algebra, algorithmic, arithmetic, cognition, data, games, graphs, logic
from .factory import create_dataset, register_dataset
__version__ = "0.1.1" __version__ = "0.1.1"
__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"] __all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"]

View file

@ -1,11 +1,9 @@
from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset, polynomial_equations_dataset from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset
from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset, simple_equations_dataset from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset
__all__ = [ __all__ = [
"SimpleEquationsDataset", "SimpleEquationsDataset",
"SimpleEquationsConfig", "SimpleEquationsConfig",
"simple_equations_dataset",
"PolynomialEquationsConfig", "PolynomialEquationsConfig",
"PolynomialEquationsDataset", "PolynomialEquationsDataset",
"polynomial_equations_dataset",
] ]

View file

@ -5,7 +5,7 @@ from typing import Optional, Tuple
from sympy import Eq, Symbol, expand, solve from sympy import Eq, Symbol, expand, solve
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -147,31 +147,4 @@ class PolynomialEquationsDataset(ProceduralDataset):
return polynomial_expr return polynomial_expr
def polynomial_equations_dataset( register_dataset("polynomial_equations", PolynomialEquationsDataset, PolynomialEquationsConfig)
min_terms: int = 2,
max_terms: int = 4,
min_value: int = 1,
max_value: int = 100,
min_degree: int = 1,
max_degree: int = 3,
operators: Tuple[str, ...] = ("+", "-"),
seed: Optional[int] = None,
size: int = 500,
) -> PolynomialEquationsDataset:
"""
Factory function for creating a PolynomialEquationsDataset.
Example usage:
dataset = polynomial_equations_dataset(min_degree=2, max_degree=3, ...)
"""
config = PolynomialEquationsConfig(
min_terms=min_terms,
max_terms=max_terms,
min_value=min_value,
max_value=max_value,
min_degree=min_degree,
max_degree=max_degree,
operators=operators,
seed=seed,
size=size,
)
return PolynomialEquationsDataset(config)

View file

@ -6,7 +6,7 @@ from typing import Optional, Tuple
import sympy import sympy
from sympy import Eq, Symbol, solve from sympy import Eq, Symbol, solve
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -116,23 +116,4 @@ class SimpleEquationsDataset(ProceduralDataset):
return f"{left_side} = {right_side}", solution_value return f"{left_side} = {right_side}", solution_value
def simple_equations_dataset( register_dataset("simple_equations", SimpleEquationsDataset, SimpleEquationsConfig)
min_terms: int = 2,
max_terms: int = 5,
min_value: int = 1,
max_value: int = 100,
operators: tuple = ("+", "-", "*"),
seed: Optional[int] = None,
size: int = 500,
) -> SimpleEquationsDataset:
"""Create a SimpleEquationsDataset with the given configuration"""
config = SimpleEquationsConfig(
min_terms=min_terms,
max_terms=max_terms,
min_value=min_value,
max_value=max_value,
operators=operators,
seed=seed,
size=size,
)
return SimpleEquationsDataset(config)

View file

@ -6,31 +6,21 @@ Algorithmic tasks for training reasoning capabilities:
- Pattern matching - Pattern matching
""" """
from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset from .base_conversion import BaseConversionConfig, BaseConversionDataset
from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset from .letter_counting import LetterCountingConfig, LetterCountingDataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset from .number_sorting import NumberSortingConfig, NumberSortingDataset
from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset from .word_reversal import WordReversalConfig, WordReversalDataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset
from .number_sorting import NumberSortingConfig, NumberSortingDataset, number_sorting_dataset
from .word_reversal import WordReversalConfig, WordReversalDataset, word_reversal_dataset
__all__ = [ __all__ = [
"basic_arithmetic_dataset",
"BaseConversionConfig", "BaseConversionConfig",
"BaseConversionDataset", "BaseConversionDataset",
"base_conversion_dataset",
"chain_sum_dataset",
"LetterCountingConfig", "LetterCountingConfig",
"LetterCountingDataset", "LetterCountingDataset",
"letter_counting_dataset",
"NumberFilteringConfig", "NumberFilteringConfig",
"NumberFilteringDataset", "NumberFilteringDataset",
"number_filtering_dataset",
"NumberSortingConfig", "NumberSortingConfig",
"NumberSortingDataset", "NumberSortingDataset",
"number_sorting_dataset",
"WordReversalConfig", "WordReversalConfig",
"WordReversalDataset", "WordReversalDataset",
"word_reversal_dataset",
] ]

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional, Tuple from typing import Optional, Tuple
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -88,21 +88,4 @@ class BaseConversionDataset(ProceduralDataset):
} }
def base_conversion_dataset( register_dataset("base_conversion", BaseConversionDataset, BaseConversionConfig)
min_base: int = 2,
max_base: int = 16,
min_value: int = 0,
max_value: int = 1000,
seed: Optional[int] = None,
size: int = 500,
) -> BaseConversionDataset:
"""Create a BaseConversionDataset with the given configuration."""
config = BaseConversionConfig(
min_base=min_base,
max_base=max_base,
min_value=min_value,
max_value=max_value,
seed=seed,
size=size,
)
return BaseConversionDataset(config)

View file

@ -7,7 +7,7 @@ from typing import List, Optional
from reasoning_gym.data import read_data_file from reasoning_gym.data import read_data_file
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -63,17 +63,4 @@ class LetterCountingDataset(ProceduralDataset):
} }
def letter_counting_dataset( register_dataset("letter_counting", LetterCountingDataset, LetterCountingConfig)
min_words: int = 5,
max_words: int = 15,
seed: Optional[int] = None,
size: int = 500,
) -> LetterCountingDataset:
"""Create a LetterCountingDataset with the given configuration."""
config = LetterCountingConfig(
min_words=min_words,
max_words=max_words,
seed=seed,
size=size,
)
return LetterCountingDataset(config)

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -98,25 +98,4 @@ class NumberFilteringDataset(ProceduralDataset):
} }
def number_filtering_dataset( register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig)
min_numbers: int = 3,
max_numbers: int = 10,
min_decimals: int = 0,
max_decimals: int = 4,
min_value: float = -100.0,
max_value: float = 100.0,
seed: Optional[int] = None,
size: int = 500,
) -> NumberFilteringDataset:
"""Create a NumberFilteringDataset with the given configuration."""
config = NumberFilteringConfig(
min_numbers=min_numbers,
max_numbers=max_numbers,
min_decimals=min_decimals,
max_decimals=max_decimals,
min_value=min_value,
max_value=max_value,
seed=seed,
size=size,
)
return NumberFilteringDataset(config)

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -86,25 +86,4 @@ class NumberSortingDataset(ProceduralDataset):
} }
def number_sorting_dataset( register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig)
min_numbers: int = 3,
max_numbers: int = 10,
min_decimals: int = 0,
max_decimals: int = 2,
min_value: float = -100.0,
max_value: float = 100.0,
seed: Optional[int] = None,
size: int = 500,
) -> NumberSortingDataset:
"""Create a NumberSortingDataset with the given configuration."""
config = NumberSortingConfig(
min_numbers=min_numbers,
max_numbers=max_numbers,
min_decimals=min_decimals,
max_decimals=max_decimals,
min_value=min_value,
max_value=max_value,
seed=seed,
size=size,
)
return NumberSortingDataset(config)

View file

@ -6,7 +6,7 @@ from random import Random
from typing import List, Optional from typing import List, Optional
from ..data import read_data_file from ..data import read_data_file
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -55,17 +55,4 @@ class WordReversalDataset(ProceduralDataset):
} }
def word_reversal_dataset( register_dataset("word_reversal", WordReversalDataset, WordReversalConfig)
min_words: int = 3,
max_words: int = 8,
seed: Optional[int] = None,
size: int = 500,
) -> WordReversalDataset:
"""Create a WordReversalDataset with the given configuration."""
config = WordReversalConfig(
min_words=min_words,
max_words=max_words,
seed=seed,
size=size,
)
return WordReversalDataset(config)

View file

@ -6,17 +6,13 @@ Arithmetic tasks for training reasoning capabilities:
- Leg counting - Leg counting
""" """
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset from .chain_sum import ChainSum, ChainSumConfig
from .fraction_simplification import ( from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
FractionSimplificationConfig, from .gcd import GCDConfig, GCDDataset
FractionSimplificationDataset, from .lcm import LCMConfig, LCMDataset
fraction_simplification_dataset, from .leg_counting import LegCountingConfig, LegCountingDataset
) from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
from .gcd import GCDConfig, GCDDataset, gcd_dataset
from .lcm import LCMConfig, LCMDataset, lcm_dataset
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset, prime_factorization_dataset
__all__ = [ __all__ = [
"BasicArithmeticDataset", "BasicArithmeticDataset",
@ -24,20 +20,14 @@ __all__ = [
"basic_arithmetic_dataset", "basic_arithmetic_dataset",
"ChainSum", "ChainSum",
"ChainSumConfig", "ChainSumConfig",
"chain_sum_dataset",
"FractionSimplificationConfig", "FractionSimplificationConfig",
"FractionSimplificationDataset", "FractionSimplificationDataset",
"fraction_simplification_dataset",
"GCDConfig", "GCDConfig",
"GCDDataset", "GCDDataset",
"gcd_dataset",
"LCMConfig", "LCMConfig",
"LCMDataset", "LCMDataset",
"lcm_dataset",
"LegCountingConfig", "LegCountingConfig",
"LegCountingDataset", "LegCountingDataset",
"leg_counting_dataset",
"PrimeFactorizationConfig", "PrimeFactorizationConfig",
"PrimeFactorizationDataset", "PrimeFactorizationDataset",
"prime_factorization_dataset",
] ]

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -231,47 +231,5 @@ class BasicArithmeticDataset(ProceduralDataset):
return rng.choice(templates).format(expression) return rng.choice(templates).format(expression)
def basic_arithmetic_dataset( # Register the dataset
min_terms: int = 2, register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig)
max_terms: int = 6,
min_digits: int = 1,
max_digits: int = 4,
operators: list[str] = ("+", "-", "*", "/"),
allow_parentheses: bool = True,
allow_negation: bool = True,
seed: Optional[int] = None,
size: int = 500,
format_style: Literal["simple", "natural"] = "simple",
whitespace: Literal["no_space", "single", "random"] = "single",
) -> BasicArithmeticDataset:
"""Create a BasicArithmeticDataset with the given configuration.
Args:
min_terms: Minimum number of terms in expressions
max_terms: Maximum number of terms in expressions
min_digits: Minimum number of digits in numbers
max_digits: Maximum number of digits in numbers
operators: List of operators to use ("+", "-", "*")
allow_parentheses: Whether to allow parentheses in expressions
allow_negation: Whether to allow negative numbers
seed: Random seed for reproducibility
size: Virtual size of the dataset
format_style: Style of question formatting ("simple" or "natural")
Returns:
BasicArithmeticDataset: Configured dataset instance
"""
config = BasicArithmeticDatasetConfig(
min_terms=min_terms,
max_terms=max_terms,
min_digits=min_digits,
max_digits=max_digits,
operators=operators,
allow_parentheses=allow_parentheses,
allow_negation=allow_negation,
seed=seed,
size=size,
format_style=format_style,
whitespace=whitespace,
)
return BasicArithmeticDataset(config)

View file

@ -2,8 +2,7 @@ import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
from ..factory import register_dataset
@dataclass @dataclass
@ -109,40 +108,5 @@ class ChainSum(ProceduralDataset):
return expression, result return expression, result
def chain_sum_dataset(
min_terms: int = 2,
max_terms: int = 6,
min_digits: int = 1,
max_digits: int = 4,
allow_negation: bool = False,
seed: Optional[int] = None,
size: int = 500,
) -> ChainSum:
"""Create a ChainSum dataset with the given configuration.
Args:
min_terms: Minimum number of terms in expressions
max_terms: Maximum number of terms in expressions
min_digits: Minimum number of digits in numbers
max_digits: Maximum number of digits in numbers
allow_negation: Whether to allow negative numbers
seed: Random seed for reproducibility
size: Virtual size of the dataset
Returns:
ChainSum: Configured dataset instance
"""
config = ChainSumConfig(
min_terms=min_terms,
max_terms=max_terms,
min_digits=min_digits,
max_digits=max_digits,
allow_negation=allow_negation,
seed=seed,
size=size,
)
return ChainSum(config)
# Register the dataset # Register the dataset
register_dataset("chain_sum", ChainSum, ChainSumConfig) register_dataset("chain_sum", ChainSum, ChainSumConfig)

View file

@ -5,7 +5,7 @@ from math import gcd
from random import Random from random import Random
from typing import Optional, Sequence, Tuple from typing import Optional, Sequence, Tuple
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -121,23 +121,4 @@ class FractionSimplificationDataset(ProceduralDataset):
} }
def fraction_simplification_dataset( register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
min_value: int = 1,
max_value: int = 100,
min_factor: int = 2,
max_factor: int = 10,
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac"),
seed: Optional[int] = None,
size: int = 500,
) -> FractionSimplificationDataset:
"""Create a FractionSimplificationDataset with the given configuration."""
config = FractionSimplificationConfig(
min_value=min_value,
max_value=max_value,
min_factor=min_factor,
max_factor=max_factor,
styles=styles,
seed=seed,
size=size,
)
return FractionSimplificationDataset(config)

View file

@ -6,7 +6,7 @@ from math import gcd
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -63,21 +63,4 @@ class GCDDataset(ProceduralDataset):
} }
def gcd_dataset( register_dataset("gcd", GCDDataset, GCDConfig)
min_numbers: int = 2,
max_numbers: int = 2,
min_value: int = 1,
max_value: int = 10_000,
seed: Optional[int] = None,
size: int = 500,
) -> GCDDataset:
"""Create a GCDDataset with the given configuration."""
config = GCDConfig(
min_numbers=min_numbers,
max_numbers=max_numbers,
min_value=min_value,
max_value=max_value,
seed=seed,
size=size,
)
return GCDDataset(config)

View file

@ -6,7 +6,7 @@ from math import lcm
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -66,21 +66,4 @@ class LCMDataset(ProceduralDataset):
} }
def lcm_dataset( register_dataset("lcm", LCMDataset, LCMConfig)
min_numbers: int = 2,
max_numbers: int = 2,
min_value: int = 1,
max_value: int = 100,
seed: Optional[int] = None,
size: int = 500,
) -> LCMDataset:
"""Create a LCMDataset with the given configuration."""
config = LCMConfig(
min_numbers=min_numbers,
max_numbers=max_numbers,
min_value=min_value,
max_value=max_value,
seed=seed,
size=size,
)
return LCMDataset(config)

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Dict, Optional from typing import Dict, Optional
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
ANIMALS = { ANIMALS = {
# Animals with 0 legs # Animals with 0 legs
@ -115,19 +115,4 @@ class LegCountingDataset(ProceduralDataset):
} }
def leg_counting_dataset( register_dataset("leg_counting", LegCountingDataset, LegCountingConfig)
min_animals: int = 2,
max_animals: int = 5,
max_instances: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> LegCountingDataset:
"""Create a LegCountingDataset with the given configuration."""
config = LegCountingConfig(
min_animals=min_animals,
max_animals=max_animals,
max_instances=max_instances,
seed=seed,
size=size,
)
return LegCountingDataset(config)

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -66,17 +66,4 @@ class PrimeFactorizationDataset(ProceduralDataset):
} }
def prime_factorization_dataset( register_dataset("prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig)
min_value: int = 2,
max_value: int = 1000,
seed: Optional[int] = None,
size: int = 500,
) -> PrimeFactorizationDataset:
"""Create a PrimeFactorizationDataset with the given configuration."""
config = PrimeFactorizationConfig(
min_value=min_value,
max_value=max_value,
seed=seed,
size=size,
)
return PrimeFactorizationDataset(config)

View file

@ -6,14 +6,12 @@ Cognition tasks for training reasoning capabilities:
- Working memory - Working memory
""" """
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset, color_cube_rotation_dataset from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset, number_sequence_dataset from .number_sequences import NumberSequenceConfig, NumberSequenceDataset
__all__ = [ __all__ = [
"NumberSequenceConfig", "NumberSequenceConfig",
"NumberSequenceDataset", "NumberSequenceDataset",
"number_sequence_dataset",
"ColorCubeRotationConfig", "ColorCubeRotationConfig",
"ColorCubeRotationDataset", "ColorCubeRotationDataset",
"color_cube_rotation_dataset",
] ]

View file

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
class Color(StrEnum): class Color(StrEnum):
@ -189,17 +189,4 @@ class ColorCubeRotationDataset(ProceduralDataset):
return "\n".join(story_parts) return "\n".join(story_parts)
def color_cube_rotation_dataset( register_dataset("color_cube_rotation", ColorCubeRotationDataset, ColorCubeRotationConfig)
min_rotations: int = 1,
max_rotations: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> ColorCubeRotationDataset:
"""Create a ColorCubeRotationDataset with the given configuration"""
config = ColorCubeRotationConfig(
min_rotations=min_rotations,
max_rotations=max_rotations,
seed=seed,
size=size,
)
return ColorCubeRotationDataset(config)

View file

@ -3,7 +3,7 @@ from enum import StrEnum
from random import Random from random import Random
from typing import List, Optional from typing import List, Optional
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
class Operation(StrEnum): class Operation(StrEnum):
@ -198,23 +198,4 @@ class NumberSequenceDataset(ProceduralDataset):
} }
def number_sequence_dataset( register_dataset("number_sequence", NumberSequenceDataset, NumberSequenceConfig)
min_terms: int = 4,
max_terms: int = 8,
min_value: int = -100,
max_value: int = 100,
max_complexity: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> NumberSequenceDataset:
"""Create a NumberSequenceDataset with the given configuration."""
config = NumberSequenceConfig(
min_terms=min_terms,
max_terms=max_terms,
min_value=min_value,
max_value=max_value,
max_complexity=max_complexity,
seed=seed,
size=size,
)
return NumberSequenceDataset(config)

View file

@ -1,75 +1,58 @@
from dataclasses import is_dataclass from dataclasses import is_dataclass
from typing import Any, Dict, Type, TypeVar from typing import Dict, Type, TypeVar
from .dataset import ProceduralDataset from .dataset import ProceduralDataset
# Type variables for generic type hints # Type variables for generic type hints
ConfigT = TypeVar('ConfigT') ConfigT = TypeVar("ConfigT")
DatasetT = TypeVar('DatasetT', bound=ProceduralDataset) DatasetT = TypeVar("DatasetT", bound=ProceduralDataset)
# Global registry of datasets # Global registry of datasets
_DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {} _DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {}
def register_dataset(
name: str, def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None:
dataset_cls: Type[DatasetT],
config_cls: Type[ConfigT]
) -> None:
""" """
Register a dataset class with its configuration class. Register a dataset class with its configuration class.
Args: Args:
name: Unique identifier for the dataset name: Unique identifier for the dataset
dataset_cls: Class derived from ProceduralDataset dataset_cls: Class derived from ProceduralDataset
config_cls: Configuration dataclass for the dataset config_cls: Configuration dataclass for the dataset
Raises: Raises:
ValueError: If name is already registered or invalid types provided ValueError: If name is already registered or invalid types provided
""" """
if name in _DATASETS: if name in _DATASETS:
raise ValueError(f"Dataset '{name}' is already registered") raise ValueError(f"Dataset '{name}' is already registered")
if not issubclass(dataset_cls, ProceduralDataset): if not issubclass(dataset_cls, ProceduralDataset):
raise ValueError( raise ValueError(f"Dataset class must inherit from ProceduralDataset, got {dataset_cls}")
f"Dataset class must inherit from ProceduralDataset, got {dataset_cls}"
)
if not is_dataclass(config_cls): if not is_dataclass(config_cls):
raise ValueError( raise ValueError(f"Config class must be a dataclass, got {config_cls}")
f"Config class must be a dataclass, got {config_cls}"
)
_DATASETS[name] = (dataset_cls, config_cls) _DATASETS[name] = (dataset_cls, config_cls)
def create_dataset(
name: str, def create_dataset(name: str, **kwargs) -> ProceduralDataset:
config: Any,
seed: int = None,
size: int = 500
) -> ProceduralDataset:
""" """
Create a dataset instance by name with the given configuration. Create a dataset instance by name with the given configuration.
Args: Args:
name: Registered dataset name name: Registered dataset name
config: Configuration instance for the dataset
seed: Optional random seed
size: Size of the dataset (default: 500)
Returns: Returns:
Configured dataset instance Configured dataset instance
Raises: Raises:
ValueError: If dataset not found or config type mismatch ValueError: If dataset not found or config type mismatch
""" """
if name not in _DATASETS: if name not in _DATASETS:
raise ValueError(f"Dataset '{name}' not found") raise ValueError(f"Dataset '{name}' not registered")
dataset_cls, config_cls = _DATASETS[name] dataset_cls, config_cls = _DATASETS[name]
if not isinstance(config, config_cls): conifg = config_cls(**kwargs)
raise ValueError(
f"Config must be instance of {config_cls.__name__}, got {type(config).__name__}" return dataset_cls(config=conifg)
)
return dataset_cls(config=config, seed=seed, size=size)

View file

@ -5,18 +5,15 @@ Game tasks for training reasoning capabilities:
- Strategy games - Strategy games
""" """
from .maze import MazeConfig, MazeDataset, maze_dataset from .maze import MazeConfig, MazeDataset
from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset, mini_sudoku_dataset from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
from .sudoku import SudokuConfig, SudokuDataset, sudoku_dataset from .sudoku import SudokuConfig, SudokuDataset
__all__ = [ __all__ = [
"MiniSudokuConfig", "MiniSudokuConfig",
"MiniSudokuDataset", "MiniSudokuDataset",
"mini_sudoku_dataset",
"SudokuConfig", "SudokuConfig",
"SudokuDataset", "SudokuDataset",
"sudoku_dataset",
"MazeConfig", "MazeConfig",
"MazeDataset", "MazeDataset",
"maze_dataset",
] ]

View file

@ -3,7 +3,7 @@ import string
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -47,7 +47,6 @@ class MazeDataset(ProceduralDataset):
num_retries=1000, num_retries=1000,
): ):
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
self.config = config
# Probability that a cell is a path instead of a wall # Probability that a cell is a path instead of a wall
self.prob_path = prob_path self.prob_path = prob_path
# Number of times to resample a grid to find a suitable maze before giving up # Number of times to resample a grid to find a suitable maze before giving up
@ -184,21 +183,4 @@ class MazeDataset(ProceduralDataset):
return "\n".join("".join(row) for row in grid) return "\n".join("".join(row) for row in grid)
def maze_dataset( register_dataset("maze", MazeDataset, MazeConfig)
min_dist: int = 5,
max_dist: int = 10,
min_grid_size: int = 5,
max_grid_size: int = 10,
seed: Optional[int] = None,
size: int = 50,
) -> MazeDataset:
"""Convenient function to create a MazeDataset."""
config = MazeConfig(
min_dist=min_dist,
max_dist=max_dist,
min_grid_size=min_grid_size,
max_grid_size=max_grid_size,
seed=seed,
size=size,
)
return MazeDataset(config)

View file

@ -1,9 +1,10 @@
"""Mini Sudoku (4x4) puzzle generator""" """Mini Sudoku (4x4) puzzle generator"""
import random
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Set, Tuple from typing import List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -21,13 +22,11 @@ class MiniSudokuConfig:
assert self.min_empty <= self.max_empty <= 16, "max_empty must be between min_empty and 16" assert self.min_empty <= self.max_empty <= 16, "max_empty must be between min_empty and 16"
class MiniSudokuDataset: class MiniSudokuDataset(ProceduralDataset):
"""Generates 4x4 sudoku puzzles with configurable difficulty""" """Generates 4x4 sudoku puzzles with configurable difficulty"""
def __init__(self, config: MiniSudokuConfig): def __init__(self, config: MiniSudokuConfig):
self.config = config super().__init__(config=config, seed=config.seed, size=config.size)
self.config.validate()
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
def __len__(self) -> int: def __len__(self) -> int:
return self.config.size return self.config.size
@ -149,17 +148,4 @@ class MiniSudokuDataset:
} }
def mini_sudoku_dataset( register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig)
min_empty: int = 8,
max_empty: int = 12,
seed: Optional[int] = None,
size: int = 500,
) -> MiniSudokuDataset:
"""Create a MiniSudokuDataset with the given configuration."""
config = MiniSudokuConfig(
min_empty=min_empty,
max_empty=max_empty,
seed=seed,
size=size,
)
return MiniSudokuDataset(config)

View file

@ -1,9 +1,10 @@
"""Sudoku puzzle generator""" """Sudoku puzzle generator"""
import random
from dataclasses import dataclass from dataclasses import dataclass
from random import Random from random import Random
from typing import List, Optional, Set, Tuple from typing import List, Optional, Tuple
from ..factory import ProceduralDataset, register_dataset
@dataclass @dataclass
@ -21,13 +22,11 @@ class SudokuConfig:
assert self.min_empty <= self.max_empty <= 81, "max_empty must be between min_empty and 81" assert self.min_empty <= self.max_empty <= 81, "max_empty must be between min_empty and 81"
class SudokuDataset: class SudokuDataset(ProceduralDataset):
"""Generates sudoku puzzles with configurable difficulty""" """Generates sudoku puzzles with configurable difficulty"""
def __init__(self, config: SudokuConfig): def __init__(self, config: SudokuConfig):
self.config = config super().__init__(config=config, seed=config.seed, size=config.size)
self.config.validate()
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
def __len__(self) -> int: def __len__(self) -> int:
return self.config.size return self.config.size
@ -139,17 +138,4 @@ class SudokuDataset:
} }
def sudoku_dataset( register_dataset("sudoku", SudokuDataset, SudokuConfig)
min_empty: int = 30,
max_empty: int = 50,
seed: Optional[int] = None,
size: int = 500,
) -> SudokuDataset:
"""Create a SudokuDataset with the given configuration."""
config = SudokuConfig(
min_empty=min_empty,
max_empty=max_empty,
seed=seed,
size=size,
)
return SudokuDataset(config)

View file

@ -1,3 +1,3 @@
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset, family_relationships_dataset from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
__all__ = ["FamilyRelationshipsDataset", "FamilyRelationshipsConfig", "family_relationships_dataset"] __all__ = ["FamilyRelationshipsDataset", "FamilyRelationshipsConfig"]

View file

@ -4,7 +4,7 @@ from enum import StrEnum
from itertools import count from itertools import count
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
class Gender(StrEnum): class Gender(StrEnum):
@ -310,21 +310,4 @@ class FamilyRelationshipsDataset(ProceduralDataset):
return " ".join(story_parts) return " ".join(story_parts)
def family_relationships_dataset( register_dataset("family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig)
min_family_size: int = 4,
max_family_size: int = 8,
male_names: List[str] = None,
female_names: List[str] = None,
seed: Optional[int] = None,
size: int = 500,
) -> FamilyRelationshipsDataset:
"""Create a FamilyRelationshipsDataset with the given configuration"""
config = FamilyRelationshipsConfig(
min_family_size=min_family_size,
max_family_size=max_family_size,
male_names=male_names,
female_names=female_names,
seed=seed,
size=size,
)
return FamilyRelationshipsDataset(config)

View file

@ -6,13 +6,12 @@ Logic tasks for training reasoning capabilities:
- Syllogisms - Syllogisms
""" """
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset, propositional_logic_dataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
from .syllogisms import SyllogismConfig, SyllogismDataset, Term, syllogism_dataset from .syllogisms import SyllogismConfig, SyllogismDataset, Term
__all__ = [ __all__ = [
"PropositionalLogicConfig", "PropositionalLogicConfig",
"PropositionalLogicDataset", "PropositionalLogicDataset",
"propositional_logic_dataset",
"SyllogismConfig", "SyllogismConfig",
"SyllogismDataset", "SyllogismDataset",
"syllogism_dataset", "syllogism_dataset",

View file

@ -5,6 +5,8 @@ from enum import StrEnum
from random import Random from random import Random
from typing import Any, List, Optional, Set from typing import Any, List, Optional, Set
from ..factory import ProceduralDataset, register_dataset
class Operator(StrEnum): class Operator(StrEnum):
"""Basic logical operators""" """Basic logical operators"""
@ -70,13 +72,11 @@ class Expression:
return f"({self.left} {self.operator.value} {self.right})" return f"({self.left} {self.operator.value} {self.right})"
class PropositionalLogicDataset: class PropositionalLogicDataset(ProceduralDataset):
"""Generates propositional logic reasoning tasks""" """Generates propositional logic reasoning tasks"""
def __init__(self, config: PropositionalLogicConfig): def __init__(self, config: PropositionalLogicConfig):
self.config = config super().__init__(config=config, seed=config.seed, size=config.size)
self.config.validate()
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
def __len__(self) -> int: def __len__(self) -> int:
return self.config.size return self.config.size
@ -199,23 +199,4 @@ class PropositionalLogicDataset:
return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right) return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right)
def propositional_logic_dataset( register_dataset("propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig)
min_vars: int = 2,
max_vars: int = 4,
min_statements: int = 2,
max_statements: int = 4,
max_complexity: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> PropositionalLogicDataset:
"""Create a PropositionalLogicDataset with the given configuration."""
config = PropositionalLogicConfig(
min_vars=min_vars,
max_vars=max_vars,
min_statements=min_statements,
max_statements=max_statements,
max_complexity=max_complexity,
seed=seed,
size=size,
)
return PropositionalLogicDataset(config)

View file

@ -5,7 +5,7 @@ from enum import Enum
from random import Random from random import Random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset from ..factory import ProceduralDataset, register_dataset
class Quantifier(Enum): class Quantifier(Enum):
@ -256,27 +256,4 @@ class SyllogismDataset(ProceduralDataset):
return self._generate_syllogism(rng) return self._generate_syllogism(rng)
def syllogism_dataset( register_dataset("syllogism", SyllogismDataset, SyllogismConfig)
terms: List[Term] = None,
allow_all: bool = True,
allow_no: bool = True,
allow_some: bool = True,
allow_some_not: bool = True,
include_invalid: bool = True,
invalid_ratio: float = 0.3,
seed: Optional[int] = None,
size: int = 500,
) -> SyllogismDataset:
"""Create a SyllogismDataset with the given configuration."""
config = SyllogismConfig(
terms=terms,
allow_all=allow_all,
allow_no=allow_no,
allow_some=allow_some,
allow_some_not=allow_some_not,
include_invalid=include_invalid,
invalid_ratio=invalid_ratio,
seed=seed,
size=size,
)
return SyllogismDataset(config)

View file

@ -1,10 +1,12 @@
import pytest import pytest
from reasoning_gym.cognition.color_cube_rotation import Color, Cube, Side, color_cube_rotation_dataset from reasoning_gym import create_dataset
from reasoning_gym.cognition.color_cube_rotation import Color, ColorCubeRotationDataset, Cube, Side
def test_color_cube_rotation_generation(): def test_color_cube_rotation_generation():
dataset = color_cube_rotation_dataset(seed=42, size=10) dataset = create_dataset("color_cube_rotation", seed=42, size=10)
assert isinstance(dataset, ColorCubeRotationDataset)
for item in dataset: for item in dataset:
# Check required keys exist # Check required keys exist
@ -33,15 +35,15 @@ def test_color_cube_rotation_generation():
def test_color_cube_rotation_config(): def test_color_cube_rotation_config():
# Test invalid config raises assertion # Test invalid config raises assertion
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = color_cube_rotation_dataset(min_rotations=0) dataset = create_dataset("color_cube_rotation", min_rotations=0)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = color_cube_rotation_dataset(max_rotations=1, min_rotations=2) dataset = create_dataset("color_cube_rotation", max_rotations=1, min_rotations=2)
def test_deterministic_generation(): def test_deterministic_generation():
dataset1 = color_cube_rotation_dataset(seed=42, size=5) dataset1 = create_dataset("color_cube_rotation", seed=42, size=5)
dataset2 = color_cube_rotation_dataset(seed=42, size=5) dataset2 = create_dataset("color_cube_rotation", seed=42, size=5)
for i in range(5): for i in range(5):
assert dataset1[i]["question"] == dataset2[i]["question"] assert dataset1[i]["question"] == dataset2[i]["question"]

View file

@ -1,10 +1,12 @@
import pytest import pytest
from reasoning_gym.graphs.family_relationships import Gender, Relationship, family_relationships_dataset from reasoning_gym import create_dataset
from reasoning_gym.graphs.family_relationships import FamilyRelationshipsDataset, Relationship
def test_family_relationships_generation(): def test_family_relationships_generation():
dataset = family_relationships_dataset(seed=42, size=10) dataset = create_dataset("family_relationships", seed=42, size=10)
assert isinstance(dataset, FamilyRelationshipsDataset)
for item in dataset: for item in dataset:
# Check required keys exist # Check required keys exist
@ -32,21 +34,21 @@ def test_family_relationships_generation():
def test_family_relationships_config(): def test_family_relationships_config():
# Test invalid config raises assertion # Test invalid config raises assertion
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = family_relationships_dataset(min_family_size=2) dataset = create_dataset("family_relationships", min_family_size=2)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = family_relationships_dataset(max_family_size=3, min_family_size=4) dataset = create_dataset("family_relationships", max_family_size=3, min_family_size=4)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = family_relationships_dataset(male_names=[]) dataset = create_dataset("family_relationships", male_names=[])
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
dataset = family_relationships_dataset(female_names=[]) dataset = create_dataset("family_relationships", female_names=[])
def test_deterministic_generation(): def test_deterministic_generation():
dataset1 = family_relationships_dataset(seed=42, size=5) dataset1 = create_dataset("family_relationships", seed=42, size=5)
dataset2 = family_relationships_dataset(seed=42, size=5) dataset2 = create_dataset("family_relationships", seed=42, size=5)
for i in range(5): for i in range(5):
assert dataset1[i]["question"] == dataset2[i]["question"] assert dataset1[i]["question"] == dataset2[i]["question"]
@ -54,7 +56,7 @@ def test_deterministic_generation():
def test_relationship_consistency(): def test_relationship_consistency():
dataset = family_relationships_dataset(seed=42, size=10) dataset = create_dataset("family_relationships", seed=42, size=10)
for item in dataset: for item in dataset:
# Check that relationship matches the gender # Check that relationship matches the gender

View file

@ -1,6 +1,7 @@
import pytest import pytest
from reasoning_gym.games.maze import MazeConfig, MazeDataset, maze_dataset from reasoning_gym import create_dataset
from reasoning_gym.games.maze import MazeConfig, MazeDataset
def test_maze_config_validation(): def test_maze_config_validation():
@ -38,7 +39,8 @@ def test_maze_dataset_creation():
def test_maze_dataset_items(): def test_maze_dataset_items():
ds = maze_dataset( ds = create_dataset(
"maze",
min_dist=3, min_dist=3,
max_dist=5, max_dist=5,
min_grid_size=5, min_grid_size=5,
@ -62,7 +64,8 @@ def test_maze_shortest_path_correctness():
""" """
min_dist = 4 min_dist = 4
max_dist = 8 max_dist = 8
ds = maze_dataset( ds = create_dataset(
"maze",
min_dist=min_dist, min_dist=min_dist,
max_dist=max_dist, max_dist=max_dist,
min_grid_size=5, min_grid_size=5,

View file

@ -1,11 +1,8 @@
import pytest import pytest
from sympy import Symbol, sympify from sympy import Symbol, sympify
from reasoning_gym.algebra.polynomial_equations import ( from reasoning_gym import create_dataset
PolynomialEquationsConfig, from reasoning_gym.algebra.polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset
PolynomialEquationsDataset,
polynomial_equations_dataset,
)
def test_polynomial_config_validation(): def test_polynomial_config_validation():
@ -47,7 +44,8 @@ def test_polynomial_equations_dataset_basic():
def test_polynomial_equations_dataset_items(): def test_polynomial_equations_dataset_items():
"""Test that generated items have correct structure""" """Test that generated items have correct structure"""
ds = polynomial_equations_dataset( ds = create_dataset(
"polynomial_equations",
min_terms=2, min_terms=2,
max_terms=3, max_terms=3,
min_value=1, min_value=1,
@ -87,7 +85,8 @@ def test_polynomial_equations_dataset_deterministic():
def test_polynomial_solutions_evaluation(): def test_polynomial_solutions_evaluation():
"""Test that real_solutions satisfy the polynomial equation.""" """Test that real_solutions satisfy the polynomial equation."""
ds = polynomial_equations_dataset( ds = create_dataset(
"polynomial_equations",
min_terms=2, min_terms=2,
max_terms=4, max_terms=4,
min_value=1, min_value=1,