mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
add reasoning_gym.create_dataset({name}, ...) global factory function
This commit is contained in:
parent
0d2d8ba6a0
commit
519e411fa5
35 changed files with 133 additions and 598 deletions
|
|
@ -3,6 +3,7 @@ Reasoning Gym - A library of procedural dataset generators for training reasonin
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from . import algebra, algorithmic, arithmetic, cognition, data, games, graphs, logic
|
from . import algebra, algorithmic, arithmetic, cognition, data, games, graphs, logic
|
||||||
|
from .factory import create_dataset, register_dataset
|
||||||
|
|
||||||
__version__ = "0.1.1"
|
__version__ = "0.1.1"
|
||||||
__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"]
|
__all__ = ["arithmetic", "algorithmic", "algebra", "cognition", "data", "games", "graphs", "logic"]
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset, polynomial_equations_dataset
|
from .polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset
|
||||||
from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset, simple_equations_dataset
|
from .simple_equations import SimpleEquationsConfig, SimpleEquationsDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SimpleEquationsDataset",
|
"SimpleEquationsDataset",
|
||||||
"SimpleEquationsConfig",
|
"SimpleEquationsConfig",
|
||||||
"simple_equations_dataset",
|
|
||||||
"PolynomialEquationsConfig",
|
"PolynomialEquationsConfig",
|
||||||
"PolynomialEquationsDataset",
|
"PolynomialEquationsDataset",
|
||||||
"polynomial_equations_dataset",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
from sympy import Eq, Symbol, expand, solve
|
from sympy import Eq, Symbol, expand, solve
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -147,31 +147,4 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
||||||
return polynomial_expr
|
return polynomial_expr
|
||||||
|
|
||||||
|
|
||||||
def polynomial_equations_dataset(
|
register_dataset("polynomial_equations", PolynomialEquationsDataset, PolynomialEquationsConfig)
|
||||||
min_terms: int = 2,
|
|
||||||
max_terms: int = 4,
|
|
||||||
min_value: int = 1,
|
|
||||||
max_value: int = 100,
|
|
||||||
min_degree: int = 1,
|
|
||||||
max_degree: int = 3,
|
|
||||||
operators: Tuple[str, ...] = ("+", "-"),
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> PolynomialEquationsDataset:
|
|
||||||
"""
|
|
||||||
Factory function for creating a PolynomialEquationsDataset.
|
|
||||||
Example usage:
|
|
||||||
dataset = polynomial_equations_dataset(min_degree=2, max_degree=3, ...)
|
|
||||||
"""
|
|
||||||
config = PolynomialEquationsConfig(
|
|
||||||
min_terms=min_terms,
|
|
||||||
max_terms=max_terms,
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
min_degree=min_degree,
|
|
||||||
max_degree=max_degree,
|
|
||||||
operators=operators,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return PolynomialEquationsDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Optional, Tuple
|
||||||
import sympy
|
import sympy
|
||||||
from sympy import Eq, Symbol, solve
|
from sympy import Eq, Symbol, solve
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -116,23 +116,4 @@ class SimpleEquationsDataset(ProceduralDataset):
|
||||||
return f"{left_side} = {right_side}", solution_value
|
return f"{left_side} = {right_side}", solution_value
|
||||||
|
|
||||||
|
|
||||||
def simple_equations_dataset(
|
register_dataset("simple_equations", SimpleEquationsDataset, SimpleEquationsConfig)
|
||||||
min_terms: int = 2,
|
|
||||||
max_terms: int = 5,
|
|
||||||
min_value: int = 1,
|
|
||||||
max_value: int = 100,
|
|
||||||
operators: tuple = ("+", "-", "*"),
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> SimpleEquationsDataset:
|
|
||||||
"""Create a SimpleEquationsDataset with the given configuration"""
|
|
||||||
config = SimpleEquationsConfig(
|
|
||||||
min_terms=min_terms,
|
|
||||||
max_terms=max_terms,
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
operators=operators,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return SimpleEquationsDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -6,31 +6,21 @@ Algorithmic tasks for training reasoning capabilities:
|
||||||
- Pattern matching
|
- Pattern matching
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from reasoning_gym.arithmetic.basic_arithmetic import basic_arithmetic_dataset
|
from .base_conversion import BaseConversionConfig, BaseConversionDataset
|
||||||
from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset
|
from .letter_counting import LetterCountingConfig, LetterCountingDataset
|
||||||
|
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||||
from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset
|
from .number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||||
from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset
|
from .word_reversal import WordReversalConfig, WordReversalDataset
|
||||||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset
|
|
||||||
from .number_sorting import NumberSortingConfig, NumberSortingDataset, number_sorting_dataset
|
|
||||||
from .word_reversal import WordReversalConfig, WordReversalDataset, word_reversal_dataset
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"basic_arithmetic_dataset",
|
|
||||||
"BaseConversionConfig",
|
"BaseConversionConfig",
|
||||||
"BaseConversionDataset",
|
"BaseConversionDataset",
|
||||||
"base_conversion_dataset",
|
|
||||||
"chain_sum_dataset",
|
|
||||||
"LetterCountingConfig",
|
"LetterCountingConfig",
|
||||||
"LetterCountingDataset",
|
"LetterCountingDataset",
|
||||||
"letter_counting_dataset",
|
|
||||||
"NumberFilteringConfig",
|
"NumberFilteringConfig",
|
||||||
"NumberFilteringDataset",
|
"NumberFilteringDataset",
|
||||||
"number_filtering_dataset",
|
|
||||||
"NumberSortingConfig",
|
"NumberSortingConfig",
|
||||||
"NumberSortingDataset",
|
"NumberSortingDataset",
|
||||||
"number_sorting_dataset",
|
|
||||||
"WordReversalConfig",
|
"WordReversalConfig",
|
||||||
"WordReversalDataset",
|
"WordReversalDataset",
|
||||||
"word_reversal_dataset",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -88,21 +88,4 @@ class BaseConversionDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def base_conversion_dataset(
|
register_dataset("base_conversion", BaseConversionDataset, BaseConversionConfig)
|
||||||
min_base: int = 2,
|
|
||||||
max_base: int = 16,
|
|
||||||
min_value: int = 0,
|
|
||||||
max_value: int = 1000,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> BaseConversionDataset:
|
|
||||||
"""Create a BaseConversionDataset with the given configuration."""
|
|
||||||
config = BaseConversionConfig(
|
|
||||||
min_base=min_base,
|
|
||||||
max_base=max_base,
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return BaseConversionDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from reasoning_gym.data import read_data_file
|
from reasoning_gym.data import read_data_file
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -63,17 +63,4 @@ class LetterCountingDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def letter_counting_dataset(
|
register_dataset("letter_counting", LetterCountingDataset, LetterCountingConfig)
|
||||||
min_words: int = 5,
|
|
||||||
max_words: int = 15,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> LetterCountingDataset:
|
|
||||||
"""Create a LetterCountingDataset with the given configuration."""
|
|
||||||
config = LetterCountingConfig(
|
|
||||||
min_words=min_words,
|
|
||||||
max_words=max_words,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return LetterCountingDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -98,25 +98,4 @@ class NumberFilteringDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def number_filtering_dataset(
|
register_dataset("number_filtering", NumberFilteringDataset, NumberFilteringConfig)
|
||||||
min_numbers: int = 3,
|
|
||||||
max_numbers: int = 10,
|
|
||||||
min_decimals: int = 0,
|
|
||||||
max_decimals: int = 4,
|
|
||||||
min_value: float = -100.0,
|
|
||||||
max_value: float = 100.0,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> NumberFilteringDataset:
|
|
||||||
"""Create a NumberFilteringDataset with the given configuration."""
|
|
||||||
config = NumberFilteringConfig(
|
|
||||||
min_numbers=min_numbers,
|
|
||||||
max_numbers=max_numbers,
|
|
||||||
min_decimals=min_decimals,
|
|
||||||
max_decimals=max_decimals,
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return NumberFilteringDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -86,25 +86,4 @@ class NumberSortingDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def number_sorting_dataset(
|
register_dataset("number_sorting", NumberSortingDataset, NumberSortingConfig)
|
||||||
min_numbers: int = 3,
|
|
||||||
max_numbers: int = 10,
|
|
||||||
min_decimals: int = 0,
|
|
||||||
max_decimals: int = 2,
|
|
||||||
min_value: float = -100.0,
|
|
||||||
max_value: float = 100.0,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> NumberSortingDataset:
|
|
||||||
"""Create a NumberSortingDataset with the given configuration."""
|
|
||||||
config = NumberSortingConfig(
|
|
||||||
min_numbers=min_numbers,
|
|
||||||
max_numbers=max_numbers,
|
|
||||||
min_decimals=min_decimals,
|
|
||||||
max_decimals=max_decimals,
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return NumberSortingDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from random import Random
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ..data import read_data_file
|
from ..data import read_data_file
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -55,17 +55,4 @@ class WordReversalDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def word_reversal_dataset(
|
register_dataset("word_reversal", WordReversalDataset, WordReversalConfig)
|
||||||
min_words: int = 3,
|
|
||||||
max_words: int = 8,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> WordReversalDataset:
|
|
||||||
"""Create a WordReversalDataset with the given configuration."""
|
|
||||||
config = WordReversalConfig(
|
|
||||||
min_words=min_words,
|
|
||||||
max_words=max_words,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return WordReversalDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -6,17 +6,13 @@ Arithmetic tasks for training reasoning capabilities:
|
||||||
- Leg counting
|
- Leg counting
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig, basic_arithmetic_dataset
|
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
||||||
from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset
|
from .chain_sum import ChainSum, ChainSumConfig
|
||||||
from .fraction_simplification import (
|
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
||||||
FractionSimplificationConfig,
|
from .gcd import GCDConfig, GCDDataset
|
||||||
FractionSimplificationDataset,
|
from .lcm import LCMConfig, LCMDataset
|
||||||
fraction_simplification_dataset,
|
from .leg_counting import LegCountingConfig, LegCountingDataset
|
||||||
)
|
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset
|
||||||
from .gcd import GCDConfig, GCDDataset, gcd_dataset
|
|
||||||
from .lcm import LCMConfig, LCMDataset, lcm_dataset
|
|
||||||
from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset
|
|
||||||
from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset, prime_factorization_dataset
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BasicArithmeticDataset",
|
"BasicArithmeticDataset",
|
||||||
|
|
@ -24,20 +20,14 @@ __all__ = [
|
||||||
"basic_arithmetic_dataset",
|
"basic_arithmetic_dataset",
|
||||||
"ChainSum",
|
"ChainSum",
|
||||||
"ChainSumConfig",
|
"ChainSumConfig",
|
||||||
"chain_sum_dataset",
|
|
||||||
"FractionSimplificationConfig",
|
"FractionSimplificationConfig",
|
||||||
"FractionSimplificationDataset",
|
"FractionSimplificationDataset",
|
||||||
"fraction_simplification_dataset",
|
|
||||||
"GCDConfig",
|
"GCDConfig",
|
||||||
"GCDDataset",
|
"GCDDataset",
|
||||||
"gcd_dataset",
|
|
||||||
"LCMConfig",
|
"LCMConfig",
|
||||||
"LCMDataset",
|
"LCMDataset",
|
||||||
"lcm_dataset",
|
|
||||||
"LegCountingConfig",
|
"LegCountingConfig",
|
||||||
"LegCountingDataset",
|
"LegCountingDataset",
|
||||||
"leg_counting_dataset",
|
|
||||||
"PrimeFactorizationConfig",
|
"PrimeFactorizationConfig",
|
||||||
"PrimeFactorizationDataset",
|
"PrimeFactorizationDataset",
|
||||||
"prime_factorization_dataset",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -231,47 +231,5 @@ class BasicArithmeticDataset(ProceduralDataset):
|
||||||
return rng.choice(templates).format(expression)
|
return rng.choice(templates).format(expression)
|
||||||
|
|
||||||
|
|
||||||
def basic_arithmetic_dataset(
|
# Register the dataset
|
||||||
min_terms: int = 2,
|
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig)
|
||||||
max_terms: int = 6,
|
|
||||||
min_digits: int = 1,
|
|
||||||
max_digits: int = 4,
|
|
||||||
operators: list[str] = ("+", "-", "*", "/"),
|
|
||||||
allow_parentheses: bool = True,
|
|
||||||
allow_negation: bool = True,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
format_style: Literal["simple", "natural"] = "simple",
|
|
||||||
whitespace: Literal["no_space", "single", "random"] = "single",
|
|
||||||
) -> BasicArithmeticDataset:
|
|
||||||
"""Create a BasicArithmeticDataset with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
min_terms: Minimum number of terms in expressions
|
|
||||||
max_terms: Maximum number of terms in expressions
|
|
||||||
min_digits: Minimum number of digits in numbers
|
|
||||||
max_digits: Maximum number of digits in numbers
|
|
||||||
operators: List of operators to use ("+", "-", "*")
|
|
||||||
allow_parentheses: Whether to allow parentheses in expressions
|
|
||||||
allow_negation: Whether to allow negative numbers
|
|
||||||
seed: Random seed for reproducibility
|
|
||||||
size: Virtual size of the dataset
|
|
||||||
format_style: Style of question formatting ("simple" or "natural")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BasicArithmeticDataset: Configured dataset instance
|
|
||||||
"""
|
|
||||||
config = BasicArithmeticDatasetConfig(
|
|
||||||
min_terms=min_terms,
|
|
||||||
max_terms=max_terms,
|
|
||||||
min_digits=min_digits,
|
|
||||||
max_digits=max_digits,
|
|
||||||
operators=operators,
|
|
||||||
allow_parentheses=allow_parentheses,
|
|
||||||
allow_negation=allow_negation,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
format_style=format_style,
|
|
||||||
whitespace=whitespace,
|
|
||||||
)
|
|
||||||
return BasicArithmeticDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,7 @@ import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
from ..factory import register_dataset
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -109,40 +108,5 @@ class ChainSum(ProceduralDataset):
|
||||||
return expression, result
|
return expression, result
|
||||||
|
|
||||||
|
|
||||||
def chain_sum_dataset(
|
|
||||||
min_terms: int = 2,
|
|
||||||
max_terms: int = 6,
|
|
||||||
min_digits: int = 1,
|
|
||||||
max_digits: int = 4,
|
|
||||||
allow_negation: bool = False,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> ChainSum:
|
|
||||||
"""Create a ChainSum dataset with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
min_terms: Minimum number of terms in expressions
|
|
||||||
max_terms: Maximum number of terms in expressions
|
|
||||||
min_digits: Minimum number of digits in numbers
|
|
||||||
max_digits: Maximum number of digits in numbers
|
|
||||||
allow_negation: Whether to allow negative numbers
|
|
||||||
seed: Random seed for reproducibility
|
|
||||||
size: Virtual size of the dataset
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChainSum: Configured dataset instance
|
|
||||||
"""
|
|
||||||
config = ChainSumConfig(
|
|
||||||
min_terms=min_terms,
|
|
||||||
max_terms=max_terms,
|
|
||||||
min_digits=min_digits,
|
|
||||||
max_digits=max_digits,
|
|
||||||
allow_negation=allow_negation,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return ChainSum(config)
|
|
||||||
|
|
||||||
|
|
||||||
# Register the dataset
|
# Register the dataset
|
||||||
register_dataset("chain_sum", ChainSum, ChainSumConfig)
|
register_dataset("chain_sum", ChainSum, ChainSumConfig)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from math import gcd
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Optional, Sequence, Tuple
|
from typing import Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -121,23 +121,4 @@ class FractionSimplificationDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def fraction_simplification_dataset(
|
register_dataset("fraction_simplification", FractionSimplificationDataset, FractionSimplificationConfig)
|
||||||
min_value: int = 1,
|
|
||||||
max_value: int = 100,
|
|
||||||
min_factor: int = 2,
|
|
||||||
max_factor: int = 10,
|
|
||||||
styles: Sequence[str] = ("plain", "latex_inline", "latex_frac", "latex_dfrac"),
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> FractionSimplificationDataset:
|
|
||||||
"""Create a FractionSimplificationDataset with the given configuration."""
|
|
||||||
config = FractionSimplificationConfig(
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
min_factor=min_factor,
|
|
||||||
max_factor=max_factor,
|
|
||||||
styles=styles,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return FractionSimplificationDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from math import gcd
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -63,21 +63,4 @@ class GCDDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def gcd_dataset(
|
register_dataset("gcd", GCDDataset, GCDConfig)
|
||||||
min_numbers: int = 2,
|
|
||||||
max_numbers: int = 2,
|
|
||||||
min_value: int = 1,
|
|
||||||
max_value: int = 10_000,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> GCDDataset:
|
|
||||||
"""Create a GCDDataset with the given configuration."""
|
|
||||||
config = GCDConfig(
|
|
||||||
min_numbers=min_numbers,
|
|
||||||
max_numbers=max_numbers,
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return GCDDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from math import lcm
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -66,21 +66,4 @@ class LCMDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def lcm_dataset(
|
register_dataset("lcm", LCMDataset, LCMConfig)
|
||||||
min_numbers: int = 2,
|
|
||||||
max_numbers: int = 2,
|
|
||||||
min_value: int = 1,
|
|
||||||
max_value: int = 100,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> LCMDataset:
|
|
||||||
"""Create a LCMDataset with the given configuration."""
|
|
||||||
config = LCMConfig(
|
|
||||||
min_numbers=min_numbers,
|
|
||||||
max_numbers=max_numbers,
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return LCMDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
ANIMALS = {
|
ANIMALS = {
|
||||||
# Animals with 0 legs
|
# Animals with 0 legs
|
||||||
|
|
@ -115,19 +115,4 @@ class LegCountingDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def leg_counting_dataset(
|
register_dataset("leg_counting", LegCountingDataset, LegCountingConfig)
|
||||||
min_animals: int = 2,
|
|
||||||
max_animals: int = 5,
|
|
||||||
max_instances: int = 3,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> LegCountingDataset:
|
|
||||||
"""Create a LegCountingDataset with the given configuration."""
|
|
||||||
config = LegCountingConfig(
|
|
||||||
min_animals=min_animals,
|
|
||||||
max_animals=max_animals,
|
|
||||||
max_instances=max_instances,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return LegCountingDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -66,17 +66,4 @@ class PrimeFactorizationDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def prime_factorization_dataset(
|
register_dataset("prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig)
|
||||||
min_value: int = 2,
|
|
||||||
max_value: int = 1000,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> PrimeFactorizationDataset:
|
|
||||||
"""Create a PrimeFactorizationDataset with the given configuration."""
|
|
||||||
config = PrimeFactorizationConfig(
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return PrimeFactorizationDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -6,14 +6,12 @@ Cognition tasks for training reasoning capabilities:
|
||||||
- Working memory
|
- Working memory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset, color_cube_rotation_dataset
|
from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset
|
||||||
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset, number_sequence_dataset
|
from .number_sequences import NumberSequenceConfig, NumberSequenceDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"NumberSequenceConfig",
|
"NumberSequenceConfig",
|
||||||
"NumberSequenceDataset",
|
"NumberSequenceDataset",
|
||||||
"number_sequence_dataset",
|
|
||||||
"ColorCubeRotationConfig",
|
"ColorCubeRotationConfig",
|
||||||
"ColorCubeRotationDataset",
|
"ColorCubeRotationDataset",
|
||||||
"color_cube_rotation_dataset",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
class Color(StrEnum):
|
class Color(StrEnum):
|
||||||
|
|
@ -189,17 +189,4 @@ class ColorCubeRotationDataset(ProceduralDataset):
|
||||||
return "\n".join(story_parts)
|
return "\n".join(story_parts)
|
||||||
|
|
||||||
|
|
||||||
def color_cube_rotation_dataset(
|
register_dataset("color_cube_rotation", ColorCubeRotationDataset, ColorCubeRotationConfig)
|
||||||
min_rotations: int = 1,
|
|
||||||
max_rotations: int = 3,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> ColorCubeRotationDataset:
|
|
||||||
"""Create a ColorCubeRotationDataset with the given configuration"""
|
|
||||||
config = ColorCubeRotationConfig(
|
|
||||||
min_rotations=min_rotations,
|
|
||||||
max_rotations=max_rotations,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return ColorCubeRotationDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from enum import StrEnum
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
class Operation(StrEnum):
|
class Operation(StrEnum):
|
||||||
|
|
@ -198,23 +198,4 @@ class NumberSequenceDataset(ProceduralDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def number_sequence_dataset(
|
register_dataset("number_sequence", NumberSequenceDataset, NumberSequenceConfig)
|
||||||
min_terms: int = 4,
|
|
||||||
max_terms: int = 8,
|
|
||||||
min_value: int = -100,
|
|
||||||
max_value: int = 100,
|
|
||||||
max_complexity: int = 3,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> NumberSequenceDataset:
|
|
||||||
"""Create a NumberSequenceDataset with the given configuration."""
|
|
||||||
config = NumberSequenceConfig(
|
|
||||||
min_terms=min_terms,
|
|
||||||
max_terms=max_terms,
|
|
||||||
min_value=min_value,
|
|
||||||
max_value=max_value,
|
|
||||||
max_complexity=max_complexity,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return NumberSequenceDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -1,75 +1,58 @@
|
||||||
from dataclasses import is_dataclass
|
from dataclasses import is_dataclass
|
||||||
from typing import Any, Dict, Type, TypeVar
|
from typing import Dict, Type, TypeVar
|
||||||
|
|
||||||
from .dataset import ProceduralDataset
|
from .dataset import ProceduralDataset
|
||||||
|
|
||||||
# Type variables for generic type hints
|
# Type variables for generic type hints
|
||||||
ConfigT = TypeVar('ConfigT')
|
ConfigT = TypeVar("ConfigT")
|
||||||
DatasetT = TypeVar('DatasetT', bound=ProceduralDataset)
|
DatasetT = TypeVar("DatasetT", bound=ProceduralDataset)
|
||||||
|
|
||||||
# Global registry of datasets
|
# Global registry of datasets
|
||||||
_DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {}
|
_DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {}
|
||||||
|
|
||||||
def register_dataset(
|
|
||||||
name: str,
|
def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None:
|
||||||
dataset_cls: Type[DatasetT],
|
|
||||||
config_cls: Type[ConfigT]
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Register a dataset class with its configuration class.
|
Register a dataset class with its configuration class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Unique identifier for the dataset
|
name: Unique identifier for the dataset
|
||||||
dataset_cls: Class derived from ProceduralDataset
|
dataset_cls: Class derived from ProceduralDataset
|
||||||
config_cls: Configuration dataclass for the dataset
|
config_cls: Configuration dataclass for the dataset
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If name is already registered or invalid types provided
|
ValueError: If name is already registered or invalid types provided
|
||||||
"""
|
"""
|
||||||
if name in _DATASETS:
|
if name in _DATASETS:
|
||||||
raise ValueError(f"Dataset '{name}' is already registered")
|
raise ValueError(f"Dataset '{name}' is already registered")
|
||||||
|
|
||||||
if not issubclass(dataset_cls, ProceduralDataset):
|
if not issubclass(dataset_cls, ProceduralDataset):
|
||||||
raise ValueError(
|
raise ValueError(f"Dataset class must inherit from ProceduralDataset, got {dataset_cls}")
|
||||||
f"Dataset class must inherit from ProceduralDataset, got {dataset_cls}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_dataclass(config_cls):
|
if not is_dataclass(config_cls):
|
||||||
raise ValueError(
|
raise ValueError(f"Config class must be a dataclass, got {config_cls}")
|
||||||
f"Config class must be a dataclass, got {config_cls}"
|
|
||||||
)
|
|
||||||
|
|
||||||
_DATASETS[name] = (dataset_cls, config_cls)
|
_DATASETS[name] = (dataset_cls, config_cls)
|
||||||
|
|
||||||
def create_dataset(
|
|
||||||
name: str,
|
def create_dataset(name: str, **kwargs) -> ProceduralDataset:
|
||||||
config: Any,
|
|
||||||
seed: int = None,
|
|
||||||
size: int = 500
|
|
||||||
) -> ProceduralDataset:
|
|
||||||
"""
|
"""
|
||||||
Create a dataset instance by name with the given configuration.
|
Create a dataset instance by name with the given configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Registered dataset name
|
name: Registered dataset name
|
||||||
config: Configuration instance for the dataset
|
|
||||||
seed: Optional random seed
|
|
||||||
size: Size of the dataset (default: 500)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configured dataset instance
|
Configured dataset instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If dataset not found or config type mismatch
|
ValueError: If dataset not found or config type mismatch
|
||||||
"""
|
"""
|
||||||
if name not in _DATASETS:
|
if name not in _DATASETS:
|
||||||
raise ValueError(f"Dataset '{name}' not found")
|
raise ValueError(f"Dataset '{name}' not registered")
|
||||||
|
|
||||||
dataset_cls, config_cls = _DATASETS[name]
|
dataset_cls, config_cls = _DATASETS[name]
|
||||||
|
|
||||||
if not isinstance(config, config_cls):
|
conifg = config_cls(**kwargs)
|
||||||
raise ValueError(
|
|
||||||
f"Config must be instance of {config_cls.__name__}, got {type(config).__name__}"
|
return dataset_cls(config=conifg)
|
||||||
)
|
|
||||||
|
|
||||||
return dataset_cls(config=config, seed=seed, size=size)
|
|
||||||
|
|
|
||||||
|
|
@ -5,18 +5,15 @@ Game tasks for training reasoning capabilities:
|
||||||
- Strategy games
|
- Strategy games
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .maze import MazeConfig, MazeDataset, maze_dataset
|
from .maze import MazeConfig, MazeDataset
|
||||||
from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset, mini_sudoku_dataset
|
from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
|
||||||
from .sudoku import SudokuConfig, SudokuDataset, sudoku_dataset
|
from .sudoku import SudokuConfig, SudokuDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MiniSudokuConfig",
|
"MiniSudokuConfig",
|
||||||
"MiniSudokuDataset",
|
"MiniSudokuDataset",
|
||||||
"mini_sudoku_dataset",
|
|
||||||
"SudokuConfig",
|
"SudokuConfig",
|
||||||
"SudokuDataset",
|
"SudokuDataset",
|
||||||
"sudoku_dataset",
|
|
||||||
"MazeConfig",
|
"MazeConfig",
|
||||||
"MazeDataset",
|
"MazeDataset",
|
||||||
"maze_dataset",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import string
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -47,7 +47,6 @@ class MazeDataset(ProceduralDataset):
|
||||||
num_retries=1000,
|
num_retries=1000,
|
||||||
):
|
):
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
self.config = config
|
|
||||||
# Probability that a cell is a path instead of a wall
|
# Probability that a cell is a path instead of a wall
|
||||||
self.prob_path = prob_path
|
self.prob_path = prob_path
|
||||||
# Number of times to resample a grid to find a suitable maze before giving up
|
# Number of times to resample a grid to find a suitable maze before giving up
|
||||||
|
|
@ -184,21 +183,4 @@ class MazeDataset(ProceduralDataset):
|
||||||
return "\n".join("".join(row) for row in grid)
|
return "\n".join("".join(row) for row in grid)
|
||||||
|
|
||||||
|
|
||||||
def maze_dataset(
|
register_dataset("maze", MazeDataset, MazeConfig)
|
||||||
min_dist: int = 5,
|
|
||||||
max_dist: int = 10,
|
|
||||||
min_grid_size: int = 5,
|
|
||||||
max_grid_size: int = 10,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 50,
|
|
||||||
) -> MazeDataset:
|
|
||||||
"""Convenient function to create a MazeDataset."""
|
|
||||||
config = MazeConfig(
|
|
||||||
min_dist=min_dist,
|
|
||||||
max_dist=max_dist,
|
|
||||||
min_grid_size=min_grid_size,
|
|
||||||
max_grid_size=max_grid_size,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return MazeDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
"""Mini Sudoku (4x4) puzzle generator"""
|
"""Mini Sudoku (4x4) puzzle generator"""
|
||||||
|
|
||||||
import random
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -21,13 +22,11 @@ class MiniSudokuConfig:
|
||||||
assert self.min_empty <= self.max_empty <= 16, "max_empty must be between min_empty and 16"
|
assert self.min_empty <= self.max_empty <= 16, "max_empty must be between min_empty and 16"
|
||||||
|
|
||||||
|
|
||||||
class MiniSudokuDataset:
|
class MiniSudokuDataset(ProceduralDataset):
|
||||||
"""Generates 4x4 sudoku puzzles with configurable difficulty"""
|
"""Generates 4x4 sudoku puzzles with configurable difficulty"""
|
||||||
|
|
||||||
def __init__(self, config: MiniSudokuConfig):
|
def __init__(self, config: MiniSudokuConfig):
|
||||||
self.config = config
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
self.config.validate()
|
|
||||||
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.config.size
|
return self.config.size
|
||||||
|
|
@ -149,17 +148,4 @@ class MiniSudokuDataset:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def mini_sudoku_dataset(
|
register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig)
|
||||||
min_empty: int = 8,
|
|
||||||
max_empty: int = 12,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> MiniSudokuDataset:
|
|
||||||
"""Create a MiniSudokuDataset with the given configuration."""
|
|
||||||
config = MiniSudokuConfig(
|
|
||||||
min_empty=min_empty,
|
|
||||||
max_empty=max_empty,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return MiniSudokuDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
"""Sudoku puzzle generator"""
|
"""Sudoku puzzle generator"""
|
||||||
|
|
||||||
import random
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -21,13 +22,11 @@ class SudokuConfig:
|
||||||
assert self.min_empty <= self.max_empty <= 81, "max_empty must be between min_empty and 81"
|
assert self.min_empty <= self.max_empty <= 81, "max_empty must be between min_empty and 81"
|
||||||
|
|
||||||
|
|
||||||
class SudokuDataset:
|
class SudokuDataset(ProceduralDataset):
|
||||||
"""Generates sudoku puzzles with configurable difficulty"""
|
"""Generates sudoku puzzles with configurable difficulty"""
|
||||||
|
|
||||||
def __init__(self, config: SudokuConfig):
|
def __init__(self, config: SudokuConfig):
|
||||||
self.config = config
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
self.config.validate()
|
|
||||||
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.config.size
|
return self.config.size
|
||||||
|
|
@ -139,17 +138,4 @@ class SudokuDataset:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def sudoku_dataset(
|
register_dataset("sudoku", SudokuDataset, SudokuConfig)
|
||||||
min_empty: int = 30,
|
|
||||||
max_empty: int = 50,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> SudokuDataset:
|
|
||||||
"""Create a SudokuDataset with the given configuration."""
|
|
||||||
config = SudokuConfig(
|
|
||||||
min_empty=min_empty,
|
|
||||||
max_empty=max_empty,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return SudokuDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset, family_relationships_dataset
|
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
|
||||||
|
|
||||||
__all__ = ["FamilyRelationshipsDataset", "FamilyRelationshipsConfig", "family_relationships_dataset"]
|
__all__ = ["FamilyRelationshipsDataset", "FamilyRelationshipsConfig"]
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from enum import StrEnum
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
class Gender(StrEnum):
|
class Gender(StrEnum):
|
||||||
|
|
@ -310,21 +310,4 @@ class FamilyRelationshipsDataset(ProceduralDataset):
|
||||||
return " ".join(story_parts)
|
return " ".join(story_parts)
|
||||||
|
|
||||||
|
|
||||||
def family_relationships_dataset(
|
register_dataset("family_relationships", FamilyRelationshipsDataset, FamilyRelationshipsConfig)
|
||||||
min_family_size: int = 4,
|
|
||||||
max_family_size: int = 8,
|
|
||||||
male_names: List[str] = None,
|
|
||||||
female_names: List[str] = None,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> FamilyRelationshipsDataset:
|
|
||||||
"""Create a FamilyRelationshipsDataset with the given configuration"""
|
|
||||||
config = FamilyRelationshipsConfig(
|
|
||||||
min_family_size=min_family_size,
|
|
||||||
max_family_size=max_family_size,
|
|
||||||
male_names=male_names,
|
|
||||||
female_names=female_names,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return FamilyRelationshipsDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -6,13 +6,12 @@ Logic tasks for training reasoning capabilities:
|
||||||
- Syllogisms
|
- Syllogisms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset, propositional_logic_dataset
|
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
|
||||||
from .syllogisms import SyllogismConfig, SyllogismDataset, Term, syllogism_dataset
|
from .syllogisms import SyllogismConfig, SyllogismDataset, Term
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PropositionalLogicConfig",
|
"PropositionalLogicConfig",
|
||||||
"PropositionalLogicDataset",
|
"PropositionalLogicDataset",
|
||||||
"propositional_logic_dataset",
|
|
||||||
"SyllogismConfig",
|
"SyllogismConfig",
|
||||||
"SyllogismDataset",
|
"SyllogismDataset",
|
||||||
"syllogism_dataset",
|
"syllogism_dataset",
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ from enum import StrEnum
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Any, List, Optional, Set
|
from typing import Any, List, Optional, Set
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
class Operator(StrEnum):
|
class Operator(StrEnum):
|
||||||
"""Basic logical operators"""
|
"""Basic logical operators"""
|
||||||
|
|
@ -70,13 +72,11 @@ class Expression:
|
||||||
return f"({self.left} {self.operator.value} {self.right})"
|
return f"({self.left} {self.operator.value} {self.right})"
|
||||||
|
|
||||||
|
|
||||||
class PropositionalLogicDataset:
|
class PropositionalLogicDataset(ProceduralDataset):
|
||||||
"""Generates propositional logic reasoning tasks"""
|
"""Generates propositional logic reasoning tasks"""
|
||||||
|
|
||||||
def __init__(self, config: PropositionalLogicConfig):
|
def __init__(self, config: PropositionalLogicConfig):
|
||||||
self.config = config
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
self.config.validate()
|
|
||||||
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.config.size
|
return self.config.size
|
||||||
|
|
@ -199,23 +199,4 @@ class PropositionalLogicDataset:
|
||||||
return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right)
|
return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right)
|
||||||
|
|
||||||
|
|
||||||
def propositional_logic_dataset(
|
register_dataset("propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig)
|
||||||
min_vars: int = 2,
|
|
||||||
max_vars: int = 4,
|
|
||||||
min_statements: int = 2,
|
|
||||||
max_statements: int = 4,
|
|
||||||
max_complexity: int = 3,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> PropositionalLogicDataset:
|
|
||||||
"""Create a PropositionalLogicDataset with the given configuration."""
|
|
||||||
config = PropositionalLogicConfig(
|
|
||||||
min_vars=min_vars,
|
|
||||||
max_vars=max_vars,
|
|
||||||
min_statements=min_statements,
|
|
||||||
max_statements=max_statements,
|
|
||||||
max_complexity=max_complexity,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return PropositionalLogicDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from enum import Enum
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..dataset import ProceduralDataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
class Quantifier(Enum):
|
class Quantifier(Enum):
|
||||||
|
|
@ -256,27 +256,4 @@ class SyllogismDataset(ProceduralDataset):
|
||||||
return self._generate_syllogism(rng)
|
return self._generate_syllogism(rng)
|
||||||
|
|
||||||
|
|
||||||
def syllogism_dataset(
|
register_dataset("syllogism", SyllogismDataset, SyllogismConfig)
|
||||||
terms: List[Term] = None,
|
|
||||||
allow_all: bool = True,
|
|
||||||
allow_no: bool = True,
|
|
||||||
allow_some: bool = True,
|
|
||||||
allow_some_not: bool = True,
|
|
||||||
include_invalid: bool = True,
|
|
||||||
invalid_ratio: float = 0.3,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
size: int = 500,
|
|
||||||
) -> SyllogismDataset:
|
|
||||||
"""Create a SyllogismDataset with the given configuration."""
|
|
||||||
config = SyllogismConfig(
|
|
||||||
terms=terms,
|
|
||||||
allow_all=allow_all,
|
|
||||||
allow_no=allow_no,
|
|
||||||
allow_some=allow_some,
|
|
||||||
allow_some_not=allow_some_not,
|
|
||||||
include_invalid=include_invalid,
|
|
||||||
invalid_ratio=invalid_ratio,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
return SyllogismDataset(config)
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.cognition.color_cube_rotation import Color, Cube, Side, color_cube_rotation_dataset
|
from reasoning_gym import create_dataset
|
||||||
|
from reasoning_gym.cognition.color_cube_rotation import Color, ColorCubeRotationDataset, Cube, Side
|
||||||
|
|
||||||
|
|
||||||
def test_color_cube_rotation_generation():
|
def test_color_cube_rotation_generation():
|
||||||
dataset = color_cube_rotation_dataset(seed=42, size=10)
|
dataset = create_dataset("color_cube_rotation", seed=42, size=10)
|
||||||
|
assert isinstance(dataset, ColorCubeRotationDataset)
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
# Check required keys exist
|
# Check required keys exist
|
||||||
|
|
@ -33,15 +35,15 @@ def test_color_cube_rotation_generation():
|
||||||
def test_color_cube_rotation_config():
|
def test_color_cube_rotation_config():
|
||||||
# Test invalid config raises assertion
|
# Test invalid config raises assertion
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = color_cube_rotation_dataset(min_rotations=0)
|
dataset = create_dataset("color_cube_rotation", min_rotations=0)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = color_cube_rotation_dataset(max_rotations=1, min_rotations=2)
|
dataset = create_dataset("color_cube_rotation", max_rotations=1, min_rotations=2)
|
||||||
|
|
||||||
|
|
||||||
def test_deterministic_generation():
|
def test_deterministic_generation():
|
||||||
dataset1 = color_cube_rotation_dataset(seed=42, size=5)
|
dataset1 = create_dataset("color_cube_rotation", seed=42, size=5)
|
||||||
dataset2 = color_cube_rotation_dataset(seed=42, size=5)
|
dataset2 = create_dataset("color_cube_rotation", seed=42, size=5)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
assert dataset1[i]["question"] == dataset2[i]["question"]
|
assert dataset1[i]["question"] == dataset2[i]["question"]
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.graphs.family_relationships import Gender, Relationship, family_relationships_dataset
|
from reasoning_gym import create_dataset
|
||||||
|
from reasoning_gym.graphs.family_relationships import FamilyRelationshipsDataset, Relationship
|
||||||
|
|
||||||
|
|
||||||
def test_family_relationships_generation():
|
def test_family_relationships_generation():
|
||||||
dataset = family_relationships_dataset(seed=42, size=10)
|
dataset = create_dataset("family_relationships", seed=42, size=10)
|
||||||
|
assert isinstance(dataset, FamilyRelationshipsDataset)
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
# Check required keys exist
|
# Check required keys exist
|
||||||
|
|
@ -32,21 +34,21 @@ def test_family_relationships_generation():
|
||||||
def test_family_relationships_config():
|
def test_family_relationships_config():
|
||||||
# Test invalid config raises assertion
|
# Test invalid config raises assertion
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = family_relationships_dataset(min_family_size=2)
|
dataset = create_dataset("family_relationships", min_family_size=2)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = family_relationships_dataset(max_family_size=3, min_family_size=4)
|
dataset = create_dataset("family_relationships", max_family_size=3, min_family_size=4)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = family_relationships_dataset(male_names=[])
|
dataset = create_dataset("family_relationships", male_names=[])
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
dataset = family_relationships_dataset(female_names=[])
|
dataset = create_dataset("family_relationships", female_names=[])
|
||||||
|
|
||||||
|
|
||||||
def test_deterministic_generation():
|
def test_deterministic_generation():
|
||||||
dataset1 = family_relationships_dataset(seed=42, size=5)
|
dataset1 = create_dataset("family_relationships", seed=42, size=5)
|
||||||
dataset2 = family_relationships_dataset(seed=42, size=5)
|
dataset2 = create_dataset("family_relationships", seed=42, size=5)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
assert dataset1[i]["question"] == dataset2[i]["question"]
|
assert dataset1[i]["question"] == dataset2[i]["question"]
|
||||||
|
|
@ -54,7 +56,7 @@ def test_deterministic_generation():
|
||||||
|
|
||||||
|
|
||||||
def test_relationship_consistency():
|
def test_relationship_consistency():
|
||||||
dataset = family_relationships_dataset(seed=42, size=10)
|
dataset = create_dataset("family_relationships", seed=42, size=10)
|
||||||
|
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
# Check that relationship matches the gender
|
# Check that relationship matches the gender
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reasoning_gym.games.maze import MazeConfig, MazeDataset, maze_dataset
|
from reasoning_gym import create_dataset
|
||||||
|
from reasoning_gym.games.maze import MazeConfig, MazeDataset
|
||||||
|
|
||||||
|
|
||||||
def test_maze_config_validation():
|
def test_maze_config_validation():
|
||||||
|
|
@ -38,7 +39,8 @@ def test_maze_dataset_creation():
|
||||||
|
|
||||||
|
|
||||||
def test_maze_dataset_items():
|
def test_maze_dataset_items():
|
||||||
ds = maze_dataset(
|
ds = create_dataset(
|
||||||
|
"maze",
|
||||||
min_dist=3,
|
min_dist=3,
|
||||||
max_dist=5,
|
max_dist=5,
|
||||||
min_grid_size=5,
|
min_grid_size=5,
|
||||||
|
|
@ -62,7 +64,8 @@ def test_maze_shortest_path_correctness():
|
||||||
"""
|
"""
|
||||||
min_dist = 4
|
min_dist = 4
|
||||||
max_dist = 8
|
max_dist = 8
|
||||||
ds = maze_dataset(
|
ds = create_dataset(
|
||||||
|
"maze",
|
||||||
min_dist=min_dist,
|
min_dist=min_dist,
|
||||||
max_dist=max_dist,
|
max_dist=max_dist,
|
||||||
min_grid_size=5,
|
min_grid_size=5,
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
from sympy import Symbol, sympify
|
from sympy import Symbol, sympify
|
||||||
|
|
||||||
from reasoning_gym.algebra.polynomial_equations import (
|
from reasoning_gym import create_dataset
|
||||||
PolynomialEquationsConfig,
|
from reasoning_gym.algebra.polynomial_equations import PolynomialEquationsConfig, PolynomialEquationsDataset
|
||||||
PolynomialEquationsDataset,
|
|
||||||
polynomial_equations_dataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_polynomial_config_validation():
|
def test_polynomial_config_validation():
|
||||||
|
|
@ -47,7 +44,8 @@ def test_polynomial_equations_dataset_basic():
|
||||||
|
|
||||||
def test_polynomial_equations_dataset_items():
|
def test_polynomial_equations_dataset_items():
|
||||||
"""Test that generated items have correct structure"""
|
"""Test that generated items have correct structure"""
|
||||||
ds = polynomial_equations_dataset(
|
ds = create_dataset(
|
||||||
|
"polynomial_equations",
|
||||||
min_terms=2,
|
min_terms=2,
|
||||||
max_terms=3,
|
max_terms=3,
|
||||||
min_value=1,
|
min_value=1,
|
||||||
|
|
@ -87,7 +85,8 @@ def test_polynomial_equations_dataset_deterministic():
|
||||||
|
|
||||||
def test_polynomial_solutions_evaluation():
|
def test_polynomial_solutions_evaluation():
|
||||||
"""Test that real_solutions satisfy the polynomial equation."""
|
"""Test that real_solutions satisfy the polynomial equation."""
|
||||||
ds = polynomial_equations_dataset(
|
ds = create_dataset(
|
||||||
|
"polynomial_equations",
|
||||||
min_terms=2,
|
min_terms=2,
|
||||||
max_terms=4,
|
max_terms=4,
|
||||||
min_value=1,
|
min_value=1,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue