From d191e78a2882affa5e296722985703e1c2e59bea Mon Sep 17 00:00:00 2001 From: "Andreas Koepf (aider)" Date: Fri, 24 Jan 2025 09:57:26 +0100 Subject: [PATCH] refactor: Inherit ChainSum from ProceduralDataset base class --- reasoning_gym/arithmetic/chain_sum.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 0b8fdf7d..dc122091 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -1,6 +1,7 @@ import random from dataclasses import dataclass from typing import Optional +from ..dataset import ProceduralDataset @dataclass @@ -27,17 +28,13 @@ class ChainSumConfig: assert 10 ** (self.min_digits - 1) >= 1, "min_digits would result in invalid number range" -class ChainSum: +class ChainSum(ProceduralDataset): """Generates simple arithmetic tasks using only + and - operators""" def __init__(self, config: ChainSumConfig): self.config = config self.config.validate() - # Generate base seed if none provided - self.seed = config.seed if config.seed is not None else random.randint(0, 2**32) - - def __len__(self) -> int: - return self.config.size + super().__init__(seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict: """Generate a single chain sum task @@ -73,18 +70,6 @@ class ChainSum: }, } - def __iter__(self): - """Make the dataset iterable""" - self._current_idx = 0 - return self - - def __next__(self): - """Get next item in iteration""" - if self._current_idx >= self.config.size: - raise StopIteration - item = self[self._current_idx] - self._current_idx += 1 - return item def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]: """Generate a chain sum task