add reasoning_gym.create_dataset({name}, ...) global factory function

This commit is contained in:
Andreas Koepf 2025-01-25 00:58:34 +01:00
parent 0d2d8ba6a0
commit 519e411fa5
35 changed files with 133 additions and 598 deletions

View file

@ -6,17 +6,13 @@ Arithmetic tasks for training reasoning capabilities:
- Leg counting
"""
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
from .fraction_simplification import (
FractionSimplificationConfig,
FractionSimplificationDataset,
fraction_simplification_dataset,
)
from .gcd import GCDConfig, GCDDataset, gcd_dataset
from .lcm import LCMConfig, LCMDataset, lcm_dataset
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset, prime_factorization_dataset
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
from .chain_sum import ChainSum, ChainSumConfig
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
from .gcd import GCDConfig, GCDDataset
from .lcm import LCMConfig, LCMDataset
from .leg_counting import LegCountingConfig, LegCountingDataset
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
__all__ = [
"BasicArithmeticDataset",
@ -24,20 +20,14 @@ __all__ = [
"basic_arithmetic_dataset",
"ChainSum",
"ChainSumConfig",
"chain_sum_dataset",
"FractionSimplificationConfig",
"FractionSimplificationDataset",
"fraction_simplification_dataset",
"GCDConfig",
"GCDDataset",
"gcd_dataset",
"LCMConfig",
"LCMDataset",
"lcm_dataset",
"LegCountingConfig",
"LegCountingDataset",
"leg_counting_dataset",
"PrimeFactorizationConfig",
"PrimeFactorizationDataset",
"prime_factorization_dataset",
]

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Literal, Optional
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
@dataclass
@ -231,47 +231,5 @@ class BasicArithmeticDataset(ProceduralDataset):
return rng.choice(templates).format(expression)
def basic_arithmetic_dataset(
min_terms: int = 2,
max_terms: int = 6,
min_digits: int = 1,
max_digits: int = 4,
operators: list[str] = ("+", "-", "*", "/"),
allow_parentheses: bool = True,
allow_negation: bool = True,
seed: Optional[int] = None,
size: int = 500,
format_style: Literal["simple", "natural"] = "simple",
whitespace: Literal["no_space", "single", "random"] = "single",
) -> BasicArithmeticDataset:
"""Create a BasicArithmeticDataset with the given configuration.
Args:
min_terms: Minimum number of terms in expressions
max_terms: Maximum number of terms in expressions
min_digits: Minimum number of digits in numbers
max_digits: Maximum number of digits in numbers
operators: List of operators to use ("+", "-", "*")
allow_parentheses: Whether to allow parentheses in expressions
allow_negation: Whether to allow negative numbers
seed: Random seed for reproducibility
size: Virtual size of the dataset
format_style: Style of question formatting ("simple" or "natural")
Returns:
BasicArithmeticDataset: Configured dataset instance
"""
config = BasicArithmeticDatasetConfig(
min_terms=min_terms,
max_terms=max_terms,
min_digits=min_digits,
max_digits=max_digits,
operators=operators,
allow_parentheses=allow_parentheses,
allow_negation=allow_negation,
seed=seed,
size=size,
format_style=format_style,
whitespace=whitespace,
)
return BasicArithmeticDataset(config)
# Register the dataset
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig)

View file

@ -2,8 +2,7 @@ import random
from dataclasses import dataclass
from typing import Optional
from ..dataset import ProceduralDataset
from ..factory import register_dataset
from ..factory import ProceduralDataset, register_dataset
@dataclass
@ -109,40 +108,5 @@ class ChainSum(ProceduralDataset):
return expression, result
def chain_sum_dataset(
min_terms: int = 2,
max_terms: int = 6,
min_digits: int = 1,
max_digits: int = 4,
allow_negation: bool = False,
seed: Optional[int] = None,
size: int = 500,
) -> ChainSum:
"""Create a ChainSum dataset with the given configuration.
Args:
min_terms: Minimum number of terms in expressions
max_terms: Maximum number of terms in expressions
min_digits: Minimum number of digits in numbers
max_digits: Maximum number of digits in numbers
allow_negation: Whether to allow negative numbers
seed: Random seed for reproducibility
size: Virtual size of the dataset
Returns:
ChainSum: Configured dataset instance
"""
config = ChainSumConfig(
min_terms=min_terms,
max_terms=max_terms,
min_digits=min_digits,
max_digits=max_digits,
allow_negation=allow_negation,
seed=seed,
size=size,
)
return ChainSum(config)
# Register the dataset
register_dataset("chain_sum", ChainSum, ChainSumConfig)

View file

@ -5,7 +5,7 @@ from math import gcd
from random import Random
from typing import Optional, Sequence, Tuple
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
@dataclass
@ -121,23 +121,4 @@ class FractionSimplificationDataset(ProceduralDataset):
}
def fraction_simplification_dataset(
min_value: int = 1,
max_value: int = 100,
min_factor: int = 2,
max_factor: int = 10,
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac"),
seed: Optional[int] = None,
size: int = 500,
) -> FractionSimplificationDataset:
"""Create a FractionSimplificationDataset with the given configuration."""
config = FractionSimplificationConfig(
min_value=min_value,
max_value=max_value,
min_factor=min_factor,
max_factor=max_factor,
styles=styles,
seed=seed,
size=size,
)
return FractionSimplificationDataset(config)
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)

View file

@ -6,7 +6,7 @@ from math import gcd
from random import Random
from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
@dataclass
@ -63,21 +63,4 @@ class GCDDataset(ProceduralDataset):
}
def gcd_dataset(
min_numbers: int = 2,
max_numbers: int = 2,
min_value: int = 1,
max_value: int = 10_000,
seed: Optional[int] = None,
size: int = 500,
) -> GCDDataset:
"""Create a GCDDataset with the given configuration."""
config = GCDConfig(
min_numbers=min_numbers,
max_numbers=max_numbers,
min_value=min_value,
max_value=max_value,
seed=seed,
size=size,
)
return GCDDataset(config)
register_dataset("gcd", GCDDataset, GCDConfig)

View file

@ -6,7 +6,7 @@ from math import lcm
from random import Random
from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
@dataclass
@ -66,21 +66,4 @@ class LCMDataset(ProceduralDataset):
}
def lcm_dataset(
min_numbers: int = 2,
max_numbers: int = 2,
min_value: int = 1,
max_value: int = 100,
seed: Optional[int] = None,
size: int = 500,
) -> LCMDataset:
"""Create a LCMDataset with the given configuration."""
config = LCMConfig(
min_numbers=min_numbers,
max_numbers=max_numbers,
min_value=min_value,
max_value=max_value,
seed=seed,
size=size,
)
return LCMDataset(config)
register_dataset("lcm", LCMDataset, LCMConfig)

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random
from typing import Dict, Optional
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
ANIMALS = {
# Animals with 0 legs
@ -115,19 +115,4 @@ class LegCountingDataset(ProceduralDataset):
}
def leg_counting_dataset(
min_animals: int = 2,
max_animals: int = 5,
max_instances: int = 3,
seed: Optional[int] = None,
size: int = 500,
) -> LegCountingDataset:
"""Create a LegCountingDataset with the given configuration."""
config = LegCountingConfig(
min_animals=min_animals,
max_animals=max_animals,
max_instances=max_instances,
seed=seed,
size=size,
)
return LegCountingDataset(config)
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig)

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from random import Random
from typing import List, Optional, Tuple
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
@dataclass
@ -66,17 +66,4 @@ class PrimeFactorizationDataset(ProceduralDataset):
}
def prime_factorization_dataset(
min_value: int = 2,
max_value: int = 1000,
seed: Optional[int] = None,
size: int = 500,
) -> PrimeFactorizationDataset:
"""Create a PrimeFactorizationDataset with the given configuration."""
config = PrimeFactorizationConfig(
min_value=min_value,
max_value=max_value,
seed=seed,
size=size,
)
return PrimeFactorizationDataset(config)
register_dataset("prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig)