diff --git a/examples/veRL/launch_on_2gpu_server.sh b/examples/veRL/launch_on_2gpu_server.sh new file mode 100755 index 00000000..4f2efc46 --- /dev/null +++ b/examples/veRL/launch_on_2gpu_server.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +export N_GPUS=2 +export BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct +export ROLLOUT_TP_SIZE=2 +export EXPERIMENT_NAME=chain_sum_llama +export VLLM_ATTENTION_BACKEND=XFORMERS + +bash ./train_grpo_server.sh diff --git a/examples/veRL/main_ppo_custom_reward_server.py b/examples/veRL/main_ppo_custom_reward_server.py new file mode 100644 index 00000000..0f20be1d --- /dev/null +++ b/examples/veRL/main_ppo_custom_reward_server.py @@ -0,0 +1,344 @@ +# This example is an adapted version of Bytedance's code: +# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py +import os +from typing import Dict, List, Optional + +import hydra +import ray +import torch +import verl.utils.torch_functional as verl_F +from omegaconf import OmegaConf, open_dict +from torch.utils.data import DataLoader, Dataset +from transformers import PreTrainedTokenizer +from verl import DataProto +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.utils.dataset.rl_dataset import collate_fn +from verl.utils.model import compute_position_id_with_mask + +import reasoning_gym +import reasoning_gym.utils +from reasoning_gym.utils import extract_answer +from tools.server.models import AnswerItem, BatchEntry, ExperimentCreate + + +class ReasoningGymDataset(Dataset): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + dataset_name: str, + seed: int, + size: int, + developer_prompt: Optional[str] = None, + developer_role: str = "system", + max_prompt_length: int = 2048, + truncation: str = "error", ## ['left', 'right', 'error'] + return_raw_chat: bool = False, + server_url: str = "http://localhost:8000", + api_key: Optional[str] = None, + batch_size: int = 32, + ): + from tools.cli.rgc.client import RGClient + + self.tokenizer = tokenizer + self.dataset_name = dataset_name + self.developer_prompt = developer_prompt + self.developer_role = developer_role + self.max_prompt_length = max_prompt_length + self.truncation = truncation + self.return_raw_chat = return_raw_chat + self.size = size + self.batch_size = batch_size + + # Initialize client and create experiment if needed + self.client = RGClient(base_url=server_url, api_key=api_key) + + # Check if experiment exists, create if not + experiments = self.client.list_experiments() + if dataset_name not in experiments.experiments: + config = ExperimentCreate( + name=dataset_name, + size=size, + seed=seed, + datasets={dataset_name: {"weight": 1.0, "config": {"seed": seed, "size": size}}}, + ) + self.client.create_experiment(dataset_name, config) + + # Cache for batches + self._batch_cache: dict[int, List[BatchEntry]] = {} + + def __len__(self) -> int: + return self.size + + def _get_batch(self, batch_idx: int) -> List[BatchEntry]: + """Fetch or retrieve cached batch""" + if batch_idx not in self._batch_cache: + base_index = batch_idx * self.batch_size + response = self.client.get_batch(self.dataset_name, base_index=base_index, batch_size=self.batch_size) + self._batch_cache[batch_idx] = response.entries + + # # Basic cache management - keep only last N batches + # if len(self._batch_cache) > 10: + # oldest_batch = min(self._batch_cache.keys()) + # del self._batch_cache[oldest_batch] + + return self._batch_cache[batch_idx] + + def __getitem__(self, index): + # Get batch containing this index + batch_idx = index // self.batch_size + + batch = self._get_batch(batch_idx) + entry = batch[index % self.batch_size] + + # Format chat/prompt + chat = [] + if self.developer_prompt is not None: + chat.append({"role": self.developer_role, "content": self.developer_prompt}) + chat.append({"role": "user", "content": entry.question}) + + prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + + # Tokenize + input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( + prompt=prompt, + tokenizer=self.tokenizer, + max_length=self.max_prompt_length, + pad_token_id=self.tokenizer.pad_token_id, + left_pad=True, + truncation=self.truncation, + ) + + position_ids = compute_position_id_with_mask(attention_mask) + + row_dict = { + "data_source": "reasoning_gym/" + self.dataset_name, + "input_ids": input_ids[0], + "attention_mask": attention_mask[0], + "position_ids": position_ids[0], + "entry_id": entry.entry_id, + "metadata": entry.metadata, + "index": index, + } + + # Add raw chat if requested + if self.return_raw_chat: + row_dict["raw_prompt"] = chat + + return row_dict + + +class RayPPOTrainerCustom(RayPPOTrainer): + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict, + resource_pool_manager, + ray_worker_group_cls, + dataset_name: str = "chain_sum", + dataset_size: int = 10000, + ): + self.dataset_name = dataset_name + self.dataset_size = dataset_size + + developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"] + rg_api_key = os.getenv("REASONING_GYM_API_KEY", "your-secret-key") + self.train_dataset = ReasoningGymDataset( + tokenizer=tokenizer, + dataset_name=self.dataset_name, + seed=1, + size=self.dataset_size, + developer_prompt=developer_prompt, + api_key=rg_api_key, + ) + + self.val_dataset = ReasoningGymDataset( + tokenizer=tokenizer, + dataset_name=self.dataset_name, + seed=2, + size=self.dataset_size, + developer_prompt=developer_prompt, + api_key=rg_api_key, + ) + + train_reward_fn = lambda data: self._score_output(data, num_examine=0) + val_reward_fn = lambda data: self._score_output(data, num_examine=1) + + super().__init__( + config, + tokenizer, + role_worker_mapping, + resource_pool_manager, + ray_worker_group_cls, + train_reward_fn, + val_reward_fn, + ) + + def _score_output(self, data: DataProto, num_examine: int = 0) -> torch.Tensor: + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) + + # Prepare batch of answers to score + answer_items = [] + valid_response_lengths = [] + sequences_strs = [] + + for i in range(len(data)): + data_item = data[i] + + # Get prompt and response + prompt_ids = data_item.batch["prompts"] + prompt_length = prompt_ids.shape[-1] + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + valid_response_lengths.append(valid_response_length) + + # Decode full sequence + sequences = torch.cat((valid_prompt_ids, valid_response_ids)) + sequences_str = self.tokenizer.decode(sequences) + sequences_strs.append(sequences_str) + + # Extract answer and prepare scoring item + found_answer = extract_answer(sequences_str, tag_name="answer") + + index = data_item.non_tensor_batch["index"] + entry_id = self.train_dataset[index]["entry_id"] + # print( + # "found_answer", + # entry_id, + # found_answer, + # ) + + answer_items.append(AnswerItem(entry_id=entry_id, answer=found_answer)) + + # Score all answers in one request + response = self.train_dataset.client.score_outputs(self.train_dataset.dataset_name, answer_items) + # print("response", response) + + # Fill reward tensor + for i, (score, valid_response_length) in enumerate(zip(response.scores, valid_response_lengths)): + reward_tensor[i, valid_response_length - 1] = score + + if i < num_examine: + print(f"reward={score}, seq={sequences_strs[i]}") + + return reward_tensor + + def _create_dataloader(self): + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.train_batch_size, + shuffle=False, + drop_last=True, + collate_fn=collate_fn, + ) + + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=len(self.val_dataset), + shuffle=False, + drop_last=True, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1 + assert len(self.val_dataloader) >= 1 + + print(f"Size of train dataloader: {len(self.train_dataloader)}") + print(f"Size of val dataloader: {len(self.val_dataloader)}") + + # inject total_training_steps to actor/critic optim_config. This is hacky. + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.critic.optim.total_training_steps = total_training_steps + + +@ray.remote +def main_task(config): + # print initial config + from pprint import pprint + + from verl.utils import hf_tokenizer + from verl.utils.fs import copy_local_path_from_hdfs + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + tokenizer = hf_tokenizer(local_path) + + # define worker classes + if config.actor_rollout_ref.actor.strategy == "fsdp": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + Role.RefPolicy: ray.remote(ActorRolloutRefWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + Role.RefPolicy: global_pool_id, + } + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayPPOTrainerCustom( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + ) + trainer.init_workers() + trainer.fit() + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) + + ray.get(main_task.remote(config)) + + +if __name__ == "__main__": + main() diff --git a/examples/veRL/train_grpo_server.sh b/examples/veRL/train_grpo_server.sh new file mode 100644 index 00000000..34b956ad --- /dev/null +++ b/examples/veRL/train_grpo_server.sh @@ -0,0 +1,39 @@ +#!/bin/bash +set -x + +python3 -u main_ppo_custom_reward_server.py \ + algorithm.adv_estimator=grpo \ + data.train_files=$DATA_DIR/train.parquet \ + data.val_files=$DATA_DIR/test.parquet \ + data.train_batch_size=32 \ + data.val_batch_size=32 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_chain_sum_grpo' \ + trainer.experiment_name=$EXPERIMENT_NAME \ + trainer.n_gpus_per_node=$N_GPUS \ + trainer.nnodes=1 \ + trainer.save_freq=100 \ + trainer.test_freq=100 \ + trainer.total_epochs=15 $@ 2>&1 | tee verl_output.log diff --git a/pyproject.toml b/pyproject.toml index 07aa57e4..4bba76fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,22 @@ license = "Apache-2.0" license-files = ["LICENSE*"] [project.optional-dependencies] -test = ["pytest>=7.0.0", "pytest-cov>=4.0.0"] +test = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "httpx>=0.27.0" +] +server = [ + "fastapi>=0.109.0", + "uvicorn>=0.27.0", + "pydantic-settings>=2.1.0", +] +cli = [ + "typer>=0.9.0", + "rich>=13.7.0", + "pyyaml>=6.0.1", + "httpx>=0.27.0", +] [project.urls] "Homepage" = "https://github.com/open-thought/reasoning-gym" @@ -40,12 +55,19 @@ test = ["pytest>=7.0.0", "pytest-cov>=4.0.0"] [tool.hatch.build] -packages = ["reasoning_gym"] -include = [ - "reasoning_gym/**/*.py", - "reasoning_gym/**/*.txt", - "reasoning_gym/**/levels/*", +packages = [ + "reasoning_gym", + "tools.cli.rgc" ] +include = [ + "reasoning_gym/**/*.py", + "reasoning_gym/**/*.txt", + "reasoning_gym/**/levels/*", + "tools/cli/rgc/**/*.py" +] + +[project.scripts] +rgc = "tools.cli.rgc.main:main" [tool.black] line-length = 120 diff --git a/reasoning_gym/coaching/experiment.py b/reasoning_gym/coaching/experiment.py new file mode 100644 index 00000000..d3a9e00f --- /dev/null +++ b/reasoning_gym/coaching/experiment.py @@ -0,0 +1,36 @@ +"""Experiment class combining dataset, scoreboard and curriculum.""" + +from dataclasses import dataclass +from typing import Optional + +from ..composite import CompositeConfig, CompositeDataset +from ..version_manager import DatasetVersionManager +from .coach import ScoreBoard + + +@dataclass +class Experiment: + """ + An experiment combines a dataset with scoring and curriculum management. + + Attributes: + name: Unique identifier for the experiment + dataset: The composite dataset for generating examples + score_board: Tracks performance metrics + config: The configuration used to create the dataset + version_manager: Manages dataset versions for scoring + """ + + name: str + dataset: CompositeDataset + score_board: ScoreBoard + config: CompositeConfig + version_manager: DatasetVersionManager + + @classmethod + def create(cls, name: str, config: CompositeConfig) -> "Experiment": + """Create a new experiment from a configuration.""" + version_manager = DatasetVersionManager() + dataset = CompositeDataset(config, version_manager=version_manager) + score_board = ScoreBoard() + return cls(name=name, dataset=dataset, score_board=score_board, config=config, version_manager=version_manager) diff --git a/reasoning_gym/coaching/registry.py b/reasoning_gym/coaching/registry.py new file mode 100644 index 00000000..5d7fdd1a --- /dev/null +++ b/reasoning_gym/coaching/registry.py @@ -0,0 +1,34 @@ +"""Registry for managing active experiments.""" + +from typing import Dict, List, Optional + +from ..composite import CompositeConfig +from .experiment import Experiment + + +class ExperimentRegistry: + """Singleton registry for managing active experiments.""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._experiments = {} + return cls._instance + + def register_experiment(self, name: str, config: CompositeConfig) -> None: + """Register a new experiment with the given name and configuration.""" + self._experiments[name] = Experiment.create(name, config) + + def get_experiment(self, name: str) -> Optional[Experiment]: + """Get an experiment by name.""" + return self._experiments.get(name) + + def list_experiments(self) -> List[str]: + """List all registered experiment names.""" + return list(self._experiments.keys()) + + def remove_experiment(self, name: str) -> bool: + """Remove an experiment by name. Returns True if removed, False if not found.""" + return bool(self._experiments.pop(name, None)) diff --git a/reasoning_gym/composite.py b/reasoning_gym/composite.py index 2050ddd1..b30151fb 100644 --- a/reasoning_gym/composite.py +++ b/reasoning_gym/composite.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, replace from random import Random from typing import Any, Dict, List, Optional @@ -6,6 +6,7 @@ import yaml from .dataset import ProceduralDataset from .factory import create_dataset, register_dataset +from .version_manager import DatasetVersionManager @dataclass @@ -37,6 +38,11 @@ class CompositeConfig: assert self.datasets, "Must specify at least one dataset" assert len(self.datasets) > 0, "Must specify at least one dataset" + # Check for duplicate dataset names + dataset_names = [ds.name for ds in self.datasets] + if len(dataset_names) != len(set(dataset_names)): + raise ValueError("Duplicate dataset names are not allowed in CompositeDataset") + # Validate each dataset spec for ds in self.datasets: ds.validate() @@ -57,13 +63,14 @@ class CompositeConfig: class CompositeDataset(ProceduralDataset): """A dataset that combines multiple datasets with weighted sampling""" - def __init__(self, config: CompositeConfig): + def __init__(self, config: CompositeConfig, version_manager: Optional[DatasetVersionManager] = None): super().__init__(config=config, seed=config.seed, size=config.size) + self.version_manager = version_manager + self.dataset_versions = {} # dataset_name -> version_id # Initialize sub-datasets with incremented seeds self.datasets = {} self.weights = [] - total_weight = 0.0 for i, ds_spec in enumerate(config.datasets): # Create dataset with derived seed @@ -73,12 +80,18 @@ class CompositeDataset(ProceduralDataset): if "size" not in ds_config: ds_config["size"] = self.size - self.datasets[ds_spec.name] = create_dataset(ds_spec.name, **ds_config) - total_weight += ds_spec.weight - self.weights.append(ds_spec.weight) + if ds_spec.weight < 0: + raise ValueError(f"Dataset '{ds_spec.name}' has invalid weight {ds_spec.weight}, must be non-negative") - # Normalize weights - self.weights = [w / total_weight for w in self.weights] + dataset = create_dataset(ds_spec.name, **ds_config) + self.datasets[ds_spec.name] = dataset + + # Register version if tracking enabled + if version_manager is not None: + version_id = version_manager.register_dataset(ds_spec.name, dataset) + self.dataset_versions[ds_spec.name] = version_id + + self.weights.append(ds_spec.weight) # Store unnormalized weights directly self.dataset_names = [ds.name for ds in config.datasets] def __getitem__(self, idx: int) -> dict: @@ -98,6 +111,13 @@ class CompositeDataset(ProceduralDataset): item["metadata"]["source_dataset"] = dataset_name item["metadata"]["source_index"] = idx + # Add version info if tracking enabled + if self.version_manager is not None: + version_id = self.dataset_versions[dataset_name] + item["metadata"]["version_id"] = version_id + # Add entry_id combining version and index + item["metadata"]["entry_id"] = f"{version_id}.{idx}" + return item def update_dataset_config(self, dataset_name: str, config_updates: Dict[str, Any]) -> None: @@ -116,23 +136,151 @@ class CompositeDataset(ProceduralDataset): dataset = self.datasets[dataset_name] - # Create new config with updates - new_config = dataset.config.__class__(**vars(dataset.config)) - for key, value in config_updates.items(): - setattr(new_config, key, value) + # Update the current config + new_config = replace(dataset.config, **config_updates) # Validate new config new_config.validate() # Create new dataset instance with updated config dataset_cls = dataset.__class__ - self.datasets[dataset_name] = dataset_cls(new_config) + new_dataset = dataset_cls(new_config) + self.datasets[dataset_name] = new_dataset + + # Register new version if tracking enabled + if self.version_manager is not None: + version_id = self.version_manager.register_dataset(dataset_name, new_dataset) + self.dataset_versions[dataset_name] = version_id + + def update_dataset_weight(self, dataset_name: str, weight: float) -> None: + """Update weight for a specific dataset in the configuration + + Args: + dataset_name: Name of the dataset to update + weight: New weight value + + Raises: + KeyError: If dataset_name not found + ValueError: If weight is negative + """ + if dataset_name not in self.datasets: + raise KeyError(f"Dataset '{dataset_name}' not found") + if weight < 0: + raise ValueError(f"Weight must be non-negative, got {weight}") + + # Update weight in both config and weights list + for i, ds_spec in enumerate(self.config.datasets): + if ds_spec.name == dataset_name: + ds_spec.weight = weight + self.weights[i] = weight + break def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: """Forward scoring to appropriate dataset""" dataset_name = entry["metadata"]["source_dataset"] return self.datasets[dataset_name].score_answer(answer, entry) + def add_dataset(self, dataset_spec: DatasetSpec) -> None: + """Add a new dataset to the composite + + Args: + dataset_spec: Specification for the dataset to add + + Raises: + ValueError: If dataset name already exists + """ + # Validate spec + dataset_spec.validate() + + # Check for duplicate name + if dataset_spec.name in self.datasets: + raise ValueError(f"Dataset '{dataset_spec.name}' already exists in composite") + + # Create dataset with derived seed + ds_config = dataset_spec.config.copy() + if "seed" not in ds_config: + ds_config["seed"] = self.seed + len(self.datasets) + 1 + if "size" not in ds_config: + ds_config["size"] = self.size + + # Create and add dataset + dataset = create_dataset(dataset_spec.name, **ds_config) + self.datasets[dataset_spec.name] = dataset + + # Register version if tracking enabled + if self.version_manager is not None: + version_id = self.version_manager.register_dataset(dataset_spec.name, dataset) + self.dataset_versions[dataset_spec.name] = version_id + + # Add to config and update internal state + self.config.datasets.append(dataset_spec) + self.dataset_names.append(dataset_spec.name) + self.weights.append(dataset_spec.weight) # Use weight directly from spec + + def remove_dataset(self, dataset_name: str) -> None: + """Remove a dataset from the composite + + Args: + dataset_name: Name of the dataset to remove + + Raises: + KeyError: If dataset not found + ValueError: If trying to remove last dataset + """ + if dataset_name not in self.datasets: + raise KeyError(f"Dataset '{dataset_name}' not found") + + if len(self.datasets) <= 1: + raise ValueError("Cannot remove last dataset from composite") + + # Remove from all internal structures + del self.datasets[dataset_name] + if self.version_manager is not None: + del self.dataset_versions[dataset_name] + + # Remove from config + self.config.datasets = [ds for ds in self.config.datasets if ds.name != dataset_name] + + # Update internal state + idx = self.dataset_names.index(dataset_name) + self.dataset_names.pop(idx) + self.weights.pop(idx) + + def score_answer_with_id(self, answer: Optional[str], entry_id: str) -> float: + """Score an answer using an entry_id to lookup the original entry + + Args: + answer: The answer to score + entry_id: String in format "version_id.index" + + Returns: + Score between 0 and 1 + + Raises: + ValueError: If entry_id format is invalid + KeyError: If version not found in version manager + """ + if self.version_manager is None: + raise RuntimeError("Version manager required for scoring with entry_id") + + try: + version_id, index = map(int, entry_id.split(".")) + except ValueError: + raise ValueError(f"Invalid entry_id format: {entry_id}, expected 'version_id.index'") + + # Get dataset from version manager + dataset_info = self.version_manager.get_dataset(version_id) + if dataset_info is None: + raise KeyError(f"Version {version_id} not found in version manager") + + dataset_name, dataset = dataset_info + + # Get entry from dataset + entry = dataset[index] + + # Score answer using dataset's scoring function + return dataset.score_answer(answer, entry) + # Register the dataset register_dataset("composite", CompositeDataset, CompositeConfig) diff --git a/reasoning_gym/version_manager.py b/reasoning_gym/version_manager.py new file mode 100644 index 00000000..dbe19a09 --- /dev/null +++ b/reasoning_gym/version_manager.py @@ -0,0 +1,76 @@ +"""Version manager for tracking dataset versions.""" + +from typing import Dict, Optional, Tuple + +from .dataset import ProceduralDataset + + +class DatasetVersionManager: + """Manages versioned ProceduralDataset instances and their configurations.""" + + def __init__(self): + """Initialize the version manager.""" + self.current_version = 0 + # version_id -> (dataset_name, dataset_instance) + self.datasets: Dict[int, Tuple[str, ProceduralDataset]] = {} + + def register_dataset(self, name: str, dataset: ProceduralDataset) -> int: + """ + Register a new dataset version. + + Args: + name: Name/identifier of the dataset type + dataset: Instance of ProceduralDataset + + Returns: + version_id: Unique identifier for this dataset version + """ + self.current_version += 1 + self.datasets[self.current_version] = (name, dataset) + return self.current_version + + def get_dataset(self, version_id: int) -> Optional[Tuple[str, ProceduralDataset]]: + """ + Retrieve a dataset by its version ID. + + Args: + version_id: The version identifier + + Returns: + Tuple of (dataset_name, dataset_instance) if found, None otherwise + """ + return self.datasets.get(version_id) + + def get_entry(self, version_id: int, index: int) -> Dict[str, any]: + """ + Get a specific entry from a versioned dataset. + + Args: + version_id: The version identifier + index: Index of the entry to retrieve + + Returns: + The dataset entry + + Raises: + KeyError: If version_id is not found + """ + if version_id not in self.datasets: + raise KeyError(f"Dataset version {version_id} not found") + + _, dataset = self.datasets[version_id] + return dataset[index] + + def cleanup_old_versions(self, keep_latest: int = 10): + """ + Remove old dataset versions to free memory. + + Args: + keep_latest: Number of most recent versions to keep + """ + if len(self.datasets) <= keep_latest: + return + + versions_to_remove = sorted(self.datasets.keys())[:-keep_latest] + for version in versions_to_remove: + del self.datasets[version] diff --git a/tests/test_composite.py b/tests/test_composite.py index cbfec38a..93cc6f0b 100644 --- a/tests/test_composite.py +++ b/tests/test_composite.py @@ -4,6 +4,7 @@ import pytest import yaml from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec +from reasoning_gym.version_manager import DatasetVersionManager def create_test_config(tmp_path): @@ -85,13 +86,165 @@ def test_composite_dataset_weights(): seed=42, datasets=[ DatasetSpec("chain_sum", 2.0, {"min_terms": 2}), - DatasetSpec("chain_sum", 3.0, {"min_terms": 3}), + DatasetSpec("products", 3.0, {"min_terms": 2}), ], ) dataset = CompositeDataset(config) - assert abs(dataset.weights[0] - 0.4) < 1e-6 - assert abs(dataset.weights[1] - 0.6) < 1e-6 + assert abs(dataset.weights[0] - 2.0) < 1e-6 + assert abs(dataset.weights[1] - 3.0) < 1e-6 + + # Test weight updates + dataset.update_dataset_weight("chain_sum", 1.0) + print(dataset.weights) + assert abs(dataset.weights[0] - 1.0) < 1e-6 + assert abs(dataset.weights[1] - 3.0) < 1e-6 + + # Test invalid weight + with pytest.raises(ValueError, match="Weight must be non-negative"): + dataset.update_dataset_weight("chain_sum", -1.0) + + # Test invalid dataset name + with pytest.raises(KeyError): + dataset.update_dataset_weight("invalid_dataset", 1.0) + + # Test zero total weight + dataset.update_dataset_weight("chain_sum", 0.0) + with pytest.raises(ValueError, match="Total of weights must be greater than zero"): + dataset.update_dataset_weight("products", 0.0) + _ = dataset[0] # access item with all weights 0 + + # Test duplicate dataset names + with pytest.raises(ValueError, match="Duplicate dataset names"): + CompositeConfig( + size=1000, + seed=42, + datasets=[ + DatasetSpec("chain_sum", 1.0, {"min_terms": 2}), + DatasetSpec("chain_sum", 1.0, {"min_terms": 3}), + ], + ).validate() + + +def test_version_tracking_with_config_updates(): + """Test that version tracking works correctly when updating dataset configs""" + # Create composite dataset with version manager + version_manager = DatasetVersionManager() + config = CompositeConfig( + size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})] + ) + dataset = CompositeDataset(config, version_manager=version_manager) + + # Get an entry and its id from initial version + entry_1 = dataset[0] + entry_id_1 = entry_1["metadata"]["entry_id"] + answer_1 = entry_1["answer"] + + # Update dataset config + dataset.update_dataset_config("chain_sum", {"min_terms": 3, "max_terms": 5}) + + # Get new entry after config update + entry_2 = dataset[0] + entry_id_2 = entry_2["metadata"]["entry_id"] + answer_2 = entry_2["answer"] + + # Verify entries have different version IDs + version_1 = int(entry_id_1.split(".")[0]) + version_2 = int(entry_id_2.split(".")[0]) + assert version_1 != version_2, "New config should create new version" + + # Verify original answer still works with original version + score_1 = dataset.score_answer_with_id(answer_1, entry_id_1) + assert score_1 == 1.0, "Original answer should still work with original version" + + # Verify new answer works with new version + score_2 = dataset.score_answer_with_id(answer_2, entry_id_2) + assert score_2 == 1.0, "New answer should work with new version" + + # Verify original answer fails with new version + score_3 = dataset.score_answer_with_id(answer_1, entry_id_2) + assert score_3 < 1.0, "Original answer should not work with new version" + + +def test_score_answer_with_id(): + """Test scoring answers using entry_id""" + # Create composite dataset with version manager + version_manager = DatasetVersionManager() + config = CompositeConfig( + size=10, seed=42, datasets=[DatasetSpec("chain_sum", 1.0, {"min_terms": 2, "max_terms": 4})] + ) + dataset = CompositeDataset(config, version_manager=version_manager) + + # Get an entry and its id + entry = dataset[0] + entry_id = entry["metadata"]["entry_id"] + + # Test successful scoring + answer = entry["answer"] + score = dataset.score_answer_with_id(answer, entry_id) + assert score == 1.0 # Correct answer should get full score + + # Test wrong answer + wrong_answer = "wrong" + score = dataset.score_answer_with_id(wrong_answer, entry_id) + assert score < 1.0 # Wrong answer should get lower score + + # Test invalid entry_id format + with pytest.raises(ValueError, match="Invalid entry_id format"): + dataset.score_answer_with_id(answer, "invalid") + + # Test non-existent version + with pytest.raises(KeyError, match="Version .* not found"): + dataset.score_answer_with_id(answer, "999.0") + + # Test without version manager + dataset_no_vm = CompositeDataset(config) + with pytest.raises(RuntimeError, match="Version manager required"): + dataset_no_vm.score_answer_with_id(answer, entry_id) + + +def test_add_remove_dataset(): + """Test adding and removing datasets from composite""" + config = CompositeConfig( + size=1000, + seed=42, + datasets=[ + DatasetSpec("chain_sum", 1.0, {"min_terms": 2}), + ], + ) + + dataset = CompositeDataset(config) + + # Test adding new dataset + new_spec = DatasetSpec("products", 2.0, {"min_terms": 2}) + dataset.add_dataset(new_spec) + + assert len(dataset.datasets) == 2 + assert "products" in dataset.datasets + assert len(dataset.config.datasets) == 2 + + assert dataset.dataset_names[0] == "chain_sum" + assert dataset.dataset_names[1] == "products" + assert abs(dataset.weights[0] - 1.0) < 1e-6 # chain_sum weight + assert abs(dataset.weights[1] - 2.0) < 1e-6 # products weight + + # Test duplicate name + with pytest.raises(ValueError, match="already exists"): + dataset.add_dataset(new_spec) + + # Test removing dataset + dataset.remove_dataset("products") + assert len(dataset.datasets) == 1 + assert "products" not in dataset.datasets + assert len(dataset.config.datasets) == 1 + + # Test removing non-existent dataset + with pytest.raises(KeyError): + dataset.remove_dataset("nonexistent") + + # Test removing last dataset + with pytest.raises(ValueError, match="Cannot remove last dataset"): + dataset.remove_dataset("chain_sum") def test_yaml_loading(tmp_path): diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 00000000..f61171c1 --- /dev/null +++ b/tools/README.md @@ -0,0 +1,83 @@ +# Reasoning Gym Tools + +This directory contains additional tools for working with Reasoning Gym: + +## Server + +A FastAPI server that manages reasoning gym experiments, allowing runtime configuration and monitoring. + +### Starting the Server + +1. Install server dependencies: +```bash +pip install -e ".[server]" +``` + +2. Set the API key environment variable: +```bash +export REASONING_GYM_API_KEY=your-secret-key +``` + +3. Start the server: +```bash +uvicorn tools.server.server:app +``` + +The server will be available at http://localhost:8000. You can access the API documentation at http://localhost:8000/docs. + +## RGC (Reasoning Gym Client) + +A command-line interface for interacting with the Reasoning Gym server. + +### Installation + +```bash +pip install -e ".[cli]" +``` + +### Usage + +First, set the API key to match your server: +```bash +export REASONING_GYM_API_KEY=your-secret-key +``` + +Then you can use the CLI: + +```bash +# List all commands +rgc --help + +# List experiments +rgc experiments list + +# Create a new experiment interactively +rgc experiments create my-experiment + +# Create from config file +rgc experiments create my-experiment -f config.yaml + +# Show experiment details +rgc experiments show my-experiment + +# Edit dataset configuration +rgc config edit my-experiment chain_sum +``` + +### Example Configuration File + +Here's an example `config.yaml` for creating an experiment: + +```yaml +size: 500 +seed: 42 +datasets: + chain_sum: + weight: 1.0 + config: + min_terms: 2 + max_terms: 4 + min_digits: 1 + max_digits: 2 + allow_negation: false +``` diff --git a/tools/cli/__init__.py b/tools/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/cli/rgc/__init__.py b/tools/cli/rgc/__init__.py new file mode 100644 index 00000000..7286e761 --- /dev/null +++ b/tools/cli/rgc/__init__.py @@ -0,0 +1,5 @@ +"""Reasoning Gym CLI tool.""" + +from .main import main + +__all__ = ["main"] diff --git a/tools/cli/rgc/client.py b/tools/cli/rgc/client.py new file mode 100644 index 00000000..4808a0fe --- /dev/null +++ b/tools/cli/rgc/client.py @@ -0,0 +1,125 @@ +"""HTTP client for interacting with the Reasoning Gym server.""" + +import os +from typing import List, Optional + +import httpx +from rich.console import Console + +from tools.server.models import ( + AnswerItem, + BatchResponse, + DatasetConfigUpdate, + ExperimentCreate, + ExperimentList, + ExperimentResponse, + ScoringRequest, + ScoringResponse, +) + +console = Console() + +DEFAULT_SERVER = "http://localhost:8000" +API_KEY = os.getenv("REASONING_GYM_API_KEY", "default-key") + + +class RGClient: + """Client for interacting with Reasoning Gym server.""" + + def __init__(self, base_url: str = DEFAULT_SERVER, api_key: str = API_KEY): + """Initialize client with server URL and API key.""" + self.base_url = base_url.rstrip("/") + self.headers = {"X-API-Key": api_key} + + def _url(self, path: str) -> str: + """Construct full URL for given path.""" + return f"{self.base_url}/{path.lstrip('/')}" + + def check_health(self) -> bool: + """Check server health status.""" + try: + response = httpx.get(self._url("/health"), headers=self.headers) + response.raise_for_status() + return response.json()["status"] == "healthy" + except Exception: + return False + + def list_experiments(self) -> ExperimentList: + """List all registered experiments.""" + response = httpx.get(self._url("/experiments"), headers=self.headers) + response.raise_for_status() + return ExperimentList.model_validate(response.json()) + + def create_experiment(self, name: str, config: ExperimentCreate) -> ExperimentResponse: + """Create a new experiment.""" + response = httpx.post( + self._url("/experiments"), + headers=self.headers, + json=config.model_dump(), + ) + response.raise_for_status() + return ExperimentResponse.model_validate(response.json()) + + def delete_experiment(self, name: str) -> None: + """Delete an experiment.""" + response = httpx.delete( + self._url(f"/experiments/{name}"), + headers=self.headers, + ) + response.raise_for_status() + + def get_experiment_config(self, name: str) -> ExperimentResponse: + """Get experiment configuration.""" + response = httpx.get( + self._url(f"/experiments/{name}/composite"), + headers=self.headers, + ) + response.raise_for_status() + return ExperimentResponse.model_validate(response.json()) + + def update_dataset_config(self, experiment: str, dataset: str, config: DatasetConfigUpdate) -> None: + """Update dataset configuration.""" + response = httpx.post( + self._url(f"/experiments/{experiment}/composite/{dataset}"), + headers=self.headers, + json=config.model_dump(), + ) + response.raise_for_status() + + def get_batch(self, experiment: str, base_index: int, batch_size: int) -> BatchResponse: + """Get a batch of entries from an experiment. + + Args: + experiment: Name of the experiment + base_index: Starting index for the batch + batch_size: Number of entries to retrieve + + Returns: + BatchResponse containing entries with questions and metadata + """ + response = httpx.get( + self._url(f"/experiments/{experiment}/batch"), + headers=self.headers, + params={"base_index": base_index, "batch_size": batch_size}, + ) + response.raise_for_status() + return BatchResponse.model_validate(response.json()) + + def score_outputs(self, experiment: str, entry_answers: List[AnswerItem]) -> ScoringResponse: + """Score a batch of answers. + + Args: + experiment: Name of the experiment + entry_answers: List of AnswerItems with entry_ids and answers to score + + Returns: + ScoringResponse containing scores and entry_ids + """ + request = ScoringRequest(answers=entry_answers) + response = httpx.post( + self._url(f"/experiments/{experiment}/score"), + headers=self.headers, + json=request.model_dump(), + ) + response.raise_for_status() + return ScoringResponse.model_validate(response.json()) diff --git a/tools/cli/rgc/main.py b/tools/cli/rgc/main.py new file mode 100644 index 00000000..827c413a --- /dev/null +++ b/tools/cli/rgc/main.py @@ -0,0 +1,231 @@ +"""Main entry point for the Reasoning Gym CLI.""" + +import os +from typing import Optional + +import typer +import yaml +from rich.console import Console +from rich.prompt import Confirm, Prompt +from rich.syntax import Syntax +from rich.table import Table + +from tools.server.models import DatasetConfigUpdate, ExperimentCreate + +# Initialize Typer apps +app = typer.Typer( + name="rgc", + help="Reasoning Gym CLI - Manage and monitor reasoning gym experiments", + add_completion=True, +) +experiments_app = typer.Typer(help="Manage experiments") +config_app = typer.Typer(help="Manage configurations") + +app.add_typer(experiments_app, name="experiments") +app.add_typer(config_app, name="config") + + +@app.command("health") +def check_health(): + """Check server connection and health status.""" + try: + if client.check_health(): + console.print("[green]Server is healthy[/]") + else: + console.print("[red]Server is not responding correctly[/]") + raise typer.Exit(1) + except Exception as e: + console.print(f"[red]Error connecting to server: {e}[/]") + raise typer.Exit(1) + + +# Initialize client and console +from .client import RGClient + +client = RGClient() +console = Console() + + +@experiments_app.command("list") +def list_experiments(): + """List all registered experiments with their status.""" + table = Table(title="Registered Experiments") + table.add_column("Name", style="cyan") + table.add_column("Datasets", style="magenta") + table.add_column("Size", style="blue") + table.add_column("Seed", style="green") + + try: + experiments = client.list_experiments() + for exp_name in experiments.experiments: + try: + config = client.get_experiment_config(exp_name) + datasets = ", ".join(config.datasets.keys()) + table.add_row(exp_name, datasets, str(config.size), str(config.seed or "")) + except Exception as e: + console.print(f"[yellow]Warning: Could not get config for {exp_name}: {e}[/]") + table.add_row(exp_name, "?", "?", "?") + except Exception as e: + console.print(f"[red]Error listing experiments: {e}[/]") + raise typer.Exit(1) + + console.print(table) + + +@experiments_app.command("create") +def create_experiment( + name: str = typer.Argument(..., help="Name of the experiment"), + config_file: Optional[str] = typer.Option(None, "--file", "-f", help="YAML configuration file"), +): + """Create a new experiment.""" + if config_file: + try: + with open(config_file, "r") as f: + exp_config = yaml.safe_load(f) + config = ExperimentCreate(**exp_config) + response = client.create_experiment(name, config) + console.print(f"[green]Created experiment[/] [cyan]{response.name}[/]") + except Exception as e: + console.print(f"[red]Error creating experiment: {e}[/]") + raise typer.Exit(1) + else: + # Interactive creation + size = Prompt.ask("Dataset size", default="500") + seed = Prompt.ask("Random seed (optional)", default="") + + datasets = {} + while Confirm.ask("Add dataset?"): + ds_name = Prompt.ask("Dataset name") + weight = float(Prompt.ask("Weight", default="1.0")) + + # Get dataset-specific config + console.print("\nEnter dataset configuration:") + config = {} + while Confirm.ask("Add config parameter?"): + key = Prompt.ask("Parameter name") + value = Prompt.ask("Parameter value") + try: + # Try to convert to appropriate type + if value.isdigit(): + value = int(value) + elif value.lower() in ("true", "false"): + value = value.lower() == "true" + elif "." in value and value.replace(".", "").isdigit(): + value = float(value) + except ValueError: + pass + config[key] = value + + datasets[ds_name] = {"weight": weight, "config": config} + + # Create experiment config + exp_config = {"name": name, "size": int(size), "seed": int(seed) if seed else None, "datasets": datasets} + + # Show final config + console.print("\nFinal configuration:") + console.print(Syntax(yaml.dump(exp_config), "yaml")) + + if Confirm.ask("Create experiment with this configuration?"): + try: + config = ExperimentCreate(**exp_config) + response = client.create_experiment(name, config) + console.print(f"[green]Created experiment[/] [cyan]{response.name}[/]") + except Exception as e: + console.print(f"[red]Error creating experiment: {e}[/]") + raise typer.Exit(1) + else: + console.print("[yellow]Experiment creation cancelled[/]") + raise typer.Exit() + + +@experiments_app.command("delete") +def delete_experiment( + name: str = typer.Argument(..., help="Name of the experiment to delete"), + force: bool = typer.Option(False, "--force", "-f", help="Force deletion without confirmation"), +): + """Delete an experiment.""" + if not force and not Confirm.ask(f"Delete experiment [cyan]{name}[/]?"): + raise typer.Exit() + + try: + client.delete_experiment(name) + console.print(f"[green]Deleted experiment[/] [cyan]{name}[/]") + except Exception as e: + console.print(f"[red]Error deleting experiment: {e}[/]") + raise typer.Exit(1) + + +@experiments_app.command("show") +def show_experiment( + name: str = typer.Argument(..., help="Name of the experiment"), +): + """Show experiment details.""" + try: + config = client.get_experiment_config(name) + console.print(Syntax(yaml.dump(config.model_dump()), "yaml")) + except Exception as e: + console.print(f"[red]Error getting experiment config: {e}[/]") + raise typer.Exit(1) + + +@config_app.command("edit") +def edit_config( + experiment: str = typer.Argument(..., help="Name of the experiment"), + dataset: str = typer.Argument(..., help="Name of the dataset to edit"), +): + """Interactive configuration editor.""" + try: + exp_config = client.get_experiment_config(experiment) + if dataset not in exp_config.datasets: + console.print(f"[red]Dataset {dataset} not found in experiment[/]") + raise typer.Exit(1) + current_config = exp_config.datasets[dataset]["config"] + + console.print(f"\nCurrent configuration for [cyan]{dataset}[/]:") + console.print(Syntax(yaml.dump(current_config), "yaml")) + + # Interactive editing + new_config = {} + for key, value in current_config.items(): + new_value = Prompt.ask(f"{key}", default=str(value), show_default=True) + + # Try to convert to appropriate type + try: + if isinstance(value, bool): + new_value = new_value.lower() == "true" + elif isinstance(value, int): + new_value = int(new_value) + elif isinstance(value, float): + new_value = float(new_value) + except ValueError: + console.print(f"[yellow]Warning: Could not convert {new_value} to {type(value)}[/]") + + new_config[key] = new_value + + # Show changes + console.print("\nNew configuration:") + console.print(Syntax(yaml.dump(new_config), "yaml")) + + if Confirm.ask("Apply these changes?"): + try: + config_update = DatasetConfigUpdate(config=new_config) + client.update_dataset_config(experiment, dataset, config_update) + console.print("[green]Configuration updated successfully[/]") + except Exception as e: + console.print(f"[red]Error updating configuration: {e}[/]") + raise typer.Exit(1) + else: + console.print("[yellow]Update cancelled[/]") + + except Exception as e: + console.print(f"[red]Error getting experiment configuration: {e}[/]") + raise typer.Exit(1) + + +def main(): + """Entry point for the CLI.""" + app() + + +if __name__ == "__main__": + main() diff --git a/tools/server/__init__.py b/tools/server/__init__.py new file mode 100644 index 00000000..64926c4c --- /dev/null +++ b/tools/server/__init__.py @@ -0,0 +1,8 @@ +""" +Reasoning Gym Server - A FastAPI server for managing reasoning gym experiments. +""" + +from .config import ServerConfig +from .server import create_app + +__all__ = ["create_app", "ServerConfig"] diff --git a/tools/server/config.py b/tools/server/config.py new file mode 100644 index 00000000..5957b947 --- /dev/null +++ b/tools/server/config.py @@ -0,0 +1,17 @@ +"""Server configuration using Pydantic settings management.""" + +from pydantic import ConfigDict, Field +from pydantic_settings import BaseSettings + + +class ServerConfig(BaseSettings): + """Configuration settings for the Reasoning Gym server.""" + + host: str = Field(default="localhost", description="Server host address") + port: int = Field(default=8000, description="Server port") + api_key: str = Field( + default=..., description="API key for authentication", json_schema_extra={"env": "REASONING_GYM_API_KEY"} + ) + log_level: str = Field(default="INFO", description="Logging level") + + model_config = ConfigDict(env_prefix="REASONING_GYM_") diff --git a/tools/server/middleware.py b/tools/server/middleware.py new file mode 100644 index 00000000..24920cb6 --- /dev/null +++ b/tools/server/middleware.py @@ -0,0 +1,23 @@ +"""API key middleware for FastAPI.""" + +from fastapi import HTTPException, Request +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.status import HTTP_401_UNAUTHORIZED + + +class APIKeyMiddleware(BaseHTTPMiddleware): + """Middleware to check for valid API key in request headers.""" + + def __init__(self, app, api_key: str): + super().__init__(app) + self.api_key = api_key + + async def dispatch(self, request: Request, call_next): + if request.url.path == "/health": + return await call_next(request) + + api_key = request.headers.get("X-API-Key") + if not api_key or api_key != self.api_key: + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid or missing API key") + + return await call_next(request) diff --git a/tools/server/models.py b/tools/server/models.py new file mode 100644 index 00000000..6c873b08 --- /dev/null +++ b/tools/server/models.py @@ -0,0 +1,75 @@ +"""Pydantic models for API request/response data.""" + +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field + + +class ExperimentCreate(BaseModel): + """Request model for creating a new experiment.""" + + name: str = Field(..., description="Unique name for the experiment") + size: int = Field(500, description="Size of the dataset") + seed: Optional[int] = Field(None, description="Random seed for reproducibility") + datasets: Dict[str, Dict[str, Any]] = Field(..., description="Dictionary of datasets configurations") + + +class ExperimentResponse(BaseModel): + """Response model for experiment operations.""" + + name: str = Field(..., description="Name of the experiment") + size: int = Field(..., description="Size of the dataset") + seed: Optional[int] = Field(None, description="Random seed used") + datasets: Dict[str, Dict[str, Any]] = Field(..., description="Current dataset configurations") + + +class ExperimentList(BaseModel): + """Response model for listing experiments.""" + + experiments: List[str] = Field(default_factory=list, description="List of registered experiment names") + + +class DatasetConfigUpdate(BaseModel): + """Request model for updating dataset configuration.""" + + config: Dict[str, Any] = Field(..., description="Configuration parameters to update") + + +class ErrorResponse(BaseModel): + """Response model for error conditions.""" + + detail: str = Field(..., description="Error message") + + +class BatchEntry(BaseModel): + """Single entry in a batch""" + + question: str = Field(..., description="The question text") + entry_id: str = Field(..., description="Unique identifier in format '{version}.{index}'") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata about the entry") + + +class BatchResponse(BaseModel): + """Response containing a batch of entries""" + + entries: List[BatchEntry] = Field(..., description="List of batch entries") + + +class AnswerItem(BaseModel): + """Single score item containing entry_id and answer""" + + entry_id: str = Field(..., description="Entry identifier to score") + answer: str = Field(..., description="Answer to evaluate") + + +class ScoringRequest(BaseModel): + """Request for scoring model outputs""" + + answers: List[AnswerItem] = Field(..., description="List of entries to score") + + +class ScoringResponse(BaseModel): + """Response containing scores for answers""" + + scores: List[float] = Field(..., description="List of scores in same order as request") + entry_ids: List[str] = Field(..., description="List of entry_ids in same order as request") diff --git a/tools/server/server.py b/tools/server/server.py new file mode 100644 index 00000000..09ded0d9 --- /dev/null +++ b/tools/server/server.py @@ -0,0 +1,169 @@ +"""FastAPI server implementation for Reasoning Gym.""" + +import logging + +from fastapi import FastAPI, HTTPException + +from reasoning_gym.coaching.registry import ExperimentRegistry +from reasoning_gym.composite import CompositeConfig, DatasetSpec + +from .config import ServerConfig +from .middleware import APIKeyMiddleware +from .models import ( + BatchEntry, + BatchResponse, + DatasetConfigUpdate, + ExperimentCreate, + ExperimentList, + ExperimentResponse, + ScoringRequest, + ScoringResponse, +) + + +def create_app(config: ServerConfig) -> FastAPI: + """Create and configure the FastAPI application.""" + + # Configure logging + logging.basicConfig(level=config.log_level) + logger = logging.getLogger(__name__) + + # Create FastAPI app + app = FastAPI(title="Reasoning Gym Server") + + # Add middleware + app.add_middleware(APIKeyMiddleware, api_key=config.api_key) + + # Initialize registry + registry = ExperimentRegistry() + + @app.get("/health") + async def health_check(): + """Health check endpoint.""" + return {"status": "healthy"} + + @app.post("/experiments", response_model=ExperimentResponse) + async def create_experiment(experiment: ExperimentCreate): + """Create a new experiment.""" + # Convert dict format to DatasetSpec list + dataset_specs = [] + for name, spec in experiment.datasets.items(): + dataset_specs.append(DatasetSpec(name=name, weight=spec.get("weight", 1.0), config=spec.get("config", {}))) + + config = CompositeConfig(size=experiment.size, seed=experiment.seed, datasets=dataset_specs) + + try: + registry.register_experiment(experiment.name, config) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + return ExperimentResponse( + name=experiment.name, size=experiment.size, seed=experiment.seed, datasets=experiment.datasets + ) + + @app.get("/experiments", response_model=ExperimentList) + async def list_experiments(): + """List all registered experiments.""" + return ExperimentList(experiments=registry.list_experiments()) + + @app.delete("/experiments/{name}") + async def delete_experiment(name: str): + """Delete an experiment.""" + if not registry.remove_experiment(name): + raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") + return {"status": "deleted"} + + @app.get("/experiments/{name}/batch", response_model=BatchResponse) + async def generate_batch(name: str, base_index: int, batch_size: int): + """Generate a batch of raw entries""" + # Validate parameters + if base_index < 0: + raise HTTPException(status_code=400, detail="base_index must be non-negative") + if batch_size <= 0: + raise HTTPException(status_code=400, detail="batch_size must be positive") + + experiment = registry.get_experiment(name) + if not experiment: + raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") + + try: + entries = [] + for i in range(base_index, base_index + batch_size): + entry = experiment.dataset[i] + + # Create BatchEntry with minimal required data + batch_entry = BatchEntry( + question=entry["question"], + entry_id=f"{entry['metadata']['version_id']}.{i}", + metadata=entry["metadata"], + ) + entries.append(batch_entry) + + return BatchResponse(entries=entries) + + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + @app.post("/experiments/{name}/score", response_model=ScoringResponse) + async def score_outputs(name: str, request: ScoringRequest): + """Score extracted answers""" + experiment = registry.get_experiment(name) + if not experiment: + raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") + + try: + scores = [] + entry_ids = [] + for item in request.answers: + score = experiment.dataset.score_answer_with_id(item.answer, item.entry_id) + scores.append(score) + entry_ids.append(item.entry_id) + + return ScoringResponse(scores=scores, entry_ids=entry_ids) + + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + @app.get("/experiments/{name}/composite", response_model=ExperimentResponse) + async def get_composite_config(name: str): + """Get composite configuration for an experiment.""" + experiment = registry.get_experiment(name) + if not experiment: + raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") + + # Convert internal config to API response format + datasets = {} + for ds_spec in experiment.config.datasets: + dataset = experiment.dataset.datasets[ds_spec.name] + datasets[ds_spec.name] = { + "weight": ds_spec.weight, + "config": vars(dataset.config), # Get current config from dataset instance + } + + return ExperimentResponse( + name=name, size=experiment.config.size, seed=experiment.config.seed, datasets=datasets + ) + + @app.post("/experiments/{name}/composite/{dataset_name}") + async def update_dataset_config(name: str, dataset_name: str, config_update: DatasetConfigUpdate): + """Update configuration for a specific dataset in the composite.""" + experiment = registry.get_experiment(name) + if not experiment: + raise HTTPException(status_code=404, detail=f"Experiment '{name}' not found") + + try: + experiment.dataset.update_dataset_config(dataset_name, config_update.config) + return {"status": "updated"} + except KeyError: + raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in experiment") + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + return app + + +async def app(scope, receive, send): + """ASGI application that lazily creates the FastAPI app.""" + if not hasattr(app, "server_app"): + app.server_app = create_app(ServerConfig()) + await app.server_app(scope, receive, send) diff --git a/tools/server/tests/__init__.py b/tools/server/tests/__init__.py new file mode 100644 index 00000000..f634e958 --- /dev/null +++ b/tools/server/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the Reasoning Gym server.""" diff --git a/tools/server/tests/test_config.py b/tools/server/tests/test_config.py new file mode 100644 index 00000000..2c847522 --- /dev/null +++ b/tools/server/tests/test_config.py @@ -0,0 +1,27 @@ +"""Tests for server configuration.""" + +import os + +import pytest + +from ..config import ServerConfig + + +def test_default_config(): + """Test default configuration values.""" + os.environ["REASONING_GYM_API_KEY"] = "test-key" + config = ServerConfig() + + assert config.host == "localhost" + assert config.port == 8000 + assert config.api_key == "test-key" + assert config.log_level == "INFO" + + +def test_missing_api_key(): + """Test that missing API key raises an error.""" + if "REASONING_GYM_API_KEY" in os.environ: + del os.environ["REASONING_GYM_API_KEY"] + + with pytest.raises(ValueError): + ServerConfig() diff --git a/tools/server/tests/test_endpoints.py b/tools/server/tests/test_endpoints.py new file mode 100644 index 00000000..69cfee65 --- /dev/null +++ b/tools/server/tests/test_endpoints.py @@ -0,0 +1,277 @@ +"""Tests for API endpoints.""" + +import pytest +from fastapi.testclient import TestClient + +from ..config import ServerConfig +from ..server import create_app + + +@pytest.fixture +def client(): + """Create a test client.""" + config = ServerConfig(host="localhost", port=8000, api_key="test-key", log_level="INFO") + app = create_app(config) + return TestClient(app) + + +def test_health_check(client): + """Test health check endpoint.""" + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +def test_experiment_endpoints(client): + """Test experiment management endpoints.""" + # Set API key + headers = {"X-API-Key": "test-key"} + + # Create experiment + create_data = { + "name": "test_exp", + "size": 10, + "seed": 42, + "datasets": { + "chain_sum": { + "weight": 1.0, + "config": { + "min_terms": 2, + "max_terms": 4, + "min_digits": 1, + "max_digits": 2, + "allow_negation": False, + "size": 10, + "seed": 42, + }, + } + }, + } + + response = client.post("/experiments", json=create_data, headers=headers) + assert response.status_code == 200 + assert response.json()["name"] == "test_exp" + + # List experiments + response = client.get("/experiments", headers=headers) + assert response.status_code == 200 + assert "test_exp" in response.json()["experiments"] + + # Delete experiment + response = client.delete("/experiments/test_exp", headers=headers) + assert response.status_code == 200 + + # Verify deletion + response = client.get("/experiments", headers=headers) + assert response.status_code == 200 + assert "test_exp" not in response.json()["experiments"] + + # Try to delete non-existent experiment + response = client.delete("/experiments/nonexistent", headers=headers) + assert response.status_code == 404 + + +def test_batch_generation_endpoint(client): + """Test batch generation endpoint.""" + headers = {"X-API-Key": "test-key"} + + # Create test experiment + create_data = { + "name": "test_exp", + "size": 10, + "seed": 42, + "datasets": { + "chain_sum": { + "weight": 1.0, + "config": { + "min_terms": 2, + "max_terms": 4, + "min_digits": 1, + "max_digits": 2, + "allow_negation": False, + "size": 10, + "seed": 42, + }, + } + }, + } + + response = client.post("/experiments", json=create_data, headers=headers) + assert response.status_code == 200 + + # Test batch generation + response = client.get( + "/experiments/test_exp/batch", + params={"base_index": 0, "batch_size": 2}, + headers=headers, + ) + assert response.status_code == 200 + data = response.json() + print(data) + + # Verify batch structure + assert "entries" in data + assert len(data["entries"]) == 2 + + # Verify entry structure + entry = data["entries"][0] + assert "question" in entry + assert "entry_id" in entry + assert "metadata" in entry + + # Test error cases + # Non-existent experiment + response = client.get( + "/experiments/nonexistent/batch", + params={"base_index": 0, "batch_size": 2}, + headers=headers, + ) + assert response.status_code == 404 + + # Invalid parameters + response = client.get( + "/experiments/test_exp/batch", + params={"base_index": -1, "batch_size": 2}, + headers=headers, + ) + assert response.status_code == 400 + + +def test_scoring_endpoint(client): + """Test answer scoring endpoint.""" + headers = {"X-API-Key": "test-key"} + + # Create test experiment + create_data = { + "name": "test_exp", + "size": 10, + "seed": 42, + "datasets": { + "chain_sum": { + "weight": 1.0, + "config": { + "min_terms": 2, + "max_terms": 4, + "min_digits": 1, + "max_digits": 2, + "allow_negation": False, + "size": 10, + "seed": 42, + }, + } + }, + } + + response = client.post("/experiments", json=create_data, headers=headers) + assert response.status_code == 200 + + # Get a batch to get valid entry_ids + response = client.get( + "/experiments/test_exp/batch", + params={"base_index": 0, "batch_size": 2}, + headers=headers, + ) + assert response.status_code == 200 + batch = response.json() + entry_id = batch["entries"][0]["entry_id"] + + # Test scoring with correct answer + response = client.post( + "/experiments/test_exp/score", + json={"answers": [{"entry_id": entry_id, "answer": "4"}]}, # Assuming 2+2=4 is the first question + headers=headers, + ) + assert response.status_code == 200 + result = response.json() + assert "scores" in result + assert "entry_ids" in result + assert len(result["scores"]) == 1 + assert len(result["entry_ids"]) == 1 + assert result["entry_ids"][0] == entry_id + assert isinstance(result["scores"][0], float) + assert 0 <= result["scores"][0] <= 1 + + # Test scoring with wrong answer + response = client.post( + "/experiments/test_exp/score", + json={"answers": [{"entry_id": entry_id, "answer": "wrong"}]}, + headers=headers, + ) + assert response.status_code == 200 + result = response.json() + assert result["scores"][0] < 1.0 + assert result["entry_ids"][0] == entry_id + + # Test error cases + # Invalid entry_id format + response = client.post( + "/experiments/test_exp/score", + json={"answers": [{"entry_id": "invalid_id", "answer": "4"}]}, + headers=headers, + ) + assert response.status_code == 400 + + # Non-existent experiment + response = client.post( + "/experiments/nonexistent/score", + json={"answers": [{"entry_id": entry_id, "answer": "4"}]}, + headers=headers, + ) + assert response.status_code == 404 + + +def test_composite_config_endpoints(client): + """Test composite configuration endpoints.""" + headers = {"X-API-Key": "test-key"} + + # Create an experiment first + create_data = { + "name": "test_exp", + "size": 10, + "seed": 42, + "datasets": { + "chain_sum": { + "weight": 1.0, + "config": { + "min_terms": 2, + "max_terms": 4, + "min_digits": 1, + "max_digits": 2, + "allow_negation": False, + "size": 10, + "seed": 42, + }, + } + }, + } + + response = client.post("/experiments", json=create_data, headers=headers) + assert response.status_code == 200 + + # Get composite config + response = client.get("/experiments/test_exp/composite", headers=headers) + assert response.status_code == 200 + config = response.json() + assert config["name"] == "test_exp" + assert "chain_sum" in config["datasets"] + + # Update dataset config + update_data = {"config": {"min_terms": 3, "max_terms": 5}} + response = client.post("/experiments/test_exp/composite/chain_sum", json=update_data, headers=headers) + assert response.status_code == 200 + + # Verify update + response = client.get("/experiments/test_exp/composite", headers=headers) + assert response.status_code == 200 + config = response.json() + assert config["datasets"]["chain_sum"]["config"]["min_terms"] == 3 + assert config["datasets"]["chain_sum"]["config"]["max_terms"] == 5 + + # Test error cases + # Non-existent experiment + response = client.get("/experiments/nonexistent/composite", headers=headers) + assert response.status_code == 404 + + # Non-existent dataset + response = client.post("/experiments/test_exp/composite/nonexistent", json=update_data, headers=headers) + assert response.status_code == 404 diff --git a/tools/server/tests/test_registry.py b/tools/server/tests/test_registry.py new file mode 100644 index 00000000..9e19df03 --- /dev/null +++ b/tools/server/tests/test_registry.py @@ -0,0 +1,44 @@ +"""Tests for experiment registry.""" + +import pytest + +from reasoning_gym.arithmetic.chain_sum import ChainSumConfig +from reasoning_gym.coaching.registry import ExperimentRegistry +from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec + + +def test_singleton(): + """Test that ExperimentRegistry is a singleton.""" + registry1 = ExperimentRegistry() + registry2 = ExperimentRegistry() + assert registry1 is registry2 + + +def test_experiment_management(): + """Test basic experiment management operations.""" + registry = ExperimentRegistry() + + # Clear any existing experiments + for name in registry.list_experiments(): + registry.remove_experiment(name) + + # Test registration with chain_sum dataset + chain_sum_spec = DatasetSpec(name="chain_sum", weight=1.0, config=vars(ChainSumConfig(size=10, seed=42))) + + config = CompositeConfig(size=10, seed=42, datasets=[chain_sum_spec]) + registry.register_experiment("test_exp", config) + + # Test listing + assert "test_exp" in registry.list_experiments() + + # Test retrieval + exp = registry.get_experiment("test_exp") + assert exp is not None + assert exp.name == "test_exp" + assert isinstance(exp.dataset, CompositeDataset) + assert exp.config == config + + # Test removal + assert registry.remove_experiment("test_exp") + assert "test_exp" not in registry.list_experiments() + assert not registry.remove_experiment("nonexistent")