add reasoning_gym.create_dataset({name}, ...) global factory function

This commit is contained in:
Andreas Koepf 2025-01-25 00:58:34 +01:00
parent 0d2d8ba6a0
commit 519e411fa5
35 changed files with 133 additions and 598 deletions

View file

@ -2,8 +2,7 @@ import random
from dataclasses import dataclass
from typing import Optional
from ..dataset import ProceduralDataset
from ..factory import register_dataset
from ..factory import ProceduralDataset, register_dataset
@dataclass
@ -109,40 +108,5 @@ class ChainSum(ProceduralDataset):
return expression, result
def chain_sum_dataset(
min_terms: int = 2,
max_terms: int = 6,
min_digits: int = 1,
max_digits: int = 4,
allow_negation: bool = False,
seed: Optional[int] = None,
size: int = 500,
) -> ChainSum:
"""Create a ChainSum dataset with the given configuration.
Args:
min_terms: Minimum number of terms in expressions
max_terms: Maximum number of terms in expressions
min_digits: Minimum number of digits in numbers
max_digits: Maximum number of digits in numbers
allow_negation: Whether to allow negative numbers
seed: Random seed for reproducibility
size: Virtual size of the dataset
Returns:
ChainSum: Configured dataset instance
"""
config = ChainSumConfig(
min_terms=min_terms,
max_terms=max_terms,
min_digits=min_digits,
max_digits=max_digits,
allow_negation=allow_negation,
seed=seed,
size=size,
)
return ChainSum(config)
# Register the dataset
register_dataset("chain_sum", ChainSum, ChainSumConfig)