style: Format code with consistent whitespace and remove unnecessary lines

This commit is contained in:
Andreas Koepf 2025-01-23 12:45:15 +01:00 committed by Andreas Koepf (aider)
parent ba493adbe7
commit 0aa35e15a3

View file

@ -6,11 +6,12 @@ from typing import Optional, Literal, Any
@dataclass
class ArithmeticDatasetConfig:
"""Configuration for arithmetic dataset generation"""
min_terms: int = 2
max_terms: int = 6
min_digits: int = 1
max_digits: int = 4
operators: list[str] = ("+" , "-", "*")
operators: list[str] = ("+", "-", "*")
allow_parentheses: bool = True
allow_negation: bool = True
seed: Optional[int] = None
@ -30,22 +31,22 @@ class ArithmeticDatasetConfig:
class ArithmeticDataset:
"""Dataset that generates arithmetic tasks with configurable complexity"""
def __init__(self, config: ArithmeticDatasetConfig):
self.config = config
self.config.validate()
# Generate base seed if none provided
self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32)
def __len__(self) -> int:
return self.config.size
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Generate a single arithmetic task
Args:
idx: Index of the item to generate
Returns:
dict with keys:
- question: str, the formatted arithmetic expression
@ -54,25 +55,21 @@ class ArithmeticDataset:
"""
# Create deterministic RNG from base seed and idx
item_rng = Random(self.seed + idx)
num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms)
num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits)
if self.config.allow_parentheses:
expression, result = self._generate_complex_task(item_rng, num_terms, num_digits)
else:
expression, result = self._generate_simple_task(item_rng, num_terms, num_digits)
question = self._format_question(expression)
return {
"question": question,
"answer": str(result),
"metadata": {
"num_terms": num_terms,
"num_digits": num_digits,
"expression": expression
}
"metadata": {"num_terms": num_terms, "num_digits": num_digits, "expression": expression},
}
def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]:
@ -142,13 +139,11 @@ class ArithmeticDataset:
expression = " ".join(expression_parts)
return expression, result
def __iter__(self):
"""Make the dataset iterable"""
self._current_idx = 0
return self
def __next__(self):
"""Get next item in iteration"""
if self._current_idx >= self.config.size:
@ -162,12 +157,7 @@ class ArithmeticDataset:
if self.config.format_style == "simple":
return f"{expression} ="
else:
templates = [
"What is {0}?",
"Calculate {0}",
"Solve {0}",
"Evaluate the expression: {0}"
]
templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"]
# Use deterministic RNG for template selection
template_rng = Random(self.seed)
return template_rng.choice(templates).format(expression)