mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* feat: Add initial server structure with configuration, registry, and middleware * feat: Add chain_sum dataset to experiment registry test * fix: Update test_registry to use DatasetSpec for composite config validation * refactor: Update Pydantic config to use json_schema_extra and ConfigDict * feat: Add Pydantic models for API request/response data * feat: Implement basic experiment management endpoints with tests * feat: Implement composite configuration endpoints for experiments * fix: Add missing DatasetConfigUpdate import in server.py * refactor: Update dataset config update method to properly merge config updates * fix: Correctly retrieve current dataset config in composite endpoint * feat: Add basic CLI structure with experiments and config commands * feat: Add initial CLI tool with basic experiment management commands * refactor: Reorganize CLI package structure and fix import paths * refactor: Implement initial CLI commands for experiment management * feat: Implement HTTP client for Reasoning Gym server in RGC CLI tool * fix: Move print statements inside try block to resolve SyntaxError * fix: Resolve SyntaxError in edit_config function by adding missing except block * feat: Add default app instance in server module for easier uvicorn startup * docs: Add README.md with server and RGC tool documentation * remove unused files * refactor: Remove unsupported type annotation in registry.py * refactor: Move ExperimentRegistry to coaching module and add Experiment class * fix: Add missing CompositeDataset import in test_registry.py * refactor: Implement lazy ASGI app creation for server initialization * feat: Add health check command to RGC CLI for server connection * feat: Add version tracking support to CompositeDataset * feat: Add DatasetVersionManager for tracking dataset versions * feat: Add entry_id metadata and score_answer_with_id method to CompositeDataset * feat: Add entry_id metadata combining version and index * fix: Resolve undefined variable by storing version_id before use * test: Add comprehensive unit tests for score_answer_with_id() function * test: Add comprehensive version tracking test for dataset config updates * feat: Validate dataset weights are positive in CompositeDataset initialization * feat: Add weight update and normalization methods to CompositeDataset * refactor: Centralize weight normalization in CompositeDataset and allow zero-weight datasets * feat: Add negative weight validation to CompositeDataset constructor * feat: Add duplicate dataset name check in CompositeDataset and update test * refactor: Move duplicate dataset name check inside dataset iteration loop * refactor: Update CompositeDataset weight management to use config as source of truth * refactor: Move duplicate dataset name check to CompositeConfig.validate() * test: Update composite dataset weight test assertions and validation * feat: Add methods to add and remove datasets in CompositeDataset * refactor: Remove weight normalization and use unnormalized weights directly * refactor: Remove redundant total weight check in update_dataset_weights * feat: Add batch generation and scoring endpoints to server * fix: Import BatchEntry in server.py to resolve undefined name error * refactor: Update ReasoningGymDataset to use server for batch generation and scoring * fix: Add missing List and Dict type imports * feat: Add get_batch() and score_outputs() methods to RGClient * test: Add unit tests for generate_batch and score_outputs endpoints * refactor: Add DatasetVersionManager to Experiment class and CompositeDataset constructor * feat: Add validation for base_index and batch_size in generate_batch endpoint * refactor: Remove unused BatchRequest type from imports * refactor: Convert models to use Pydantic exclusively * test: Update scoring endpoint tests to use correct request model format * refactor: Rename ScoreItem to AnswerItem and update related code * feat: Update scoring endpoint to return ordered ScoringResponse with scores and entry_ids * fix: Add missing ScoringResponse import in server.py * move verl ppo sample with server into own file * refactor: Use Pydantic models for get_batch() and score_outputs() in RGClient * refactor: Update client methods to use Pydantic models for type safety * refactor: Use Pydantic models for experiment and dataset config operations * refactor: Clean up duplicate methods and improve error handling in main.py * first bits of rg server use for verl * refactor: Optimize scoring with single HTTP request in _score_output * fix: Correct experiment creation with ExperimentCreate object * grpo tests with server
259 lines
8.4 KiB
Python
259 lines
8.4 KiB
Python
import os
|
|
|
|
import pytest
|
|
import yaml
|
|
|
|
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
|
|
from reasoning_gym.version_manager import DatasetVersionManager
|
|
|
|
|
|
def create_test_config(tmp_path):
|
|
"""Create a test YAML config file"""
|
|
config = {
|
|
"size": 100,
|
|
"seed": 42,
|
|
"datasets": [
|
|
{
|
|
"name": "chain_sum",
|
|
"weight": 0.3,
|
|
"config": {
|
|
"min_terms": 2,
|
|
"max_terms": 4,
|
|
},
|
|
},
|
|
{
|
|
"name": "leg_counting",
|
|
"weight": 0.7,
|
|
"config": {
|
|
"min_animals": 1,
|
|
"max_animals": 3,
|
|
},
|
|
},
|
|
],
|
|
}
|
|
|
|
config_path = os.path.join(tmp_path, "test_config.yaml")
|
|
print(config_path)
|
|
with open(config_path, "w") as f:
|
|
yaml.dump(config, f)
|
|
|
|
return config_path
|
|
|
|
|
|
def test_composite_config_validation():
|
|
"""Test configuration validation"""
|
|
with pytest.raises(AssertionError):
|
|
config = CompositeConfig(size=-1)
|
|
config.validate()
|
|
|
|
with pytest.raises(AssertionError):
|
|
config = CompositeConfig(datasets=[])
|
|
config.validate()
|
|
|
|
|
|
def test_composite_dataset_deterministic():
|
|
"""Test that dataset generates same items with same seed"""
|
|
config = CompositeConfig(
|
|
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
|
|
)
|
|
|
|
dataset1 = CompositeDataset(config)
|
|
dataset2 = CompositeDataset(config)
|
|
|
|
for i in range(len(dataset1)):
|
|
assert dataset1[i] == dataset2[i]
|
|
|
|
|
|
def test_composite_dataset_metadata():
|
|
"""Test that metadata includes source dataset information"""
|
|
config = CompositeConfig(
|
|
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
|
|
)
|
|
|
|
dataset = CompositeDataset(config)
|
|
item = dataset[0]
|
|
|
|
assert "source_dataset" in item["metadata"]
|
|
assert "source_index" in item["metadata"]
|
|
assert item["metadata"]["source_dataset"] == "chain_sum"
|
|
assert isinstance(item["metadata"]["source_index"], int)
|
|
|
|
|
|
def test_composite_dataset_weights():
|
|
"""Test that dataset weights are properly normalized"""
|
|
config = CompositeConfig(
|
|
size=1000,
|
|
seed=42,
|
|
datasets=[
|
|
DatasetSpec("chain_sum", 2.0, {"min_terms": 2}),
|
|
DatasetSpec("products", 3.0, {"min_terms": 2}),
|
|
],
|
|
)
|
|
|
|
dataset = CompositeDataset(config)
|
|
assert abs(dataset.weights[0] - 2.0) < 1e-6
|
|
assert abs(dataset.weights[1] - 3.0) < 1e-6
|
|
|
|
# Test weight updates
|
|
dataset.update_dataset_weight("chain_sum", 1.0)
|
|
print(dataset.weights)
|
|
assert abs(dataset.weights[0] - 1.0) < 1e-6
|
|
assert abs(dataset.weights[1] - 3.0) < 1e-6
|
|
|
|
# Test invalid weight
|
|
with pytest.raises(ValueError, match="Weight must be non-negative"):
|
|
dataset.update_dataset_weight("chain_sum", -1.0)
|
|
|
|
# Test invalid dataset name
|
|
with pytest.raises(KeyError):
|
|
dataset.update_dataset_weight("invalid_dataset", 1.0)
|
|
|
|
# Test zero total weight
|
|
dataset.update_dataset_weight("chain_sum", 0.0)
|
|
with pytest.raises(ValueError, match="Total of weights must be greater than zero"):
|
|
dataset.update_dataset_weight("products", 0.0)
|
|
_ = dataset[0] # access item with all weights 0
|
|
|
|
# Test duplicate dataset names
|
|
with pytest.raises(ValueError, match="Duplicate dataset names"):
|
|
CompositeConfig(
|
|
size=1000,
|
|
seed=42,
|
|
datasets=[
|
|
DatasetSpec("chain_sum", 1.0, {"min_terms": 2}),
|
|
DatasetSpec("chain_sum", 1.0, {"min_terms": 3}),
|
|
],
|
|
).validate()
|
|
|
|
|
|
def test_version_tracking_with_config_updates():
|
|
"""Test that version tracking works correctly when updating dataset configs"""
|
|
# Create composite dataset with version manager
|
|
version_manager = DatasetVersionManager()
|
|
config = CompositeConfig(
|
|
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
|
|
)
|
|
dataset = CompositeDataset(config, version_manager=version_manager)
|
|
|
|
# Get an entry and its id from initial version
|
|
entry_1 = dataset[0]
|
|
entry_id_1 = entry_1["metadata"]["entry_id"]
|
|
answer_1 = entry_1["answer"]
|
|
|
|
# Update dataset config
|
|
dataset.update_dataset_config("chain_sum", {"min_terms": 3, "max_terms": 5})
|
|
|
|
# Get new entry after config update
|
|
entry_2 = dataset[0]
|
|
entry_id_2 = entry_2["metadata"]["entry_id"]
|
|
answer_2 = entry_2["answer"]
|
|
|
|
# Verify entries have different version IDs
|
|
version_1 = int(entry_id_1.split(".")[0])
|
|
version_2 = int(entry_id_2.split(".")[0])
|
|
assert version_1 != version_2, "New config should create new version"
|
|
|
|
# Verify original answer still works with original version
|
|
score_1 = dataset.score_answer_with_id(answer_1, entry_id_1)
|
|
assert score_1 == 1.0, "Original answer should still work with original version"
|
|
|
|
# Verify new answer works with new version
|
|
score_2 = dataset.score_answer_with_id(answer_2, entry_id_2)
|
|
assert score_2 == 1.0, "New answer should work with new version"
|
|
|
|
# Verify original answer fails with new version
|
|
score_3 = dataset.score_answer_with_id(answer_1, entry_id_2)
|
|
assert score_3 < 1.0, "Original answer should not work with new version"
|
|
|
|
|
|
def test_score_answer_with_id():
|
|
"""Test scoring answers using entry_id"""
|
|
# Create composite dataset with version manager
|
|
version_manager = DatasetVersionManager()
|
|
config = CompositeConfig(
|
|
size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})]
|
|
)
|
|
dataset = CompositeDataset(config, version_manager=version_manager)
|
|
|
|
# Get an entry and its id
|
|
entry = dataset[0]
|
|
entry_id = entry["metadata"]["entry_id"]
|
|
|
|
# Test successful scoring
|
|
answer = entry["answer"]
|
|
score = dataset.score_answer_with_id(answer, entry_id)
|
|
assert score == 1.0 # Correct answer should get full score
|
|
|
|
# Test wrong answer
|
|
wrong_answer = "wrong"
|
|
score = dataset.score_answer_with_id(wrong_answer, entry_id)
|
|
assert score < 1.0 # Wrong answer should get lower score
|
|
|
|
# Test invalid entry_id format
|
|
with pytest.raises(ValueError, match="Invalid entry_id format"):
|
|
dataset.score_answer_with_id(answer, "invalid")
|
|
|
|
# Test non-existent version
|
|
with pytest.raises(KeyError, match="Version .* not found"):
|
|
dataset.score_answer_with_id(answer, "999.0")
|
|
|
|
# Test without version manager
|
|
dataset_no_vm = CompositeDataset(config)
|
|
with pytest.raises(RuntimeError, match="Version manager required"):
|
|
dataset_no_vm.score_answer_with_id(answer, entry_id)
|
|
|
|
|
|
def test_add_remove_dataset():
|
|
"""Test adding and removing datasets from composite"""
|
|
config = CompositeConfig(
|
|
size=1000,
|
|
seed=42,
|
|
datasets=[
|
|
DatasetSpec("chain_sum", 1.0, {"min_terms": 2}),
|
|
],
|
|
)
|
|
|
|
dataset = CompositeDataset(config)
|
|
|
|
# Test adding new dataset
|
|
new_spec = DatasetSpec("products", 2.0, {"min_terms": 2})
|
|
dataset.add_dataset(new_spec)
|
|
|
|
assert len(dataset.datasets) == 2
|
|
assert "products" in dataset.datasets
|
|
assert len(dataset.config.datasets) == 2
|
|
|
|
assert dataset.dataset_names[0] == "chain_sum"
|
|
assert dataset.dataset_names[1] == "products"
|
|
assert abs(dataset.weights[0] - 1.0) < 1e-6 # chain_sum weight
|
|
assert abs(dataset.weights[1] - 2.0) < 1e-6 # products weight
|
|
|
|
# Test duplicate name
|
|
with pytest.raises(ValueError, match="already exists"):
|
|
dataset.add_dataset(new_spec)
|
|
|
|
# Test removing dataset
|
|
dataset.remove_dataset("products")
|
|
assert len(dataset.datasets) == 1
|
|
assert "products" not in dataset.datasets
|
|
assert len(dataset.config.datasets) == 1
|
|
|
|
# Test removing non-existent dataset
|
|
with pytest.raises(KeyError):
|
|
dataset.remove_dataset("nonexistent")
|
|
|
|
# Test removing last dataset
|
|
with pytest.raises(ValueError, match="Cannot remove last dataset"):
|
|
dataset.remove_dataset("chain_sum")
|
|
|
|
|
|
def test_yaml_loading(tmp_path):
|
|
"""Test loading configuration from YAML"""
|
|
config_path = create_test_config(tmp_path)
|
|
config = CompositeConfig.from_yaml(config_path)
|
|
|
|
assert config.size == 100
|
|
assert config.seed == 42
|
|
assert len(config.datasets) == 2
|
|
assert config.datasets[0].name == "chain_sum"
|
|
assert config.datasets[1].name == "leg_counting"
|