diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index 6837e40f..bbfe3895 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -9,8 +9,9 @@ from typing import Any, Dict, Iterator, Optional class ProceduralDataset(ABC, Sized, Iterable[Dict[str, Any]]): """Abstract base class for procedural dataset generators""" - def __init__(self, seed: Optional[int] = None, size: int = 500): - """Initialize the dataset with optional seed and size""" + def __init__(self, config: Any, seed: Optional[int] = None, size: int = 500): + """Initialize the dataset with config, optional seed and size""" + self.config = config self.size = size self.seed = seed if seed is not None else Random().randint(0, 2**32) diff --git a/reasoning_gym/factory.py b/reasoning_gym/factory.py new file mode 100644 index 00000000..b482d1b1 --- /dev/null +++ b/reasoning_gym/factory.py @@ -0,0 +1,75 @@ +from dataclasses import is_dataclass +from typing import Any, Dict, Type, TypeVar + +from .dataset import ProceduralDataset + +# Type variables for generic type hints +ConfigT = TypeVar('ConfigT') +DatasetT = TypeVar('DatasetT', bound=ProceduralDataset) + +# Global registry of datasets +_DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {} + +def register_dataset( + name: str, + dataset_cls: Type[DatasetT], + config_cls: Type[ConfigT] +) -> None: + """ + Register a dataset class with its configuration class. + + Args: + name: Unique identifier for the dataset + dataset_cls: Class derived from ProceduralDataset + config_cls: Configuration dataclass for the dataset + + Raises: + ValueError: If name is already registered or invalid types provided + """ + if name in _DATASETS: + raise ValueError(f"Dataset '{name}' is already registered") + + if not issubclass(dataset_cls, ProceduralDataset): + raise ValueError( + f"Dataset class must inherit from ProceduralDataset, got {dataset_cls}" + ) + + if not is_dataclass(config_cls): + raise ValueError( + f"Config class must be a dataclass, got {config_cls}" + ) + + _DATASETS[name] = (dataset_cls, config_cls) + +def create_dataset( + name: str, + config: Any, + seed: int = None, + size: int = 500 +) -> ProceduralDataset: + """ + Create a dataset instance by name with the given configuration. + + Args: + name: Registered dataset name + config: Configuration instance for the dataset + seed: Optional random seed + size: Size of the dataset (default: 500) + + Returns: + Configured dataset instance + + Raises: + ValueError: If dataset not found or config type mismatch + """ + if name not in _DATASETS: + raise ValueError(f"Dataset '{name}' not found") + + dataset_cls, config_cls = _DATASETS[name] + + if not isinstance(config, config_cls): + raise ValueError( + f"Config must be instance of {config_cls.__name__}, got {type(config).__name__}" + ) + + return dataset_cls(config=config, seed=seed, size=size)