diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 26758a74..cde9a543 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Literal, Optional +from ..dataset import ProceduralDataset @dataclass @@ -30,17 +31,13 @@ class ArithmeticDatasetConfig: assert op in ["+", "-", "*"], f"unsupported operator: {op}" -class ArithmeticDataset: +class ArithmeticDataset(ProceduralDataset): """Dataset that generates arithmetic tasks with configurable complexity""" def __init__(self, config: ArithmeticDatasetConfig): self.config = config self.config.validate() - # Generate base seed if none provided - self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) - - def __len__(self) -> int: - return self.config.size + super().__init__(seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict[str, Any]: """Generate a single arithmetic task @@ -148,18 +145,6 @@ class ArithmeticDataset: expression = " ".join(expression_parts) return expression, result - def __iter__(self): - """Make the dataset iterable""" - self._current_idx = 0 - return self - - def __next__(self): - """Get next item in iteration""" - if self._current_idx >= self.config.size: - raise StopIteration - item = self[self._current_idx] - self._current_idx += 1 - return item def _format_question(self, rng: Random, expression: str) -> str: """Format the expression according to config style"""