add simple dataset gallery generation script

This commit is contained in:
Andreas Koepf 2025-01-30 22:30:26 +01:00
parent 71ccd41adb
commit 5a88cf2529
6 changed files with 1352 additions and 105 deletions

1309
GALLERY.md Normal file

File diff suppressed because it is too large Load diff

63
python
View file

@ -1,63 +0,0 @@
"""Generate a markdown gallery of all available datasets with examples"""
from pathlib import Path
import textwrap
from reasoning_gym.factory import DATASETS, create_dataset
def generate_gallery() -> str:
"""Generate markdown content for the gallery"""
# Start with header
content = ["# Dataset Gallery\n"]
# Add index
content.append("## Available Datasets\n")
for name in sorted(DATASETS.keys()):
# Create anchor link
anchor = name.replace("_", "-")
content.append(f"- [{name}](#{anchor})\n")
content.append("\n")
# Add examples for each dataset
content.append("## Examples\n")
for name in sorted(DATASETS.keys()):
dataset = create_dataset(name)
# Add dataset header
content.append(f"### {name}\n")
# Get dataset class docstring if available
if dataset.__class__.__doc__:
doc = textwrap.dedent(dataset.__class__.__doc__.strip())
content.append(f"{doc}\n")
content.append("```\n")
# Show 3 examples
for i, item in enumerate(dataset):
if i >= 3:
break
content.append(f"Example {i+1}:\n")
content.append(f"Question: {item['question']}\n")
content.append(f"Answer: {item['answer']}\n")
content.append(f"Metadata: {item['metadata']}\n")
content.append("\n")
content.append("```\n\n")
return "".join(content)
def main():
"""Generate gallery markdown file"""
gallery_path = Path(__file__).parent.parent / "GALLERY.md"
gallery_content = generate_gallery()
with open(gallery_path, "w") as f:
f.write(gallery_content)
print(f"Generated gallery at {gallery_path}")
if __name__ == "__main__":
main()

