feat: Unify arithmetic task generation with configurable dataset class

This commit introduces a new `ArithmeticDataset` class that:
- Combines complex and simple arithmetic task generation approaches
- Provides configurable task generation via `ArithmeticDatasetConfig`
- Supports deterministic task generation
- Implements dataset interface for easy use with HuggingFace datasets
- Adds comprehensive test coverage for the new implementation

Changes include:
- Refactored `basic_arithmetic.py` to use a unified dataset approach
- Added configuration validation and flexible generation options
- Created test suite to validate dataset behavior
- Removed file generation methods in favor of in-memory generation
This commit is contained in:
Andreas Koepf (aider) 2025-01-23 11:30:09 +01:00
parent 8a6364a791
commit 40596262e1
2 changed files with 211 additions and 145 deletions

View file

@ -1,157 +1,155 @@
import json
from pathlib import Path
from dataclasses import dataclass
from random import Random
# more variability
def generate_math_task(
rng: Random, num_terms: int, num_digits: int, op: list[str] = ["+", "-", "*"]
) -> tuple[str, int]:
parts = []
def add_terms(remaining: int):
num_left = rng.randint(1, remaining)
num_right = remaining - num_left
if num_left > 1 and rng.random() > 0.5:
if rng.random() > 0.5:
parts.append("-(")
else:
parts.append("(")
add_terms(num_left)
parts.append(")")
else:
for i in range(num_left):
c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1)
parts.append(str(c))
if i + 1 < num_left:
parts.append(rng.choice(op))
if num_right > 0:
parts.append(rng.choice(op))
add_terms(num_right)
add_terms(num_terms)
space_parts = []
for p in parts:
while rng.random() < 0.15:
space_parts.append(" ")
space_parts.append(p)
term = " ".join(space_parts)
ground_truth = eval(term)
return term, ground_truth
from typing import Optional, Literal, Any
def generate_task_file():
rng = Random(42)
num_tasks = 100_000
i = 0
output_filename = "math_tasks.jsonl"
file_path = Path(output_filename)
with file_path.open("w", encoding="utf-8") as f:
while i < num_tasks:
num_terms = rng.randint(2, 6)
num_digits = rng.randint(1, 6)
term, ground_truth = generate_math_task(
rng, num_terms=num_terms, num_digits=num_digits
)
if abs(ground_truth) > 10**8 or abs(ground_truth) < 10:
continue
question_templates = [
"{0}",
"{0} =",
"{0} = ?",
"What is {0}?",
"Solve {0}",
]
template = rng.choice(question_templates)
formatted_task = template.format(term)
entry = {
"id": str(i),
"question": formatted_task,
"answer": str(ground_truth),
"num_terms": num_terms,
"num_digits": num_digits,
}
json.dump(entry, f)
f.write("\n")
i += 1
class BasicIntArithmeticTaskConfig:
def __init__(
self,
min_digits: int = 1,
max_digits: int = 5,
min_terms: int = 2,
max_terms: int = 8,
):
self.min_digits = min_digits
self.max_digits = max_digits
self.min_terms = min_terms
self.max_terms = max_terms
self.operators = ["+", "-"]
@dataclass
class ArithmeticDatasetConfig:
"""Configuration for arithmetic dataset generation"""
min_terms: int = 2
max_terms: int = 6
min_digits: int = 1
max_digits: int = 4
operators: list[str] = ("+" , "-", "*")
allow_parentheses: bool = True
allow_negation: bool = True
seed: Optional[int] = None
size: int = 10000 # Virtual dataset size
format_style: Literal["simple", "natural"] = "simple"
def validate(self):
assert self.min_digits > 0
assert self.max_digits >= self.min_digits
assert self.min_terms > 1
assert self.max_terms >= self.min_terms
assert len(self.operators) > 0
"""Validate configuration parameters"""
assert self.min_terms > 0, "min_terms must be positive"
assert self.max_terms >= self.min_terms, "max_terms must be >= min_terms"
assert self.min_digits > 0, "min_digits must be positive"
assert self.max_digits >= self.min_digits, "max_digits must be >= min_digits"
assert len(self.operators) > 0, "must provide at least one operator"
for op in self.operators:
assert op in ["+", "-", "*"], f"unsupported operator: {op}"
def generate_task(rng: Random, cfg: BasicIntArithmeticTaskConfig) -> str:
num_terms = rng.randint(cfg.min_terms, cfg.max_terms)
num_digits = rng.randint(cfg.min_digits, cfg.max_digits)
constants = [rng.randint(0, 10**num_digits) for _ in range(num_terms)]
operators = [rng.choice(cfg.operators) for _ in range(num_terms - 1)]
buffer = []
ground_truth = constants[0]
buffer.append(f"{constants[0]}")
for i, op in enumerate(operators):
c = constants[i + 1]
buffer.append(op)
buffer.append(f"{c}")
if op == "+":
ground_truth += c
elif op == "-":
ground_truth -= c
class ArithmeticDataset:
"""Dataset that generates arithmetic tasks with configurable complexity"""
def __init__(self, config: ArithmeticDatasetConfig):
self.config = config
self.config.validate()
self.rng = Random(config.seed)
def __len__(self) -> int:
return self.config.size
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Generate a single arithmetic task
Args:
idx: Index of the item to generate
Returns:
dict with keys:
- question: str, the formatted arithmetic expression
- answer: str, the ground truth result
- metadata: dict with generation parameters
"""
# Use seed derived from idx for deterministic generation
item_rng = Random(self.rng.randint(0, 2**32) + idx)
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms)
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits)
if self.config.allow_parentheses:
expression, result = self._generate_complex_task(item_rng, num_terms, num_digits)
else:
RuntimeError("Unsupported operator")
expression, result = self._generate_simple_task(item_rng, num_terms, num_digits)
question = self._format_question(expression)
return {
"question": question,
"answer": str(result),
"metadata": {
"num_terms": num_terms,
"num_digits": num_digits,
"expression": expression
}
}
buffer.append(f"")
def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]:
"""Generate a complex arithmetic task with possible parentheses"""
parts = []
question_templates = [
"{0}",
"{0} =",
"{0} = ?",
"What is {0}?",
"Solve {0}",
"Calculate {0}",
# 'evaluate {0}',
# 'do me a favor and calculate {0}',
# 'Give me the result of {0}',
# 'Help me solve this: {0}',
# 'calculator: {0}',
# 'Tell me the result of the following expression {0}',
]
def add_terms(remaining: int):
num_left = rng.randint(1, remaining)
num_right = remaining - num_left
template = rng.choice(question_templates)
task = " ".join(buffer)
formatted_task = template.format(task)
if num_left > 1 and rng.random() > 0.5 and self.config.allow_parentheses:
if rng.random() > 0.5 and self.config.allow_negation:
parts.append("-(")
else:
parts.append("(")
add_terms(num_left)
parts.append(")")
else:
for i in range(num_left):
c = rng.randint(-(10**num_digits) + 1, 10**num_digits - 1)
parts.append(str(c))
if i + 1 < num_left:
parts.append(rng.choice(self.config.operators))
return formatted_task, str(ground_truth), num_terms, num_digits
if num_right > 0:
parts.append(rng.choice(self.config.operators))
add_terms(num_right)
add_terms(num_terms)
# Add random spaces
space_parts = []
for p in parts:
while rng.random() < 0.15:
space_parts.append(" ")
space_parts.append(p)
expression = " ".join(space_parts)
result = eval(expression) # Note: eval is safe here as we control the input
return expression, result
def _generate_simple_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]:
"""Generate a simple linear arithmetic task without parentheses"""
constants = [rng.randint(0, 10**num_digits) for _ in range(num_terms)]
operators = [rng.choice(self.config.operators) for _ in range(num_terms - 1)]
# Build expression and compute result
expression_parts = []
result = constants[0]
expression_parts.append(str(constants[0]))
for i, op in enumerate(operators):
c = constants[i + 1]
expression_parts.append(op)
expression_parts.append(str(c))
if op == "+":
result += c
elif op == "-":
result -= c
elif op == "*":
result *= c
else:
raise RuntimeError(f"Unsupported operator: {op}")
expression = " ".join(expression_parts)
return expression, result
def _format_question(self, expression: str) -> str:
"""Format the expression according to config style"""
if self.config.format_style == "simple":
return f"{expression} ="
else:
templates = [
"What is {0}?",
"Calculate {0}",
"Solve {0}",
"Evaluate the expression: {0}"
]
return self.rng.choice(templates).format(expression)