feat: Add arithmetic_dataset() factory function to basic_arithmetic.py

This commit is contained in:
Andreas Koepf (aider) 2025-01-23 12:45:17 +01:00 committed by Andreas Koepf
parent 0aa35e15a3
commit 8d1dac9e62
3 changed files with 104 additions and 20 deletions

View file

@ -1,6 +1,6 @@
from dataclasses import dataclass
from random import Random
from typing import Optional, Literal, Any
from typing import Any, Literal, Optional
@dataclass
@ -69,7 +69,11 @@ class ArithmeticDataset:
return {
"question": question,
"answer": str(result),
"metadata": {"num_terms": num_terms, "num_digits": num_digits, "expression": expression},
"metadata": {
"num_terms": num_terms,
"num_digits": num_digits,
"expression": expression,
},
}
def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]:
@ -161,3 +165,47 @@ class ArithmeticDataset:
# Use deterministic RNG for template selection
template_rng = Random(self.seed)
return template_rng.choice(templates).format(expression)
def arithmetic_dataset(
min_terms: int = 2,
max_terms: int = 6,
min_digits: int = 1,
max_digits: int = 4,
operators: list[str] = ("+", "-", "*"),
allow_parentheses: bool = True,
allow_negation: bool = True,
seed: Optional[int] = None,
size: int = 500,
format_style: Literal["simple", "natural"] = "simple",
) -> ArithmeticDataset:
"""Create an ArithmeticDataset with the given configuration.
Args:
min_terms: Minimum number of terms in expressions
max_terms: Maximum number of terms in expressions
min_digits: Minimum number of digits in numbers
max_digits: Maximum number of digits in numbers
operators: List of operators to use ("+", "-", "*")
allow_parentheses: Whether to allow parentheses in expressions
allow_negation: Whether to allow negative numbers
seed: Random seed for reproducibility
size: Virtual size of the dataset
format_style: Style of question formatting ("simple" or "natural")
Returns:
ArithmeticDataset: Configured dataset instance
"""
config = ArithmeticDatasetConfig(
min_terms=min_terms,
max_terms=max_terms,
min_digits=min_digits,
max_digits=max_digits,
operators=operators,
allow_parentheses=allow_parentheses,
allow_negation=allow_negation,
seed=seed,
size=size,
format_style=format_style,
)
return ArithmeticDataset(config)