diff --git a/reasoning_gym/composite.py b/reasoning_gym/composite.py new file mode 100644 index 00000000..f9ce7b7f --- /dev/null +++ b/reasoning_gym/composite.py @@ -0,0 +1,103 @@ +from dataclasses import dataclass +from random import Random +from typing import List, Dict, Any, Optional +import yaml + +from .dataset import ProceduralDataset +from .factory import create_dataset + + +@dataclass +class DatasetSpec: + """Specification for a single dataset within the composite""" + name: str + weight: float + config: dict + + def validate(self): + """Validate dataset specification""" + assert self.name, "Dataset name cannot be empty" + assert self.weight > 0, "Weight must be positive" + assert isinstance(self.config, dict), "Config must be a dictionary" + + +@dataclass +class CompositeConfig: + """Configuration for CompositeDataset""" + size: int = 500 + seed: Optional[int] = None + datasets: List[DatasetSpec] = None + + def validate(self): + """Validate configuration parameters""" + assert self.size > 0, "size must be positive" + assert self.datasets, "Must specify at least one dataset" + assert len(self.datasets) > 0, "Must specify at least one dataset" + + # Validate each dataset spec + for ds in self.datasets: + ds.validate() + + @classmethod + def from_yaml(cls, yaml_path: str) -> 'CompositeConfig': + """Load configuration from YAML file""" + with open(yaml_path, 'r') as f: + data = yaml.safe_load(f) + + # Convert dataset specs to DatasetSpec objects + if 'datasets' in data: + data['datasets'] = [DatasetSpec(**ds) for ds in data['datasets']] + + return cls(**data) + + +class CompositeDataset(ProceduralDataset): + """A dataset that combines multiple datasets with weighted sampling""" + + def __init__(self, config: CompositeConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + # Initialize sub-datasets with incremented seeds + self.datasets = {} + self.weights = [] + total_weight = 0.0 + + for i, ds_spec in enumerate(config.datasets): + # Create dataset with derived seed + ds_config = ds_spec.config.copy() + if 'seed' not in ds_config: + ds_config['seed'] = self.seed + i + 1 + if 'size' not in ds_config: + ds_config['size'] = self.size + + self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config) + total_weight += ds_spec.weight + self.weights.append(ds_spec.weight) + + # Normalize weights + self.weights = [w / total_weight for w in self.weights] + self.dataset_names = [ds.name for ds in config.datasets] + + def __getitem__(self, idx: int) -> dict: + """Generate a single dataset item by sampling from sub-datasets""" + # Create deterministic RNG for this index + rng = Random(self.seed + idx) + + # Sample dataset according to weights + dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0] + dataset_name = self.dataset_names[dataset_idx] + dataset = self.datasets[dataset_name] + + # Get item from selected dataset + item = dataset[idx] + + # Add source dataset info to metadata + item['metadata']['source_dataset'] = dataset_name + item['metadata']['source_index'] = idx + + return item + + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + """Forward scoring to appropriate dataset""" + dataset_name = entry['metadata']['source_dataset'] + return self.datasets[dataset_name].score_answer(answer, entry) diff --git a/reasoning_gym/factory.py b/reasoning_gym/factory.py index d00def13..e3ff6df2 100644 --- a/reasoning_gym/factory.py +++ b/reasoning_gym/factory.py @@ -2,6 +2,7 @@ from dataclasses import is_dataclass from typing import Dict, Type, TypeVar from .dataset import ProceduralDataset +from .composite import CompositeDataset, CompositeConfig # Type variables for generic type hints ConfigT = TypeVar("ConfigT") @@ -10,6 +11,9 @@ DatasetT = TypeVar("DatasetT", bound=ProceduralDataset) # Global registry of datasets DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {} +# Register composite dataset +register_dataset("composite", CompositeDataset, CompositeConfig) + def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None: """ diff --git a/tests/test_composite.py b/tests/test_composite.py new file mode 100644 index 00000000..e96b056b --- /dev/null +++ b/tests/test_composite.py @@ -0,0 +1,113 @@ +import os +import pytest +import tempfile +import yaml + +from reasoning_gym.composite import CompositeDataset, CompositeConfig, DatasetSpec + + +def create_test_config(tmp_path): + """Create a test YAML config file""" + config = { + "size": 100, + "seed": 42, + "datasets": [ + { + "name": "chain_sum", + "weight": 0.3, + "config": { + "min_terms": 2, + "max_terms": 4, + } + }, + { + "name": "leg_counting", + "weight": 0.7, + "config": { + "min_animals": 1, + "max_animals": 3, + } + } + ] + } + + config_path = os.path.join(tmp_path, "test_config.yaml") + with open(config_path, 'w') as f: + yaml.dump(config, f) + + return config_path + + +def test_composite_config_validation(): + """Test configuration validation""" + with pytest.raises(AssertionError): + config = CompositeConfig(size=-1) + config.validate() + + with pytest.raises(AssertionError): + config = CompositeConfig(datasets=[]) + config.validate() + + +def test_composite_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = CompositeConfig( + size=10, + seed=42, + datasets=[ + DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4}) + ] + ) + + dataset1 = CompositeDataset(config) + dataset2 = CompositeDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_composite_dataset_metadata(): + """Test that metadata includes source dataset information""" + config = CompositeConfig( + size=10, + seed=42, + datasets=[ + DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4}) + ] + ) + + dataset = CompositeDataset(config) + item = dataset[0] + + assert "source_dataset" in item["metadata"] + assert "source_index" in item["metadata"] + assert item["metadata"]["source_dataset"] == "chain_sum" + assert isinstance(item["metadata"]["source_index"], int) + + +def test_composite_dataset_weights(): + """Test that dataset weights are properly normalized""" + config = CompositeConfig( + size=1000, + seed=42, + datasets=[ + DatasetSpec("chain_sum", 2.0, {"min_terms": 2}), + DatasetSpec("chain_sum", 3.0, {"min_terms": 3}), + ] + ) + + dataset = CompositeDataset(config) + assert abs(dataset.weights[0] - 0.4) < 1e-6 + assert abs(dataset.weights[1] - 0.6) < 1e-6 + + +def test_yaml_loading(tmp_path): + """Test loading configuration from YAML""" + config_path = create_test_config(tmp_path) + config = CompositeConfig.from_yaml(config_path) + + assert config.size == 100 + assert config.seed == 42 + assert len(config.datasets) == 2 + assert config.datasets[0].name == "chain_sum" + assert config.datasets[1].name == "leg_counting"