diff --git a/reasoning_gym/algorithmic/word_reversal.py b/reasoning_gym/algorithmic/word_reversal.py index 7fa10332..fe0aa1f9 100644 --- a/reasoning_gym/algorithmic/word_reversal.py +++ b/reasoning_gym/algorithmic/word_reversal.py @@ -5,7 +5,8 @@ from dataclasses import dataclass from random import Random from typing import List, Optional -from reasoning_gym.data import read_data_file +from ..data import read_data_file +from ..dataset import ProceduralDataset @dataclass @@ -23,33 +24,19 @@ class WordReversalConfig: assert self.max_words >= self.min_words, "max_words must be >= min_words" -class WordReversalDataset: +class WordReversalDataset(ProceduralDataset): """Generates word reversal tasks from text spans""" def __init__(self, config: WordReversalConfig): self.config = config self.config.validate() - self.seed = config.seed if config.seed is not None else Random().randint(0, 2**32) + super().__init__(seed=config.seed, size=config.size) # Load and preprocess text text = read_data_file("in_the_year_2889.txt") # Extract words and clean them to contain only alphanumeric characters self.words = [word for word in re.findall(r"\b\w+\b", text) if word.isalnum()] - def __len__(self) -> int: - return self.config.size - - def __iter__(self): - self._current_idx = 0 - return self - - def __next__(self): - if self._current_idx >= self.config.size: - raise StopIteration - item = self[self._current_idx] - self._current_idx += 1 - return item - def __getitem__(self, idx: int) -> dict: """Generate a single word reversal task""" rng = Random(self.seed + idx)