mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
Based on the implementation and requirements, here's a concise commit message:
feat: Add CompositeDataset for weighted multi-dataset sampling
This commit is contained in:
parent
0561844779
commit
f07b6b7f61
3 changed files with 220 additions and 0 deletions
103
reasoning_gym/composite.py
Normal file
103
reasoning_gym/composite.py
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from random import Random
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from .dataset import ProceduralDataset
|
||||||
|
from .factory import create_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetSpec:
|
||||||
|
"""Specification for a single dataset within the composite"""
|
||||||
|
name: str
|
||||||
|
weight: float
|
||||||
|
config: dict
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate dataset specification"""
|
||||||
|
assert self.name, "Dataset name cannot be empty"
|
||||||
|
assert self.weight > 0, "Weight must be positive"
|
||||||
|
assert isinstance(self.config, dict), "Config must be a dictionary"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompositeConfig:
|
||||||
|
"""Configuration for CompositeDataset"""
|
||||||
|
size: int = 500
|
||||||
|
seed: Optional[int] = None
|
||||||
|
datasets: List[DatasetSpec] = None
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert self.size > 0, "size must be positive"
|
||||||
|
assert self.datasets, "Must specify at least one dataset"
|
||||||
|
assert len(self.datasets) > 0, "Must specify at least one dataset"
|
||||||
|
|
||||||
|
# Validate each dataset spec
|
||||||
|
for ds in self.datasets:
|
||||||
|
ds.validate()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, yaml_path: str) -> 'CompositeConfig':
|
||||||
|
"""Load configuration from YAML file"""
|
||||||
|
with open(yaml_path, 'r') as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Convert dataset specs to DatasetSpec objects
|
||||||
|
if 'datasets' in data:
|
||||||
|
data['datasets'] = [DatasetSpec(**ds) for ds in data['datasets']]
|
||||||
|
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeDataset(ProceduralDataset):
|
||||||
|
"""A dataset that combines multiple datasets with weighted sampling"""
|
||||||
|
|
||||||
|
def __init__(self, config: CompositeConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
# Initialize sub-datasets with incremented seeds
|
||||||
|
self.datasets = {}
|
||||||
|
self.weights = []
|
||||||
|
total_weight = 0.0
|
||||||
|
|
||||||
|
for i, ds_spec in enumerate(config.datasets):
|
||||||
|
# Create dataset with derived seed
|
||||||
|
ds_config = ds_spec.config.copy()
|
||||||
|
if 'seed' not in ds_config:
|
||||||
|
ds_config['seed'] = self.seed + i + 1
|
||||||
|
if 'size' not in ds_config:
|
||||||
|
ds_config['size'] = self.size
|
||||||
|
|
||||||
|
self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config)
|
||||||
|
total_weight += ds_spec.weight
|
||||||
|
self.weights.append(ds_spec.weight)
|
||||||
|
|
||||||
|
# Normalize weights
|
||||||
|
self.weights = [w / total_weight for w in self.weights]
|
||||||
|
self.dataset_names = [ds.name for ds in config.datasets]
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""Generate a single dataset item by sampling from sub-datasets"""
|
||||||
|
# Create deterministic RNG for this index
|
||||||
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
|
# Sample dataset according to weights
|
||||||
|
dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0]
|
||||||
|
dataset_name = self.dataset_names[dataset_idx]
|
||||||
|
dataset = self.datasets[dataset_name]
|
||||||
|
|
||||||
|
# Get item from selected dataset
|
||||||
|
item = dataset[idx]
|
||||||
|
|
||||||
|
# Add source dataset info to metadata
|
||||||
|
item['metadata']['source_dataset'] = dataset_name
|
||||||
|
item['metadata']['source_index'] = idx
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||||
|
"""Forward scoring to appropriate dataset"""
|
||||||
|
dataset_name = entry['metadata']['source_dataset']
|
||||||
|
return self.datasets[dataset_name].score_answer(answer, entry)
|
||||||
|
|
@ -2,6 +2,7 @@ from dataclasses import is_dataclass
|
||||||
from typing import Dict, Type, TypeVar
|
from typing import Dict, Type, TypeVar
|
||||||
|
|
||||||
from .dataset import ProceduralDataset
|
from .dataset import ProceduralDataset
|
||||||
|
from .composite import CompositeDataset, CompositeConfig
|
||||||
|
|
||||||
# Type variables for generic type hints
|
# Type variables for generic type hints
|
||||||
ConfigT = TypeVar("ConfigT")
|
ConfigT = TypeVar("ConfigT")
|
||||||
|
|
@ -10,6 +11,9 @@ DatasetT = TypeVar("DatasetT", bound=ProceduralDataset)
|
||||||
# Global registry of datasets
|
# Global registry of datasets
|
||||||
DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {}
|
DATASETS: Dict[str, tuple[Type[ProceduralDataset], Type]] = {}
|
||||||
|
|
||||||
|
# Register composite dataset
|
||||||
|
register_dataset("composite", CompositeDataset, CompositeConfig)
|
||||||
|
|
||||||
|
|
||||||
def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None:
|
def register_dataset(name: str, dataset_cls: Type[DatasetT], config_cls: Type[ConfigT]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
113
tests/test_composite.py
Normal file
113
tests/test_composite.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from reasoning_gym.composite import CompositeDataset, CompositeConfig, DatasetSpec
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_config(tmp_path):
|
||||||
|
"""Create a test YAML config file"""
|
||||||
|
config = {
|
||||||
|
"size": 100,
|
||||||
|
"seed": 42,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"name": "chain_sum",
|
||||||
|
"weight": 0.3,
|
||||||
|
"config": {
|
||||||
|
"min_terms": 2,
|
||||||
|
"max_terms": 4,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "leg_counting",
|
||||||
|
"weight": 0.7,
|
||||||
|
"config": {
|
||||||
|
"min_animals": 1,
|
||||||
|
"max_animals": 3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
config_path = os.path.join(tmp_path, "test_config.yaml")
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
yaml.dump(config, f)
|
||||||
|
|
||||||
|
return config_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_composite_config_validation():
|
||||||
|
"""Test configuration validation"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CompositeConfig(size=-1)
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = CompositeConfig(datasets=[])
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_composite_dataset_deterministic():
|
||||||
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
config = CompositeConfig(
|
||||||
|
size=10,
|
||||||
|
seed=42,
|
||||||
|
datasets=[
|
||||||
|
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset1 = CompositeDataset(config)
|
||||||
|
dataset2 = CompositeDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_composite_dataset_metadata():
|
||||||
|
"""Test that metadata includes source dataset information"""
|
||||||
|
config = CompositeConfig(
|
||||||
|
size=10,
|
||||||
|
seed=42,
|
||||||
|
datasets=[
|
||||||
|
DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = CompositeDataset(config)
|
||||||
|
item = dataset[0]
|
||||||
|
|
||||||
|
assert "source_dataset" in item["metadata"]
|
||||||
|
assert "source_index" in item["metadata"]
|
||||||
|
assert item["metadata"]["source_dataset"] == "chain_sum"
|
||||||
|
assert isinstance(item["metadata"]["source_index"], int)
|
||||||
|
|
||||||
|
|
||||||
|
def test_composite_dataset_weights():
|
||||||
|
"""Test that dataset weights are properly normalized"""
|
||||||
|
config = CompositeConfig(
|
||||||
|
size=1000,
|
||||||
|
seed=42,
|
||||||
|
datasets=[
|
||||||
|
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
|
||||||
|
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = CompositeDataset(config)
|
||||||
|
assert abs(dataset.weights[0] - 0.4) < 1e-6
|
||||||
|
assert abs(dataset.weights[1] - 0.6) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_yaml_loading(tmp_path):
|
||||||
|
"""Test loading configuration from YAML"""
|
||||||
|
config_path = create_test_config(tmp_path)
|
||||||
|
config = CompositeConfig.from_yaml(config_path)
|
||||||
|
|
||||||
|
assert config.size == 100
|
||||||
|
assert config.seed == 42
|
||||||
|
assert len(config.datasets) == 2
|
||||||
|
assert config.datasets[0].name == "chain_sum"
|
||||||
|
assert config.datasets[1].name == "leg_counting"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue