mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
112 lines
4.5 KiB
Python
112 lines
4.5 KiB
Python
"""Experiment class combining dataset, scoreboard and curriculum."""
|
|
|
|
from typing import Any, Literal, Optional
|
|
|
|
from reasoning_gym.coaching.base_curriculum import CurriculumContext
|
|
|
|
from ..composite import CompositeConfig, CompositeDataset, DatasetSpec
|
|
from ..factory import create_curriculum
|
|
from ..version_manager import DatasetVersionManager
|
|
from .curriculum_config import CurriculumExperimentConfig
|
|
from .score_board import ScoreBoard
|
|
|
|
|
|
class Experiment:
|
|
def __init__(self, name: str, composite: CompositeDataset):
|
|
self.name = name
|
|
self.composite = composite
|
|
self.score_board = ScoreBoard()
|
|
|
|
def get_dataset_entry(self, index: int) -> dict:
|
|
return self.composite[index]
|
|
|
|
def score_answer_with_id(
|
|
self, answer: Optional[str], entry_id: str, conversation: Optional[list[dict]] = None
|
|
) -> float:
|
|
dataset, index, dataset_name = self.composite.resolve_entry_id(entry_id)
|
|
entry = dataset[index]
|
|
score = dataset.score_answer(answer, entry)
|
|
metadata = entry["metadata"]
|
|
score_board_metadata = {"difficulty": metadata["difficulty"], "source_dataset": metadata["source_dataset"]}
|
|
self.score_board.add_score(dataset_name, score, score_board_metadata, conversation)
|
|
return score
|
|
|
|
@classmethod
|
|
def create(cls, name: str, config: CompositeConfig) -> "Experiment":
|
|
"""Create a new experiment from a configuration."""
|
|
version_manager = DatasetVersionManager()
|
|
dataset = CompositeDataset(config, version_manager=version_manager)
|
|
return cls(name=name, dataset=dataset)
|
|
|
|
|
|
class CurriculumExperiment(Experiment):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
config: CurriculumExperimentConfig,
|
|
size: int,
|
|
context: Optional[CurriculumContext] = None,
|
|
seed: Optional[int] = None,
|
|
):
|
|
"""Initialize curriculum experiment with configured datasets and their curricula.
|
|
|
|
Args:
|
|
name: Name of the experiment
|
|
config: Configuration specifying datasets and their attribute levels
|
|
size: Number of examples to generate
|
|
seed: Random seed for reproducibility
|
|
"""
|
|
# Initialize curricula and build dataset specs
|
|
self.curricula = {}
|
|
dataset_specs = []
|
|
|
|
# Process each dataset in the curriculum config
|
|
for dataset_name, attr_config in config.curricula.items():
|
|
# Create and store curriculum
|
|
curriculum = create_curriculum(dataset_name)
|
|
self.curricula[dataset_name] = curriculum
|
|
|
|
# Handle special "*" attribute that sets all levels
|
|
if "*" in attr_config.attribute_levels:
|
|
level = attr_config.attribute_levels["*"]
|
|
for attr_name in curriculum.attributes:
|
|
curriculum.set_attr_level(attr_name, level)
|
|
|
|
# Set individual attribute levels (overriding "*" if specified)
|
|
for attr_name, level in attr_config.attribute_levels.items():
|
|
if attr_name != "*":
|
|
curriculum.set_attr_level(attr_name, level)
|
|
|
|
# Generate dataset config from curriculum
|
|
dataset_config = curriculum.generate_configuration(context=context)
|
|
|
|
# Create dataset spec
|
|
spec = DatasetSpec(name=dataset_name, weight=attr_config.weight, config=dataset_config.__dict__)
|
|
dataset_specs.append(spec)
|
|
|
|
# Create composite config with all datasets
|
|
composite_config = CompositeConfig(size=size, seed=seed, datasets=dataset_specs)
|
|
|
|
# Create composite dataset
|
|
version_manager = DatasetVersionManager()
|
|
composite = CompositeDataset(config=composite_config, version_manager=version_manager)
|
|
|
|
# Initialize base experiment
|
|
super().__init__(name=name, composite=composite)
|
|
|
|
# Store curriculum config
|
|
self.curriculum_config = config
|
|
self.context = context
|
|
|
|
def update_difficulty(self, dataset_name: str, method: Literal["increment", "decrement"]):
|
|
"""Update difficulty levels based on performance metrics"""
|
|
if method not in ["increment", "decrement"]:
|
|
raise ValueError(f"Invalid method: {method}")
|
|
|
|
if method == "increment":
|
|
self.curricula[dataset_name].increment_global_level()
|
|
elif method == "decrement":
|
|
self.curricula[dataset_name].decrement_global_level()
|
|
|
|
config = self.curricula[dataset_name].get_global_level()
|
|
self.composite.update_dataset_config(dataset_name, config)
|