feat: Add ReseedingDataset wrapper for infinite procedural datasets

This commit is contained in:
Andreas Koepf (aider) 2025-01-30 21:56:43 +01:00
parent 5b35ea51a7
commit 6d59648264

View file

@ -2,8 +2,9 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sized
from copy import deepcopy
from random import Random
from typing import Any, Dict, Iterator, Optional
from typing import Any, Dict, Iterator, Optional, TypeVar, Type
class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
@ -63,3 +64,57 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
reward = 0.01
return reward
T = TypeVar('T', bound='ProceduralDataset')
class ReseedingDataset(Iterable[Dict[str, Any]]):
"""Wrapper that makes any ProceduralDataset infinite by reseeding when reaching the end"""
def __init__(self, dataset: T, chunk_size: int = 500):
"""Initialize with dataset instance and chunk size
Args:
dataset: The ProceduralDataset instance to wrap
chunk_size: Size of each generated chunk before reseeding
"""
self.dataset = dataset
self.dataset_cls: Type[T] = type(dataset)
self.chunk_size = chunk_size
# Start with chunk 0
self._current_chunk = 0
self._current_dataset = self._create_chunk(0)
self._current_idx = 0
def _create_chunk(self, chunk_num: int) -> T:
"""Create a new dataset chunk with unique seed"""
# Create new config with modified seed
new_config = deepcopy(self.dataset.config)
if hasattr(new_config, "seed"):
# Derive new seed from chunk number
base_seed = new_config.seed if new_config.seed is not None else 0
new_config.seed = base_seed + chunk_num
# Create new dataset instance with chunk config
return self.dataset_cls(new_config)
def __iter__(self) -> Iterator[Dict[str, Any]]:
"""Make the dataset iterable"""
self._current_chunk = 0
self._current_dataset = self._create_chunk(0)
self._current_idx = 0
return self
def __next__(self) -> Dict[str, Any]:
"""Get next item, creating new chunk if needed"""
if self._current_idx >= self.chunk_size:
# Move to next chunk
self._current_chunk += 1
self._current_dataset = self._create_chunk(self._current_chunk)
self._current_idx = 0
item = self._current_dataset[self._current_idx]
self._current_idx += 1
return item