add simple dataset gallery generation script

This commit is contained in:
Andreas Koepf 2025-01-30 22:30:26 +01:00
parent 71ccd41adb
commit 5a88cf2529
6 changed files with 1352 additions and 105 deletions

View file

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Iterable, Sized
from copy import deepcopy
from random import Random
from typing import Any, Dict, Iterator, Optional, TypeVar, Type
from typing import Any, Dict, Iterator, Optional, Type, TypeVar
class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
@ -66,15 +66,15 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
return reward
T = TypeVar('T', bound='ProceduralDataset')
T = TypeVar("T", bound="ProceduralDataset")
class ReseedingDataset(Iterable[Dict[str, Any]]):
"""Wrapper that makes any ProceduralDataset infinite by reseeding when reaching the end"""
def __init__(self, dataset: T, chunk_size: int = 500):
"""Initialize with dataset instance and chunk size
Args:
dataset: The ProceduralDataset instance to wrap
chunk_size: Size of each generated chunk before reseeding
@ -82,12 +82,12 @@ class ReseedingDataset(Iterable[Dict[str, Any]]):
self.dataset = dataset
self.dataset_cls: Type[T] = type(dataset)
self.chunk_size = chunk_size
# Start with chunk 0
self._current_chunk = 0
self._current_dataset = self._create_chunk(0)
self._current_idx = 0
def _create_chunk(self, chunk_num: int) -> T:
"""Create a new dataset chunk with unique seed"""
# Create new config with modified seed
@ -95,17 +95,17 @@ class ReseedingDataset(Iterable[Dict[str, Any]]):
if hasattr(new_config, "seed"):
# Derive new seed from chunk number using dataset's seed, wrapping around at 2^32
new_config.seed = (self.dataset.seed + chunk_num) % (2**32)
# Create new dataset instance with chunk config
return self.dataset_cls(new_config)
def __iter__(self) -> Iterator[Dict[str, Any]]:
"""Make the dataset iterable"""
self._current_chunk = 0
self._current_dataset = self._create_chunk(0)
self._current_idx = 0
return self
def __next__(self) -> Dict[str, Any]:
"""Get next item, creating new chunk if needed"""
if self._current_idx >= self.chunk_size:
@ -113,11 +113,11 @@ class ReseedingDataset(Iterable[Dict[str, Any]]):
self._current_chunk += 1
self._current_dataset = self._create_chunk(self._current_chunk)
self._current_idx = 0
item = self._current_dataset[self._current_idx]
self._current_idx += 1
return item
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Forward scoring to the wrapped dataset's implementation"""
return self.dataset.score_answer(answer, entry)