View file

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Iterable, Sized from collections.abc import Iterable, Sized
from copy import deepcopy from copy import deepcopy
from random import Random from random import Random
from typing import Any, Dict, Iterator, Optional, TypeVar, Type from typing import Any, Dict, Iterator, Optional, Type, TypeVar
class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
@ -66,15 +66,15 @@ class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]):
return reward return reward
T = TypeVar('T', bound='ProceduralDataset') T = TypeVar("T", bound="ProceduralDataset")
class ReseedingDataset(Iterable[Dict[str, Any]]): class ReseedingDataset(Iterable[Dict[str, Any]]):
"""Wrapper that makes any ProceduralDataset infinite by reseeding when reaching the end""" """Wrapper that makes any ProceduralDataset infinite by reseeding when reaching the end"""
def __init__(self, dataset: T, chunk_size: int = 500): def __init__(self, dataset: T, chunk_size: int = 500):
"""Initialize with dataset instance and chunk size """Initialize with dataset instance and chunk size
Args: Args:
dataset: The ProceduralDataset instance to wrap dataset: The ProceduralDataset instance to wrap
chunk_size: Size of each generated chunk before reseeding chunk_size: Size of each generated chunk before reseeding
@ -82,12 +82,12 @@ class ReseedingDataset(Iterable[Dict[str, Any]]):
self.dataset = dataset self.dataset = dataset
self.dataset_cls: Type[T] = type(dataset) self.dataset_cls: Type[T] = type(dataset)
self.chunk_size = chunk_size self.chunk_size = chunk_size
# Start with chunk 0 # Start with chunk 0
self._current_chunk = 0 self._current_chunk = 0
self._current_dataset = self._create_chunk(0) self._current_dataset = self._create_chunk(0)
self._current_idx = 0 self._current_idx = 0
def _create_chunk(self, chunk_num: int) -> T: def _create_chunk(self, chunk_num: int) -> T:
"""Create a new dataset chunk with unique seed""" """Create a new dataset chunk with unique seed"""
# Create new config with modified seed # Create new config with modified seed
@ -95,17 +95,17 @@ class ReseedingDataset(Iterable[Dict[str, Any]]):
if hasattr(new_config, "seed"): if hasattr(new_config, "seed"):
# Derive new seed from chunk number using dataset's seed, wrapping around at 2^32 # Derive new seed from chunk number using dataset's seed, wrapping around at 2^32
new_config.seed = (self.dataset.seed + chunk_num) % (2**32) new_config.seed = (self.dataset.seed + chunk_num) % (2**32)
# Create new dataset instance with chunk config # Create new dataset instance with chunk config
return self.dataset_cls(new_config) return self.dataset_cls(new_config)
def __iter__(self) -> Iterator[Dict[str, Any]]: def __iter__(self) -> Iterator[Dict[str, Any]]:
"""Make the dataset iterable""" """Make the dataset iterable"""
self._current_chunk = 0 self._current_chunk = 0
self._current_dataset = self._create_chunk(0) self._current_dataset = self._create_chunk(0)
self._current_idx = 0 self._current_idx = 0
return self return self
def __next__(self) -> Dict[str, Any]: def __next__(self) -> Dict[str, Any]:
"""Get next item, creating new chunk if needed""" """Get next item, creating new chunk if needed"""
if self._current_idx >= self.chunk_size: if self._current_idx >= self.chunk_size:
@ -113,11 +113,11 @@ class ReseedingDataset(Iterable[Dict[str, Any]]):
self._current_chunk += 1 self._current_chunk += 1
self._current_dataset = self._create_chunk(self._current_chunk) self._current_dataset = self._create_chunk(self._current_chunk)
self._current_idx = 0 self._current_idx = 0
item = self._current_dataset[self._current_idx] item = self._current_dataset[self._current_idx]
self._current_idx += 1 self._current_idx += 1
return item return item
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""Forward scoring to the wrapped dataset's implementation""" """Forward scoring to the wrapped dataset's implementation"""
return self.dataset.score_answer(answer, entry) return self.dataset.score_answer(answer, entry)

View file

@ -1,3 +1,9 @@
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
from .quantum_lock import QuantumLockConfig, QuantumLockDataset
__all__ = ["FamilyRelationshipsDataset", "FamilyRelationshipsConfig"] __all__ = [
"FamilyRelationshipsDataset",
"FamilyRelationshipsConfig",
"QuantumLockConfig",
"QuantumLockDataset",
]

28
scripts/generate_gallery.py Normal file → Executable file
View file

@ -2,19 +2,21 @@
"""Generate a markdown gallery of all available datasets with examples""" """Generate a markdown gallery of all available datasets with examples"""
import os import os
from pathlib import Path
import textwrap import textwrap
from pathlib import Path
import reasoning_gym.cognition.figlet_fonts
import reasoning_gym.cognition.rubiks_cube
from reasoning_gym.factory import DATASETS, create_dataset from reasoning_gym.factory import DATASETS, create_dataset
def generate_gallery() -> str: def generate_gallery() -> str:
"""Generate markdown content for the gallery""" """Generate markdown content for the gallery"""
# Start with header # Start with header
content = ["# Reasoning Gym Dataset Gallery\n"] content = ["# Reasoning Gym Dataset Gallery\n"]
content.append("This gallery shows examples from all available datasets using their default configurations.\n\n") content.append("This gallery shows examples from all available datasets using their default configurations.\n\n")
# Add index # Add index
content.append("## Available Datasets\n") content.append("## Available Datasets\n")
for name in sorted(DATASETS.keys()): for name in sorted(DATASETS.keys()):
@ -22,21 +24,21 @@ def generate_gallery() -> str:
anchor = name.replace("_", "-").lower() anchor = name.replace("_", "-").lower()
content.append(f"- [{name}](#{anchor})\n") content.append(f"- [{name}](#{anchor})\n")
content.append("\n") content.append("\n")
# Add examples for each dataset # Add examples for each dataset
content.append("## Dataset Examples\n") content.append("## Dataset Examples\n")
for name in sorted(DATASETS.keys()): for name in sorted(DATASETS.keys()):
dataset = create_dataset(name) dataset = create_dataset(name)
# Add dataset header with anchor # Add dataset header with anchor
anchor = name.replace("_", "-").lower() anchor = name.replace("_", "-").lower()
content.append(f"### {name} {{{anchor}}}\n") content.append(f"### {name} {{{anchor}}}\n")
# Get dataset class docstring if available # Get dataset class docstring if available
if dataset.__class__.__doc__: if dataset.__class__.__doc__:
doc = textwrap.dedent(dataset.__class__.__doc__.strip()) doc = textwrap.dedent(dataset.__class__.__doc__.strip())
content.append(f"{doc}\n\n") content.append(f"{doc}\n\n")
# Show configuration # Show configuration
content.append("Default configuration:\n") content.append("Default configuration:\n")
content.append("```python\n") content.append("```python\n")
@ -44,7 +46,7 @@ def generate_gallery() -> str:
if not key.startswith("_"): if not key.startswith("_"):
content.append(f"{key} = {value}\n") content.append(f"{key} = {value}\n")
content.append("```\n\n") content.append("```\n\n")
# Show examples # Show examples
content.append("Example tasks:\n") content.append("Example tasks:\n")
content.append("```\n") content.append("```\n")
@ -54,11 +56,11 @@ def generate_gallery() -> str:
content.append(f"Example {i+1}:\n") content.append(f"Example {i+1}:\n")
content.append(f"Question: {item['question']}\n") content.append(f"Question: {item['question']}\n")
content.append(f"Answer: {item['answer']}\n") content.append(f"Answer: {item['answer']}\n")
if item.get('metadata'): if item.get("metadata"):
content.append(f"Metadata: {item['metadata']}\n") content.append(f"Metadata: {item['metadata']}\n")
content.append("\n") content.append("\n")
content.append("```\n\n") content.append("```\n\n")
return "".join(content) return "".join(content)
@ -68,13 +70,13 @@ def main():
script_dir = Path(__file__).parent script_dir = Path(__file__).parent
if not script_dir.exists(): if not script_dir.exists():
script_dir.mkdir(parents=True) script_dir.mkdir(parents=True)
gallery_path = script_dir.parent / "GALLERY.md" gallery_path = script_dir.parent / "GALLERY.md"
gallery_content = generate_gallery() gallery_content = generate_gallery()
with open(gallery_path, "w") as f: with open(gallery_path, "w") as f:
f.write(gallery_content) f.write(gallery_content)
print(f"Generated gallery at {gallery_path}") print(f"Generated gallery at {gallery_path}")

View file

@ -1,46 +1,39 @@
import pytest import pytest
from reasoning_gym.dataset import ReseedingDataset
from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
from reasoning_gym.dataset import ReseedingDataset
def test_reseeding_dataset_iteration(): def test_reseeding_dataset_iteration():
"""Test that ReseedingDataset provides infinite iteration with consistent chunks""" """Test that ReseedingDataset provides infinite iteration with consistent chunks"""
# Create base dataset # Create base dataset
config = BasicArithmeticDatasetConfig( config = BasicArithmeticDatasetConfig(
min_terms=2, min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10
max_terms=3,
min_digits=1,
max_digits=2,
operators=["+"],
allow_parentheses=False,
seed=42,
size=10
) )
base_dataset = BasicArithmeticDataset(config) base_dataset = BasicArithmeticDataset(config)
# Create reseeding dataset with small chunk size # Create reseeding dataset with small chunk size
chunk_size = 3 chunk_size = 3
infinite_dataset = ReseedingDataset(base_dataset, chunk_size=chunk_size) infinite_dataset = ReseedingDataset(base_dataset, chunk_size=chunk_size)
# Get first 10 items # Get first 10 items
first_items = [] first_items = []
for _, item in zip(range(10), infinite_dataset): for _, item in zip(range(10), infinite_dataset):
first_items.append(item["question"]) first_items.append(item["question"])
# Create new iterator and verify first 10 items are identical # Create new iterator and verify first 10 items are identical
second_items = [] second_items = []
for _, item in zip(range(10), infinite_dataset): for _, item in zip(range(10), infinite_dataset):
second_items.append(item["question"]) second_items.append(item["question"])
assert first_items == second_items, "Items should be deterministic across iterations" assert first_items == second_items, "Items should be deterministic across iterations"
# Verify chunks are different # Verify chunks are different
chunk1 = first_items[:chunk_size] chunk1 = first_items[:chunk_size]
chunk2 = first_items[chunk_size:2*chunk_size] chunk2 = first_items[chunk_size : 2 * chunk_size]
assert chunk1 != chunk2, "Different chunks should generate different items" assert chunk1 != chunk2, "Different chunks should generate different items"
# Test score_answer forwarding # Test score_answer forwarding
test_item = next(iter(infinite_dataset)) test_item = next(iter(infinite_dataset))
assert infinite_dataset.score_answer("wrong", test_item) == 0.01 assert infinite_dataset.score_answer("wrong", test_item) == 0.01