mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
add reasoning_gym.create_dataset({name}, ...) global factory function
This commit is contained in:
parent
0d2d8ba6a0
commit
519e411fa5
35 changed files with 133 additions and 598 deletions
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||
from random import Random
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from ..dataset import ProceduralDataset
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -231,47 +231,5 @@ class BasicArithmeticDataset(ProceduralDataset):
|
|||
return rng.choice(templates).format(expression)
|
||||
|
||||
|
||||
def basic_arithmetic_dataset(
|
||||
min_terms: int = 2,
|
||||
max_terms: int = 6,
|
||||
min_digits: int = 1,
|
||||
max_digits: int = 4,
|
||||
operators: list[str] = ("+", "-", "*", "/"),
|
||||
allow_parentheses: bool = True,
|
||||
allow_negation: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
format_style: Literal["simple", "natural"] = "simple",
|
||||
whitespace: Literal["no_space", "single", "random"] = "single",
|
||||
) -> BasicArithmeticDataset:
|
||||
"""Create a BasicArithmeticDataset with the given configuration.
|
||||
|
||||
Args:
|
||||
min_terms: Minimum number of terms in expressions
|
||||
max_terms: Maximum number of terms in expressions
|
||||
min_digits: Minimum number of digits in numbers
|
||||
max_digits: Maximum number of digits in numbers
|
||||
operators: List of operators to use ("+", "-", "*")
|
||||
allow_parentheses: Whether to allow parentheses in expressions
|
||||
allow_negation: Whether to allow negative numbers
|
||||
seed: Random seed for reproducibility
|
||||
size: Virtual size of the dataset
|
||||
format_style: Style of question formatting ("simple" or "natural")
|
||||
|
||||
Returns:
|
||||
BasicArithmeticDataset: Configured dataset instance
|
||||
"""
|
||||
config = BasicArithmeticDatasetConfig(
|
||||
min_terms=min_terms,
|
||||
max_terms=max_terms,
|
||||
min_digits=min_digits,
|
||||
max_digits=max_digits,
|
||||
operators=operators,
|
||||
allow_parentheses=allow_parentheses,
|
||||
allow_negation=allow_negation,
|
||||
seed=seed,
|
||||
size=size,
|
||||
format_style=format_style,
|
||||
whitespace=whitespace,
|
||||
)
|
||||
return BasicArithmeticDataset(config)
|
||||
# Register the dataset
|
||||
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue