mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat: Add ReseedingDataset wrapper for infinite procedural datasets
This commit is contained in:
parent
5b35ea51a7
commit
6d59648264
1 changed files with 56 additions and 1 deletions
|
|
@ -2,8 +2,9 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Sized
|
||||
from copy import deepcopy
|
||||
from random import Random
|
||||
from typing import Any, Dict, Iterator, Optional
|
||||
from typing import Any, Dict, Iterator, Optional, TypeVar, Type
|
||||
|
||||
|
||||
class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
|
||||
|
|
@ -63,3 +64,57 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
|
|||
reward = 0.01
|
||||
|
||||
return reward
|
||||
|
||||
|
||||
T = TypeVar('T', bound='ProceduralDataset')
|
||||
|
||||
|
||||
class ReseedingDataset(Iterable[Dict[str, Any]]):
|
||||
"""Wrapper that makes any ProceduralDataset infinite by reseeding when reaching the end"""
|
||||
|
||||
def __init__(self, dataset: T, chunk_size: int = 500):
|
||||
"""Initialize with dataset instance and chunk size
|
||||
|
||||
Args:
|
||||
dataset: The ProceduralDataset instance to wrap
|
||||
chunk_size: Size of each generated chunk before reseeding
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.dataset_cls: Type[T] = type(dataset)
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
# Start with chunk 0
|
||||
self._current_chunk = 0
|
||||
self._current_dataset = self._create_chunk(0)
|
||||
self._current_idx = 0
|
||||
|
||||
def _create_chunk(self, chunk_num: int) -> T:
|
||||
"""Create a new dataset chunk with unique seed"""
|
||||
# Create new config with modified seed
|
||||
new_config = deepcopy(self.dataset.config)
|
||||
if hasattr(new_config, "seed"):
|
||||
# Derive new seed from chunk number
|
||||
base_seed = new_config.seed if new_config.seed is not None else 0
|
||||
new_config.seed = base_seed + chunk_num
|
||||
|
||||
# Create new dataset instance with chunk config
|
||||
return self.dataset_cls(new_config)
|
||||
|
||||
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
||||
"""Make the dataset iterable"""
|
||||
self._current_chunk = 0
|
||||
self._current_dataset = self._create_chunk(0)
|
||||
self._current_idx = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> Dict[str, Any]:
|
||||
"""Get next item, creating new chunk if needed"""
|
||||
if self._current_idx >= self.chunk_size:
|
||||
# Move to next chunk
|
||||
self._current_chunk += 1
|
||||
self._current_dataset = self._create_chunk(self._current_chunk)
|
||||
self._current_idx = 0
|
||||
|
||||
item = self._current_dataset[self._current_idx]
|
||||
self._current_idx += 1
|
||||
return item
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue