mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat: Add arithmetic_dataset() factory function to basic_arithmetic.py
This commit is contained in:
parent
0aa35e15a3
commit
8d1dac9e62
3 changed files with 104 additions and 20 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional, Literal, Any
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -69,7 +69,11 @@ class ArithmeticDataset:
|
|||
return {
|
||||
"question": question,
|
||||
"answer": str(result),
|
||||
"metadata": {"num_terms": num_terms, "num_digits": num_digits, "expression": expression},
|
||||
"metadata": {
|
||||
"num_terms": num_terms,
|
||||
"num_digits": num_digits,
|
||||
"expression": expression,
|
||||
},
|
||||
}
|
||||
|
||||
def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]:
|
||||
|
|
@ -161,3 +165,47 @@ class ArithmeticDataset:
|
|||
# Use deterministic RNG for template selection
|
||||
template_rng = Random(self.seed)
|
||||
return template_rng.choice(templates).format(expression)
|
||||
|
||||
|
||||
def arithmetic_dataset(
|
||||
min_terms: int = 2,
|
||||
max_terms: int = 6,
|
||||
min_digits: int = 1,
|
||||
max_digits: int = 4,
|
||||
operators: list[str] = ("+", "-", "*"),
|
||||
allow_parentheses: bool = True,
|
||||
allow_negation: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
size: int = 500,
|
||||
format_style: Literal["simple", "natural"] = "simple",
|
||||
) -> ArithmeticDataset:
|
||||
"""Create an ArithmeticDataset with the given configuration.
|
||||
|
||||
Args:
|
||||
min_terms: Minimum number of terms in expressions
|
||||
max_terms: Maximum number of terms in expressions
|
||||
min_digits: Minimum number of digits in numbers
|
||||
max_digits: Maximum number of digits in numbers
|
||||
operators: List of operators to use ("+", "-", "*")
|
||||
allow_parentheses: Whether to allow parentheses in expressions
|
||||
allow_negation: Whether to allow negative numbers
|
||||
seed: Random seed for reproducibility
|
||||
size: Virtual size of the dataset
|
||||
format_style: Style of question formatting ("simple" or "natural")
|
||||
|
||||
Returns:
|
||||
ArithmeticDataset: Configured dataset instance
|
||||
"""
|
||||
config = ArithmeticDatasetConfig(
|
||||
min_terms=min_terms,
|
||||
max_terms=max_terms,
|
||||
min_digits=min_digits,
|
||||
max_digits=max_digits,
|
||||
operators=operators,
|
||||
allow_parentheses=allow_parentheses,
|
||||
allow_negation=allow_negation,
|
||||
seed=seed,
|
||||
size=size,
|
||||
format_style=format_style,
|
||||
)
|
||||
return ArithmeticDataset(config)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue