reasoning-gym/reasoning_gym/composite.py
2025-02-04 19:17:34 +01:00

110 lines
3.6 KiB
Python

from dataclasses import dataclass
from random import Random
from typing import Any, Dict, List, Optional
import yaml
from .dataset import ProceduralDataset
from .factory import create_dataset, register_dataset
@dataclass
class DatasetSpec:
"""Specification for a single dataset within the composite"""
name: str
weight: float
config: dict
def validate(self):
"""Validate dataset specification"""
assert self.name, "Dataset name cannot be empty"
assert self.weight > 0, "Weight must be positive"
assert isinstance(self.config, dict), "Config must be a dictionary"
@dataclass
class CompositeConfig:
"""Configuration for CompositeDataset"""
size: int = 500
seed: Optional[int] = None
datasets: List[DatasetSpec] = None
def validate(self):
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert self.datasets, "Must specify at least one dataset"
assert len(self.datasets) > 0, "Must specify at least one dataset"
# Validate each dataset spec
for ds in self.datasets:
ds.validate()
@classmethod
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
"""Load configuration from YAML file"""
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)
# Convert dataset specs to DatasetSpec objects
if "datasets" in data:
data["datasets"] = [DatasetSpec(**ds) for ds in data["datasets"]]
return cls(**data)
class CompositeDataset(ProceduralDataset):
"""A dataset that combines multiple datasets with weighted sampling"""
def __init__(self, config: CompositeConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
# Initialize sub-datasets with incremented seeds
self.datasets = {}
self.weights = []
total_weight = 0.0
for i, ds_spec in enumerate(config.datasets):
# Create dataset with derived seed
ds_config = ds_spec.config.copy()
if "seed" not in ds_config:
ds_config["seed"] = self.seed + i + 1
if "size" not in ds_config:
ds_config["size"] = self.size
self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config)
total_weight += ds_spec.weight
self.weights.append(ds_spec.weight)
# Normalize weights
self.weights = [w / total_weight for w in self.weights]
self.dataset_names = [ds.name for ds in config.datasets]
def __getitem__(self, idx: int) -> dict:
"""Generate a single dataset item by sampling from sub-datasets"""
# Create deterministic RNG for this index
rng = Random(self.seed + idx)
# Sample dataset according to weights
dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0]
dataset_name = self.dataset_names[dataset_idx]
dataset = self.datasets[dataset_name]
# Get item from selected dataset
item = dataset[idx]
# Add source dataset info to metadata
item["metadata"]["source_dataset"] = dataset_name
item["metadata"]["source_index"] = idx
return item
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
"""Forward scoring to appropriate dataset"""
dataset_name = entry["metadata"]["source_dataset"]
return self.datasets[dataset_name].score_answer(answer, entry)
# Register the dataset
register_dataset("composite", CompositeDataset, CompositeConfig)