mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* feat: Add initial server structure with configuration, registry, and middleware * feat: Add chain_sum dataset to experiment registry test * fix: Update test_registry to use DatasetSpec for composite config validation * refactor: Update Pydantic config to use json_schema_extra and ConfigDict * feat: Add Pydantic models for API request/response data * feat: Implement basic experiment management endpoints with tests * feat: Implement composite configuration endpoints for experiments * fix: Add missing DatasetConfigUpdate import in server.py * refactor: Update dataset config update method to properly merge config updates * fix: Correctly retrieve current dataset config in composite endpoint * feat: Add basic CLI structure with experiments and config commands * feat: Add initial CLI tool with basic experiment management commands * refactor: Reorganize CLI package structure and fix import paths * refactor: Implement initial CLI commands for experiment management * feat: Implement HTTP client for Reasoning Gym server in RGC CLI tool * fix: Move print statements inside try block to resolve SyntaxError * fix: Resolve SyntaxError in edit_config function by adding missing except block * feat: Add default app instance in server module for easier uvicorn startup * docs: Add README.md with server and RGC tool documentation * remove unused files * refactor: Remove unsupported type annotation in registry.py * refactor: Move ExperimentRegistry to coaching module and add Experiment class * fix: Add missing CompositeDataset import in test_registry.py * refactor: Implement lazy ASGI app creation for server initialization * feat: Add health check command to RGC CLI for server connection * feat: Add version tracking support to CompositeDataset * feat: Add DatasetVersionManager for tracking dataset versions * feat: Add entry_id metadata and score_answer_with_id method to CompositeDataset * feat: Add entry_id metadata combining version and index * fix: Resolve undefined variable by storing version_id before use * test: Add comprehensive unit tests for score_answer_with_id() function * test: Add comprehensive version tracking test for dataset config updates * feat: Validate dataset weights are positive in CompositeDataset initialization * feat: Add weight update and normalization methods to CompositeDataset * refactor: Centralize weight normalization in CompositeDataset and allow zero-weight datasets * feat: Add negative weight validation to CompositeDataset constructor * feat: Add duplicate dataset name check in CompositeDataset and update test * refactor: Move duplicate dataset name check inside dataset iteration loop * refactor: Update CompositeDataset weight management to use config as source of truth * refactor: Move duplicate dataset name check to CompositeConfig.validate() * test: Update composite dataset weight test assertions and validation * feat: Add methods to add and remove datasets in CompositeDataset * refactor: Remove weight normalization and use unnormalized weights directly * refactor: Remove redundant total weight check in update_dataset_weights * feat: Add batch generation and scoring endpoints to server * fix: Import BatchEntry in server.py to resolve undefined name error * refactor: Update ReasoningGymDataset to use server for batch generation and scoring * fix: Add missing List and Dict type imports * feat: Add get_batch() and score_outputs() methods to RGClient * test: Add unit tests for generate_batch and score_outputs endpoints * refactor: Add DatasetVersionManager to Experiment class and CompositeDataset constructor * feat: Add validation for base_index and batch_size in generate_batch endpoint * refactor: Remove unused BatchRequest type from imports * refactor: Convert models to use Pydantic exclusively * test: Update scoring endpoint tests to use correct request model format * refactor: Rename ScoreItem to AnswerItem and update related code * feat: Update scoring endpoint to return ordered ScoringResponse with scores and entry_ids * fix: Add missing ScoringResponse import in server.py * move verl ppo sample with server into own file * refactor: Use Pydantic models for get_batch() and score_outputs() in RGClient * refactor: Update client methods to use Pydantic models for type safety * refactor: Use Pydantic models for experiment and dataset config operations * refactor: Clean up duplicate methods and improve error handling in main.py * first bits of rg server use for verl * refactor: Optimize scoring with single HTTP request in _score_output * fix: Correct experiment creation with ExperimentCreate object * grpo tests with server
286 lines
10 KiB
Python
286 lines
10 KiB
Python
from dataclasses import dataclass, replace
|
|
from random import Random
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import yaml
|
|
|
|
from .dataset import ProceduralDataset
|
|
from .factory import create_dataset, register_dataset
|
|
from .version_manager import DatasetVersionManager
|
|
|
|
|
|
@dataclass
|
|
class DatasetSpec:
|
|
"""Specification for a single dataset within the composite"""
|
|
|
|
name: str
|
|
weight: float
|
|
config: dict
|
|
|
|
def validate(self):
|
|
"""Validate dataset specification"""
|
|
assert self.name, "Dataset name cannot be empty"
|
|
assert self.weight > 0, "Weight must be positive"
|
|
assert isinstance(self.config, dict), "Config must be a dictionary"
|
|
|
|
|
|
@dataclass
|
|
class CompositeConfig:
|
|
"""Configuration for CompositeDataset"""
|
|
|
|
size: int = 500
|
|
seed: Optional[int] = None
|
|
datasets: List[DatasetSpec] = None
|
|
|
|
def validate(self):
|
|
"""Validate configuration parameters"""
|
|
assert self.size > 0, "size must be positive"
|
|
assert self.datasets, "Must specify at least one dataset"
|
|
assert len(self.datasets) > 0, "Must specify at least one dataset"
|
|
|
|
# Check for duplicate dataset names
|
|
dataset_names = [ds.name for ds in self.datasets]
|
|
if len(dataset_names) != len(set(dataset_names)):
|
|
raise ValueError("Duplicate dataset names are not allowed in CompositeDataset")
|
|
|
|
# Validate each dataset spec
|
|
for ds in self.datasets:
|
|
ds.validate()
|
|
|
|
@classmethod
|
|
def from_yaml(cls, yaml_path: str) -> "CompositeConfig":
|
|
"""Load configuration from YAML file"""
|
|
with open(yaml_path, "r") as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
# Convert dataset specs to DatasetSpec objects
|
|
if "datasets" in data:
|
|
data["datasets"] = [DatasetSpec(**ds) for ds in data["datasets"]]
|
|
|
|
return cls(**data)
|
|
|
|
|
|
class CompositeDataset(ProceduralDataset):
|
|
"""A dataset that combines multiple datasets with weighted sampling"""
|
|
|
|
def __init__(self, config: CompositeConfig, version_manager: Optional[DatasetVersionManager] = None):
|
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
|
self.version_manager = version_manager
|
|
self.dataset_versions = {} # dataset_name -> version_id
|
|
|
|
# Initialize sub-datasets with incremented seeds
|
|
self.datasets = {}
|
|
self.weights = []
|
|
|
|
for i, ds_spec in enumerate(config.datasets):
|
|
# Create dataset with derived seed
|
|
ds_config = ds_spec.config.copy()
|
|
if "seed" not in ds_config:
|
|
ds_config["seed"] = self.seed + i + 1
|
|
if "size" not in ds_config:
|
|
ds_config["size"] = self.size
|
|
|
|
if ds_spec.weight < 0:
|
|
raise ValueError(f"Dataset '{ds_spec.name}' has invalid weight {ds_spec.weight}, must be non-negative")
|
|
|
|
dataset = create_dataset(ds_spec.name, **ds_config)
|
|
self.datasets[ds_spec.name] = dataset
|
|
|
|
# Register version if tracking enabled
|
|
if version_manager is not None:
|
|
version_id = version_manager.register_dataset(ds_spec.name, dataset)
|
|
self.dataset_versions[ds_spec.name] = version_id
|
|
|
|
self.weights.append(ds_spec.weight) # Store unnormalized weights directly
|
|
self.dataset_names = [ds.name for ds in config.datasets]
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
"""Generate a single dataset item by sampling from sub-datasets"""
|
|
# Create deterministic RNG for this index
|
|
rng = Random(self.seed + idx)
|
|
|
|
# Sample dataset according to weights
|
|
dataset_idx = rng.choices(range(len(self.dataset_names)), weights=self.weights, k=1)[0]
|
|
dataset_name = self.dataset_names[dataset_idx]
|
|
dataset = self.datasets[dataset_name]
|
|
|
|
# Get item from selected dataset
|
|
item = dataset[idx]
|
|
|
|
# Add source dataset info to metadata
|
|
item["metadata"]["source_dataset"] = dataset_name
|
|
item["metadata"]["source_index"] = idx
|
|
|
|
# Add version info if tracking enabled
|
|
if self.version_manager is not None:
|
|
version_id = self.dataset_versions[dataset_name]
|
|
item["metadata"]["version_id"] = version_id
|
|
# Add entry_id combining version and index
|
|
item["metadata"]["entry_id"] = f"{version_id}.{idx}"
|
|
|
|
return item
|
|
|
|
def update_dataset_config(self, dataset_name: str, config_updates: Dict[str, Any]) -> None:
|
|
"""Update configuration of a specific dataset
|
|
|
|
Args:
|
|
dataset_name: Name of the dataset to update
|
|
config_updates: Dictionary of configuration parameters to update
|
|
|
|
Raises:
|
|
KeyError: If dataset_name is not found
|
|
AttributeError: If config parameter doesn't exist
|
|
"""
|
|
if dataset_name not in self.datasets:
|
|
raise KeyError(f"Dataset '{dataset_name}' not found")
|
|
|
|
dataset = self.datasets[dataset_name]
|
|
|
|
# Update the current config
|
|
new_config = replace(dataset.config, **config_updates)
|
|
|
|
# Validate new config
|
|
new_config.validate()
|
|
|
|
# Create new dataset instance with updated config
|
|
dataset_cls = dataset.__class__
|
|
new_dataset = dataset_cls(new_config)
|
|
self.datasets[dataset_name] = new_dataset
|
|
|
|
# Register new version if tracking enabled
|
|
if self.version_manager is not None:
|
|
version_id = self.version_manager.register_dataset(dataset_name, new_dataset)
|
|
self.dataset_versions[dataset_name] = version_id
|
|
|
|
def update_dataset_weight(self, dataset_name: str, weight: float) -> None:
|
|
"""Update weight for a specific dataset in the configuration
|
|
|
|
Args:
|
|
dataset_name: Name of the dataset to update
|
|
weight: New weight value
|
|
|
|
Raises:
|
|
KeyError: If dataset_name not found
|
|
ValueError: If weight is negative
|
|
"""
|
|
if dataset_name not in self.datasets:
|
|
raise KeyError(f"Dataset '{dataset_name}' not found")
|
|
if weight < 0:
|
|
raise ValueError(f"Weight must be non-negative, got {weight}")
|
|
|
|
# Update weight in both config and weights list
|
|
for i, ds_spec in enumerate(self.config.datasets):
|
|
if ds_spec.name == dataset_name:
|
|
ds_spec.weight = weight
|
|
self.weights[i] = weight
|
|
break
|
|
|
|
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
|
"""Forward scoring to appropriate dataset"""
|
|
dataset_name = entry["metadata"]["source_dataset"]
|
|
return self.datasets[dataset_name].score_answer(answer, entry)
|
|
|
|
def add_dataset(self, dataset_spec: DatasetSpec) -> None:
|
|
"""Add a new dataset to the composite
|
|
|
|
Args:
|
|
dataset_spec: Specification for the dataset to add
|
|
|
|
Raises:
|
|
ValueError: If dataset name already exists
|
|
"""
|
|
# Validate spec
|
|
dataset_spec.validate()
|
|
|
|
# Check for duplicate name
|
|
if dataset_spec.name in self.datasets:
|
|
raise ValueError(f"Dataset '{dataset_spec.name}' already exists in composite")
|
|
|
|
# Create dataset with derived seed
|
|
ds_config = dataset_spec.config.copy()
|
|
if "seed" not in ds_config:
|
|
ds_config["seed"] = self.seed + len(self.datasets) + 1
|
|
if "size" not in ds_config:
|
|
ds_config["size"] = self.size
|
|
|
|
# Create and add dataset
|
|
dataset = create_dataset(dataset_spec.name, **ds_config)
|
|
self.datasets[dataset_spec.name] = dataset
|
|
|
|
# Register version if tracking enabled
|
|
if self.version_manager is not None:
|
|
version_id = self.version_manager.register_dataset(dataset_spec.name, dataset)
|
|
self.dataset_versions[dataset_spec.name] = version_id
|
|
|
|
# Add to config and update internal state
|
|
self.config.datasets.append(dataset_spec)
|
|
self.dataset_names.append(dataset_spec.name)
|
|
self.weights.append(dataset_spec.weight) # Use weight directly from spec
|
|
|
|
def remove_dataset(self, dataset_name: str) -> None:
|
|
"""Remove a dataset from the composite
|
|
|
|
Args:
|
|
dataset_name: Name of the dataset to remove
|
|
|
|
Raises:
|
|
KeyError: If dataset not found
|
|
ValueError: If trying to remove last dataset
|
|
"""
|
|
if dataset_name not in self.datasets:
|
|
raise KeyError(f"Dataset '{dataset_name}' not found")
|
|
|
|
if len(self.datasets) <= 1:
|
|
raise ValueError("Cannot remove last dataset from composite")
|
|
|
|
# Remove from all internal structures
|
|
del self.datasets[dataset_name]
|
|
if self.version_manager is not None:
|
|
del self.dataset_versions[dataset_name]
|
|
|
|
# Remove from config
|
|
self.config.datasets = [ds for ds in self.config.datasets if ds.name != dataset_name]
|
|
|
|
# Update internal state
|
|
idx = self.dataset_names.index(dataset_name)
|
|
self.dataset_names.pop(idx)
|
|
self.weights.pop(idx)
|
|
|
|
def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float:
|
|
"""Score an answer using an entry_id to lookup the original entry
|
|
|
|
Args:
|
|
answer: The answer to score
|
|
entry_id: String in format "version_id.index"
|
|
|
|
Returns:
|
|
Score between 0 and 1
|
|
|
|
Raises:
|
|
ValueError: If entry_id format is invalid
|
|
KeyError: If version not found in version manager
|
|
"""
|
|
if self.version_manager is None:
|
|
raise RuntimeError("Version manager required for scoring with entry_id")
|
|
|
|
try:
|
|
version_id, index = map(int, entry_id.split("."))
|
|
except ValueError:
|
|
raise ValueError(f"Invalid entry_id format: {entry_id}, expected 'version_id.index'")
|
|
|
|
# Get dataset from version manager
|
|
dataset_info = self.version_manager.get_dataset(version_id)
|
|
if dataset_info is None:
|
|
raise KeyError(f"Version {version_id} not found in version manager")
|
|
|
|
dataset_name, dataset = dataset_info
|
|
|
|
# Get entry from dataset
|
|
entry = dataset[index]
|
|
|
|
# Score answer using dataset's scoring function
|
|
return dataset.score_answer(answer, entry)
|
|
|
|
|
|
# Register the dataset
|
|
register_dataset("composite", CompositeDataset, CompositeConfig)
|