diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 677141bf..b0c3dc5d 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -6,11 +6,12 @@ from typing import Optional, Literal, Any @dataclass class ArithmeticDatasetConfig: """Configuration for arithmetic dataset generation""" + min_terms: int = 2 max_terms: int = 6 min_digits: int = 1 max_digits: int = 4 - operators: list[str] = ("+" , "-", "*") + operators: list[str] = ("+", "-", "*") allow_parentheses: bool = True allow_negation: bool = True seed: Optional[int] = None @@ -30,22 +31,22 @@ class ArithmeticDatasetConfig: class ArithmeticDataset: """Dataset that generates arithmetic tasks with configurable complexity""" - + def __init__(self, config: ArithmeticDatasetConfig): self.config = config self.config.validate() # Generate base seed if none provided self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) - + def __len__(self) -> int: return self.config.size - + def __getitem__(self, idx: int) -> dict[str, Any]: """Generate a single arithmetic task - + Args: idx: Index of the item to generate - + Returns: dict with keys: - question: str, the formatted arithmetic expression @@ -54,25 +55,21 @@ class ArithmeticDataset: """ # Create deterministic RNG from base seed and idx item_rng = Random(self.seed + idx) - + num_terms = item_rng.randint(self.config.min_terms, self.config.max_terms) num_digits = item_rng.randint(self.config.min_digits, self.config.max_digits) - + if self.config.allow_parentheses: expression, result = self._generate_complex_task(item_rng, num_terms, num_digits) else: expression, result = self._generate_simple_task(item_rng, num_terms, num_digits) - + question = self._format_question(expression) - + return { "question": question, "answer": str(result), - "metadata": { - "num_terms": num_terms, - "num_digits": num_digits, - "expression": expression - } + "metadata": {"num_terms": num_terms, "num_digits": num_digits, "expression": expression}, } def _generate_complex_task(self, rng: Random, num_terms: int, num_digits: int) -> tuple[str, int]: @@ -142,13 +139,11 @@ class ArithmeticDataset: expression = " ".join(expression_parts) return expression, result - - def __iter__(self): """Make the dataset iterable""" self._current_idx = 0 return self - + def __next__(self): """Get next item in iteration""" if self._current_idx >= self.config.size: @@ -162,12 +157,7 @@ class ArithmeticDataset: if self.config.format_style == "simple": return f"{expression} =" else: - templates = [ - "What is {0}?", - "Calculate {0}", - "Solve {0}", - "Evaluate the expression: {0}" - ] + templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"] # Use deterministic RNG for template selection template_rng = Random(self.seed) return template_rng.choice(templates).format(expression)