mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* feat: Add initial server structure with configuration, registry, and middleware * feat: Add chain_sum dataset to experiment registry test * fix: Update test_registry to use DatasetSpec for composite config validation * refactor: Update Pydantic config to use json_schema_extra and ConfigDict * feat: Add Pydantic models for API request/response data * feat: Implement basic experiment management endpoints with tests * feat: Implement composite configuration endpoints for experiments * fix: Add missing DatasetConfigUpdate import in server.py * refactor: Update dataset config update method to properly merge config updates * fix: Correctly retrieve current dataset config in composite endpoint * feat: Add basic CLI structure with experiments and config commands * feat: Add initial CLI tool with basic experiment management commands * refactor: Reorganize CLI package structure and fix import paths * refactor: Implement initial CLI commands for experiment management * feat: Implement HTTP client for Reasoning Gym server in RGC CLI tool * fix: Move print statements inside try block to resolve SyntaxError * fix: Resolve SyntaxError in edit_config function by adding missing except block * feat: Add default app instance in server module for easier uvicorn startup * docs: Add README.md with server and RGC tool documentation * remove unused files * refactor: Remove unsupported type annotation in registry.py * refactor: Move ExperimentRegistry to coaching module and add Experiment class * fix: Add missing CompositeDataset import in test_registry.py * refactor: Implement lazy ASGI app creation for server initialization * feat: Add health check command to RGC CLI for server connection * feat: Add version tracking support to CompositeDataset * feat: Add DatasetVersionManager for tracking dataset versions * feat: Add entry_id metadata and score_answer_with_id method to CompositeDataset * feat: Add entry_id metadata combining version and index * fix: Resolve undefined variable by storing version_id before use * test: Add comprehensive unit tests for score_answer_with_id() function * test: Add comprehensive version tracking test for dataset config updates * feat: Validate dataset weights are positive in CompositeDataset initialization * feat: Add weight update and normalization methods to CompositeDataset * refactor: Centralize weight normalization in CompositeDataset and allow zero-weight datasets * feat: Add negative weight validation to CompositeDataset constructor * feat: Add duplicate dataset name check in CompositeDataset and update test * refactor: Move duplicate dataset name check inside dataset iteration loop * refactor: Update CompositeDataset weight management to use config as source of truth * refactor: Move duplicate dataset name check to CompositeConfig.validate() * test: Update composite dataset weight test assertions and validation * feat: Add methods to add and remove datasets in CompositeDataset * refactor: Remove weight normalization and use unnormalized weights directly * refactor: Remove redundant total weight check in update_dataset_weights * feat: Add batch generation and scoring endpoints to server * fix: Import BatchEntry in server.py to resolve undefined name error * refactor: Update ReasoningGymDataset to use server for batch generation and scoring * fix: Add missing List and Dict type imports * feat: Add get_batch() and score_outputs() methods to RGClient * test: Add unit tests for generate_batch and score_outputs endpoints * refactor: Add DatasetVersionManager to Experiment class and CompositeDataset constructor * feat: Add validation for base_index and batch_size in generate_batch endpoint * refactor: Remove unused BatchRequest type from imports * refactor: Convert models to use Pydantic exclusively * test: Update scoring endpoint tests to use correct request model format * refactor: Rename ScoreItem to AnswerItem and update related code * feat: Update scoring endpoint to return ordered ScoringResponse with scores and entry_ids * fix: Add missing ScoringResponse import in server.py * move verl ppo sample with server into own file * refactor: Use Pydantic models for get_batch() and score_outputs() in RGClient * refactor: Update client methods to use Pydantic models for type safety * refactor: Use Pydantic models for experiment and dataset config operations * refactor: Clean up duplicate methods and improve error handling in main.py * first bits of rg server use for verl * refactor: Optimize scoring with single HTTP request in _score_output * fix: Correct experiment creation with ExperimentCreate object * grpo tests with server
169 lines
6.3 KiB
Python
169 lines
6.3 KiB
Python
"""FastAPI server implementation for Reasoning Gym."""
|
|
|
|
import logging
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
|
from reasoning_gym.coaching.registry import ExperimentRegistry
|
|
from reasoning_gym.composite import CompositeConfig, DatasetSpec
|
|
|
|
from .config import ServerConfig
|
|
from .middleware import APIKeyMiddleware
|
|
from .models import (
|
|
BatchEntry,
|
|
BatchResponse,
|
|
DatasetConfigUpdate,
|
|
ExperimentCreate,
|
|
ExperimentList,
|
|
ExperimentResponse,
|
|
ScoringRequest,
|
|
ScoringResponse,
|
|
)
|
|
|
|
|
|
def create_app(config: ServerConfig) -> FastAPI:
|
|
"""Create and configure the FastAPI application."""
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=config.log_level)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(title="Reasoning Gym Server")
|
|
|
|
# Add middleware
|
|
app.add_middleware(APIKeyMiddleware, api_key=config.api_key)
|
|
|
|
# Initialize registry
|
|
registry = ExperimentRegistry()
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {"status": "healthy"}
|
|
|
|
@app.post("/experiments", response_model=ExperimentResponse)
|
|
async def create_experiment(experiment: ExperimentCreate):
|
|
"""Create a new experiment."""
|
|
# Convert dict format to DatasetSpec list
|
|
dataset_specs = []
|
|
for name, spec in experiment.datasets.items():
|
|
dataset_specs.append(DatasetSpec(name=name, weight=spec.get("weight", 1.0), config=spec.get("config", {})))
|
|
|
|
config = CompositeConfig(size=experiment.size, seed=experiment.seed, datasets=dataset_specs)
|
|
|
|
try:
|
|
registry.register_experiment(experiment.name, config)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
return ExperimentResponse(
|
|
name=experiment.name, size=experiment.size, seed=experiment.seed, datasets=experiment.datasets
|
|
)
|
|
|
|
@app.get("/experiments", response_model=ExperimentList)
|
|
async def list_experiments():
|
|
"""List all registered experiments."""
|
|
return ExperimentList(experiments=registry.list_experiments())
|
|
|
|
@app.delete("/experiments/{name}")
|
|
async def delete_experiment(name: str):
|
|
"""Delete an experiment."""
|
|
if not registry.remove_experiment(name):
|
|
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
|
return {"status": "deleted"}
|
|
|
|
@app.get("/experiments/{name}/batch", response_model=BatchResponse)
|
|
async def generate_batch(name: str, base_index: int, batch_size: int):
|
|
"""Generate a batch of raw entries"""
|
|
# Validate parameters
|
|
if base_index < 0:
|
|
raise HTTPException(status_code=400, detail="base_index must be non-negative")
|
|
if batch_size <= 0:
|
|
raise HTTPException(status_code=400, detail="batch_size must be positive")
|
|
|
|
experiment = registry.get_experiment(name)
|
|
if not experiment:
|
|
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
|
|
|
try:
|
|
entries = []
|
|
for i in range(base_index, base_index + batch_size):
|
|
entry = experiment.dataset[i]
|
|
|
|
# Create BatchEntry with minimal required data
|
|
batch_entry = BatchEntry(
|
|
question=entry["question"],
|
|
entry_id=f"{entry['metadata']['version_id']}.{i}",
|
|
metadata=entry["metadata"],
|
|
)
|
|
entries.append(batch_entry)
|
|
|
|
return BatchResponse(entries=entries)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
@app.post("/experiments/{name}/score", response_model=ScoringResponse)
|
|
async def score_outputs(name: str, request: ScoringRequest):
|
|
"""Score extracted answers"""
|
|
experiment = registry.get_experiment(name)
|
|
if not experiment:
|
|
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
|
|
|
try:
|
|
scores = []
|
|
entry_ids = []
|
|
for item in request.answers:
|
|
score = experiment.dataset.score_answer_with_id(item.answer, item.entry_id)
|
|
scores.append(score)
|
|
entry_ids.append(item.entry_id)
|
|
|
|
return ScoringResponse(scores=scores, entry_ids=entry_ids)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
@app.get("/experiments/{name}/composite", response_model=ExperimentResponse)
|
|
async def get_composite_config(name: str):
|
|
"""Get composite configuration for an experiment."""
|
|
experiment = registry.get_experiment(name)
|
|
if not experiment:
|
|
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
|
|
|
# Convert internal config to API response format
|
|
datasets = {}
|
|
for ds_spec in experiment.config.datasets:
|
|
dataset = experiment.dataset.datasets[ds_spec.name]
|
|
datasets[ds_spec.name] = {
|
|
"weight": ds_spec.weight,
|
|
"config": vars(dataset.config), # Get current config from dataset instance
|
|
}
|
|
|
|
return ExperimentResponse(
|
|
name=name, size=experiment.config.size, seed=experiment.config.seed, datasets=datasets
|
|
)
|
|
|
|
@app.post("/experiments/{name}/composite/{dataset_name}")
|
|
async def update_dataset_config(name: str, dataset_name: str, config_update: DatasetConfigUpdate):
|
|
"""Update configuration for a specific dataset in the composite."""
|
|
experiment = registry.get_experiment(name)
|
|
if not experiment:
|
|
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
|
|
|
|
try:
|
|
experiment.dataset.update_dataset_config(dataset_name, config_update.config)
|
|
return {"status": "updated"}
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in experiment")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
return app
|
|
|
|
|
|
async def app(scope, receive, send):
|
|
"""ASGI application that lazily creates the FastAPI app."""
|
|
if not hasattr(app, "server_app"):
|
|
app.server_app = create_app(ServerConfig())
|
|
await app.server_app(scope, receive, send)
|