From 44fd0d4a25cd1db39b8404bf9f46d1617382b66d Mon Sep 17 00:00:00 2001 From: "Andreas Koepf (aider)" Date: Fri, 24 Jan 2025 10:19:11 +0100 Subject: [PATCH] refactor: Inherit `LegCountingDataset` from `ProceduralDataset` --- reasoning_gym/arithmetic/leg_counting.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/reasoning_gym/arithmetic/leg_counting.py b/reasoning_gym/arithmetic/leg_counting.py index 8d8a8cff..7bec5c5e 100644 --- a/reasoning_gym/arithmetic/leg_counting.py +++ b/reasoning_gym/arithmetic/leg_counting.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from random import Random from typing import Dict, Optional +from ..dataset import ProceduralDataset ANIMALS = { # Animals with 0 legs @@ -67,27 +68,13 @@ class LegCountingConfig: assert self.max_instances > 0, "max_instances must be positive" -class LegCountingDataset: +class LegCountingDataset(ProceduralDataset): """Generates leg counting arithmetic tasks""" def __init__(self, config: LegCountingConfig): self.config = config self.config.validate() - self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) - - def __len__(self) -> int: - return self.config.size - - def __iter__(self): - self._current_idx = 0 - return self - - def __next__(self): - if self._current_idx >= self.config.size: - raise StopIteration - item = self[self._current_idx] - self._current_idx += 1 - return item + super().__init__(seed=config.seed, size=config.size) def _generate_animals(self, rng: Random) -> Dict[str, int]: """Generate a random set of animals and their counts"""