reasoning-gym/tools/server/server.py
Andreas Köpf e2702092f4
reasoning-gym-server & cli tool (#154)
* feat: Add initial server structure with configuration, registry, and middleware

* feat: Add chain_sum dataset to experiment registry test

* fix: Update test_registry to use DatasetSpec for composite config validation

* refactor: Update Pydantic config to use json_schema_extra and ConfigDict

* feat: Add Pydantic models for API request/response data

* feat: Implement basic experiment management endpoints with tests

* feat: Implement composite configuration endpoints for experiments

* fix: Add missing DatasetConfigUpdate import in server.py

* refactor: Update dataset config update method to properly merge config updates

* fix: Correctly retrieve current dataset config in composite endpoint

* feat: Add basic CLI structure with experiments and config commands

* feat: Add initial CLI tool with basic experiment management commands

* refactor: Reorganize CLI package structure and fix import paths

* refactor: Implement initial CLI commands for experiment management

* feat: Implement HTTP client for Reasoning Gym server in RGC CLI tool

* fix: Move print statements inside try block to resolve SyntaxError

* fix: Resolve SyntaxError in edit_config function by adding missing except block

* feat: Add default app instance in server module for easier uvicorn startup

* docs: Add README.md with server and RGC tool documentation

* remove unused files

* refactor: Remove unsupported type annotation in registry.py

* refactor: Move ExperimentRegistry to coaching module and add Experiment class

* fix: Add missing CompositeDataset import in test_registry.py

* refactor: Implement lazy ASGI app creation for server initialization

* feat: Add health check command to RGC CLI for server connection

* feat: Add version tracking support to CompositeDataset

* feat: Add DatasetVersionManager for tracking dataset versions

* feat: Add entry_id metadata and score_answer_with_id method to CompositeDataset

* feat: Add entry_id metadata combining version and index

* fix: Resolve undefined variable by storing version_id before use

* test: Add comprehensive unit tests for score_answer_with_id() function

* test: Add comprehensive version tracking test for dataset config updates

* feat: Validate dataset weights are positive in CompositeDataset initialization

* feat: Add weight update and normalization methods to CompositeDataset

* refactor: Centralize weight normalization in CompositeDataset and allow zero-weight datasets

* feat: Add negative weight validation to CompositeDataset constructor

* feat: Add duplicate dataset name check in CompositeDataset and update test

* refactor: Move duplicate dataset name check inside dataset iteration loop

* refactor: Update CompositeDataset weight management to use config as source of truth

* refactor: Move duplicate dataset name check to CompositeConfig.validate()

* test: Update composite dataset weight test assertions and validation

* feat: Add methods to add and remove datasets in CompositeDataset

* refactor: Remove weight normalization and use unnormalized weights directly

* refactor: Remove redundant total weight check in update_dataset_weights

* feat: Add batch generation and scoring endpoints to server

* fix: Import BatchEntry in server.py to resolve undefined name error

* refactor: Update ReasoningGymDataset to use server for batch generation and scoring

* fix: Add missing List and Dict type imports

* feat: Add get_batch() and score_outputs() methods to RGClient

* test: Add unit tests for generate_batch and score_outputs endpoints

* refactor: Add DatasetVersionManager to Experiment class and CompositeDataset constructor

* feat: Add validation for base_index and batch_size in generate_batch endpoint

* refactor: Remove unused BatchRequest type from imports

* refactor: Convert models to use Pydantic exclusively

* test: Update scoring endpoint tests to use correct request model format

* refactor: Rename ScoreItem to AnswerItem and update related code

* feat: Update scoring endpoint to return ordered ScoringResponse with scores and entry_ids

* fix: Add missing ScoringResponse import in server.py

* move verl ppo sample with server into own file

* refactor: Use Pydantic models for get_batch() and score_outputs() in RGClient

* refactor: Update client methods to use Pydantic models for type safety

* refactor: Use Pydantic models for experiment and dataset config operations

* refactor: Clean up duplicate methods and improve error handling in main.py

* first bits of rg server use for verl

* refactor: Optimize scoring with single HTTP request in _score_output

* fix: Correct experiment creation with ExperimentCreate object

* grpo tests with server
2025-02-19 22:41:33 +01:00

169 lines
6.3 KiB
Python

"""FastAPI server implementation for Reasoning Gym."""
import logging
from fastapi import FastAPI, HTTPException
from reasoning_gym.coaching.registry import ExperimentRegistry
from reasoning_gym.composite import CompositeConfig, DatasetSpec
from .config import ServerConfig
from .middleware import APIKeyMiddleware
from .models import (
BatchEntry,
BatchResponse,
DatasetConfigUpdate,
ExperimentCreate,
ExperimentList,
ExperimentResponse,
ScoringRequest,
ScoringResponse,
)
def create_app(config: ServerConfig) -> FastAPI:
"""Create and configure the FastAPI application."""
# Configure logging
logging.basicConfig(level=config.log_level)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(title="Reasoning Gym Server")
# Add middleware
app.add_middleware(APIKeyMiddleware, api_key=config.api_key)
# Initialize registry
registry = ExperimentRegistry()
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
@app.post("/experiments", response_model=ExperimentResponse)
async def create_experiment(experiment: ExperimentCreate):
"""Create a new experiment."""
# Convert dict format to DatasetSpec list
dataset_specs = []
for name, spec in experiment.datasets.items():
dataset_specs.append(DatasetSpec(name=name, weight=spec.get("weight", 1.0), config=spec.get("config", {})))
config = CompositeConfig(size=experiment.size, seed=experiment.seed, datasets=dataset_specs)
try:
registry.register_experiment(experiment.name, config)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return ExperimentResponse(
name=experiment.name, size=experiment.size, seed=experiment.seed, datasets=experiment.datasets
)
@app.get("/experiments", response_model=ExperimentList)
async def list_experiments():
"""List all registered experiments."""
return ExperimentList(experiments=registry.list_experiments())
@app.delete("/experiments/{name}")
async def delete_experiment(name: str):
"""Delete an experiment."""
if not registry.remove_experiment(name):
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
return {"status": "deleted"}
@app.get("/experiments/{name}/batch", response_model=BatchResponse)
async def generate_batch(name: str, base_index: int, batch_size: int):
"""Generate a batch of raw entries"""
# Validate parameters
if base_index < 0:
raise HTTPException(status_code=400, detail="base_index must be non-negative")
if batch_size <= 0:
raise HTTPException(status_code=400, detail="batch_size must be positive")
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
entries = []
for i in range(base_index, base_index + batch_size):
entry = experiment.dataset[i]
# Create BatchEntry with minimal required data
batch_entry = BatchEntry(
question=entry["question"],
entry_id=f"{entry['metadata']['version_id']}.{i}",
metadata=entry["metadata"],
)
entries.append(batch_entry)
return BatchResponse(entries=entries)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/experiments/{name}/score", response_model=ScoringResponse)
async def score_outputs(name: str, request: ScoringRequest):
"""Score extracted answers"""
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
scores = []
entry_ids = []
for item in request.answers:
score = experiment.dataset.score_answer_with_id(item.answer, item.entry_id)
scores.append(score)
entry_ids.append(item.entry_id)
return ScoringResponse(scores=scores, entry_ids=entry_ids)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/experiments/{name}/composite", response_model=ExperimentResponse)
async def get_composite_config(name: str):
"""Get composite configuration for an experiment."""
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
# Convert internal config to API response format
datasets = {}
for ds_spec in experiment.config.datasets:
dataset = experiment.dataset.datasets[ds_spec.name]
datasets[ds_spec.name] = {
"weight": ds_spec.weight,
"config": vars(dataset.config), # Get current config from dataset instance
}
return ExperimentResponse(
name=name, size=experiment.config.size, seed=experiment.config.seed, datasets=datasets
)
@app.post("/experiments/{name}/composite/{dataset_name}")
async def update_dataset_config(name: str, dataset_name: str, config_update: DatasetConfigUpdate):
"""Update configuration for a specific dataset in the composite."""
experiment = registry.get_experiment(name)
if not experiment:
raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found")
try:
experiment.dataset.update_dataset_config(dataset_name, config_update.config)
return {"status": "updated"}
except KeyError:
raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in experiment")
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return app
async def app(scope, receive, send):
"""ASGI application that lazily creates the FastAPI app."""
if not hasattr(app, "server_app"):
app.server_app = create_app(ServerConfig())
await app.server_app(scope, receive, send)