add reasoning_gym.create_dataset({name}, ...) global factory function

This commit is contained in:
Andreas Koepf 2025-01-25 00:58:34 +01:00
parent 0d2d8ba6a0
commit 519e411fa5
35 changed files with 133 additions and 598 deletions

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Literal, Optional
from ..dataset import ProceduralDataset
from ..factory import ProceduralDataset, register_dataset
@dataclass
@ -231,47 +231,5 @@ class BasicArithmeticDataset(ProceduralDataset):
return rng.choice(templates).format(expression)
def basic_arithmetic_dataset(
min_terms: int = 2,
max_terms: int = 6,
min_digits: int = 1,
max_digits: int = 4,
operators: list[str] = ("+", "-", "*", "/"),
allow_parentheses: bool = True,
allow_negation: bool = True,
seed: Optional[int] = None,
size: int = 500,
format_style: Literal["simple", "natural"] = "simple",
whitespace: Literal["no_space", "single", "random"] = "single",
) -> BasicArithmeticDataset:
"""Create a BasicArithmeticDataset with the given configuration.
Args:
min_terms: Minimum number of terms in expressions
max_terms: Maximum number of terms in expressions
min_digits: Minimum number of digits in numbers
max_digits: Maximum number of digits in numbers
operators: List of operators to use ("+", "-", "*")
allow_parentheses: Whether to allow parentheses in expressions
allow_negation: Whether to allow negative numbers
seed: Random seed for reproducibility
size: Virtual size of the dataset
format_style: Style of question formatting ("simple" or "natural")
Returns:
BasicArithmeticDataset: Configured dataset instance
"""
config = BasicArithmeticDatasetConfig(
min_terms=min_terms,
max_terms=max_terms,
min_digits=min_digits,
max_digits=max_digits,
operators=operators,
allow_parentheses=allow_parentheses,
allow_negation=allow_negation,
seed=seed,
size=size,
format_style=format_style,
whitespace=whitespace,
)
return BasicArithmeticDataset(config)
# Register the dataset
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig)