mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* feat: Add Coach and ScoreBoard classes for performance tracking and difficulty adjustment * feat: Add GroupedScores class to wrap aggregated scores * refactor: Create ScoreStats class with tuple-based score statistics * feat: Add unit test for Coach with CompositeDataset and multiple datasets * fix: Add difficulty metadata to leg counting dataset * feat: Add clear() method to ScoreBoard to reset all stored data * feat: Add __len__ method to ScoreBoard to return number of scores * feat: Add update_dataset_config method to CompositeDataset * cleanup __init__ & imports
138 lines
4.6 KiB
Python
138 lines
4.6 KiB
Python
from dataclasses import dataclass
|
|
from random import Random
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import yaml
|
|
|
|
from .dataset import ProceduralDataset
|
|
from .factory import create_dataset, register_dataset
|
|
|
|
|
|
@dataclass
|
|
class DatasetSpec:
|
|
"""Specification for a single dataset within the composite"""
|
|
|
|
name: str
|
|
weight: float
|
|
config: dict
|
|
|
|
def validate(self):
|
|
"""Validate dataset specification"""
|
|
assert self.name, "Dataset name cannot be empty"
|
|
assert self.weight > 0, "Weight must be positive"
|
|
assert isinstance(self.config, dict), "Config must be a dictionary"
|
|
|
|
|
|
@dataclass
|
|
class CompositeConfig:
|
|
"""Configuration for CompositeDataset"""
|
|
|
|
size: int = 500
|
|
seed: Optional[int] = None
|
|
datasets: List[DatasetSpec] = None
|
|
|
|
def validate(self):
|
|
"""Validate configuration parameters"""
|
|
assert self.size > 0, "size must be positive"
|
|
assert self.datasets, "Must specify at least one dataset"
|
|
assert len(self.datasets) > 0, "Must specify at least one dataset"
|
|
|
|
# Validate each dataset spec
|
|
for ds in self.datasets:
|
|
ds.validate()
|
|
|
|
@classmethod
|
|
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
|
|
"""Load configuration from YAML file"""
|
|
with open(yaml_path, "r") as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
# Convert dataset specs to DatasetSpec objects
|
|
if "datasets" in data:
|
|
data["datasets"] = [DatasetSpec(**ds) for ds in data["datasets"]]
|
|
|
|
return cls(**data)
|
|
|
|
|
|
class CompositeDataset(ProceduralDataset):
|
|
"""A dataset that combines multiple datasets with weighted sampling"""
|
|
|
|
def __init__(self, config: CompositeConfig):
|
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
|
|
|
# Initialize sub-datasets with incremented seeds
|
|
self.datasets = {}
|
|
self.weights = []
|
|
total_weight = 0.0
|
|
|
|
for i, ds_spec in enumerate(config.datasets):
|
|
# Create dataset with derived seed
|
|
ds_config = ds_spec.config.copy()
|
|
if "seed" not in ds_config:
|
|
ds_config["seed"] = self.seed + i + 1
|
|
if "size" not in ds_config:
|
|
ds_config["size"] = self.size
|
|
|
|
self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config)
|
|
total_weight += ds_spec.weight
|
|
self.weights.append(ds_spec.weight)
|
|
|
|
# Normalize weights
|
|
self.weights = [w / total_weight for w in self.weights]
|
|
self.dataset_names = [ds.name for ds in config.datasets]
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
"""Generate a single dataset item by sampling from sub-datasets"""
|
|
# Create deterministic RNG for this index
|
|
rng = Random(self.seed + idx)
|
|
|
|
# Sample dataset according to weights
|
|
dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0]
|
|
dataset_name = self.dataset_names[dataset_idx]
|
|
dataset = self.datasets[dataset_name]
|
|
|
|
# Get item from selected dataset
|
|
item = dataset[idx]
|
|
|
|
# Add source dataset info to metadata
|
|
item["metadata"]["source_dataset"] = dataset_name
|
|
item["metadata"]["source_index"] = idx
|
|
|
|
return item
|
|
|
|
def update_dataset_config(self, dataset_name: str, config_updates: Dict[str, Any]) -> None:
|
|
"""Update configuration of a specific dataset
|
|
|
|
Args:
|
|
dataset_name: Name of the dataset to update
|
|
config_updates: Dictionary of configuration parameters to update
|
|
|
|
Raises:
|
|
KeyError: If dataset_name is not found
|
|
AttributeError: If config parameter doesn't exist
|
|
"""
|
|
if dataset_name not in self.datasets:
|
|
raise KeyError(f"Dataset '{dataset_name}' not found")
|
|
|
|
dataset = self.datasets[dataset_name]
|
|
|
|
# Create new config with updates
|
|
new_config = dataset.config.__class__(**vars(dataset.config))
|
|
for key, value in config_updates.items():
|
|
setattr(new_config, key, value)
|
|
|
|
# Validate new config
|
|
new_config.validate()
|
|
|
|
# Create new dataset instance with updated config
|
|
dataset_cls = dataset.__class__
|
|
self.datasets[dataset_name] = dataset_cls(new_config)
|
|
|
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
|
"""Forward scoring to appropriate dataset"""
|
|
dataset_name = entry["metadata"]["source_dataset"]
|
|
return self.datasets[dataset_name].score_answer(answer, entry)
|
|
|
|
|
|
# Register the dataset
|
|
register_dataset("composite", CompositeDataset, CompositeConfig)
|