reasoning-gym/reasoning_gym/composite.py
Andreas Köpf e2702092f4
reasoning-gym-server & cli tool (#154)
* feat: Add initial server structure with configuration, registry, and middleware

* feat: Add chain_sum dataset to experiment registry test

* fix: Update test_registry to use DatasetSpec for composite config validation

* refactor: Update Pydantic config to use json_schema_extra and ConfigDict

* feat: Add Pydantic models for API request/response data

* feat: Implement basic experiment management endpoints with tests

* feat: Implement composite configuration endpoints for experiments

* fix: Add missing DatasetConfigUpdate import in server.py

* refactor: Update dataset config update method to properly merge config updates

* fix: Correctly retrieve current dataset config in composite endpoint

* feat: Add basic CLI structure with experiments and config commands

* feat: Add initial CLI tool with basic experiment management commands

* refactor: Reorganize CLI package structure and fix import paths

* refactor: Implement initial CLI commands for experiment management

* feat: Implement HTTP client for Reasoning Gym server in RGC CLI tool

* fix: Move print statements inside try block to resolve SyntaxError

* fix: Resolve SyntaxError in edit_config function by adding missing except block

* feat: Add default app instance in server module for easier uvicorn startup

* docs: Add README.md with server and RGC tool documentation

* remove unused files

* refactor: Remove unsupported type annotation in registry.py

* refactor: Move ExperimentRegistry to coaching module and add Experiment class

* fix: Add missing CompositeDataset import in test_registry.py

* refactor: Implement lazy ASGI app creation for server initialization

* feat: Add health check command to RGC CLI for server connection

* feat: Add version tracking support to CompositeDataset

* feat: Add DatasetVersionManager for tracking dataset versions

* feat: Add entry_id metadata and score_answer_with_id method to CompositeDataset

* feat: Add entry_id metadata combining version and index

* fix: Resolve undefined variable by storing version_id before use

* test: Add comprehensive unit tests for score_answer_with_id() function

* test: Add comprehensive version tracking test for dataset config updates

* feat: Validate dataset weights are positive in CompositeDataset initialization

* feat: Add weight update and normalization methods to CompositeDataset

* refactor: Centralize weight normalization in CompositeDataset and allow zero-weight datasets

* feat: Add negative weight validation to CompositeDataset constructor

* feat: Add duplicate dataset name check in CompositeDataset and update test

* refactor: Move duplicate dataset name check inside dataset iteration loop

* refactor: Update CompositeDataset weight management to use config as source of truth

* refactor: Move duplicate dataset name check to CompositeConfig.validate()

* test: Update composite dataset weight test assertions and validation

* feat: Add methods to add and remove datasets in CompositeDataset

* refactor: Remove weight normalization and use unnormalized weights directly

* refactor: Remove redundant total weight check in update_dataset_weights

* feat: Add batch generation and scoring endpoints to server

* fix: Import BatchEntry in server.py to resolve undefined name error

* refactor: Update ReasoningGymDataset to use server for batch generation and scoring

* fix: Add missing List and Dict type imports

* feat: Add get_batch() and score_outputs() methods to RGClient

* test: Add unit tests for generate_batch and score_outputs endpoints

* refactor: Add DatasetVersionManager to Experiment class and CompositeDataset constructor

* feat: Add validation for base_index and batch_size in generate_batch endpoint

* refactor: Remove unused BatchRequest type from imports

* refactor: Convert models to use Pydantic exclusively

* test: Update scoring endpoint tests to use correct request model format

* refactor: Rename ScoreItem to AnswerItem and update related code

* feat: Update scoring endpoint to return ordered ScoringResponse with scores and entry_ids

* fix: Add missing ScoringResponse import in server.py

* move verl ppo sample with server into own file

* refactor: Use Pydantic models for get_batch() and score_outputs() in RGClient

* refactor: Update client methods to use Pydantic models for type safety

* refactor: Use Pydantic models for experiment and dataset config operations

* refactor: Clean up duplicate methods and improve error handling in main.py

* first bits of rg server use for verl

* refactor: Optimize scoring with single HTTP request in _score_output

* fix: Correct experiment creation with ExperimentCreate object

* grpo tests with server
2025-02-19 22:41:33 +01:00

286 lines
10 KiB
Python

from dataclasses import dataclass, replace
from random import Random
from typing import Any, Dict, List, Optional
import yaml
from .dataset import ProceduralDataset
from .factory import create_dataset, register_dataset
from .version_manager import DatasetVersionManager
@dataclass
class DatasetSpec:
"""Specification for a single dataset within the composite"""
name: str
weight: float
config: dict
def validate(self):
"""Validate dataset specification"""
assert self.name, "Dataset name cannot be empty"
assert self.weight > 0, "Weight must be positive"
assert isinstance(self.config, dict), "Config must be a dictionary"
@dataclass
class CompositeConfig:
"""Configuration for CompositeDataset"""
size: int = 500
seed: Optional[int] = None
datasets: List[DatasetSpec] = None
def validate(self):
"""Validate configuration parameters"""
assert self.size > 0, "size must be positive"
assert self.datasets, "Must specify at least one dataset"
assert len(self.datasets) > 0, "Must specify at least one dataset"
# Check for duplicate dataset names
dataset_names = [ds.name for ds in self.datasets]
if len(dataset_names) != len(set(dataset_names)):
raise ValueError("Duplicate dataset names are not allowed in CompositeDataset")
# Validate each dataset spec
for ds in self.datasets:
ds.validate()
@classmethod
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
"""Load configuration from YAML file"""
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)
# Convert dataset specs to DatasetSpec objects
if "datasets" in data:
data["datasets"] = [DatasetSpec(**ds) for ds in data["datasets"]]
return cls(**data)
class CompositeDataset(ProceduralDataset):
"""A dataset that combines multiple datasets with weighted sampling"""
def __init__(self, config: CompositeConfig, version_manager: Optional[DatasetVersionManager] = None):
super().__init__(config=config, seed=config.seed, size=config.size)
self.version_manager = version_manager
self.dataset_versions = {} # dataset_name -> version_id
# Initialize sub-datasets with incremented seeds
self.datasets = {}
self.weights = []
for i, ds_spec in enumerate(config.datasets):
# Create dataset with derived seed
ds_config = ds_spec.config.copy()
if "seed" not in ds_config:
ds_config["seed"] = self.seed + i + 1
if "size" not in ds_config:
ds_config["size"] = self.size
if ds_spec.weight < 0:
raise ValueError(f"Dataset '{ds_spec.name}' has invalid weight {ds_spec.weight}, must be non-negative")
dataset = create_dataset(ds_spec.name, **ds_config)
self.datasets[ds_spec.name] = dataset
# Register version if tracking enabled
if version_manager is not None:
version_id = version_manager.register_dataset(ds_spec.name, dataset)
self.dataset_versions[ds_spec.name] = version_id
self.weights.append(ds_spec.weight) # Store unnormalized weights directly
self.dataset_names = [ds.name for ds in config.datasets]
def __getitem__(self, idx: int) -> dict:
"""Generate a single dataset item by sampling from sub-datasets"""
# Create deterministic RNG for this index
rng = Random(self.seed + idx)
# Sample dataset according to weights
dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0]
dataset_name = self.dataset_names[dataset_idx]
dataset = self.datasets[dataset_name]
# Get item from selected dataset
item = dataset[idx]
# Add source dataset info to metadata
item["metadata"]["source_dataset"] = dataset_name
item["metadata"]["source_index"] = idx
# Add version info if tracking enabled
if self.version_manager is not None:
version_id = self.dataset_versions[dataset_name]
item["metadata"]["version_id"] = version_id
# Add entry_id combining version and index
item["metadata"]["entry_id"] = f"{version_id}.{idx}"
return item
def update_dataset_config(self, dataset_name: str, config_updates: Dict[str, Any]) -> None:
"""Update configuration of a specific dataset
Args:
dataset_name: Name of the dataset to update
config_updates: Dictionary of configuration parameters to update
Raises:
KeyError: If dataset_name is not found
AttributeError: If config parameter doesn't exist
"""
if dataset_name not in self.datasets:
raise KeyError(f"Dataset '{dataset_name}' not found")
dataset = self.datasets[dataset_name]
# Update the current config
new_config = replace(dataset.config, **config_updates)
# Validate new config
new_config.validate()
# Create new dataset instance with updated config
dataset_cls = dataset.__class__
new_dataset = dataset_cls(new_config)
self.datasets[dataset_name] = new_dataset
# Register new version if tracking enabled
if self.version_manager is not None:
version_id = self.version_manager.register_dataset(dataset_name, new_dataset)
self.dataset_versions[dataset_name] = version_id
def update_dataset_weight(self, dataset_name: str, weight: float) -> None:
"""Update weight for a specific dataset in the configuration
Args:
dataset_name: Name of the dataset to update
weight: New weight value
Raises:
KeyError: If dataset_name not found
ValueError: If weight is negative
"""
if dataset_name not in self.datasets:
raise KeyError(f"Dataset '{dataset_name}' not found")
if weight < 0:
raise ValueError(f"Weight must be non-negative, got {weight}")
# Update weight in both config and weights list
for i, ds_spec in enumerate(self.config.datasets):
if ds_spec.name == dataset_name:
ds_spec.weight = weight
self.weights[i] = weight
break
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
"""Forward scoring to appropriate dataset"""
dataset_name = entry["metadata"]["source_dataset"]
return self.datasets[dataset_name].score_answer(answer, entry)
def add_dataset(self, dataset_spec: DatasetSpec) -> None:
"""Add a new dataset to the composite
Args:
dataset_spec: Specification for the dataset to add
Raises:
ValueError: If dataset name already exists
"""
# Validate spec
dataset_spec.validate()
# Check for duplicate name
if dataset_spec.name in self.datasets:
raise ValueError(f"Dataset '{dataset_spec.name}' already exists in composite")
# Create dataset with derived seed
ds_config = dataset_spec.config.copy()
if "seed" not in ds_config:
ds_config["seed"] = self.seed + len(self.datasets) + 1
if "size" not in ds_config:
ds_config["size"] = self.size
# Create and add dataset
dataset = create_dataset(dataset_spec.name, **ds_config)
self.datasets[dataset_spec.name] = dataset
# Register version if tracking enabled
if self.version_manager is not None:
version_id = self.version_manager.register_dataset(dataset_spec.name, dataset)
self.dataset_versions[dataset_spec.name] = version_id
# Add to config and update internal state
self.config.datasets.append(dataset_spec)
self.dataset_names.append(dataset_spec.name)
self.weights.append(dataset_spec.weight) # Use weight directly from spec
def remove_dataset(self, dataset_name: str) -> None:
"""Remove a dataset from the composite
Args:
dataset_name: Name of the dataset to remove
Raises:
KeyError: If dataset not found
ValueError: If trying to remove last dataset
"""
if dataset_name not in self.datasets:
raise KeyError(f"Dataset '{dataset_name}' not found")
if len(self.datasets) <= 1:
raise ValueError("Cannot remove last dataset from composite")
# Remove from all internal structures
del self.datasets[dataset_name]
if self.version_manager is not None:
del self.dataset_versions[dataset_name]
# Remove from config
self.config.datasets = [ds for ds in self.config.datasets if ds.name != dataset_name]
# Update internal state
idx = self.dataset_names.index(dataset_name)
self.dataset_names.pop(idx)
self.weights.pop(idx)
def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float:
"""Score an answer using an entry_id to lookup the original entry
Args:
answer: The answer to score
entry_id: String in format "version_id.index"
Returns:
Score between 0 and 1
Raises:
ValueError: If entry_id format is invalid
KeyError: If version not found in version manager
"""
if self.version_manager is None:
raise RuntimeError("Version manager required for scoring with entry_id")
try:
version_id, index = map(int, entry_id.split("."))
except ValueError:
raise ValueError(f"Invalid entry_id format: {entry_id}, expected 'version_id.index'")
# Get dataset from version manager
dataset_info = self.version_manager.get_dataset(version_id)
if dataset_info is None:
raise KeyError(f"Version {version_id} not found in version manager")
dataset_name, dataset = dataset_info
# Get entry from dataset
entry = dataset[index]
# Score answer using dataset's scoring function
return dataset.score_answer(answer, entry)
# Register the dataset
register_dataset("composite", CompositeDataset, CompositeConfig)