mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
"""Version manager for tracking dataset versions."""
|
|
|
|
from typing import Any, Optional
|
|
|
|
from .dataset import ProceduralDataset
|
|
|
|
|
|
class DatasetVersionManager:
|
|
"""Manages versioned ProceduralDataset instances and their configurations."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the version manager."""
|
|
self.current_version = 0
|
|
# version_id -> (dataset_name, dataset_instance)
|
|
self.datasets: dict[int, tuple[str, ProceduralDataset]] = {}
|
|
|
|
def register_dataset(self, name: str, dataset: ProceduralDataset) -> int:
|
|
"""
|
|
Register a new dataset version.
|
|
|
|
Args:
|
|
name: Name/identifier of the dataset type
|
|
dataset: Instance of ProceduralDataset
|
|
|
|
Returns:
|
|
version_id: Unique identifier for this dataset version
|
|
"""
|
|
self.current_version += 1
|
|
self.datasets[self.current_version] = (name, dataset)
|
|
return self.current_version
|
|
|
|
def get_dataset(self, version_id: int) -> Optional[tuple[str, ProceduralDataset]]:
|
|
"""
|
|
Retrieve a dataset by its version ID.
|
|
|
|
Args:
|
|
version_id: The version identifier
|
|
|
|
Returns:
|
|
Tuple of (dataset_name, dataset_instance) if found, None otherwise
|
|
"""
|
|
return self.datasets.get(version_id)
|
|
|
|
def get_entry(self, version_id: int, index: int) -> dict[str, Any]:
|
|
"""
|
|
Get a specific entry from a versioned dataset.
|
|
|
|
Args:
|
|
version_id: The version identifier
|
|
index: Index of the entry to retrieve
|
|
|
|
Returns:
|
|
The dataset entry
|
|
|
|
Raises:
|
|
KeyError: If version_id is not found
|
|
"""
|
|
if version_id not in self.datasets:
|
|
raise KeyError(f"Dataset version {version_id} not found")
|
|
|
|
_, dataset = self.datasets[version_id]
|
|
return dataset[index]
|
|
|
|
def cleanup_old_versions(self, keep_latest: int = 10):
|
|
"""
|
|
Remove old dataset versions to free memory.
|
|
|
|
Args:
|
|
keep_latest: Number of most recent versions to keep
|
|
"""
|
|
if len(self.datasets) <= keep_latest:
|
|
return
|
|
|
|
versions_to_remove = sorted(self.datasets.keys())[:-keep_latest]
|
|
for version in versions_to_remove:
|
|
del self.datasets[version]
|