mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
75 lines
2.1 KiB
Python
75 lines
2.1 KiB
Python
from dataclasses import is_dataclass
|
|
from typing import Any, Dict, Type, TypeVar
|
|
|
|
from .dataset import ProceduralDataset
|
|
|
|
# Type variables for generic type hints
|
|
ConfigT = TypeVar('ConfigT')
|
|
DatasetT = TypeVar('DatasetT', bound=ProceduralDataset)
|
|
|
|
# Global registry of datasets
|
|
_DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {}
|
|
|
|
def register_dataset(
|
|
name: str,
|
|
dataset_cls: Type[DatasetT],
|
|
config_cls: Type[ConfigT]
|
|
) -> None:
|
|
"""
|
|
Register a dataset class with its configuration class.
|
|
|
|
Args:
|
|
name: Unique identifier for the dataset
|
|
dataset_cls: Class derived from ProceduralDataset
|
|
config_cls: Configuration dataclass for the dataset
|
|
|
|
Raises:
|
|
ValueError: If name is already registered or invalid types provided
|
|
"""
|
|
if name in _DATASETS:
|
|
raise ValueError(f"Dataset '{name}' is already registered")
|
|
|
|
if not issubclass(dataset_cls, ProceduralDataset):
|
|
raise ValueError(
|
|
f"Dataset class must inherit from ProceduralDataset, got {dataset_cls}"
|
|
)
|
|
|
|
if not is_dataclass(config_cls):
|
|
raise ValueError(
|
|
f"Config class must be a dataclass, got {config_cls}"
|
|
)
|
|
|
|
_DATASETS[name] = (dataset_cls, config_cls)
|
|
|
|
def create_dataset(
|
|
name: str,
|
|
config: Any,
|
|
seed: int = None,
|
|
size: int = 500
|
|
) -> ProceduralDataset:
|
|
"""
|
|
Create a dataset instance by name with the given configuration.
|
|
|
|
Args:
|
|
name: Registered dataset name
|
|
config: Configuration instance for the dataset
|
|
seed: Optional random seed
|
|
size: Size of the dataset (default: 500)
|
|
|
|
Returns:
|
|
Configured dataset instance
|
|
|
|
Raises:
|
|
ValueError: If dataset not found or config type mismatch
|
|
"""
|
|
if name not in _DATASETS:
|
|
raise ValueError(f"Dataset '{name}' not found")
|
|
|
|
dataset_cls, config_cls = _DATASETS[name]
|
|
|
|
if not isinstance(config, config_cls):
|
|
raise ValueError(
|
|
f"Config must be instance of {config_cls.__name__}, got {type(config).__name__}"
|
|
)
|
|
|
|
return dataset_cls(config=config, seed=seed, size=size)
|