mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
add reasoning_gym.create_dataset({name}, ...) global factory function
This commit is contained in:
parent
0d2d8ba6a0
commit
519e411fa5
35 changed files with 133 additions and 598 deletions
|
|
@ -6,17 +6,13 @@ Arithmetic tasks for training reasoning capabilities:
|
|||
- Leg counting
|
||||
"""
|
||||
|
||||
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
|
||||
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
|
||||
from .fraction_simplification import (
|
||||
FractionSimplificationConfig,
|
||||
FractionSimplificationDataset,
|
||||
fraction_simplification_dataset,
|
||||
)
|
||||
from .gcd import GCDConfig, GCDDataset, gcd_dataset
|
||||
from .lcm import LCMConfig, LCMDataset, lcm_dataset
|
||||
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
|
||||
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset, prime_factorization_dataset
|
||||
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
||||
from .chain_sum import ChainSum, ChainSumConfig
|
||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
||||
from .gcd import GCDConfig, GCDDataset
|
||||
from .lcm import LCMConfig, LCMDataset
|
||||
from .leg_counting import LegCountingConfig, LegCountingDataset
|
||||
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
||||
|
||||
__all__ = [
|
||||
"BasicArithmeticDataset",
|
||||
|
|
@ -24,20 +20,14 @@ __all__ = [
|
|||
"basic_arithmetic_dataset",
|
||||
"ChainSum",
|
||||
"ChainSumConfig",
|
||||
"chain_sum_dataset",
|
||||
"FractionSimplificationConfig",
|
||||
"FractionSimplificationDataset",
|
||||
"fraction_simplification_dataset",
|
||||
"GCDConfig",
|
||||
"GCDDataset",
|
||||
"gcd_dataset",
|
||||
"LCMConfig",
|
||||
"LCMDataset",
|
||||
"lcm_dataset",
|
||||
"LegCountingConfig",
|
||||
"LegCountingDataset",
|
||||
"leg_counting_dataset",
|
||||
"PrimeFactorizationConfig",
|
||||
"PrimeFactorizationDataset",
|
||||
"prime_factorization_dataset",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -231,47 +231,5 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
return rng.choice(templates).format(expression)
|
||||
|
||||
|
||||
def basic_arithmetic_dataset(
|
||||
min_terms: int = 2,
|
||||
max_terms: int = 6,
|
||||
min_digits: int = 1,
|
||||
max_digits: int = 4,
|
||||
operators: list[str] = ("+", "-", "*", "/"),
|
||||
allow_parentheses: bool = True,
|
||||
allow_negation: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
format_style: Literal["simple", "natural"] = "simple",
|
||||
whitespace: Literal["no_space", "single", "random"] = "single",
|
||||
) -> BasicArithmeticDataset:
|
||||
"""Create a BasicArithmeticDataset with the given configuration.
|
||||
|
||||
Args:
|
||||
min_terms: Minimum number of terms in expressions
|
||||
max_terms: Maximum number of terms in expressions
|
||||
min_digits: Minimum number of digits in numbers
|
||||
max_digits: Maximum number of digits in numbers
|
||||
operators: List of operators to use ("+", "-", "*")
|
||||
allow_parentheses: Whether to allow parentheses in expressions
|
||||
allow_negation: Whether to allow negative numbers
|
||||
seed: Random seed for reproducibility
|
||||
size: Virtual size of the dataset
|
||||
format_style: Style of question formatting ("simple" or "natural")
|
||||
|
||||
Returns:
|
||||
BasicArithmeticDataset: Configured dataset instance
|
||||
"""
|
||||
config = BasicArithmeticDatasetConfig(
|
||||
min_terms=min_terms,
|
||||
max_terms=max_terms,
|
||||
min_digits=min_digits,
|
||||
max_digits=max_digits,
|
||||
operators=operators,
|
||||
allow_parentheses=allow_parentheses,
|
||||
allow_negation=allow_negation,
|
||||
seed=seed,
|
||||
size=size,
|
||||
format_style=format_style,
|
||||
whitespace=whitespace,
|
||||
)
|
||||
return BasicArithmeticDataset(config)
|
||||
# Register the dataset
|
||||
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ import random
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
from ..factory import register_dataset
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -109,40 +108,5 @@ class ChainSum(ProceduralDataset):
|
|||
return expression, result
|
||||
|
||||
|
||||
def chain_sum_dataset(
|
||||
min_terms: int = 2,
|
||||
max_terms: int = 6,
|
||||
min_digits: int = 1,
|
||||
max_digits: int = 4,
|
||||
allow_negation: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
) -> ChainSum:
|
||||
"""Create a ChainSum dataset with the given configuration.
|
||||
|
||||
Args:
|
||||
min_terms: Minimum number of terms in expressions
|
||||
max_terms: Maximum number of terms in expressions
|
||||
min_digits: Minimum number of digits in numbers
|
||||
max_digits: Maximum number of digits in numbers
|
||||
allow_negation: Whether to allow negative numbers
|
||||
seed: Random seed for reproducibility
|
||||
size: Virtual size of the dataset
|
||||
|
||||
Returns:
|
||||
ChainSum: Configured dataset instance
|
||||
"""
|
||||
config = ChainSumConfig(
|
||||
min_terms=min_terms,
|
||||
max_terms=max_terms,
|
||||
min_digits=min_digits,
|
||||
max_digits=max_digits,
|
||||
allow_negation=allow_negation,
|
||||
seed=seed,
|
||||
size=size,
|
||||
)
|
||||
return ChainSum(config)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("chain_sum", ChainSum, ChainSumConfig)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from math import gcd
|
|||
from random import Random
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -121,23 +121,4 @@ class FractionSimplificationDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
|
||||
def fraction_simplification_dataset(
|
||||
min_value: int = 1,
|
||||
max_value: int = 100,
|
||||
min_factor: int = 2,
|
||||
max_factor: int = 10,
|
||||
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac"),
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
) -> FractionSimplificationDataset:
|
||||
"""Create a FractionSimplificationDataset with the given configuration."""
|
||||
config = FractionSimplificationConfig(
|
||||
min_value=min_value,
|
||||
max_value=max_value,
|
||||
min_factor=min_factor,
|
||||
max_factor=max_factor,
|
||||
styles=styles,
|
||||
seed=seed,
|
||||
size=size,
|
||||
)
|
||||
return FractionSimplificationDataset(config)
|
||||
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from math import gcd
|
|||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -63,21 +63,4 @@ class GCDDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
|
||||
def gcd_dataset(
|
||||
min_numbers: int = 2,
|
||||
max_numbers: int = 2,
|
||||
min_value: int = 1,
|
||||
max_value: int = 10_000,
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
) -> GCDDataset:
|
||||
"""Create a GCDDataset with the given configuration."""
|
||||
config = GCDConfig(
|
||||
min_numbers=min_numbers,
|
||||
max_numbers=max_numbers,
|
||||
min_value=min_value,
|
||||
max_value=max_value,
|
||||
seed=seed,
|
||||
size=size,
|
||||
)
|
||||
return GCDDataset(config)
|
||||
register_dataset("gcd", GCDDataset, GCDConfig)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from math import lcm
|
|||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -66,21 +66,4 @@ class LCMDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
|
||||
def lcm_dataset(
|
||||
min_numbers: int = 2,
|
||||
max_numbers: int = 2,
|
||||
min_value: int = 1,
|
||||
max_value: int = 100,
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
) -> LCMDataset:
|
||||
"""Create a LCMDataset with the given configuration."""
|
||||
config = LCMConfig(
|
||||
min_numbers=min_numbers,
|
||||
max_numbers=max_numbers,
|
||||
min_value=min_value,
|
||||
max_value=max_value,
|
||||
seed=seed,
|
||||
size=size,
|
||||
)
|
||||
return LCMDataset(config)
|
||||
register_dataset("lcm", LCMDataset, LCMConfig)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
ANIMALS = {
|
||||
# Animals with 0 legs
|
||||
|
|
@ -115,19 +115,4 @@ class LegCountingDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
|
||||
def leg_counting_dataset(
|
||||
min_animals: int = 2,
|
||||
max_animals: int = 5,
|
||||
max_instances: int = 3,
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
) -> LegCountingDataset:
|
||||
"""Create a LegCountingDataset with the given configuration."""
|
||||
config = LegCountingConfig(
|
||||
min_animals=min_animals,
|
||||
max_animals=max_animals,
|
||||
max_instances=max_instances,
|
||||
seed=seed,
|
||||
size=size,
|
||||
)
|
||||
return LegCountingDataset(config)
|
||||
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -66,17 +66,4 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
|
||||
def prime_factorization_dataset(
|
||||
min_value: int = 2,
|
||||
max_value: int = 1000,
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
) -> PrimeFactorizationDataset:
|
||||
"""Create a PrimeFactorizationDataset with the given configuration."""
|
||||
config = PrimeFactorizationConfig(
|
||||
min_value=min_value,
|
||||
max_value=max_value,
|
||||
seed=seed,
|
||||
size=size,
|
||||
)
|
||||
return PrimeFactorizationDataset(config)
|
||||
register_dataset("prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue