reasoning-gym-server & cli tool (#154)

* feat: Add initial server structure with configuration, registry, and middleware

* feat: Add chain_sum dataset to experiment registry test

* fix: Update test_registry to use DatasetSpec for composite config validation

* refactor: Update Pydantic config to use json_schema_extra and ConfigDict

* feat: Add Pydantic models for API request/response data

* feat: Implement basic experiment management endpoints with tests

* feat: Implement composite configuration endpoints for experiments

* fix: Add missing DatasetConfigUpdate import in server.py

* refactor: Update dataset config update method to properly merge config updates

* fix: Correctly retrieve current dataset config in composite endpoint

* feat: Add basic CLI structure with experiments and config commands

* feat: Add initial CLI tool with basic experiment management commands

* refactor: Reorganize CLI package structure and fix import paths

* refactor: Implement initial CLI commands for experiment management

* feat: Implement HTTP client for Reasoning Gym server in RGC CLI tool

* fix: Move print statements inside try block to resolve SyntaxError

* fix: Resolve SyntaxError in edit_config function by adding missing except block

* feat: Add default app instance in server module for easier uvicorn startup

* docs: Add README.md with server and RGC tool documentation

* remove unused files

* refactor: Remove unsupported type annotation in registry.py

* refactor: Move ExperimentRegistry to coaching module and add Experiment class

* fix: Add missing CompositeDataset import in test_registry.py

* refactor: Implement lazy ASGI app creation for server initialization

* feat: Add health check command to RGC CLI for server connection

* feat: Add version tracking support to CompositeDataset

* feat: Add DatasetVersionManager for tracking dataset versions

* feat: Add entry_id metadata and score_answer_with_id method to CompositeDataset

* feat: Add entry_id metadata combining version and index

* fix: Resolve undefined variable by storing version_id before use

* test: Add comprehensive unit tests for score_answer_with_id() function

* test: Add comprehensive version tracking test for dataset config updates

* feat: Validate dataset weights are positive in CompositeDataset initialization

* feat: Add weight update and normalization methods to CompositeDataset

* refactor: Centralize weight normalization in CompositeDataset and allow zero-weight datasets

* feat: Add negative weight validation to CompositeDataset constructor

* feat: Add duplicate dataset name check in CompositeDataset and update test

* refactor: Move duplicate dataset name check inside dataset iteration loop

* refactor: Update CompositeDataset weight management to use config as source of truth

* refactor: Move duplicate dataset name check to CompositeConfig.validate()

* test: Update composite dataset weight test assertions and validation

* feat: Add methods to add and remove datasets in CompositeDataset

* refactor: Remove weight normalization and use unnormalized weights directly

* refactor: Remove redundant total weight check in update_dataset_weights

* feat: Add batch generation and scoring endpoints to server

* fix: Import BatchEntry in server.py to resolve undefined name error

* refactor: Update ReasoningGymDataset to use server for batch generation and scoring

* fix: Add missing List and Dict type imports

* feat: Add get_batch() and score_outputs() methods to RGClient

* test: Add unit tests for generate_batch and score_outputs endpoints

* refactor: Add DatasetVersionManager to Experiment class and CompositeDataset constructor

* feat: Add validation for base_index and batch_size in generate_batch endpoint

* refactor: Remove unused BatchRequest type from imports

* refactor: Convert models to use Pydantic exclusively

* test: Update scoring endpoint tests to use correct request model format

* refactor: Rename ScoreItem to AnswerItem and update related code

* feat: Update scoring endpoint to return ordered ScoringResponse with scores and entry_ids

* fix: Add missing ScoringResponse import in server.py

* move verl ppo sample with server into own file

* refactor: Use Pydantic models for get_batch() and score_outputs() in RGClient

* refactor: Update client methods to use Pydantic models for type safety

* refactor: Use Pydantic models for experiment and dataset config operations

* refactor: Clean up duplicate methods and improve error handling in main.py

* first bits of rg server use for verl

* refactor: Optimize scoring with single HTTP request in _score_output

* fix: Correct experiment creation with ExperimentCreate object

* grpo tests with server
This commit is contained in:
Andreas Köpf 2025-02-19 22:41:33 +01:00 committed by GitHub
parent bec6aefd11
commit e2702092f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1968 additions and 22 deletions

View file

@ -4,6 +4,7 @@ import pytest
import yaml
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
from reasoning_gym.version_manager import DatasetVersionManager
def create_test_config(tmp_path):
@ -85,13 +86,165 @@ def test_composite_dataset_weights():
seed=42,
datasets=[
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
DatasetSpec("chain_sum", 3.0, {"min_terms": 3}),
DatasetSpec("products", 3.0, {"min_terms": 2}),
],
)
dataset = CompositeDataset(config)
assert abs(dataset.weights[0] - 0.4) < 1e-6
assert abs(dataset.weights[1] - 0.6) < 1e-6
assert abs(dataset.weights[0] - 2.0) < 1e-6
assert abs(dataset.weights[1] - 3.0) < 1e-6
# Test weight updates
dataset.update_dataset_weight("chain_sum", 1.0)
print(dataset.weights)
assert abs(dataset.weights[0] - 1.0) < 1e-6
assert abs(dataset.weights[1] - 3.0) < 1e-6
# Test invalid weight
with pytest.raises(ValueError, match="Weight must be non-negative"):
dataset.update_dataset_weight("chain_sum", -1.0)
# Test invalid dataset name
with pytest.raises(KeyError):
dataset.update_dataset_weight("invalid_dataset", 1.0)
# Test zero total weight
dataset.update_dataset_weight("chain_sum", 0.0)
with pytest.raises(ValueError, match="Total of weights must be greater than zero"):
dataset.update_dataset_weight("products", 0.0)
_ = dataset[0] # access item with all weights 0
# Test duplicate dataset names
with pytest.raises(ValueError, match="Duplicate dataset names"):
CompositeConfig(
size=1000,
seed=42,
datasets=[
DatasetSpec("chain_sum", 1.0, {"min_terms": 2}),
DatasetSpec("chain_sum", 1.0, {"min_terms": 3}),
],
).validate()
def test_version_tracking_with_config_updates():
"""Test that version tracking works correctly when updating dataset configs"""
# Create composite dataset with version manager
version_manager = DatasetVersionManager()
config = CompositeConfig(
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)
dataset = CompositeDataset(config, version_manager=version_manager)
# Get an entry and its id from initial version
entry_1 = dataset[0]
entry_id_1 = entry_1["metadata"]["entry_id"]
answer_1 = entry_1["answer"]
# Update dataset config
dataset.update_dataset_config("chain_sum", {"min_terms": 3, "max_terms": 5})
# Get new entry after config update
entry_2 = dataset[0]
entry_id_2 = entry_2["metadata"]["entry_id"]
answer_2 = entry_2["answer"]
# Verify entries have different version IDs
version_1 = int(entry_id_1.split(".")[0])
version_2 = int(entry_id_2.split(".")[0])
assert version_1 != version_2, "New config should create new version"
# Verify original answer still works with original version
score_1 = dataset.score_answer_with_id(answer_1, entry_id_1)
assert score_1 == 1.0, "Original answer should still work with original version"
# Verify new answer works with new version
score_2 = dataset.score_answer_with_id(answer_2, entry_id_2)
assert score_2 == 1.0, "New answer should work with new version"
# Verify original answer fails with new version
score_3 = dataset.score_answer_with_id(answer_1, entry_id_2)
assert score_3 < 1.0, "Original answer should not work with new version"
def test_score_answer_with_id():
"""Test scoring answers using entry_id"""
# Create composite dataset with version manager
version_manager = DatasetVersionManager()
config = CompositeConfig(
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
)
dataset = CompositeDataset(config, version_manager=version_manager)
# Get an entry and its id
entry = dataset[0]
entry_id = entry["metadata"]["entry_id"]
# Test successful scoring
answer = entry["answer"]
score = dataset.score_answer_with_id(answer, entry_id)
assert score == 1.0 # Correct answer should get full score
# Test wrong answer
wrong_answer = "wrong"
score = dataset.score_answer_with_id(wrong_answer, entry_id)
assert score < 1.0 # Wrong answer should get lower score
# Test invalid entry_id format
with pytest.raises(ValueError, match="Invalid entry_id format"):
dataset.score_answer_with_id(answer, "invalid")
# Test non-existent version
with pytest.raises(KeyError, match="Version .* not found"):
dataset.score_answer_with_id(answer, "999.0")
# Test without version manager
dataset_no_vm = CompositeDataset(config)
with pytest.raises(RuntimeError, match="Version manager required"):
dataset_no_vm.score_answer_with_id(answer, entry_id)
def test_add_remove_dataset():
"""Test adding and removing datasets from composite"""
config = CompositeConfig(
size=1000,
seed=42,
datasets=[
DatasetSpec("chain_sum", 1.0, {"min_terms": 2}),
],
)
dataset = CompositeDataset(config)
# Test adding new dataset
new_spec = DatasetSpec("products", 2.0, {"min_terms": 2})
dataset.add_dataset(new_spec)
assert len(dataset.datasets) == 2
assert "products" in dataset.datasets
assert len(dataset.config.datasets) == 2
assert dataset.dataset_names[0] == "chain_sum"
assert dataset.dataset_names[1] == "products"
assert abs(dataset.weights[0] - 1.0) < 1e-6 # chain_sum weight
assert abs(dataset.weights[1] - 2.0) < 1e-6 # products weight
# Test duplicate name
with pytest.raises(ValueError, match="already exists"):
dataset.add_dataset(new_spec)
# Test removing dataset
dataset.remove_dataset("products")
assert len(dataset.datasets) == 1
assert "products" not in dataset.datasets
assert len(dataset.config.datasets) == 1
# Test removing non-existent dataset
with pytest.raises(KeyError):
dataset.remove_dataset("nonexistent")
# Test removing last dataset
with pytest.raises(ValueError, match="Cannot remove last dataset"):
dataset.remove_dataset("chain_sum")
def test_yaml_loading(tmp_path):