diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index 07649a41..fe8e79aa 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -2,8 +2,9 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Sized +from copy import deepcopy from random import Random -from typing import Any, Dict, Iterator, Optional +from typing import Any, Dict, Iterator, Optional, TypeVar, Type class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): @@ -63,3 +64,57 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): reward = 0.01 return reward + + +T = TypeVar('T', bound='ProceduralDataset') + + +class ReseedingDataset(Iterable[Dict[str, Any]]): + """Wrapper that makes any ProceduralDataset infinite by reseeding when reaching the end""" + + def __init__(self, dataset: T, chunk_size: int = 500): + """Initialize with dataset instance and chunk size + + Args: + dataset: The ProceduralDataset instance to wrap + chunk_size: Size of each generated chunk before reseeding + """ + self.dataset = dataset + self.dataset_cls: Type[T] = type(dataset) + self.chunk_size = chunk_size + + # Start with chunk 0 + self._current_chunk = 0 + self._current_dataset = self._create_chunk(0) + self._current_idx = 0 + + def _create_chunk(self, chunk_num: int) -> T: + """Create a new dataset chunk with unique seed""" + # Create new config with modified seed + new_config = deepcopy(self.dataset.config) + if hasattr(new_config, "seed"): + # Derive new seed from chunk number + base_seed = new_config.seed if new_config.seed is not None else 0 + new_config.seed = base_seed + chunk_num + + # Create new dataset instance with chunk config + return self.dataset_cls(new_config) + + def __iter__(self) -> Iterator[Dict[str, Any]]: + """Make the dataset iterable""" + self._current_chunk = 0 + self._current_dataset = self._create_chunk(0) + self._current_idx = 0 + return self + + def __next__(self) -> Dict[str, Any]: + """Get next item, creating new chunk if needed""" + if self._current_idx >= self.chunk_size: + # Move to next chunk + self._current_chunk += 1 + self._current_dataset = self._create_chunk(self._current_chunk) + self._current_idx = 0 + + item = self._current_dataset[self._current_idx] + self._current_idx += 1 + return item