diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index c2d9718d..85cfc854 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -1,10 +1,11 @@ """Number sorting task generator""" -import random from dataclasses import dataclass from random import Random from typing import List, Optional, Tuple +from ..dataset import ProceduralDataset + @dataclass class NumberSortingConfig: @@ -28,27 +29,13 @@ class NumberSortingConfig: assert self.min_value < self.max_value, "max_value must be > min_value" -class NumberSortingDataset: +class NumberSortingDataset(ProceduralDataset): """Generates number sorting tasks""" def __init__(self, config: NumberSortingConfig): self.config = config self.config.validate() - self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) - - def __len__(self) -> int: - return self.config.size - - def __iter__(self): - self._current_idx = 0 - return self - - def __next__(self): - if self._current_idx >= self.config.size: - raise StopIteration - item = self[self._current_idx] - self._current_idx += 1 - return item + super().__init__(seed=config.seed, size=config.size) def _format_number(self, num: float, decimals: int) -> str: """Format number with specified decimal places"""