Based on the implementation and requirements, here's a concise commit message:

feat: Add CompositeDataset for weighted multi-dataset sampling
This commit is contained in:
Andreas Koepf (aider) 2025-02-04 19:06:13 +01:00
parent 0561844779
commit f07b6b7f61
3 changed files with 220 additions and 0 deletions

103
reasoning_gym/composite.py Normal file
View file

@ -0,0 +1,103 @@
from dataclasses import dataclass
from random import Random
from typing import List, Dict, Any, Optional
import yaml
from .dataset import ProceduralDataset
from .factory import create_dataset
@dataclass
class DatasetSpec:
"""Specification for a single dataset within the composite"""
name: str
weight: float
config: dict
def validate(self):
"""Validate dataset specification"""
assert self.name, "Dataset name cannot be empty"
assert self.weight > 0, "Weight must be positive"
assert isinstance(self.config, dict), "Config must be a dictionary"
@dataclass
class CompositeConfig:
"""Configuration for CompositeDataset"""
size: int = 500
seed: Optional[int] = None
datasets: List[DatasetSpec] = None
def validate(self):
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert self.datasets, "Must specify at least one dataset"
assert len(self.datasets) > 0, "Must specify at least one dataset"
# Validate each dataset spec
for ds in self.datasets:
ds.validate()
@classmethod
def from_yaml(cls, yaml_path: str) -> 'CompositeConfig':
"""Load configuration from YAML file"""
with open(yaml_path, 'r') as f:
data = yaml.safe_load(f)
# Convert dataset specs to DatasetSpec objects
if 'datasets' in data:
data['datasets'] = [DatasetSpec(**ds) for ds in data['datasets']]
return cls(**data)
class CompositeDataset(ProceduralDataset):
"""A dataset that combines multiple datasets with weighted sampling"""
def __init__(self, config: CompositeConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
# Initialize sub-datasets with incremented seeds
self.datasets = {}
self.weights = []
total_weight = 0.0
for i, ds_spec in enumerate(config.datasets):
# Create dataset with derived seed
ds_config = ds_spec.config.copy()
if 'seed' not in ds_config:
ds_config['seed'] = self.seed + i + 1
if 'size' not in ds_config:
ds_config['size'] = self.size
self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config)
total_weight += ds_spec.weight
self.weights.append(ds_spec.weight)
# Normalize weights
self.weights = [w / total_weight for w in self.weights]
self.dataset_names = [ds.name for ds in config.datasets]
def __getitem__(self, idx: int) -> dict:
"""Generate a single dataset item by sampling from sub-datasets"""
# Create deterministic RNG for this index
rng = Random(self.seed + idx)
# Sample dataset according to weights
dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0]
dataset_name = self.dataset_names[dataset_idx]
dataset = self.datasets[dataset_name]
# Get item from selected dataset
item = dataset[idx]
# Add source dataset info to metadata
item['metadata']['source_dataset'] = dataset_name
item['metadata']['source_index'] = idx
return item
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
"""Forward scoring to appropriate dataset"""
dataset_name = entry['metadata']['source_dataset']
return self.datasets[dataset_name].score_answer(answer, entry)

View file

@ -2,6 +2,7 @@ from dataclasses import is_dataclass
from typing import Dict, Type, TypeVar from typing import Dict, Type, TypeVar
from .dataset import ProceduralDataset from .dataset import ProceduralDataset
from .composite import CompositeDataset, CompositeConfig
# Type variables for generic type hints # Type variables for generic type hints
ConfigT = TypeVar("ConfigT") ConfigT = TypeVar("ConfigT")
@ -10,6 +11,9 @@ DatasetT = TypeVar("DatasetT", bound=ProceduralDataset)
# Global registry of datasets # Global registry of datasets
DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {} DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {}
# Register composite dataset
register_dataset("composite", CompositeDataset, CompositeConfig)
def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None: def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None:
""" """

113
tests/test_composite.py Normal file
View file

@ -0,0 +1,113 @@
import os
import pytest
import tempfile
import yaml
from reasoning_gym.composite import CompositeDataset, CompositeConfig, DatasetSpec
def create_test_config(tmp_path):
"""Create a test YAML config file"""
config = {
"size": 100,
"seed": 42,
"datasets": [
{
"name": "chain_sum",
"weight": 0.3,
"config": {
"min_terms": 2,
"max_terms": 4,
}
},
{
"name": "leg_counting",
"weight": 0.7,
"config": {
"min_animals": 1,
"max_animals": 3,
}
}
]
}
config_path = os.path.join(tmp_path, "test_config.yaml")
with open(config_path, 'w') as f:
yaml.dump(config, f)
return config_path
def test_composite_config_validation():
"""Test configuration validation"""
with pytest.raises(AssertionError):
config = CompositeConfig(size=-1)
config.validate()
with pytest.raises(AssertionError):
config = CompositeConfig(datasets=[])
config.validate()
def test_composite_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = CompositeConfig(
size=10,
seed=42,
datasets=[
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
]
)
dataset1 = CompositeDataset(config)
dataset2 = CompositeDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_composite_dataset_metadata():
"""Test that metadata includes source dataset information"""
config = CompositeConfig(
size=10,
seed=42,
datasets=[
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
]
)
dataset = CompositeDataset(config)
item = dataset[0]
assert "source_dataset" in item["metadata"]
assert "source_index" in item["metadata"]
assert item["metadata"]["source_dataset"] == "chain_sum"
assert isinstance(item["metadata"]["source_index"], int)
def test_composite_dataset_weights():
"""Test that dataset weights are properly normalized"""
config = CompositeConfig(
size=1000,
seed=42,
datasets=[
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
]
)
dataset = CompositeDataset(config)
assert abs(dataset.weights[0] - 0.4) < 1e-6
assert abs(dataset.weights[1] - 0.6) < 1e-6
def test_yaml_loading(tmp_path):
"""Test loading configuration from YAML"""
config_path = create_test_config(tmp_path)
config = CompositeConfig.from_yaml(config_path)
assert config.size == 100
assert config.seed == 42
assert len(config.datasets) == 2
assert config.datasets[0].name == "chain_sum"
assert config.datasets[1].name == "leg_counting"