diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index f6b1d4d1..151b842b 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -348,6 +348,14 @@ async def info(): @app.get("/batch") async def get_batch(): + # Check if trainer has registered first + if not hasattr(app.state, "started"): + return { + "status": "error", + "message": "Trainer not registered. Call /register first.", + "batch": [], + } + if not app.state.started: app.state.started = True diff --git a/environments/eval_environments/gsm8k_eval.py b/environments/eval_environments/gsm8k_eval.py index 8dec68a5..680cfe1c 100644 --- a/environments/eval_environments/gsm8k_eval.py +++ b/environments/eval_environments/gsm8k_eval.py @@ -20,7 +20,7 @@ Supports thinking mode with tags for extended reasoning. import asyncio import random from concurrent.futures import ProcessPoolExecutor -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import wandb from datasets import load_dataset @@ -128,15 +128,16 @@ class GSM8KEvalEnv(BaseEnv): """ name = "gsm8k_eval" + env_config_cls = GSM8KEvalConfig def __init__( self, config: GSM8KEvalConfig, server_configs: List[APIServerConfig], - slurm_job_id: Optional[str] = None, + slurm=False, testing: bool = False, ): - super().__init__(config, server_configs, slurm_job_id, testing) + super().__init__(config, server_configs, slurm, testing) self.config: GSM8KEvalConfig = config self.eval_items: List[Dict] = [] self._dataset_loaded = False @@ -146,10 +147,28 @@ class GSM8KEvalEnv(BaseEnv): def config_cls(cls) -> type: return GSM8KEvalConfig + @classmethod + def config_init(cls) -> Tuple[GSM8KEvalConfig, List[APIServerConfig]]: + """Initialize default configuration for the environment.""" + env_config = GSM8KEvalConfig( + tokenizer_name="Qwen/Qwen2.5-3B-Instruct", + group_size=1, + use_wandb=False, + max_num_workers_per_node=128, + rollout_server_url="http://localhost:8000", + total_steps=1, + wandb_name="gsm8k_eval", + ) + server_configs = [ + APIServerConfig( + model_name="Qwen/Qwen2.5-3B-Instruct", + base_url="http://localhost:9001/v1", + ) + ] + return env_config, server_configs + async def setup(self) -> None: """Initialize the environment and load the dataset.""" - await super().setup() - # Initialize math executor self._math_executor = get_math_executor(self.config.max_math_workers) @@ -165,7 +184,10 @@ class GSM8KEvalEnv(BaseEnv): thinking_prompt = get_default_thinking_prompt( self.config.custom_thinking_prompt ) - print(f" Thinking prompt: {thinking_prompt[:80]}...") + if thinking_prompt: + print(f" Thinking prompt: {thinking_prompt[:80]}...") + else: + print(" Thinking prompt: (using model's native reasoning)") print(f" Loaded {len(self.eval_items)} evaluation items") async def _load_dataset(self) -> None: @@ -351,7 +373,9 @@ class GSM8KEvalEnv(BaseEnv): # Create evaluation tasks async def eval_task(item): - return await self.rollout_and_score_eval(item, self.server_configs[0]) + return await self.rollout_and_score_eval( + item, self.server.servers[0].config + ) tasks = [eval_task(item) for item in self.eval_items] diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 932df9dc..41a667b6 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -4,11 +4,12 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero """ import asyncio +import os import random import re import logging from concurrent.futures import ProcessPoolExecutor -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import aiohttp import wandb @@ -30,6 +31,7 @@ from atroposlib.envs.base import ( ScoredDataGroup, ServerBaseline, ) +from atroposlib.envs.server_handling.server_baseline import APIServerConfig prompt_format = ( "A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant " @@ -125,7 +127,7 @@ class MathEnv(BaseEnv): def __init__( self, config: RSConfig, - server_configs: ServerBaseline, + server_configs: Union[ServerBaseline, List[APIServerConfig]], slurm=True, testing=False, ): @@ -152,26 +154,41 @@ class MathEnv(BaseEnv): print("=" * 60) @classmethod - def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: + def config_init(cls) -> Tuple[RSConfig, List[APIServerConfig]]: + # Allow configuration via environment variables for running multiple instances + model_name = os.environ.get("MATH_ENV_MODEL", "Qwen/Qwen3-4B-Instruct-2507") + rollout_url = os.environ.get("MATH_ENV_ROLLOUT_URL", "http://localhost:8000") + vllm_url = os.environ.get("MATH_ENV_VLLM_URL", "http://localhost:9001/v1") + wandb_name = os.environ.get("MATH_ENV_WANDB_NAME", "math-zero-env") + max_token_length = int(os.environ.get("MATH_ENV_MAX_TOKENS", "32000")) + worker_timeout = float(os.environ.get("MATH_ENV_WORKER_TIMEOUT", "1500")) + env_config = RSConfig( - tokenizer_name="Qwen/Qwen2.5-7B", - group_size=16, + tokenizer_name=model_name, + group_size=8, use_wandb=True, - rollout_server_url="http://localhost:8000", - total_steps=1000, - batch_size=1024, - steps_per_eval=25, - max_token_length=31000, # 22000 // (2 ** i), - wandb_name="math", + rollout_server_url=rollout_url, + total_steps=120, + batch_size=64, + steps_per_eval=20, + max_token_length=max_token_length, + start_tok_length=max_token_length, + wandb_name=wandb_name, eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, max_num_workers_per_node=24, + worker_timeout=worker_timeout, ) - server_configs = ServerBaseline( - model_name="Qwen/Qwen2.5-7B", - num_requests_for_eval=256, # since evaling only on one... - server_type="vllm", - ) + server_configs = [ + APIServerConfig( + model_name=model_name, + base_url=vllm_url, + api_key="x", + num_requests_for_eval=256, + server_type="vllm", + weight=1.0, + ) + ] return env_config, server_configs @@ -352,7 +369,7 @@ class MathEnv(BaseEnv): completion = await managed.completion( prompt=question, n=1, - max_tokens=32765, + max_tokens=self.config.max_token_length, temperature=0.0, split="eval", stop=stop_list, @@ -376,6 +393,10 @@ class MathEnv(BaseEnv): async def evaluate(self, *args, **kwargs): if not self.config.run_evaluation: return + import time + + start_time = time.time() + eval_tasks = [] for item in self.test: eval_tasks.append(self.rollout_and_score_eval(item[0], item[1], item[2])) @@ -385,17 +406,53 @@ class MathEnv(BaseEnv): if subset not in task_lists: task_lists[subset] = list() task_lists[subset].append(score) - # Now get the average + + # Build metrics dictionary for saving + metrics = {} + + # Now get the average per subset for subset, scores in task_lists.items(): - self.eval_metrics.append( - (f"eval/{subset}_percent_correct", sum(scores) / len(scores)) - ) + accuracy = sum(scores) / len(scores) + metrics[f"{subset}_accuracy"] = accuracy + metrics[f"{subset}_total"] = len(scores) + metrics[f"{subset}_correct"] = sum(scores) + self.eval_metrics.append((f"eval/{subset}_percent_correct", accuracy)) + # overall score - scores = [] + all_scores = [] for subset, score in task_lists.items(): - scores.extend(score) - self.eval_metrics.append( - ("eval/overall_percent_correct", sum(scores) / len(scores)) + all_scores.extend(score) + overall_accuracy = sum(all_scores) / len(all_scores) + metrics["overall_accuracy"] = overall_accuracy + metrics["overall_total"] = len(all_scores) + metrics["overall_correct"] = sum(all_scores) + self.eval_metrics.append(("eval/overall_percent_correct", overall_accuracy)) + + end_time = time.time() + + # Print results to console + print("\n" + "=" * 60) + print("Math Zero Evaluation Results") + print("=" * 60) + print( + f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})" + ) + print("\nPer-subset breakdown:") + for subset, scores in sorted(task_lists.items()): + acc = sum(scores) / len(scores) + print(f" {subset}: {acc:.2%} ({sum(scores)}/{len(scores)})") + print("=" * 60 + "\n") + + # Save results to disk + await self.evaluate_log( + metrics=metrics, + task_name="math_zero", + start_time=start_time, + end_time=end_time, + generation_parameters={ + "max_tokens": self.config.max_token_length, + "temperature": 0.0, + }, ) async def collect_trajectories(self, item) -> Tuple[List, List]: diff --git a/example_trainer/README.md b/example_trainer/README.md index aee1831d..9be1cf66 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -1,79 +1,931 @@ -# GRPO Example Trainer +# GRPO Trainer -This directory contains an example script (`grpo.py`) demonstrating how to integrate a custom training loop with the Atropos API for reinforcement learning using the GRPO (Group Relative Policy Optimization) algorithm. +A modular training framework for fine-tuning language models with **Group Relative Policy Optimization (GRPO)**, designed to work with the Atropos environment system. -**Note: Example trainer does not support multimodal training out of the box. As other trainers add support for Atropos, we will list them in the main readme, some of which may support multimodal RL - please check the main repo readme for any updates.** +## Module Structure -This example uses `vLLM` for efficient inference during the (simulated) data generation phase and `transformers` for the training phase. +**Note:** The `configs/` directory contains YAML configuration files for the **environment server** (e.g., `math_server_zero.py`), not for the trainer itself. The trainer is configured via CLI arguments documented in the [CLI Reference](#-cli-reference) section. -**Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training. +``` +example_trainer/ +├── grpo.py # CLI entry point (dispatches to 4 training modes) +├── run.py # Unified launcher for shared_vllm mode (starts vLLM+trainer) +├── config.py # TrainingConfig Pydantic model (all hyperparameters) +├── cli.py # CLI argument parsing (modular, single source of truth) +├── api.py # Atropos API communication (registration, batch fetching) +├── data.py # Data fetching, preprocessing, logprob alignment +├── model.py # Model loading, CUDA IPC, tensor mapping (QKV/Gate fusion) +├── training.py # GRPO loss (importance sampling, KL penalty, clipping) +├── checkpointing.py # Save models & LoRA adapters (handles fused tensor unfusing) +├── vllm_manager.py # vLLM process lifecycle (launch, health, termination) +├── trainers.py # 4 training mode implementations + optimizer selection +├── vllm_api_server.py # Custom vLLM server with /generate endpoint + LoRA +├── vllm_patching/ # CUDA IPC patches for weight sharing + B200 GPU compatibility +│ └── patched_gpu_runner.py +└── configs/ # Environment server configuration examples + ├── math_zero_shared.yaml # Config for math_server_zero.py (shared_vllm mode) + └── math_zero_lora.yaml # Config for math_server_zero.py (lora mode) +``` -### Custom vLLM Server -The `vllm_api_server.py` file in this directory provides a customized vLLM API server implementation based on vLLM's native API. This server exposes enhanced endpoints for token and logprob tracking. The `VLLMServer` class in `atroposlib/envs/server_handling/vllm_server.py` can connect to this server for direct access to vLLM's `/generate` endpoint with full token-level logprobs. +## GRPO Training Loop -## Prerequisites +``` +1. Generate multiple responses to the same prompt +2. Score each response (reward) +3. Compute ADVANTAGE = reward - mean(rewards) +4. Train: increase probability of above-average responses + decrease probability of below-average responses +``` -1. **Python:** Python 3.10 or higher is recommended. -2. **Atropos API Server:** The Atropos API server must be running and accessible (defaults to `http://localhost:8000` in the script). -3. **Python Packages:** You need to install the required Python libraries: - * `torch` (with CUDA support recommended) - * `transformers` - * `vllm` - * `pydantic` - * `numpy` - * `requests` - * `tenacity` - * `wandb` (optional, for logging) +### Key Concepts -## Setup +| Concept | What It Means | +|---------|---------------| +| **Advantage** | How much better/worse than average a response was | +| **Importance Sampling** | Corrects for policy drift during training | +| **KL Penalty** | Prevents the model from changing too drastically from base | +| **Clipping** | Limits update magnitude for stability | -1. **Clone the Repository:** Ensure you have the repository containing this example. -2. **Install Dependencies:** `pip install -r requirements.txt` -3. **Ensure Atropos API is Running:** `run-api` in a new window -4. **Run an env:** `python environments/gsm8k_server.py serve --slurm False` -## Configuration +## System Architecture -The training configuration is managed within the `grpo.py` script using the `TrainingConfig` Pydantic model (found near the top of the file). +``` +Data Flow: +1. Environment generates prompts → calls vLLM → scores responses +2. Environment sends trajectories to run-api +3. Trainer fetches batches from run-api +4. Trainer updates model weights +5. Weight synchronization: + - shared_vllm: vLLM sees updates immediately via CUDA IPC (zero-copy) + - lora_only: Trainer pushes adapter to vLLM via HTTP (slow) + - lora_restart: Trainer restarts vLLM with new adapter (fast) + - none (legacy): Trainer saves checkpoint and restarts vLLM +``` -Key parameters you might want to adjust include: +--- -* `model_name`: The Hugging Face model identifier to use for training (e.g., `"gpt2"`, `"Qwen/Qwen2.5-1.5B-Instruct"`). -* `training_steps`: The total number of optimization steps to perform. -* `batch_size` / `gradient_accumulation_steps`: Control the effective batch size. -* `lr`: Learning rate. -* `save_path`: Directory where model checkpoints will be saved. -* `vllm_port`: The port used by the vLLM server instance launched by this script. -* `vllm_restart_interval`: How often (in steps) to save a checkpoint and restart the vLLM server with the new weights. -* `use_wandb`: Set to `True` to enable logging to Weights & Biases. -* `wandb_project`: Your W&B project name (required if `use_wandb=True`). -* `wandb_group`: Optional W&B group name. +## Four Training Modes -**API Endpoints:** The script currently assumes the Atropos API is available at `http://localhost:8000/register` and `http://localhost:8000/batch`. If your API runs elsewhere, you'll need to modify the `register_trainer` and `get_batch` functions accordingly. +| Mode | Description | Memory | Inference Speed | Best For | +|------|-------------|--------|-----------------|----------| +| **shared_vllm** | Single-copy via CUDA IPC | 1x model | ~172 TPS | Same GPU, maximum efficiency | +| **lora_restart** | LoRA + vLLM restarts | 1x + adapter | ~108 TPS | LoRA training with speed | +| **lora_only** | LoRA + HTTP hot-swap | 1x + adapter | ~13 TPS ⚠️ | Debugging only | +| **none** (legacy) | Full model, restart vLLM | 2x model | ~172 TPS | simple setup | -## Running the Example +### ⚠️ IMPORTANT: `lora_only` Performance Warning -Once the prerequisites are met and configuration is set: +The `lora_only` mode requires `--enforce-eager` which **disables CUDA graphs**, resulting in: +- **8x slower inference** (~13 TPS vs ~108 TPS) +- Training that takes **4x longer** (401 min vs 132 min for 120 steps) -1. Navigate to the root directory of the project in your terminal. -2. Run the script: +**Use `lora_restart` instead** - it runs vLLM without `--enforce-eager` for 8x faster inference. - ```bash - python example_trainer/grpo.py - ``` +### Recommendation -## Output +**Use `shared_vllm`** for production training when: +- You have enough GPU memory for the full model +- You want fastest training (no overhead) +- Trainer and vLLM are on the same GPU(s) -* **Logs:** Training progress, loss, logp, and vLLM status will be printed to the console. -* **Checkpoints:** Model checkpoints will be saved periodically in the directory specified by `save_path` (default: `./trained_model_checkpoints`). A `final_model` directory will be created upon completion. -* **WandB:** If `use_wandb` is `True`, logs will be sent to Weights & Biases. A link to the run page will be printed in the console. -* `temp.json`: Contains the raw data from the last fetched batch (used for debugging/manual inspection). +**Use `lora_restart`** when: +- You want LoRA's memory efficiency +- You can tolerate ~45s restart overhead every N steps + +**Avoid `lora_only`** unless you're debugging - the 8x inference penalty is severe. + +**Use `none` (legacy)** mode when: +- You want the simplest setup without CUDA IPC or LoRA + +--- + +## Quick Start: LoRA Training (Recommended) + +### Step 1: Install Dependencies +- They are listed in the requirements.txt file that you can see + +### Step 2: Start All Components + +**Terminal 1: API Server** +```bash +run-api --port 8002 +``` + +**Terminal 2: vLLM Server** +```bash +python -m example_trainer.vllm_api_server \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --port 9001 \ + --gpu-memory-utilization 0.5 \ + --max-model-len 4096 \ + --dtype bfloat16 \ + --enable-lora \ + --enforce-eager +``` + +**Terminal 3: Environment** +```bash +# Important: Use server_type=vllm to get logprobs (required for GRPO) +python environments/gsm8k_server.py serve \ + --env.group_size 4 \ + --env.batch_size 16 \ + --env.total_steps 200 \ + --env.steps_per_eval 50 \ + --env.max_num_workers_per_node 8 \ + --env.rollout_server_url "http://localhost:8002" \ + --env.use_wandb true \ + --env.wandb_name "gsm8k-lora-only-env" \ + --openai.api_key "dummy" \ + --openai.base_url "http://localhost:9001/v1" \ + --openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \ + --openai.server_type vllm +``` + +**Terminal 4: Trainer** +```bash +python -m example_trainer.grpo \ + --model-name NousResearch/Hermes-3-Llama-3.1-8B \ + --weight-bridge-mode lora_only \ + --vllm-port 9001 \ + --atropos-url "http://localhost:8002" \ + --batch-size 4 \ + --gradient-accumulation-steps 4 \ + --lr 1e-5 \ + --training-steps 30 \ + --kl-coef 0.1 \ + --clip-eps 0.2 \ + --vllm-restart-interval 5 \ + --save-path ./lora_checkpoints \ + --wandb-project "grpo-training" +``` + +### Startup Order ```bash -# Install dependencies -pip install -r example_trainer/requirements.txt +# CRITICAL: Follow this exact order! +# 1. Start API first +run-api --port 8002 -# Run the trainer directly (basic test) -python example_trainer/grpo.py +# 2. Wait 5s, then start vLLM +# Check health: curl http://localhost:9001/health +python -m example_trainer.vllm_api_server --model ... --enable-lora --enforce-eager + +# 3. Wait for vLLM health endpoint to return 200 +while ! curl -s http://localhost:9001/health > /dev/null; do sleep 1; done + +# 4. Start environment (MUST use --openai.server_type vllm for logprobs) +python environments/gsm8k_server.py serve \ + --env.group_size 4 \ + --env.batch_size 16 \ + --env.total_steps 200 \ + --env.steps_per_eval 50 \ + --env.max_num_workers_per_node 8 \ + --env.rollout_server_url "http://localhost:8002" \ + --env.use_wandb true \ + --env.wandb_name "gsm8k-train-env" \ + --openai.base_url "http://localhost:9001/v1" \ + --openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \ + --openai.server_type vllm + +# 5. Start trainer (will register with API and begin training) +python -m example_trainer.grpo --weight-bridge-mode lora_only ... +``` + +--- + +## Shared vLLM Mode + +Single-copy mode shares GPU memory between vLLM and the trainer - zero model duplication! + +### How It Works + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ SINGLE GPU (CUDA IPC) │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Model Weights (ONE copy in GPU memory) │ │ +│ │ (accessible via CUDA IPC handles) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +│ ▲ ▲ │ +│ │ Reads (inference) │ Writes │ +│ ┌────────┴────────┐ ┌───────────┴───────────┐ │ +│ │ vLLM Worker │ │ Trainer Process │ │ +│ │ │ │ (attached via IPC) │ │ +│ └─────────────────┘ └───────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +### Running Shared vLLM Mode + +**Terminal 1: API** +```bash +run-api --port 8002 +``` + +**Terminal 2: vLLM with Shared Weights** +```bash +VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=/tmp/grpo_training \ +python -m example_trainer.vllm_api_server \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --port 9001 \ + --gpu-memory-utilization 0.45 \ + --enforce-eager +``` + +**Terminal 3: Environment** +```bash +# Important: Use server_type=vllm to get logprobs (required for GRPO) +python environments/gsm8k_server.py serve \ + --openai.base_url "http://localhost:9001/v1" \ + --openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \ + --openai.server_type vllm \ + --env.group_size 4 \ + --env.batch_size 16 \ + --env.total_steps 200 \ + --env.steps_per_eval 50 \ + --env.max_num_workers_per_node 8 \ + --env.rollout_server_url "http://localhost:8002" \ + --env.use_wandb true \ + --env.wandb_name "gsm8k-shared-vllm-env" +``` + +**Terminal 4: Trainer** +```bash +python -m example_trainer.grpo \ + --model-name NousResearch/Hermes-3-Llama-3.1-8B \ + --weight-bridge-mode shared_vllm \ + --vllm-port 9001 \ + --vllm-config-path /tmp/grpo_training/vllm_bridge_config.json \ + --atropos-url "http://localhost:8002" \ + --kl-coef 0.1 \ + --clip-eps 0.2 +``` + +### Or Use the Unified Launcher + +```bash +# Single command starts both vLLM and trainer +VLLM_ENABLE_SHARED_WEIGHTS=1 python -m example_trainer.run \ + --model-name NousResearch/Hermes-3-Llama-3.1-8B \ + --atropos-url "http://localhost:8002" \ + --training-steps 30 +``` + +--- + +## Best Practices & Lessons Learned + + +### 1. Use `--openai.server_type vllm` for Training + +**CRITICAL:** The atropos environment MUST use `server_type=vllm` to get logprobs for proper GRPO training. + +Only `server_type=vllm` calls the `/generate` endpoint which returns token-level logprobs. These logprobs serve as the reference policy (π_old) for importance sampling in GRPO. + +```bash +# CORRECT - gets logprobs for training (REQUIRED!) +--openai.server_type vllm + +# WRONG for training - no logprobs, training will FAIL +--openai.server_type openai +``` + +**What happens without logprobs:** +- The trainer will raise an error: "GRPO requires inference_logprobs for importance sampling!" +- Without the reference policy, GRPO degenerates to vanilla REINFORCE (leads to reward hacking) + +**How logprobs flow through the system:** +1. Environment calls vLLM `/generate` with `logprobs=true` +2. vLLM returns token-level logprobs for each generated token +3. Environment embeds these in trajectory data sent to API +4. Trainer extracts and aligns logprobs with training labels +5. GRPO loss uses logprobs as π_old for importance sampling ratio + +### 2. KL Coefficient and Clipping Are Essential + +**CRITICAL:** Without these hyperparameters, training WILL collapse (reward hacking): + +```bash +--kl-coef 0.1 # Prevents policy from drifting too far from reference +--clip-eps 0.2 # Limits importance sampling ratio to [0.8, 1.2] +``` + +**Why these matter:** +- **KL Penalty** (β): Penalizes the policy for deviating from the reference policy (inference-time policy) + - Uses Schulman's unbiased estimator: `exp(-log_ratio) + log_ratio - 1` + - Higher β = more conservative updates + - Set to 0 to disable (NOT recommended - leads to instability) + +- **PPO Clipping** (ε): Clips the importance sampling ratio to `[1-ε, 1+ε]` + - Prevents catastrophically large policy updates + - Takes pessimistic bound (conservative update) + +**Symptoms of missing/misconfigured KL/clipping:** +- Accuracy drops dramatically (e.g., 59% → 7%) +- Loss goes to very negative values (< -10) +- Model outputs become repetitive/degenerate +- `mean_ratio` diverges far from 1.0 +- `mean_kl` explodes (> 1.0) + +**Healthy training metrics:** +- `mean_ratio`: 0.8 - 1.2 (close to 1.0) +- `mean_kl`: 0.01 - 0.1 +- `clipped_fraction`: < 0.3 (< 30% of tokens clipped) + +### 3. Memory Budgeting for Large Models + +| Model Size | GPU Memory | Recommended Settings | +|------------|------------|----------------------| +| 8B | 80GB | `--gpu-memory-utilization 0.5` | +| 14B | 80GB | `--gpu-memory-utilization 0.45`, `--batch-size 2` | +| 24B | 192GB (B200) | `--gpu-memory-utilization 0.30`, `--optimizer adafactor` | + +**🔧 B200/Blackwell GPU Support:** + +The trainer includes automatic patches for NVIDIA B200 (Blackwell architecture) GPUs when using LoRA mode. These patches disable Grid Dependency Control (GDC) in vLLM's Triton kernels, which causes compilation failures on Blackwell GPUs. The patches are applied automatically when: +- `VLLM_ENABLE_SHARED_WEIGHTS=1` is set, or +- `NUM_INFERENCE_NODES` is set (distributed inference path) + +The patching clears the Triton cache and disables GDC to ensure compatibility. No manual intervention required. + +### 4. Optimizer Selection + +The trainer supports multiple optimizer options to trade off between speed, memory, and precision: + +| Optimizer | GPU Memory for States | Speed | Precision | Dependencies | +|-----------|----------------------|-------|-----------|--------------| +| `adamw` | ~32GB (for 8B model) | Fastest | Full FP32 | None | +| `adamw_8bit` (default) | ~8GB | Fast | 8-bit quantized | `bitsandbytes` | +| `adafactor` | ~8GB | Fast | Full (no momentum) | `transformers` | + +**Usage:** +```bash +# 8-bit AdamW (default) - recommended for memory-constrained setups +--optimizer adamw_8bit + +# Standard AdamW - full precision +--optimizer adamw + +# Adafactor - no momentum states, good for large models +--optimizer adafactor +``` + +**Recommendations:** +- **8B models on 80GB:** Use `adamw` (fastest) +- **14B+ models on 80GB:** Use `adamw_8bit` or `adafactor` +- **24B models:** Use `adafactor` with reduced batch size + +**Potential Risks:** +- `adamw_8bit`: Quantization may slightly affect convergence in edge cases; generally safe +- `adafactor`: No momentum can make training slightly less stable; use with larger batch sizes + +--- + +## Tensor Mapping (vLLM ↔ HuggingFace) + +### The Problem + +vLLM fuses certain layers for efficiency, but HuggingFace keeps them separate: + +``` +HuggingFace Model: vLLM Model: +├── q_proj [4096, 4096] ├── qkv_proj [12288, 4096] ← FUSED! +├── k_proj [1024, 4096] │ (contains q, k, v concatenated) +├── v_proj [1024, 4096] │ +│ │ +├── gate_proj [14336, 4096] ├── gate_up_proj [28672, 4096] ← FUSED! +├── up_proj [14336, 4096] │ (contains gate and up concatenated) +``` + +### How We Solve It + +The trainer creates **views** into vLLM's fused tensors: + +```python +# vLLM has: qkv_proj.weight [12288, 4096] +# We need: q_proj [4096], k_proj [1024], v_proj [1024] + +# Get sizes from model config +q_size = num_heads * head_dim # e.g., 4096 +k_size = num_kv_heads * head_dim # e.g., 1024 +v_size = num_kv_heads * head_dim # e.g., 1024 + +# Create views (no copy!) +hf_model.q_proj.weight = vllm_qkv[0:4096, :] # First chunk +hf_model.k_proj.weight = vllm_qkv[4096:5120, :] # Second chunk +hf_model.v_proj.weight = vllm_qkv[5120:6144, :] # Third chunk +``` + +### Key Insight: Views Share Memory + +```python +# These point to the SAME GPU memory: +trainer_q_proj.data_ptr() == vllm_qkv_proj.data_ptr() # True! + +# So when optimizer updates trainer weights: +optimizer.step() # Updates trainer_q_proj + +# vLLM sees the change immediately (same memory)! +``` + +### The Config File + +vLLM exports tensor mappings to `vllm_bridge_config.json`: + +```json +{ + "model": "NousResearch/Hermes-3-Llama-3.1-8B", + "param_mappings": { + "model.layers.0.self_attn.qkv_proj.weight": { + "ipc_handle": "base64_encoded_cuda_ipc_handle", + "shape": [6144, 4096], + "dtype": "bfloat16" + } + } +} +``` + +--- + +## ❓ FAQ + +### Q: How do I debug logprob alignment issues? + +**A:** Look for these log messages during training: +``` +[WARNING] ref_logprobs at generated positions avg 0.85 (should be negative!) +[WARNING] This suggests inference_logprobs alignment is wrong +``` + +This means inference logprobs aren't being passed correctly. Debug steps: + +1. **Check environment server type:** + ```bash + # Must be 'vllm', NOT 'openai' + --openai.server_type vllm + ``` + +2. **Verify vLLM returns logprobs:** + ```bash + curl -X POST http://localhost:9001/generate \ + -H "Content-Type: application/json" \ + -d '{"prompt": "Hello", "max_tokens": 5}' + # Response should include "logprobs": [...] + ``` + +3. **Check data.py logs:** + ``` + [Data] ✓ inference_logprobs found in batch (sample len: 128) + ``` + +4. **Monitor alignment metrics in training logs:** + - `alignment/diff_mean` should be close to 0 at step start + - `alignment/diff_abs_mean` < 0.1 = good alignment + - Large values = weights not properly shared or logprobs misaligned + + +## Troubleshooting + +### "Atropos API not reachable" + +```bash +# Start the API server first +run-api --port 8002 +``` + +### "404 Not Found" on /generate + +You're using a vLLM server that doesn't expose `/generate`. Use our custom server: + +```bash +python -m example_trainer.vllm_api_server ... # Has /generate +# NOT: python -m vllm.entrypoints.openai.api_server # Only has /v1/* +``` + +### "Cannot re-initialize CUDA in forked subprocess" + +vLLM v1 engine issue. We disable it by default, but if you see this: + +```bash +VLLM_USE_V1=0 python -m example_trainer.vllm_api_server ... +``` + +### "WARNING: ref_logprobs avg X.XXX (should be negative!)" + +This warning appears during training when inference logprobs alignment is incorrect. Weight updates may not be visible to inference. Fix: + +```bash +# Add --enforce-eager to vLLM +python vllm_api_server.py --model $MODEL --enforce-eager +``` + +You may also see related alignment warnings: +``` +[WARNING] This suggests inference_logprobs alignment is wrong +[DEBUG] Logprob gap: ref=X.XXX, train=X.XXX +``` + +### OOM (Out of Memory) + +Reduce memory usage: + +```bash +--gpu-memory-utilization 0.4 # Less vLLM memory +--batch-size 2 # Smaller batches +--gradient-accumulation-steps 8 # Compensate with accumulation +--seq-len 1024 # Shorter sequences +--optimizer adafactor # Uses less memory than AdamW +``` + +### "FlexibleArgumentParser" import error + +vLLM version incompatibility. Our server handles this automatically, but make sure you're using: + +```bash +python -m example_trainer.vllm_api_server # NOT direct vllm commands +``` + + +## 📊 Monitoring Training + +### WandB Logging + +```bash +--use-wandb \ +--wandb-project "my-grpo-training" \ +--wandb-group "hermes-8b-gsm8k" +``` + +--- + +## CLI Reference + +### Essential Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--model-name` or `--model` | (required) | HuggingFace model ID | +| `--weight-bridge-mode` | `none` | `shared_vllm`, `lora_only`, `lora_restart`, or `none` | +| `--training-steps` | 10 | Number of training steps | +| `--checkpoint-interval` | 3 | Save checkpoint every N steps (0 = final only) | +| `--batch-size` | 2 | Micro-batch size | +| `--gradient-accumulation-steps` | 32 | Effective batch = batch × accum | +| `--seq-len` | 2048 | Maximum sequence length | + +### GRPO Hyperparameters + +| Argument | Default | Description | +|----------|---------|-------------| +| `--kl-coef` | 0.1 | KL penalty strength (higher = more conservative) | +| `--clip-eps` | 0.2 | PPO clipping range [1-ε, 1+ε] | +| `--lr` | 1e-5 | Learning rate (NOT --learning-rate) | +| `--no-reference-logprobs` | False | Disable GRPO reference logprobs (falls back to REINFORCE-style updates) | + +### LoRA Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--lora-r` | 16 | LoRA rank (dimension of low-rank matrices) | +| `--lora-alpha` | 32 | LoRA alpha scaling factor | +| `--lora-dropout` | 0.05 | LoRA dropout probability | +| `--lora-target-modules` | None | Module names to apply LoRA (`None` falls back to `q_proj v_proj`) | +| `--lora-layer-indices` | None | Optional layer filter (examples: `20-31`, `0-3,28-31`) | + +### LoRA Layer Index Guide (by Architecture) + +`--lora-layer-indices` is model-dependent. Different models expose different numbers of transformer blocks, so a valid range for one model may be invalid for another. + +| Architecture family | Common config fields | Typical layer list path | Notes | +|---------------------|----------------------|-------------------------|-------| +| LLaMA / Llama-2 / Llama-3 / Mistral | `num_hidden_layers` | `model.layers` | Most common causal-LM layout | +| Qwen / Qwen2 / Qwen2.5 / Qwen3 | `num_hidden_layers` | `model.layers` | Similar layer indexing to LLaMA | +| GPT-2 / GPT-J style | `n_layer` or mapped to `num_hidden_layers` | `transformer.h` | PEFT may use `h` pattern internally | +| Falcon | `num_hidden_layers` | `transformer.h` | Uses `h` block list in model module tree | + +#### Reliable way to check for any model + +Always query the model config before choosing indices: + +```bash +python - <<'PY' +from transformers import AutoConfig + +model_id = "meta-llama/Meta-Llama-3-8B-Instruct" +cfg = AutoConfig.from_pretrained(model_id) +num_layers = getattr(cfg, "num_hidden_layers", None) +if num_layers is None: + num_layers = getattr(cfg, "n_layer", None) + +print(f"model={model_id}") +print(f"num_hidden_layers={num_layers}") +if num_layers is not None: + print(f"valid index range: 0-{num_layers-1}") +PY +``` + +#### Practical presets + +If your model has `N` layers: + +- Full layers: omit `--lora-layer-indices` +- Top 25%: `--lora-layer-indices {int(0.75*N)}-{N-1}` +- Top 50%: `--lora-layer-indices {int(0.5*N)}-{N-1}` +- Last 12 layers: `--lora-layer-indices {N-12}-{N-1}` (if `N >= 12`) + +### vLLM Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--vllm-port` | 9001 | vLLM server port | +| `--vllm-config-path` | auto | Path to bridge config (shared mode) | +| `--gpu-memory-utilization` | 0.45 | vLLM GPU memory fraction | +| `--vllm-gpu` | None | GPU ID for vLLM (None = same as trainer) | +| `--max-model-len` | 4096 | Maximum context length | +| `--dtype` | `bfloat16` | Model dtype: `bfloat16`, `float16`, or `auto` | +| `--vllm-restart-interval` | 3 | Restart vLLM every N steps (legacy/lora_restart) | + +### Atropos API Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--atropos-url` | `http://localhost:8000` | URL of the Atropos API server | + +**Note:** Many examples in this README use `http://localhost:8002` because they start `run-api --port 8002`. + +### Weights & Biases Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--use-wandb` | False | Enable W&B logging | +| `--wandb-project` | None | W&B project name | +| `--wandb-group` | None | W&B group name (auto-generated if omitted) | + +### Distributed Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--trainer-rank` | 0 | Trainer rank | +| `--world-size` | 1 | World size | +| `--init-method` | `env://` | Distributed init method | +| `--num-inference-nodes` | 0 | Number of inference nodes | + +### Debug & Benchmark Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--debug-loading` | False | Verbose model loading diagnostics | +| `--benchmark` | False | Print benchmark/timing metrics | +| `--log-dir` | `./logs` | Directory for unified launcher logs | + +--- + +## Module Documentation + +| Module | Purpose | +|--------|---------| +| `grpo.py` | CLI entry point, dispatches to training modes (4 modes) | +| `run.py` | Unified launcher for shared_vllm mode (starts vLLM + trainer) | +| `cli.py` | Single source of truth for all CLI arguments (modular builders) | +| `config.py` | `TrainingConfig` Pydantic model with all hyperparameters | +| `api.py` | Communication with Atropos API (registration, batch fetching) | +| `data.py` | Batch preprocessing, padding, logprob extraction and alignment | +| `model.py` | Model loading, CUDA IPC attachment, tensor mapping (QKV/Gate fusion) | +| `training.py` | GRPO loss computation (importance sampling, KL penalty, clipping) | +| `trainers.py` | Mode-specific training loops (4 implementations + optimizer selection) | +| `vllm_api_server.py` | Custom vLLM server with `/generate` endpoint and LoRA support | +| `vllm_manager.py` | vLLM process lifecycle management (launch, health checks, termination) | +| `checkpointing.py` | Save/load checkpoints and adapters (handles fused tensor unfusing) | + +--- + +## Code Execution Flow + +### High-Level Flow (All Modes) + +``` +1. CLI Parsing (cli.py) + ↓ +2. Config Creation (config.py) + ↓ +3. Mode Dispatcher (grpo.py or run.py) + ↓ +4. Trainer Function (trainers.py) + ├─ Setup Phase + │ ├─ Initialize W&B (training.py) + │ ├─ Load Model (model.py) + │ ├─ Create Optimizer (trainers.py) + │ ├─ Check Atropos API (api.py) + │ ├─ Register Trainer (api.py) + │ └─ Launch/Connect vLLM (vllm_manager.py or external) + │ + └─ Training Loop + ├─ Fetch Batch (api.py → data.py) + │ ├─ Poll /batch endpoint + │ ├─ Pad sequences (data.py) + │ ├─ Extract inference logprobs (data.py) + │ └─ Normalize advantages (data.py) + │ + ├─ Training Step (training.py) + │ ├─ For each micro-batch: + │ │ ├─ Forward pass (model) + │ │ ├─ Compute GRPO loss (training.py) + │ │ │ ├─ Temperature scaling + │ │ │ ├─ Compute log probabilities + │ │ │ ├─ Importance sampling ratio (using inference logprobs) + │ │ │ ├─ PPO clipping + │ │ │ ├─ Schulman KL penalty + │ │ │ └─ Return loss + metrics + │ │ └─ Backward pass (accumulate gradients) + │ ├─ Clip gradients (norm=1.0) + │ ├─ Optimizer step + │ └─ Zero gradients + │ + ├─ Weight Sync (mode-dependent) + │ ├─ shared_vllm: No sync needed (weights shared via CUDA IPC) + │ ├─ lora_only: HTTP POST to /lora/load + │ ├─ lora_restart: Save adapter + terminate + relaunch vLLM + │ └─ none: Save checkpoint + terminate + relaunch vLLM + │ + ├─ Log Metrics (training.py) + │ ├─ Console output + │ └─ W&B logging (if enabled) + │ + └─ Periodic Checkpoint (checkpointing.py) + ├─ Ensure tensors are contiguous (unfuse views) + ├─ Save state dict + └─ Free GPU memory +``` + +### Mode-Specific Details + +#### shared_vllm Mode + +```python +# Entry: grpo.py → trainers.train_shared_vllm() + +1. Model Loading (model.py): + - Find vllm_bridge_config.json + - Load IPC handles (CUDA memory pointers) + - Create empty model on meta device + - Reconstruct tensors from IPC handles + - Map vLLM fused tensors → HF unfused parameters + * qkv_proj → q_proj, k_proj, v_proj (views) + * gate_up_proj → gate_proj, up_proj (views) + - Initialize remaining meta tensors (buffers, etc.) + +2. Training Loop: + - optimizer.step() directly modifies vLLM's tensors + - No weight synchronization needed! + - Checkpoints: Unfuse views before saving (checkpointing.py) + +3. Tensor Mapping (model.py:_create_vllm_to_hf_mapping): + - Reads actual HF tensor shapes from model.state_dict() + - Creates slice mappings for fused layers + - Example: q_proj = qkv_proj[0:4096, :] +``` + +#### lora_restart Mode + +```python +# Entry: grpo.py → trainers.train_lora_restart() + +1. Model Loading (model.py): + - Load base model with PEFT + - Apply LoRA config to target modules + - Freeze base weights, only LoRA trainable + +2. vLLM Management: + - Launch: _launch_vllm_with_lora() + * NO --enforce-eager flag (CUDA graphs enabled) + * Pre-load initial adapter + - Periodic Restart: + * Save new adapter (checkpointing.py) + * Terminate vLLM aggressively (_terminate_vllm) + - Kill process group + - Kill by port (fuser) + - Kill by process name patterns + - Wait for GPU memory release (critical!) + * Relaunch with new adapter + +3. Performance: + - ~108 TPS (CUDA graphs enabled) + - ~45s restart overhead + - Much faster than lora_only (~8x speedup) +``` + +#### lora_only Mode + +```python +# Entry: grpo.py → trainers.train_lora() + +1. Model Loading: Same as lora_restart + +2. vLLM: External server (must be pre-started) + - MUST use --enforce-eager (disables CUDA graphs) + - MUST use --enable-lora + +3. Weight Sync: _hotswap_lora_adapter() + - Tries /v1/load_lora_adapter (native vLLM) + - Falls back to /lora/load (custom endpoint) + +4. Performance: + - ~13 TPS (CUDA graphs disabled) + - No restart overhead + - 8x slower than lora_restart! +``` + +#### none (legacy) Mode + +```python +# Entry: grpo.py → trainers.train_legacy() + +1. Model Loading: Full model (model.py) + +2. vLLM Management: + - Launch: vllm_manager.launch_vllm_server() + - Periodic Restart: + * Save full checkpoint (checkpointing.py) + * Terminate vLLM (vllm_manager.terminate_vllm_process) + * Relaunch with new checkpoint + +3. Use Case: + - Different GPUs for trainer and vLLM + - Simple setup without CUDA IPC or LoRA +``` + +### Data Flow Detail (data.py) + +```python +# api.get_batch() → data.get_data() → data.pad_data_to_good_offset() + +1. Batch Structure from API: + { + "batch": [ + { + "tokens": [[tok1, tok2, ...], ...], # group_size sequences + "masks": [[mask1, mask2, ...], ...], # -100 for prompt, token_id for generated + "scores": [score1, score2, ...], # rewards + "inference_logprobs": [[lp1, lp2, ...], ...], # CRITICAL for GRPO! + "generation_params": {"temperature": 1.0}, + ... + } + ] + } + +2. Preprocessing (pad_data_to_good_offset): + - Normalize advantages (mean=0, std=1 per group) + - Pad sequences to multiple of 64 + - Align inference_logprobs with labels: + * 1.0 for prompt tokens (masked) + * Actual negative logprobs for generated tokens + * Shift by 1 for causal alignment + - Extract temperatures (priority: override > generation_params > 1.0) + - Batch into micro-batches + +3. Output: + - token_batches: [B, seq_len] + - label_batches: [B, seq_len] # -100 for masked + - advantage_batches: [B, 1] + - temperature_batches: [B, 1, 1] + - inference_logprob_batches: [B, seq_len] # aligned with labels! +``` + +### GRPO Loss Computation (training.py) + +```python +# training.compute_grpo_loss() + +1. Forward Pass: + - Get logits from model + - Apply temperature scaling (from data) + - Compute log probabilities per token + +2. Reference Policy (π_old): + - Extract from inference_logprobs (from vLLM at generation time) + - Already aligned with labels by data.py + +3. Importance Sampling: + - log_ratio = log π_new(a|s) - log π_old(a|s) + - ratio = exp(log_ratio) + - Clipped ratio = clip(ratio, 1-ε, 1+ε) + +4. Policy Loss: + - surr1 = ratio * advantage + - surr2 = clipped_ratio * advantage + - policy_loss = -min(surr1, surr2) # pessimistic bound + +5. KL Penalty (Schulman's estimator): + - kl = exp(-log_ratio) + log_ratio - 1 + - Guaranteed non-negative, unbiased + +6. Total Loss: + - loss = policy_loss + β * kl_penalty + - Scaled by 1/gradient_accumulation_steps + +7. Metrics: + - mean_ratio: Average importance sampling ratio + - mean_kl: Average KL divergence + - clipped_fraction: % of tokens clipped + - alignment/* : Token-level logprob alignment (verifies weight sharing) ``` diff --git a/example_trainer/__init__.py b/example_trainer/__init__.py index f0ebdb72..57216902 100644 --- a/example_trainer/__init__.py +++ b/example_trainer/__init__.py @@ -1,7 +1,42 @@ """ -Example trainer implementations of how to implement a trainer for the Atropos library. +GRPO (Group Relative Policy Optimization) Trainer + +A training framework for fine-tuning language models with reinforcement learning, +designed to work with the Atropos environment system. + +Supports three training modes: +- Legacy: Checkpoint-based training with vLLM restarts +- Shared vLLM: Single-copy mode with CUDA IPC (no model duplication!) +- LoRA: Adapter-only training with hot-swap capability +- LoRA restart: Adapter training with periodic fast vLLM restarts + +Usage: + # As CLI + python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct --training-steps 100 + + # As library + from example_trainer import ( + TrainingConfig, + train_legacy, + train_shared_vllm, + train_lora, + train_lora_restart, + ) + + config = TrainingConfig(model_name="Qwen/Qwen2.5-3B-Instruct", training_steps=100) + train_legacy(config) """ -from example_trainer.grpo import TrainingConfig, train +from .cli import config_from_args, parse_args +from .config import TrainingConfig +from .trainers import train_legacy, train_lora, train_lora_restart, train_shared_vllm -__all__ = ["TrainingConfig", "train"] +__all__ = [ + "TrainingConfig", + "train_legacy", + "train_shared_vllm", + "train_lora", + "train_lora_restart", + "parse_args", + "config_from_args", +] diff --git a/example_trainer/api.py b/example_trainer/api.py new file mode 100644 index 00000000..21c4288e --- /dev/null +++ b/example_trainer/api.py @@ -0,0 +1,108 @@ +""" +Atropos API communication utilities. + +Handles communication with the Atropos API server for: +- Server health checks +- Trainer registration +- Batch retrieval +""" + +import time as _time + +import requests +from tenacity import retry, stop_after_attempt, wait_exponential + +from .config import TrainingConfig + + +def check_atropos_api( + url: str = "http://localhost:8000", timeout: float = 30.0 +) -> bool: + """ + Check if the Atropos API server is reachable. + + Args: + url: Base URL of the Atropos API server + timeout: Maximum time to wait for the server + + Returns: + True if server is reachable + """ + start = _time.time() + while _time.time() - start < timeout: + try: + response = requests.get(f"{url}/info", timeout=2) + if response.status_code == 200: + print(f"[Trainer] ✓ Atropos API server is reachable at {url}") + return True + except requests.exceptions.ConnectionError: + pass + except Exception as e: + print(f"[Trainer] Waiting for Atropos API at {url}... ({e})") + _time.sleep(1) + + print(f"[Trainer] ⚠ Warning: Atropos API server not reachable at {url}") + return False + + +@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30)) +def register_trainer(config: TrainingConfig): + """ + Register the trainer with the Atropos API. + + Verifies registration succeeded before returning. + """ + url = config.atropos_url + save_checkpoint_interval = ( + config.training_steps + if config.checkpoint_interval <= 0 + else config.checkpoint_interval + ) + response = requests.post( + f"{url}/register", + json={ + # wandb fields are required strings - use empty string if None + "wandb_group": config.wandb_group or "", + "wandb_project": config.wandb_project or "", + "batch_size": config.batch_size * config.gradient_accumulation_steps, + "max_token_len": config.seq_len, + "starting_step": 0, + "checkpoint_dir": config.save_path, + "save_checkpoint_interval": save_checkpoint_interval, + "num_steps": config.training_steps, + }, + timeout=10, + ) + + # Check for HTTP errors + response.raise_for_status() + + # Verify we got a valid response with UUID + data = response.json() + if "uuid" not in data: + raise RuntimeError(f"Registration failed: {data}") + + print(f"[Trainer] ✓ Registered with Atropos API at {url} (uuid: {data['uuid']})") + + +@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30)) +def get_batch(url: str = "http://localhost:8000"): + """ + Get a batch of training data from the Atropos API. + + Args: + url: Base URL of the Atropos API server + + Returns: + Batch data dictionary containing tokens, masks, scores, etc. + + Raises: + RuntimeError: If trainer is not registered or other API error + """ + data = requests.get(f"{url}/batch", timeout=10).json() + + # Check if there was an error (trainer not registered) + if data.get("status") == "error": + raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}") + + return data diff --git a/example_trainer/checkpointing.py b/example_trainer/checkpointing.py new file mode 100644 index 00000000..b5d60bbe --- /dev/null +++ b/example_trainer/checkpointing.py @@ -0,0 +1,157 @@ +""" +Checkpoint saving utilities for GRPO trainer. + +Handles saving model checkpoints for different training modes: +- Full model checkpoints (legacy and shared_vllm modes) +- LoRA adapter checkpoints + +IMPORTANT: For shared_vllm mode, the model parameters are VIEWS into vLLM's +fused tensors (qkv_proj, gate_up_proj). This module handles unfusing them +back to HuggingFace format for safe checkpoint saving. +""" + +import os +import shutil +from typing import Dict + +import torch + + +def _ensure_contiguous_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: + """ + Create a state dict with contiguous tensors for safe saving. + + This is critical for shared_vllm mode where parameters are views into + vLLM's fused tensors. Views may share storage and not be contiguous, + which can cause issues when saving. + + Returns: + State dict with all tensors made contiguous (copied if necessary) + """ + state_dict = {} + for name, param in model.named_parameters(): + # Check if tensor is a view (non-contiguous or shares storage) + if not param.is_contiguous() or param.storage_offset() != 0: + # Make a contiguous copy - this "unfuses" the view + state_dict[name] = param.detach().clone().contiguous() + else: + state_dict[name] = param.detach() + + # Also include buffers + for name, buffer in model.named_buffers(): + if not buffer.is_contiguous() or buffer.storage_offset() != 0: + state_dict[name] = buffer.detach().clone().contiguous() + else: + state_dict[name] = buffer.detach() + + return state_dict + + +def save_checkpoint( + model: torch.nn.Module, + tokenizer, + save_path: str, + step: int, + is_final: bool = False, + safe_mode: bool = True, +) -> str: + """ + Save full model checkpoint. + + Args: + model: Model to save + tokenizer: Tokenizer to save + save_path: Base directory for checkpoints + step: Current training step + is_final: Whether this is the final checkpoint + safe_mode: If True, ensure all tensors are contiguous before saving. + This is important for shared_vllm mode where params are + views into fused vLLM tensors. + + Returns: + Path where checkpoint was saved + """ + if is_final: + checkpoint_path = os.path.join(save_path, "final_model") + else: + checkpoint_path = os.path.join(save_path, f"step_{step}") + + print(f" Saving checkpoint to {checkpoint_path}...") + + if os.path.exists(checkpoint_path): + shutil.rmtree(checkpoint_path) + os.makedirs(checkpoint_path, exist_ok=True) + + if safe_mode: + # For shared_vllm mode: ensure views are properly unfused + print(" [Checkpoint] Using safe mode - ensuring contiguous tensors...") + state_dict = _ensure_contiguous_state_dict(model) + + # Count how many were non-contiguous (views into fused tensors) + view_count = sum( + 1 + for name, param in model.named_parameters() + if not param.is_contiguous() or param.storage_offset() != 0 + ) + if view_count > 0: + print( + f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)" + ) + + # Save state dict manually, then save config separately + torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin")) + model.config.save_pretrained(checkpoint_path) + + # CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory! + del state_dict + import gc + + gc.collect() + torch.cuda.empty_cache() + else: + # Standard save (may have issues with view tensors) + model.save_pretrained(checkpoint_path) + + tokenizer.save_pretrained(checkpoint_path) + + print(" Checkpoint saved.") + return checkpoint_path + + +def save_lora_checkpoint( + model: torch.nn.Module, + save_path: str, + step: int, + is_final: bool = False, +) -> str: + """ + Save LoRA adapter checkpoint. + + Only saves the LoRA adapter weights, not the full model. + This results in much smaller checkpoint files. + + Args: + model: PEFT model with LoRA adapters + save_path: Base directory for checkpoints + step: Current training step + is_final: Whether this is the final checkpoint + + Returns: + Path where adapter was saved + """ + if is_final: + adapter_path = os.path.join(save_path, "final_adapter") + else: + adapter_path = os.path.join(save_path, f"adapter_step_{step}") + + print(f" Saving LoRA adapter to {adapter_path}...") + + if os.path.exists(adapter_path): + shutil.rmtree(adapter_path) + os.makedirs(adapter_path, exist_ok=True) + + # Save only the adapter weights (much smaller than full model) + model.save_pretrained(adapter_path) + + print(" Adapter saved.") + return adapter_path diff --git a/example_trainer/cli.py b/example_trainer/cli.py new file mode 100644 index 00000000..4b602277 --- /dev/null +++ b/example_trainer/cli.py @@ -0,0 +1,438 @@ +""" +Command-line interface for GRPO trainer. + +Provides modular argument group builders and configuration building. +This is the SINGLE SOURCE OF TRUTH for all CLI arguments. +""" + +import argparse +from typing import List, Optional + +import torch + +from .config import TrainingConfig + +# ============================================================================= +# Argument Group Builders (modular, reusable) +# ============================================================================= + + +def _parse_lora_layer_indices(value: str) -> Optional[List[int]]: + """ + Parse LoRA layer indices from comma/range syntax. + + Supported formats: + - "20-31" + - "0,1,2,28,29,30,31" + - "0-3,28-31" + """ + if value is None: + return None + + raw = value.strip() + if not raw: + return None + + indices: List[int] = [] + parts = [part.strip() for part in raw.split(",") if part.strip()] + + try: + for part in parts: + if "-" in part: + start_s, end_s = part.split("-", 1) + start = int(start_s.strip()) + end = int(end_s.strip()) + if start > end: + raise argparse.ArgumentTypeError( + f"Invalid range '{part}': start must be <= end" + ) + indices.extend(range(start, end + 1)) + else: + indices.append(int(part)) + except ValueError as e: + raise argparse.ArgumentTypeError( + f"Invalid --lora-layer-indices value '{value}': {e}" + ) from e + + if not indices: + return None + if any(idx < 0 for idx in indices): + raise argparse.ArgumentTypeError( + f"Invalid --lora-layer-indices '{value}': indices must be >= 0" + ) + + return sorted(set(indices)) + + +def add_model_args(parser: argparse.ArgumentParser) -> None: + """Add model-related arguments.""" + group = parser.add_argument_group("Model") + group.add_argument( + "--model", + "--model-name", + type=str, + required=True, + dest="model_name", + help="HuggingFace model identifier (e.g., 'Qwen/Qwen2.5-1.5B-Instruct')", + ) + + +def add_training_args(parser: argparse.ArgumentParser) -> None: + """Add core training arguments.""" + group = parser.add_argument_group("Training") + group.add_argument( + "--lr", + type=float, + default=1e-5, + help="Learning rate for the optimizer", + ) + group.add_argument( + "--training-steps", + type=int, + default=10, + help="Number of training steps to run", + ) + group.add_argument( + "--batch-size", + type=int, + default=2, + help="Batch size for training", + ) + group.add_argument( + "--seq-len", + type=int, + default=2048, + help="Maximum sequence length", + ) + group.add_argument( + "--gradient-accumulation-steps", + type=int, + default=32, + help="Number of gradient accumulation steps", + ) + group.add_argument( + "--optimizer", + type=str, + choices=["adamw", "adamw_8bit", "adafactor"], + default="adamw_8bit", + help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), " + "'adafactor' (no momentum)", + ) + group.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to train on (cuda/cpu)", + ) + group.add_argument( + "--save-path", + type=str, + default="trained_model_checkpoints", + help="Directory to save model checkpoints", + ) + group.add_argument( + "--checkpoint-interval", + type=int, + default=3, + help="Save checkpoint every N training steps (0 = only save final)", + ) + + +def add_grpo_args(parser: argparse.ArgumentParser) -> None: + """Add GRPO/PPO hyperparameter arguments.""" + group = parser.add_argument_group("GRPO/PPO Hyperparameters") + group.add_argument( + "--kl-coef", + type=float, + default=0.1, + help="KL divergence penalty coefficient (beta). Higher = more conservative.", + ) + group.add_argument( + "--clip-eps", + type=float, + default=0.2, + help="PPO-style clipping epsilon. Clips ratio to [1-eps, 1+eps].", + ) + group.add_argument( + "--no-reference-logprobs", + action="store_true", + help="Disable use of inference logprobs as reference policy (not recommended).", + ) + + +def add_vllm_args(parser: argparse.ArgumentParser) -> None: + """Add vLLM server arguments.""" + group = parser.add_argument_group("vLLM Server") + group.add_argument( + "--vllm-port", + type=int, + default=9001, + help="Port for the vLLM server", + ) + group.add_argument( + "--vllm-gpu", + type=int, + default=None, + help="GPU ID for vLLM server. If not set, uses same GPU as trainer.", + ) + group.add_argument( + "--gpu-memory-utilization", + "--vllm-gpu-memory-utilization", + type=float, + default=0.45, + dest="gpu_memory_utilization", + help="GPU memory utilization for vLLM server (0.0-1.0)", + ) + group.add_argument( + "--max-model-len", + type=int, + default=4096, + help="Maximum context length for vLLM", + ) + group.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["bfloat16", "float16", "auto"], + help="Data type for model weights", + ) + group.add_argument( + "--vllm-restart-interval", + type=int, + default=3, + help="Restart vLLM every N training steps (legacy/lora_restart modes)", + ) + + +def add_atropos_args(parser: argparse.ArgumentParser) -> None: + """Add Atropos API arguments.""" + group = parser.add_argument_group("Atropos API") + group.add_argument( + "--atropos-url", + type=str, + default="http://localhost:8000", + help="URL of the Atropos API/environment server", + ) + + +def add_wandb_args(parser: argparse.ArgumentParser) -> None: + """Add Weights & Biases arguments.""" + group = parser.add_argument_group("Weights & Biases") + group.add_argument( + "--use-wandb", + action="store_true", + help="Enable Weights & Biases logging", + ) + group.add_argument( + "--wandb-project", + type=str, + default=None, + help="Wandb project name", + ) + group.add_argument( + "--wandb-group", + type=str, + default=None, + help="Wandb group name", + ) + + +def add_mode_args(parser: argparse.ArgumentParser) -> None: + """Add training mode arguments.""" + group = parser.add_argument_group("Training Mode") + group.add_argument( + "--weight-bridge-mode", + type=str, + choices=["shared_vllm", "lora_only", "lora_restart", "none"], + default="none", + help=( + "Weight sync mode: 'shared_vllm' (CUDA IPC), 'lora_only' (slow, --enforce-eager), " + "'lora_restart' (fast, restarts vLLM), or 'none' (legacy)" + ), + ) + group.add_argument( + "--vllm-config-path", + type=str, + default=None, + help="Explicit path to vllm_bridge_config.json (auto-detected if not provided)", + ) + + +def add_lora_args(parser: argparse.ArgumentParser) -> None: + """Add LoRA-specific arguments.""" + group = parser.add_argument_group("LoRA Configuration") + group.add_argument("--lora-r", type=int, default=16, help="LoRA rank") + group.add_argument( + "--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)" + ) + group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout") + group.add_argument( + "--lora-target-modules", + type=str, + nargs="+", + default=None, + help="Module names to apply LoRA to (default: q_proj v_proj)", + ) + group.add_argument( + "--lora-layer-indices", + type=_parse_lora_layer_indices, + default=None, + help=( + "Optional layer indices to apply LoRA to, e.g. '20-31' or " + "'0-3,28-31'. If omitted, applies to all matching layers." + ), + ) + + +def add_distributed_args(parser: argparse.ArgumentParser) -> None: + """Add distributed training arguments.""" + group = parser.add_argument_group("Distributed Training") + group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank") + group.add_argument("--world-size", type=int, default=1, help="World size") + group.add_argument( + "--init-method", type=str, default="env://", help="Distributed init method" + ) + group.add_argument( + "--num-inference-nodes", type=int, default=0, help="Number of inference nodes" + ) + + +def add_debug_args(parser: argparse.ArgumentParser) -> None: + """Add debug/benchmark arguments.""" + group = parser.add_argument_group("Debug & Benchmarking") + group.add_argument( + "--debug-loading", + action="store_true", + help="Enable verbose debug output during model loading", + ) + group.add_argument( + "--benchmark", + action="store_true", + help="Enable benchmark timing output", + ) + group.add_argument( + "--log-dir", + type=str, + default="./logs", + help="Directory for log files", + ) + + +# ============================================================================= +# Parser Builders +# ============================================================================= + + +def create_base_parser(description: str) -> argparse.ArgumentParser: + """Create a base parser with common formatting.""" + return argparse.ArgumentParser( + description=description, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + +def create_full_parser() -> argparse.ArgumentParser: + """ + Create a parser with ALL arguments (for grpo.py multi-mode entry point). + """ + parser = create_base_parser("GRPO Trainer - Multi-mode training") + + add_model_args(parser) + add_training_args(parser) + add_grpo_args(parser) + add_vllm_args(parser) + add_atropos_args(parser) + add_wandb_args(parser) + add_mode_args(parser) + add_lora_args(parser) + add_distributed_args(parser) + add_debug_args(parser) + + return parser + + +def create_unified_parser() -> argparse.ArgumentParser: + """ + Create a parser for run.py (unified shared_vllm mode with integrated vLLM). + """ + parser = create_base_parser( + "Unified GRPO Trainer - Starts vLLM server and trainer in one command" + ) + + add_model_args(parser) + add_training_args(parser) + add_grpo_args(parser) + add_vllm_args(parser) + add_atropos_args(parser) + add_wandb_args(parser) + add_debug_args(parser) + + return parser + + +# ============================================================================= +# Legacy API (backwards compatibility) +# ============================================================================= + + +def parse_args() -> argparse.Namespace: + """ + Parse command-line arguments for the GRPO trainer (grpo.py). + + Returns: + Parsed arguments namespace + """ + parser = create_full_parser() + return parser.parse_args() + + +def config_from_args(args: argparse.Namespace) -> TrainingConfig: + """ + Build a TrainingConfig from parsed CLI arguments. + + Args: + args: Parsed argparse namespace + + Returns: + TrainingConfig instance + """ + return TrainingConfig( + model_name=args.model_name, + lr=args.lr, + training_steps=args.training_steps, + batch_size=args.batch_size, + seq_len=args.seq_len, + gradient_accumulation_steps=args.gradient_accumulation_steps, + optimizer=args.optimizer, + device=args.device, + save_path=args.save_path, + checkpoint_interval=getattr(args, "checkpoint_interval", 3), + # GRPO/PPO hyperparameters + kl_coef=getattr(args, "kl_coef", 0.1), + clip_eps=getattr(args, "clip_eps", 0.2), + use_reference_logprobs=not getattr(args, "no_reference_logprobs", False), + # vLLM settings + vllm_restart_interval=getattr(args, "vllm_restart_interval", 3), + vllm_port=args.vllm_port, + vllm_gpu=getattr(args, "vllm_gpu", None), + vllm_gpu_memory_utilization=getattr(args, "gpu_memory_utilization", 0.45), + max_model_len=getattr(args, "max_model_len", 4096), + dtype=getattr(args, "dtype", "bfloat16"), + use_wandb=args.use_wandb, + wandb_project=args.wandb_project, + wandb_group=getattr(args, "wandb_group", None), + weight_bridge_mode=getattr(args, "weight_bridge_mode", "none"), + trainer_rank=getattr(args, "trainer_rank", 0), + world_size=getattr(args, "world_size", 1), + init_method=getattr(args, "init_method", "env://"), + num_inference_nodes=getattr(args, "num_inference_nodes", 0), + lora_r=getattr(args, "lora_r", 16), + lora_alpha=getattr(args, "lora_alpha", 32), + lora_dropout=getattr(args, "lora_dropout", 0.05), + lora_target_modules=getattr(args, "lora_target_modules", None), + lora_layer_indices=getattr(args, "lora_layer_indices", None), + vllm_config_path=getattr(args, "vllm_config_path", None), + debug_loading=getattr(args, "debug_loading", False), + benchmark=getattr(args, "benchmark", False), + atropos_url=getattr(args, "atropos_url", "http://localhost:8000"), + ) diff --git a/example_trainer/config.py b/example_trainer/config.py new file mode 100644 index 00000000..7acc3494 --- /dev/null +++ b/example_trainer/config.py @@ -0,0 +1,210 @@ +""" +Training configuration for GRPO trainer. + +This module contains the TrainingConfig class which defines all training +parameters, model settings, and operational modes. +""" + +from typing import List, Literal, Optional + +import torch +from pydantic import BaseModel, Field + + +class TrainingConfig(BaseModel): + """ + Training configuration for GRPO trainer. + + Supports three training modes: + - 'none' (legacy): Periodic checkpoint saves + vLLM restarts + - 'shared_vllm': Attach to vLLM's shared memory tensors, update in-place + - 'lora_only': Freeze base model, train LoRA adapters only + """ + + # === Model Configuration === + model_name: str = Field(..., description="Name of the base model to train") + + # === Training Hyperparameters === + lr: float = Field(1e-5, description="Learning rate for the optimizer") + training_steps: int = Field(10, description="Number of training steps") + batch_size: int = Field(2, description="Batch size for training") + seq_len: int = Field(2048, description="Sequence length for training") + gradient_accumulation_steps: int = Field( + 32, description="Number of gradient accumulation steps" + ) + optimizer: Literal["adamw", "adamw_8bit", "adafactor"] = Field( + "adamw_8bit", + description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), " + "'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), " + "'adafactor' (no momentum, ~8GB GPU)", + ) + + # === GRPO/PPO Hyperparameters === + kl_coef: float = Field( + 0.1, + description=( + "KL divergence penalty coefficient (beta). " + "Controls how much the policy can deviate from the reference (inference-time) policy. " + "Higher values = more conservative updates, prevents reward hacking. " + "Set to 0 to disable KL penalty (not recommended)." + ), + ) + clip_eps: float = Field( + 0.2, + description=( + "PPO-style clipping epsilon. " + "Clips the importance sampling ratio to [1-eps, 1+eps]. " + "Prevents large policy updates that could destabilize training." + ), + ) + use_reference_logprobs: bool = Field( + True, + description=( + "Whether to use inference logprobs as the reference policy (π_old). " + "When True, implements proper GRPO with importance sampling. " + "When False, falls back to REINFORCE-style updates (not recommended)." + ), + ) + + # === Device & Storage === + device: str = Field( + "cuda" if torch.cuda.is_available() else "cpu", description="Device to train on" + ) + save_path: str = Field( + "trained_model_checkpoints", description="Base path to save model checkpoints" + ) + checkpoint_interval: int = Field( + 3, + description=( + "Save checkpoint every N training steps. " + "Set to 0 to only save final checkpoint." + ), + ) + + # === vLLM Server Configuration === + vllm_restart_interval: int = Field( + 3, description="Restart vLLM every N training steps (legacy mode)" + ) + vllm_port: int = Field(9001, description="Port for the vLLM server") + vllm_gpu: Optional[int] = Field( + None, + description=( + "GPU ID for vLLM server (lora_restart/legacy modes). " + "If None, uses same GPU as trainer. Set different for parallelism." + ), + ) + vllm_gpu_memory_utilization: float = Field( + 0.45, description="GPU memory utilization for vLLM server (0.0-1.0)" + ) + max_model_len: int = Field( + 4096, description="Maximum context length for vLLM server" + ) + dtype: str = Field( + "bfloat16", description="Data type for model weights (bfloat16, float16, auto)" + ) + + # === Weights & Biases Configuration === + use_wandb: bool = Field( + False, description="Whether to use Weights & Biases for logging" + ) + wandb_project: Optional[str] = Field(None, description="Wandb project name") + wandb_group: Optional[str] = Field(None, description="Wandb group name") + + # === Training Mode Configuration === + weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field( + "none", + description=( + "How to synchronize weights with inference server. " + "'shared_vllm': attach to vLLM's shared memory tensors and update in-place. " + "'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). " + "'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). " + "'none': legacy mode, restart vLLM with new checkpoint files." + ), + ) + + # === Distributed Training Configuration === + trainer_rank: int = Field( + 0, description="Rank of this trainer in the distributed group" + ) + world_size: int = Field(1, description="Total processes in the distributed group") + init_method: str = Field( + "env://", + description=( + "PyTorch distributed init method URL. " + "Use 'env://' to read MASTER_ADDR/MASTER_PORT from environment, " + "or 'tcp://host:port' for explicit rendezvous." + ), + ) + num_inference_nodes: int = Field( + 0, + description=( + "Number of inference nodes (vLLM servers) to coordinate with. " + "0 means single-node local mode." + ), + ) + + # === LoRA Configuration === + lora_r: int = Field(16, description="LoRA rank (dimension of low-rank matrices)") + lora_alpha: int = Field(32, description="LoRA alpha (scaling factor)") + lora_dropout: float = Field(0.05, description="Dropout probability for LoRA layers") + lora_target_modules: Optional[List[str]] = Field( + None, + description=( + "List of module names to apply LoRA to. " + "If None, defaults to ['q_proj', 'v_proj'] for most models." + ), + ) + lora_layer_indices: Optional[List[int]] = Field( + None, + description=( + "Optional list of transformer layer indices to apply LoRA to. " + "If None, applies LoRA to all matching layers." + ), + ) + + # === Single-Copy Mode Configuration === + single_copy: bool = Field( + False, + description=( + "Enable TRUE single-copy mode via CUDA IPC. " + "The trainer attaches to vLLM's model tensors directly, " + "meaning only ONE copy of the model exists in GPU memory. " + "Requires trainer and vLLM to be on the SAME GPU(s). " + "vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1." + ), + ) + vllm_config_path: Optional[str] = Field( + None, + description=( + "Explicit path to vllm_bridge_config.json. " + "If not provided, auto-detects from LOGDIR environment variable, " + "current directory, or /tmp/atropos_bridge. " + "This file is created by vLLM when VLLM_ENABLE_SHARED_WEIGHTS=1 " + "and contains CUDA IPC handles for single-copy mode." + ), + ) + + # === Debug & Benchmark Flags === + debug_loading: bool = Field( + False, + description=( + "Enable verbose debug output during model loading and IPC attachment. " + "Useful for diagnosing single-copy mode issues." + ), + ) + benchmark: bool = Field( + False, + description=( + "Enable benchmark timing output showing step time, sync time, " + "data fetch time, and GPU memory usage per step." + ), + ) + + # === Atropos API Configuration === + atropos_url: str = Field( + "http://localhost:8000", + description=( + "URL of the Atropos API server (environment server). " + "Default is http://localhost:8000. Change for concurrent tests." + ), + ) diff --git a/example_trainer/configs/math_zero_lora.yaml b/example_trainer/configs/math_zero_lora.yaml new file mode 100644 index 00000000..5480c87e --- /dev/null +++ b/example_trainer/configs/math_zero_lora.yaml @@ -0,0 +1,13 @@ +env: + tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507" + rollout_server_url: "http://localhost:8002" + max_token_length: 8192 + start_tok_length: 8192 + group_size: 8 + batch_size: 64 + total_steps: 120 + steps_per_eval: 20 + use_wandb: true + wandb_name: "math-zero-lora-env" + eval_limit_ratio: 0.1 + max_num_workers_per_node: 24 diff --git a/example_trainer/configs/math_zero_shared.yaml b/example_trainer/configs/math_zero_shared.yaml new file mode 100644 index 00000000..e5ee82c7 --- /dev/null +++ b/example_trainer/configs/math_zero_shared.yaml @@ -0,0 +1,13 @@ +env: + tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507" + rollout_server_url: "http://localhost:8001" + max_token_length: 8192 + start_tok_length: 8192 + group_size: 8 + batch_size: 64 + total_steps: 120 + steps_per_eval: 20 + use_wandb: true + wandb_name: "math-zero-shared-env" + eval_limit_ratio: 0.1 + max_num_workers_per_node: 24 diff --git a/example_trainer/data.py b/example_trainer/data.py new file mode 100644 index 00000000..bb3ebbdb --- /dev/null +++ b/example_trainer/data.py @@ -0,0 +1,320 @@ +""" +Data processing utilities for GRPO trainer. + +Handles data retrieval from Atropos API, padding, batching, +and advantage normalization. + +Also extracts inference logprobs for proper GRPO loss computation: +- Inference logprobs serve as π_old (reference policy) for importance sampling +- They are batched and padded to align token-by-token with training labels +""" + +import math +import time +from typing import List, Optional, Tuple + +import numpy as np +import torch + +from .api import get_batch + + +def pad_data_to_good_offset( + data: dict, + batch_size: int, + extract_inference_logprobs: bool = True, +) -> Tuple[ + List[torch.Tensor], # token_batches + List[torch.Tensor], # label_batches + List[torch.Tensor], # advantage_batches + List[torch.Tensor], # temperature_batches + Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels) +]: + """ + Pad and batch data from the Atropos API. + + Processes raw batch data into properly padded tensors suitable for training: + - Pads token sequences to nearest multiple of 64 + - Normalizes advantage scores + - Extracts temperature values + - Extracts and pads inference logprobs for proper GRPO loss computation + + Args: + data: Raw batch data from Atropos API + batch_size: Size of each training batch + extract_inference_logprobs: Whether to extract inference logprobs + + Returns: + Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches) + inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data + + Note: + inference_logprob_batches are padded with 0.0 at positions where labels == -100. + This allows token-by-token alignment during GRPO loss computation. + """ + max_token_len = max( + [max([len(x) for x in item["tokens"]]) for item in data["batch"]] + ) + + # Pad to nearest multiple of 64 for GPU efficiency + good_multiple = 64 + if (max_token_len - 1) % (good_multiple) != 0: + max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple + token_setup_len = max_token_len + 1 # +1 for causal shift + else: + token_setup_len = max_token_len + max_token_len = max_token_len - 1 # -1 for causal shift + + # Process all items + input_ids = [] + labels = [] + advantages = [] + lengths = [] + temperatures = [] + inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape + has_any_logprobs = False + + for item in data["batch"]: + # Normalize advantage scores + scores = np.array(item["scores"]) + if len(scores) > 1: + scores = scores - scores.mean() + scores = scores / max(scores.std(), 1e-8) + item["scores"] = scores + + # Handle score overrides + if item["overrides"] is not None: + for i in range(len(item["overrides"])): + if item["overrides"][i].get("set_advantage_to_zero", False): + item["scores"][i] = 0 + + # Process each sample in the item + for i in range(len(item["tokens"])): + seq_len = len(item["tokens"][i]) + lengths.append(math.ceil((seq_len - 1) / good_multiple) * good_multiple) + + # Create labels with padding (-100 for masked positions) + label_item = np.concatenate( + [ + np.array(item["masks"][i]), + np.full( + max(0, token_setup_len - seq_len), + -100, + dtype=np.int32, + ), + ] + ) + + # Pad tokens + item["tokens"][i] = np.concatenate( + [ + np.array(item["tokens"][i]), + np.zeros( + max(0, token_setup_len - seq_len), + dtype=np.int32, + ), + ] + ) + + input_ids.append(item["tokens"][i][:-1]) # Remove last for causal + labels.append(label_item[1:]) # Shift by 1 for causal + advantages.append(item["scores"][i]) + + # Extract and pad inference logprobs to match labels shape + # IMPORTANT: inference_logprobs is ALREADY ALIGNED with tokens/masks: + # - 1.0 for prompt tokens (masked positions) + # - actual negative logprobs for generated tokens + # We just need to pad to match the sequence length + if extract_inference_logprobs and "inference_logprobs" in item: + if i < len(item["inference_logprobs"]): + raw_logprobs = np.array( + item["inference_logprobs"][i], dtype=np.float32 + ) + has_any_logprobs = True + + # Create padded logprobs array matching token_setup_len + # Fill with 1.0 (the masked token placeholder value) for padding + padded_logprobs = np.full(token_setup_len, 1.0, dtype=np.float32) + + # Copy raw_logprobs directly - they're already aligned with tokens + n_to_copy = min(len(raw_logprobs), token_setup_len) + padded_logprobs[:n_to_copy] = raw_logprobs[:n_to_copy] + + # Shift by 1 to match causal label shift + inference_logprobs_padded.append(padded_logprobs[1:]) + else: + # No logprobs for this sample, use 1.0 + inference_logprobs_padded.append( + np.full(token_setup_len - 1, 1.0, dtype=np.float32) + ) + elif extract_inference_logprobs: + # No inference_logprobs in item, use 1.0 + inference_logprobs_padded.append( + np.full(token_setup_len - 1, 1.0, dtype=np.float32) + ) + + # Extract temperature (priority: override > generation_params > group_overrides > 1.0) + t = 1.0 + if ( + item.get("overrides") + and i < len(item["overrides"]) + and isinstance(item["overrides"][i], dict) + and ("temperature" in item["overrides"][i]) + ): + t = float(item["overrides"][i]["temperature"]) + elif item.get("generation_params") and ( + "temperature" in item["generation_params"] + ): + t = float(item["generation_params"]["temperature"]) + elif item.get("group_overrides") and ( + "temperature" in item["group_overrides"] + ): + t = float(item["group_overrides"]["temperature"]) + temperatures.append(t) + + # Batch the data + token_batches = [] + label_batches = [] + advantage_batches = [] + temperature_batches = [] + inference_logprob_batches = [] + + for start in range(0, len(input_ids), batch_size): + end = min(start + batch_size, len(input_ids)) + + token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0))) + label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0))) + advantage_batches.append( + torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1) + ) + temperature_batches.append( + torch.tensor(np.array(temperatures[start:end], dtype=np.float32)).view( + -1, 1, 1 + ) + ) + + # Batch inference logprobs (same shape as labels) + if extract_inference_logprobs and inference_logprobs_padded: + inference_logprob_batches.append( + torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0)) + ) + + # Return inference logprob batches if we have any real logprobs + final_logprob_batches = ( + inference_logprob_batches + if (has_any_logprobs and inference_logprob_batches) + else None + ) + + return ( + token_batches, + label_batches, + advantage_batches, + temperature_batches, + final_logprob_batches, + ) + + +def get_data( + batch_size: int, + seq_len: int, + atropos_url: str = "http://localhost:8000", + extract_inference_logprobs: bool = True, +) -> Tuple[ + List[ + Tuple[ + List[torch.Tensor], # token_batches + List[torch.Tensor], # label_batches + List[torch.Tensor], # advantage_batches + List[torch.Tensor], # temperature_batches + Optional[List[torch.Tensor]], # inference_logprob_batches + ] + ], + None, # Legacy return (no longer used) +]: + """ + Fetch and process training data from the Atropos API. + + Continuously polls the API until data is available, then processes + all available batches. + + Args: + batch_size: Size of each training batch + seq_len: Maximum sequence length (for reference, not used directly) + atropos_url: URL of the Atropos API server + extract_inference_logprobs: Whether to extract inference logprobs for GRPO loss + + Returns: + Tuple of (batches, None) + - batches: List of processed batch tuples, each containing: + (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches) + - inference_logprob_batches are aligned with labels for proper GRPO loss computation + """ + batches = [] + _logged_logprob_warning = False + + while True: + data = get_batch(url=atropos_url) + + if data["batch"] is not None: + # DEBUG: Check if inference_logprobs exists in the data + if not _logged_logprob_warning: + has_logprobs = any( + "inference_logprobs" in item for item in data["batch"] + ) + if has_logprobs: + # Check if they're non-empty + sample_item = next( + ( + item + for item in data["batch"] + if "inference_logprobs" in item + ), + None, + ) + if sample_item and sample_item.get("inference_logprobs"): + sample_lp = ( + sample_item["inference_logprobs"][0] + if sample_item["inference_logprobs"] + else [] + ) + print( + f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})" + ) + else: + print( + " [Data] ⚠ inference_logprobs key exists but is empty!" + ) + else: + print(" [Data] ⚠ NO inference_logprobs in batch data!") + print( + f" [Data] Keys in first item: {list(data['batch'][0].keys())}" + ) + _logged_logprob_warning = True + + # Process and accumulate batches (now includes batched inference logprobs) + ( + token_batches, + label_batches, + adv_batches, + temp_batches, + inf_logprob_batches, + ) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) + + # Include inference logprob batches in the tuple + batches.append( + ( + token_batches, + label_batches, + adv_batches, + temp_batches, + inf_logprob_batches, + ) + ) + + elif len(batches) > 0: + # Return accumulated batches when no more data + return batches, None + else: + # Wait for data + time.sleep(1) diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index b8ab3534..41eb0063 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -1,590 +1,67 @@ -import atexit -import json -import math -import os -import random -import shutil -import string -import subprocess -import time -from typing import List, Optional, Tuple +#!/usr/bin/env python3 +""" +GRPO (Group Relative Policy Optimization) Trainer. -import numpy as np -import requests -import torch -import torch.nn.functional as F -import wandb # Added for logging -from pydantic import BaseModel, Field -from tenacity import retry, stop_after_attempt, wait_exponential -from torch.optim import AdamW -from transformers import AutoModelForCausalLM, AutoTokenizer +Supports four training modes: +- none (legacy): Periodic checkpoint saves + vLLM restarts +- shared_vllm: Single-copy mode with CUDA IPC weight sharing +- lora_only: LoRA adapter training with HTTP hot-swap (SLOW - needs --enforce-eager) +- lora_restart: LoRA training with vLLM restarts (FAST - CUDA graphs enabled) -# Global variable to keep track of the vLLM process -vllm_process = None +Usage: + # Legacy mode (manages vLLM internally) + python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct + + # Shared vLLM mode (requires external vLLM with VLLM_ENABLE_SHARED_WEIGHTS=1) + python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\ + --weight-bridge-mode shared_vllm + + # LoRA mode with HTTP hot-swap (SLOW - 13 TPS due to --enforce-eager) + python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\ + --weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32 + + # LoRA mode with vLLM restarts (FAST - 170 TPS with CUDA graphs) + python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\ + --weight-bridge-mode lora_restart --lora-r 16 --lora-alpha 32 \\ + --vllm-restart-interval 3 +""" + +from .cli import config_from_args, parse_args +from .trainers import train_legacy, train_lora, train_lora_restart, train_shared_vllm -def cleanup_vllm(): - global vllm_process - if vllm_process: - print("\nTerminating vLLM process...") - vllm_process.terminate() - try: - vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown - print("vLLM process terminated.") - except subprocess.TimeoutExpired: - print("vLLM process did not terminate gracefully, killing.") - vllm_process.kill() - vllm_process.wait() - print("vLLM process killed.") - vllm_process = None +def main(): + """Main entry point for GRPO trainer.""" + args = parse_args() + config = config_from_args(args) + print("\n" + "=" * 60) + print("GRPO TRAINER") + print("=" * 60) + print(f"Model: {config.model_name}") + print(f"Mode: {config.weight_bridge_mode}") + print(f"Training steps: {config.training_steps}") + print(f"GRPO: kl_coef={config.kl_coef}, clip_eps={config.clip_eps}") + print(f"{'='*60}\n") -# Register the cleanup function to be called on script exit -atexit.register(cleanup_vllm) + if config.weight_bridge_mode == "shared_vllm": + # Single-copy mode: attach to vLLM's weights, update in-place + train_shared_vllm(config) + elif config.weight_bridge_mode == "lora_only": + # LoRA mode: freeze base model, train adapters only (HTTP hot-swap) + # WARNING: This is SLOW (~13 TPS) because it requires --enforce-eager + train_lora(config) -class TrainingConfig(BaseModel): - """ - Training details, model, etc - """ + elif config.weight_bridge_mode == "lora_restart": + # LoRA mode with vLLM restarts (FAST - uses CUDA graphs) + # Restarts vLLM every vllm_restart_interval steps with new adapter + train_lora_restart(config) - model_name: str = Field(..., description="Name of the base model to train") - lr: float = Field(1e-5, description="Learning rate for the optimizer") - training_steps: int = Field( - 10, description="Number of training steps" - ) # Renamed from epochs - batch_size: int = Field( - 2, description="Batch size for training (will be handled by get_data)" - ) - seq_len: int = Field(2048, description="Sequence length for training") - gradient_accumulation_steps: int = Field( - 32, description="Number of gradient accumulation steps" - ) - device: str = Field( - "cuda" if torch.cuda.is_available() else "cpu", description="Device to train on" - ) - save_path: str = Field( - "trained_model_checkpoints", description="Base path to save model checkpoints" - ) - vllm_restart_interval: int = Field( - 3, description="Restart vLLM every N training steps" - ) - vllm_port: int = Field(9001, description="Port for the vLLM server") - - # Wandb configuration - use_wandb: bool = Field( - False, description="Whether to use Weights & Biases for logging" - ) - wandb_project: Optional[str] = Field(None, description="Wandb project name") - wandb_group: Optional[str] = Field(None, description="Wandb group name") - - -@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15)) -def register_trainer(config: TrainingConfig): - """ - Register the trainer with the Atropos API - """ - requests.post( - "http://localhost:8000/register", - json={ - "wandb_group": config.wandb_group, - "wandb_project": config.wandb_project, - "batch_size": config.batch_size * config.gradient_accumulation_steps, - "max_token_len": config.seq_len, - "starting_step": 0, - "checkpoint_dir": config.save_path, - "save_checkpoint_interval": config.training_steps, - "num_steps": config.training_steps, - }, - timeout=10, - ) - - -@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15)) -def get_batch(): - data = requests.get("http://localhost:8000/batch", timeout=10).json() - return data - - -def pad_data_to_good_offset(data, batch_size: int): - max_token_len = max( - [max([len(x) for x in item["tokens"]]) for item in data["batch"]] - ) - # usually 64 is a good choice to ensure nonweird scaling behavior on GPUS - # so we pad to the nearest multiple of 64 - good_multiple = 64 - if (max_token_len - 1) % (good_multiple) != 0: - max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple - token_setup_len = ( - max_token_len + 1 - ) # add 1 so we can make it causal at the proper length else: - token_setup_len = max_token_len - max_token_len = ( - max_token_len - 1 - ) # since it's causal we need to remove the last bit... - # pad all tokens to max_token_len and add to lists - input_ids = list() - labels = list() - advantages = list() - lengths = list() - temperatures = list() - for item in data["batch"]: - scores = item["scores"] - scores = np.array(scores) - # check if we have more than 1 score... - if len(scores) > 1: - scores = scores - scores.mean() - scores = scores / max(scores.std(), 1e-8) - item["scores"] = scores - if item["overrides"] is not None: - for i in range(len(item["overrides"])): - if item["overrides"][i].get("set_advantage_to_zero", False): - item["scores"][i] = 0 - for i in range(len(item["tokens"])): - lengths.append( - math.ceil((len(item["tokens"][i]) - 1) / (good_multiple)) - * good_multiple - ) - label_item = np.concatenate( - [ - np.array(item["masks"][i]), - np.full( - max(0, token_setup_len - len(item["tokens"][i])), - -100, - dtype=np.int32, - ), - ] - ) - item["tokens"][i] = np.concatenate( - [ - np.array(item["tokens"][i]), - np.zeros( - max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32 - ), - ] - ) - input_ids.append(item["tokens"][i][:-1]) - labels.append(label_item[1:]) - advantages.append(item["scores"][i]) - # per-sample override -> group generation_params -> group_overrides - > 1.0 - # need to update docs since this lets you set the temperature for each sample from the override - t = 1.0 - if ( - item.get("overrides") - and i < len(item["overrides"]) - and isinstance(item["overrides"][i], dict) - and ("temperature" in item["overrides"][i]) - ): - t = float(item["overrides"][i]["temperature"]) - elif item.get("generation_params") and ( - "temperature" in item["generation_params"] - ): - t = float(item["generation_params"]["temperature"]) - elif item.get("group_overrides") and ( - "temperature" in item["group_overrides"] - ): - t = float(item["group_overrides"]["temperature"]) - temperatures.append(t) - # combine all lists into tensors - token_batches = [] - label_batches = [] - advantage_batches = [] - temperature_batches = [] - for i in range(len(input_ids) // batch_size): - token_batches.append( - torch.tensor( - np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0) - ) - ) - label_batches.append( - torch.tensor( - np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0) - ) - ) - advantage_batches.append( - torch.tensor( - np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0) - ).view(-1, 1) - ) - # Temperatures: one per sample, shaped for broadcasting to [B, 1, 1] - temperature_batches.append( - torch.tensor( - np.array( - temperatures[i * batch_size : (i + 1) * batch_size], - dtype=np.float32, - ) - ).view(-1, 1, 1) - ) - - return token_batches, label_batches, advantage_batches, temperature_batches + # Legacy mode: periodic checkpoint saves + vLLM restarts + train_legacy(config) -def get_data( - batch_size: int, seq_len: int -) -> List[ - Tuple[ - List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor] - ] -]: - """ - getting data from the api - """ - batches = [] - while True: - data = get_batch() - if data["batch"] is not None: - # Save the batch - with open("temp.json", "w", encoding="utf-8") as f: - json.dump(data, f) - # In case the inference runs ahead of the training, we loop until we don't have any more data - batches.append(pad_data_to_good_offset(data, batch_size)) - elif len(batches) > 0: - # Return the batches - return batches - else: - time.sleep(1) - - -def train(config: TrainingConfig): - """ - Setups and runs GRPO training, restarting vLLM periodically, with wandb logging. - """ - global vllm_process # Declare intention to modify the global variable - - # --- Wandb Setup --- - if config.use_wandb: - if not config.wandb_project: - print("Warning: wandb_project not set, disabling wandb.") - config.use_wandb = False - else: - if not config.wandb_group: - # Set group to random 8 character string - config.wandb_group = "".join( - random.choices(string.ascii_letters + string.digits, k=8) - ) - try: - wandb.init( - project=config.wandb_project, - group=config.wandb_group, - config=config.dict(), # Log config parameters - ) - print( - f"Wandb logging enabled. Run: {wandb.run.name} (Project: {config.wandb_project}) " - ) - except Exception as e: - print(f"Error initializing wandb: {e}. Disabling wandb.") - config.use_wandb = False - # --- End Wandb Setup --- - - # Initialize model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(config.model_name) - model = AutoModelForCausalLM.from_pretrained( - config.model_name, torch_dtype=torch.bfloat16 - ) - - model.to(config.device) - model.gradient_checkpointing_enable() - model.train() - - # Setup optimizer - optimizer = AdamW(model.parameters(), lr=config.lr) - - print( - f"Starting training for {config.training_steps} steps on device: {config.device}" - ) - print( - f"vLLM will be restarted every {config.vllm_restart_interval} steps on port {config.vllm_port}" - ) - - os.makedirs(config.save_path, exist_ok=True) # Ensure base save directory exists - register_trainer(config) - - # Init vllm - vllm_command = [ - "python", - "-m", - "vllm.entrypoints.openai.api_server", - "--model", - config.model_name, - "--port", - str(config.vllm_port), - "--dtype", - "auto", - "--gpu-memory-utilization", - "0.45", - "--disable-log-requests", - ] - print(f" Launching vLLM server: {' '.join(vllm_command)}") - try: - vllm_process = subprocess.Popen(vllm_command) - print(f" vLLM server launched with PID: {vllm_process.pid}") - # Check immediate errors - try: - stdout, stderr = vllm_process.communicate(timeout=2) - if vllm_process.returncode is not None and vllm_process.returncode != 0: - print(f" Error starting vLLM: {stderr.decode()}") - vllm_process = None - # Maybe raise error or just warn? - print(" WARNING: Failed to start vLLM server after checkpoint.") - except subprocess.TimeoutExpired: - print(" vLLM process started (check logs for details).") - except FileNotFoundError: - print( - "\n *** ERROR: 'python -m vllm...' command not found. Make sure vLLM is installed and accessible. ***\n" - ) - # Potentially stop training or just disable further vLLM restarts - print(" Disabling further vLLM restarts.") - config.vllm_restart_interval = ( - config.training_steps + 1 - ) # Prevent further restarts - except Exception as e: - print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n") - print(" Disabling further vLLM restarts.") - config.vllm_restart_interval = ( - config.training_steps + 1 - ) # Prevent further restarts - - batches = list() - for step in range(config.training_steps): - total_loss = 0 - print(f"Step {step+1}/{config.training_steps}") - total_pos_logp = 0 - total_neg_logp = 0 - total_logp = 0 - total_pos = 0 - total_neg = 0 - if len(batches) == 0: - batches = get_data(config.batch_size, config.seq_len) - token_batches, label_batches, advantage_batches, temperature_batches = ( - batches.pop(0) - ) - # Terminate existing vLLM process if running - if ( - step + 1 - ) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step - # Terminate existing vLLM process if running - if vllm_process: - print(" Terminating existing vLLM process...") - vllm_process.terminate() - try: - vllm_process.wait(timeout=5) - except subprocess.TimeoutExpired: - print( - " Existing vLLM process did not terminate gracefully, killing." - ) - vllm_process.kill() - vllm_process.wait() - vllm_process = None - for tokens, labels, advantages, temperatures in zip( - token_batches, label_batches, advantage_batches, temperature_batches - ): - - tokens, labels, advantages = ( - tokens.to(config.device), - labels.to(config.device), - advantages.to(config.device), - ) - - # Forward pass - # User specified that tokens/labels are already prepared by get_data - outputs = model(tokens) # Assuming model just needs tokens - logits = outputs.logits # Assuming this is the structure - # temp scaled logits before cross entropy (clamp to prevent zero division or just ignore 0 temps?) - t = temperatures.to(logits.device, logits.dtype) - t = torch.where(t <= 0, torch.ones_like(t), t) - logits = logits / t - - # Calculate GRPO loss (reverting to user's previous logic) - # User stated ignore_index is -100 and tokens/labels are aligned by get_data - # Assuming logits correspond directly to labels indices (no shift needed here) - logp_per_token = -F.cross_entropy( - logits.view(-1, logits.size(-1)), # Flatten logits - labels.view(-1), # Flatten labels - reduction="none", - ignore_index=-100, # User specified ignore index - ).view( - labels.shape - ) # Reshape back to (batch, seq_len) - - # Masking based on labels != -100 - mask = (labels != -100).float() - with torch.no_grad(): - pos = (advantages > 0).float() - neg = (advantages <= 0).float() - mask = mask.to(logp_per_token.dtype) - mask_sum = mask.sum(dim=-1).clamp_min(1e-8) - avg_logp = (logp_per_token * mask).sum(dim=-1) / mask_sum - pos_logp = (logp_per_token * pos).mean().item() - neg_logp = (logp_per_token * neg).mean().item() - total_pos_logp += pos_logp - total_neg_logp += neg_logp - total_logp += avg_logp - total_pos += pos.sum().item() - total_neg += neg.sum().item() - - grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach()) - grpo_loss = ( - ((-grpo_loss_term * mask).sum(-1) / mask.sum(-1)) - * advantages.to(logp_per_token.device) - ).mean() / config.gradient_accumulation_steps - grpo_loss.backward() - total_loss += grpo_loss.item() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - optimizer.step() - optimizer.zero_grad() - if total_pos > 0: - total_pos_logp /= total_pos - if total_neg > 0: - total_neg_logp /= total_neg - # --- Wandb Logging --- - if config.use_wandb: - wandb.log( - { - "train/loss": total_loss, - "train/learning_rate": optimizer.param_groups[0]["lr"], - "train/grad_norm": grad_norm.item(), - "train/pos_logp": total_pos_logp, - "train/neg_logp": total_neg_logp, - "train/logp": total_logp, - }, - step=step + 1, - ) - # --- End Wandb Logging --- - - print(f" Step Loss: {grpo_loss.item():.4f}") - - # --- vLLM Restart Logic (Moved AFTER optimizer step) --- - # Note: There are much better ways of updating the policy, this is just a very simple example - if ( - step + 1 - ) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step - checkpoint_path = os.path.join( - config.save_path, f"step_{step+1}" - ) # Save as step+1 since it's after step completion - print(f" Saving checkpoint to {checkpoint_path}...") - # Ensure fresh directory for saving - if os.path.exists(checkpoint_path): - shutil.rmtree(checkpoint_path) # Remove old checkpoint if it exists - os.makedirs(checkpoint_path, exist_ok=True) - model.save_pretrained(checkpoint_path) - tokenizer.save_pretrained(checkpoint_path) - print(" Checkpoint saved.") - - # Terminate existing vLLM process if running - if vllm_process: - print(" Terminating existing vLLM process...") - vllm_process.terminate() - try: - vllm_process.wait(timeout=5) - except subprocess.TimeoutExpired: - print( - " Existing vLLM process did not terminate gracefully, killing." - ) - vllm_process.kill() - vllm_process.wait() - vllm_process = None - - # Launch new vLLM process (only if not the very last step, maybe? depends on use case) - # Let's still launch it on the last step for consistency, cleanup will handle it. - vllm_command = [ - "python", - "-m", - "vllm.entrypoints.openai.api_server", - "--model", - os.path.join(config.save_path, f"step_{step+1}"), - "--port", - str(config.vllm_port), - "--dtype", - "auto", - "--gpu-memory-utilization", - "0.45", - "--disable-log-requests", - "--served-model-name", - config.model_name, - ] - print(f" Launching vLLM server: {' '.join(vllm_command)}") - torch.cuda.empty_cache() - try: - vllm_process = subprocess.Popen(vllm_command) - print(f" vLLM server launched with PID: {vllm_process.pid}") - # Check immediate errors - try: - stdout, stderr = vllm_process.communicate(timeout=2) - if ( - vllm_process.returncode is not None - and vllm_process.returncode != 0 - ): - print(f" Error starting vLLM: {stderr.decode()}") - vllm_process = None - # Maybe raise error or just warn? - print( - " WARNING: Failed to start vLLM server after checkpoint." - ) - except subprocess.TimeoutExpired: - print(" vLLM process started (check logs for details).") - except FileNotFoundError: - print( - "\n *** ERROR: 'python -m vllm...' command not found. ", - "Make sure vLLM is installed and accessible. ***\n", - ) - # Potentially stop training or just disable further vLLM restarts - print(" Disabling further vLLM restarts.") - config.vllm_restart_interval = ( - config.training_steps + 1 - ) # Prevent further restarts - except Exception as e: - print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n") - print(" Disabling further vLLM restarts.") - config.vllm_restart_interval = ( - config.training_steps + 1 - ) # Prevent further restarts - # --- End vLLM Restart Logic --- - - # Basic check if vLLM process terminated unexpectedly (outside interval check) - if vllm_process and vllm_process.poll() is not None: - print( - f"\n *** WARNING: vLLM process terminated unexpectedly (return code: {vllm_process.returncode}). ", - "Check vLLM logs. ***\n", - ) - stderr_output = ( - vllm_process.stderr.read().decode() - if vllm_process.stderr - else "No stderr" - ) - print(f"vLLM stderr: {stderr_output}") - vllm_process = None # Reset so it relaunches next interval - - print("Training finished.") - # --- Wandb Finish --- - if config.use_wandb: - wandb.finish() - # --- End Wandb Finish --- - # Final cleanup (vLLM termination) is handled by atexit - - # --- Placeholder for final model save --- - final_save_path = os.path.join(config.save_path, "final_model") - print(f"Saving final model to {final_save_path}") - if os.path.exists(final_save_path): - shutil.rmtree(final_save_path) - os.makedirs(final_save_path, exist_ok=True) - model.save_pretrained(final_save_path) - tokenizer.save_pretrained(final_save_path) - print("Final model saved.") - - -# Example usage (optional, can be run from another script) if __name__ == "__main__": - # Example: Create a config and run training - # Replace "gpt2" with your desired model - training_config = TrainingConfig( - model_name="Qwen/Qwen2.5-1.5B-Instruct", - training_steps=20, # Use steps - vllm_restart_interval=3, # Example interval - use_wandb=True, # Set to True to enable logging - wandb_project="grpo-trainer-example", # Replace with your project name - ) - - # --- End Mock --- - - train(training_config) + main() diff --git a/example_trainer/model.py b/example_trainer/model.py new file mode 100644 index 00000000..315adc20 --- /dev/null +++ b/example_trainer/model.py @@ -0,0 +1,792 @@ +""" +Model loading and shared memory utilities for GRPO trainer. + +Handles: +- Standard model loading (legacy mode) +- LoRA model loading and wrapping +- Single-copy mode: Attaching to vLLM's shared tensors via CUDA IPC +""" + +import base64 +import json +import os +from typing import Dict, Optional, Tuple + +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from .config import TrainingConfig + +# Import PEFT for LoRA training +try: + from peft import LoraConfig, TaskType, get_peft_model + + PEFT_AVAILABLE = True +except ImportError: + PEFT_AVAILABLE = False + + +def _get_attention_implementation() -> str: + """ + Determine the best attention implementation to use. + Priority: + 1. Flash Attention 2 (if flash_attn library is available and works) + 2. SDPA (PyTorch's scaled dot-product attention) + + Returns: + Tuple of (attn_implementation string, human-readable name) + """ + try: + import flash_attn # noqa: F401 + + return "flash_attention_2" + except ImportError: + return "sdpa" + + +def _load_model_with_attention( + model_name_or_config, + torch_dtype=torch.bfloat16, + from_config: bool = False, +) -> torch.nn.Module: + """ + Load a model with the best available attention implementation. + + Args: + model_name_or_config: Either a model name (str) or a model config object + torch_dtype: Data type for model weights + from_config: If True, use from_config (for meta device loading - no weights) + If False, use from_pretrained (downloads and loads weights) + + Returns: + Loaded model with appropriate attention implementation + """ + # Select the loader function based on mode + # from_config: creates empty shell (meta device), from_pretrained: loads weights + loader = ( + AutoModelForCausalLM.from_config + if from_config + else AutoModelForCausalLM.from_pretrained + ) + + # Try attention implementations in order of preference + for attn_impl in ["flash_attention_2", "sdpa"]: + # Skip flash_attention_2 if not available + if ( + attn_impl == "flash_attention_2" + and _get_attention_implementation() != "flash_attention_2" + ): + continue + + try: + model = loader( + model_name_or_config, + torch_dtype=torch_dtype, + attn_implementation=attn_impl, + ) + print(f"[Setup] Using {attn_impl.replace('_', ' ').title()}") + return model + except Exception as e: + if attn_impl == "flash_attention_2": + print(f"[Setup] Flash Attention 2 failed ({e}), trying SDPA...") + continue + raise + + # Should never reach here, but just in case + raise RuntimeError("Failed to load model with any attention implementation") + + +def load_model_and_tokenizer( + config: TrainingConfig, + single_copy: bool = False, +) -> Tuple[torch.nn.Module, AutoTokenizer]: + """ + Load or attach to model based on weight_bridge_mode. + + Args: + config: Training configuration + single_copy: If True, attach to vLLM's shared tensors via CUDA IPC + + Returns: + Tuple of (model, tokenizer) + """ + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + + # Single-copy mode: attach to vLLM's shared tensors via CUDA IPC + if single_copy or config.weight_bridge_mode == "shared_vllm": + config_path = _find_vllm_config(config) + model = _attach_to_vllm_shared_tensors(config, config_path) + + if model is not None: + print("[Setup] ✓ Single-copy mode active - using vLLM's tensors directly!") + # Enable gradient checkpointing to save memory (was missing before!) + _setup_gradient_checkpointing(model, config) + model.train() + return model, tokenizer + else: + raise RuntimeError( + "[Setup] Single-copy mode FAILED to attach to vLLM's tensors.\n" + "Check:\n" + " 1. vLLM running with VLLM_ENABLE_SHARED_WEIGHTS=1\n" + " 2. vllm_bridge_config.json exists with ipc_handles\n" + " 3. Trainer is on SAME GPUs as vLLM" + ) + + elif config.weight_bridge_mode in ("lora_only", "lora_restart"): + # Both lora_only and lora_restart use PEFT LoRA adapters + model = _load_model_with_lora(config) + + else: + # Legacy mode: load full model + print("[Setup] Loading model for legacy mode...") + model = _load_model_with_attention(config.model_name) + model.to(config.device) + + # Enable gradient checkpointing + _setup_gradient_checkpointing(model, config) + model.train() + + return model, tokenizer + + +def _find_vllm_config(config: TrainingConfig) -> str: + """Find the vllm_bridge_config.json file.""" + # Check explicit path first + if config.vllm_config_path and os.path.exists(config.vllm_config_path): + print(f"[Setup] Using explicit vLLM config path: {config.vllm_config_path}") + return config.vllm_config_path + + # Auto-detect from common locations + possible_paths = [ + os.environ.get("LOGDIR", "."), + ".", + "/tmp/atropos_bridge", + os.path.dirname(os.path.abspath(__file__)), + ] + # Look through possible + for log_dir in possible_paths: + candidate = os.path.join(log_dir, "vllm_bridge_config.json") + if os.path.exists(candidate): + print(f"[Setup] Found vLLM config at: {candidate}") + return candidate + + checked = [os.path.join(p, "vllm_bridge_config.json") for p in possible_paths] + raise RuntimeError( + f"[Setup] Could not find vllm_bridge_config.json\n" + f"Checked: {checked}\n" + f"Tip: Use --vllm-config-path to specify the path explicitly\n" + f"Make sure vLLM is running with VLLM_ENABLE_SHARED_WEIGHTS=1 and LOGDIR set" + ) + + +def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module: + """ + Load base model and wrap with LoRA adapters. + + Args: + config: Training configuration with LoRA settings + + Returns: + PEFT model with LoRA adapters applied + """ + if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface + raise RuntimeError("PEFT library not available. Install with: pip install peft") + + print("[Setup] Loading base model for LoRA mode...") + base_model = _load_model_with_attention(config.model_name) + base_model.to(config.device) + + # Determine target modules + target_modules = config.lora_target_modules + if target_modules is None: + target_modules = ["q_proj", "v_proj"] + layer_indices = config.lora_layer_indices + + if layer_indices is not None: + num_hidden_layers = getattr(base_model.config, "num_hidden_layers", None) + if num_hidden_layers is None: + raise RuntimeError( + "Model config does not expose num_hidden_layers; cannot validate " + "--lora-layer-indices for this architecture." + ) + invalid = [idx for idx in layer_indices if idx >= num_hidden_layers] + if invalid: + raise ValueError( + f"Invalid --lora-layer-indices {invalid} for model with " + f"{num_hidden_layers} layers (valid range: 0-{num_hidden_layers - 1})" + ) + + print(f"Applying LoRA: r={config.lora_r}, alpha={config.lora_alpha}") + print(f"Target modules: {target_modules}") + if layer_indices is not None: + print( + f"Applying LoRA only to layers: {layer_indices} " + f"(total {len(layer_indices)})" + ) + + lora_kwargs = dict( + task_type=TaskType.CAUSAL_LM, + r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + target_modules=target_modules, + bias="none", + ) + if layer_indices is not None: + lora_kwargs["layers_to_transform"] = layer_indices + lora_config = LoraConfig(**lora_kwargs) + + model = get_peft_model(base_model, lora_config) + model.print_trainable_parameters() + + return model + + +def _setup_gradient_checkpointing( + model: torch.nn.Module, config: TrainingConfig +) -> None: + """Configure gradient checkpointing for the model.""" + # Disable KV cache - incompatible with gradient checkpointing + model.config.use_cache = False + + if config.weight_bridge_mode in ("lora_only", "lora_restart"): + # PEFT models need special handling - enable_input_require_grads is CRITICAL + # Without this, the LoRA parameters won't receive gradients! + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + else: + model.gradient_checkpointing_enable() + + +def _attach_to_vllm_shared_tensors( + config: TrainingConfig, + bridge_config_path: str, +) -> Optional[torch.nn.Module]: + """ + Attach to vLLM's shared tensors via CUDA IPC (true single-copy mode). + + This creates a model whose parameters point to the SAME GPU memory as vLLM, + meaning only ONE copy of the model exists in GPU memory. + + Args: + config: Training configuration + bridge_config_path: Path to vllm_bridge_config.json + + Returns: + Model with parameters pointing to vLLM's tensors, or None if not possible + """ + print(f"[Setup] Reading bridge config from: {bridge_config_path}") + # Load the bridge that we just searched for + try: + with open(bridge_config_path, "r") as f: + bridge_config = json.load(f) + print(f"[Setup] Bridge config keys: {list(bridge_config.keys())}") + except Exception as e: + print(f"[Setup] Could not read bridge config: {e}") + return None + + single_copy_enabled = bridge_config.get("single_copy_enabled", False) + print(f"[Setup] single_copy_enabled in config: {single_copy_enabled}") + # If single copy is not enable here then we exist because VLLM is likely botched + if not single_copy_enabled: + print("[Setup] Single-copy mode not available (single_copy_enabled=False)") + print("[Setup] Make sure vLLM was started with VLLM_ENABLE_SHARED_WEIGHTS=1") + return None + # Get IPC handles from bridge config - memory pointers to shared weight tensors + ipc_handles_raw = bridge_config.get("ipc_handles", {}) + print(f"[Setup] IPC handles count: {len(ipc_handles_raw)}") + if not ipc_handles_raw: + print("[Setup] No IPC handles found in bridge config") + return None + + # Deserialize base64-encoded bytes + ipc_handles = _deserialize_ipc_handles(ipc_handles_raw) + + print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...") + print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!") + + # Load model config (not weights) to get architecture + # doesn't store the buffers just basically the schematics. This is the + # the blueprint for the house not the actual house + model_config = AutoConfig.from_pretrained(config.model_name) + + # Create empty model on meta device (no memory allocation) + # Try Flash Attention 2 first (matches vLLM better), fall back to SDPA + with torch.device("meta"): + model = _load_model_with_attention(model_config, from_config=True) + + param_names = list(model.state_dict().keys()) + print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True) + + # Initialize CUDA on devices used by vLLM + device_indices = _initialize_cuda_devices(ipc_handles) + + # Create mapping from HF names to vLLM tensors + vllm_to_hf_mapping = _create_vllm_to_hf_mapping( + model, ipc_handles, debug=config.debug_loading + ) + + # Reconstruct tensors and build state dict + hf_state_dict, attached_count, fused_count = _reconstruct_shared_tensors( + ipc_handles, vllm_to_hf_mapping, config + ) + + print( + f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)" + ) + + if attached_count == 0: + print("[Setup] Could not attach any tensors, falling back to regular loading") + return None + + # Validate mapping coverage + _validate_mapping_coverage(model, hf_state_dict, attached_count) + + # Load state dict into model + model.load_state_dict(hf_state_dict, strict=False, assign=True) + + # Initialize remaining meta tensors + device = f"cuda:{list(device_indices)[0]}" if device_indices else "cuda:0" + _initialize_meta_tensors(model, device, config) + + # Final validation - ensure nothing is left on meta device + _validate_no_meta_tensors(model) + + print("[Setup] ✓ All tensors successfully initialized on CUDA") + return model + + +def _deserialize_ipc_handles(handles_raw: dict) -> dict: + """Deserialize base64-encoded bytes in IPC handles.""" + + def deserialize(handles): + result = {} + for k, v in handles.items(): + if isinstance(v, dict): + if "_bytes_b64_" in v: + result[k] = base64.b64decode(v["_bytes_b64_"]) + else: + result[k] = deserialize(v) + else: + result[k] = v + return result + + return deserialize(handles_raw) + + +def _initialize_cuda_devices(ipc_handles: dict) -> set: + """Initialize CUDA context on devices used by IPC handles.""" + device_indices = set() + for name, info in ipc_handles.items(): + if "device_index" in info: + device_indices.add(info["device_index"]) + + print(f"[Setup] IPC handles span devices: {sorted(device_indices)}", flush=True) + + for dev_idx in sorted(device_indices): + print(f"[Setup] Initializing CUDA on device {dev_idx}...", flush=True) + torch.cuda.set_device(dev_idx) + torch.cuda.synchronize(dev_idx) + print(f"[Setup] ✓ Device {dev_idx} ready", flush=True) + + return device_indices + + +def _reconstruct_shared_tensors( + ipc_handles: dict, + vllm_to_hf_mapping: dict, + config: TrainingConfig, +) -> Tuple[dict, int, int]: + """Reconstruct tensors from IPC handles and build state dict.""" + hf_state_dict = {} + vllm_tensor_cache: Dict[str, torch.Tensor] = {} + attached_count = 0 + fused_count = 0 + + def reconstruct_vllm_tensor(vllm_name: str) -> Optional[torch.Tensor]: + if vllm_name in vllm_tensor_cache: + return vllm_tensor_cache[vllm_name] + + if vllm_name not in ipc_handles: + return None + + ipc_info = ipc_handles[vllm_name] + if "ipc_handle_b64" not in ipc_info: + return None + + try: + device_index = ipc_info["device_index"] + ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"]) + storage_size = ipc_info["storage_size"] + storage_offset_orig = ipc_info["storage_offset_orig"] + ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"]) + ref_counter_offset = ipc_info["ref_counter_offset"] + event_handle = base64.b64decode(ipc_info["event_handle_b64"]) + event_sync_required = ipc_info["event_sync_required"] + + share_tuple = ( + device_index, + ipc_handle, + storage_size, + storage_offset_orig, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) + + storage = torch.UntypedStorage._new_shared_cuda(*share_tuple) + dtype = getattr(torch, ipc_info["dtype"].replace("torch.", "")) + tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}") + tensor.set_( + storage, + storage_offset=ipc_info["tensor_storage_offset"], + size=ipc_info["shape"], + stride=ipc_info["stride"], + ) + + vllm_tensor_cache[vllm_name] = tensor + return tensor + + except Exception as e: + print(f"[Setup] Failed to reconstruct {vllm_name}: {e}", flush=True) + return None + + for hf_name, mapping_info in vllm_to_hf_mapping.items(): + try: + if isinstance(mapping_info, dict): + # Fused mapping - slice the source tensor + vllm_name = mapping_info["source"] + slice_start, slice_end = mapping_info["slice"] + slice_dim = mapping_info["dim"] + + full_tensor = reconstruct_vllm_tensor(vllm_name) + if full_tensor is None: + continue + + # Create VIEW (not copy) into the fused tensor + if slice_dim == 0: + tensor = full_tensor[slice_start:slice_end] + else: + tensor = full_tensor.narrow( + slice_dim, slice_start, slice_end - slice_start + ) + + tensor.requires_grad_(True) + hf_state_dict[hf_name] = tensor + fused_count += 1 + attached_count += 1 + + else: + # Direct mapping + vllm_name = mapping_info + tensor = reconstruct_vllm_tensor(vllm_name) + if tensor is None: + continue + + tensor.requires_grad_(True) + hf_state_dict[hf_name] = tensor + attached_count += 1 + + except Exception as e: + print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True) + + return hf_state_dict, attached_count, fused_count + + +def _validate_mapping_coverage( + model: torch.nn.Module, + hf_state_dict: dict, + attached_count: int, +) -> None: + """Validate that enough parameters were mapped.""" + hf_param_count = len(list(model.named_parameters())) + mapping_coverage = attached_count / hf_param_count if hf_param_count > 0 else 0 + + # Note: attached_count may be > param_count because state_dict includes buffers + # while named_parameters only counts trainable params + print( + f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters " + f"(>100% is OK - includes buffers)" + ) + + if mapping_coverage < 0.90: + unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys()) + warning_msg = ( + f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n" + ) + warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n" + for name in list(unmapped_params)[:20]: + warning_msg += f" - {name}\n" + print(warning_msg) + + if mapping_coverage < 0.50: + raise RuntimeError( + f"[Setup] CRITICAL: Only {mapping_coverage:.1%} of parameters mapped!" + ) + else: + print(f"[Setup] ✓ Good mapping coverage ({mapping_coverage:.1%})") + + +def _initialize_meta_tensors( + model: torch.nn.Module, + device: str, + config: TrainingConfig, +) -> None: + """Initialize any remaining meta tensors after loading.""" + meta_params = [ + name for name, p in model.named_parameters() if p.device.type == "meta" + ] + meta_buffers = [ + name for name, b in model.named_buffers() if b.device.type == "meta" + ] + + if config.debug_loading: + print( + f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}" + ) + + def get_parent_and_name(model, full_name): + parts = full_name.split(".") + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + return parent, parts[-1] + + meta_count = 0 + + # Initialize meta parameters + for name in meta_params: + param = dict(model.named_parameters()).get(name) + if param is None: + continue + + try: + new_data = torch.zeros(param.shape, dtype=param.dtype, device=device) + new_param = torch.nn.Parameter(new_data, requires_grad=param.requires_grad) + parent, attr_name = get_parent_and_name(model, name) + setattr(parent, attr_name, new_param) + meta_count += 1 + except Exception as e: + if config.debug_loading: + print(f"[DIAGNOSTIC] FAILED to initialize {name}: {e}") + + # Initialize meta buffers + for name in meta_buffers: + buffer = dict(model.named_buffers()).get(name) + if buffer is None: + continue + + try: + if "inv_freq" in name: + dim = buffer.shape[0] * 2 + # Get rope_theta from model config (default 10000.0 for LLaMA, but Qwen3 uses 5000000!) + rope_theta = getattr(model.config, "rope_theta", 10000.0) + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) + new_buffer = inv_freq.to(dtype=buffer.dtype, device=device) + print(f"[Setup] Initialized {name} with rope_theta={rope_theta}") + else: + new_buffer = torch.zeros( + buffer.shape, dtype=buffer.dtype, device=device + ) + + parent, attr_name = get_parent_and_name(model, name) + parent.register_buffer(attr_name, new_buffer) + meta_count += 1 + except Exception as e: + if config.debug_loading: + print(f"[DIAGNOSTIC] FAILED to initialize buffer {name}: {e}") + + print(f"\n[Setup] Initialized {meta_count} remaining meta tensors") + + +def _validate_no_meta_tensors(model: torch.nn.Module) -> None: + """Ensure no parameters or buffers are still on meta device.""" + final_meta_params = [ + name for name, p in model.named_parameters() if p.device.type == "meta" + ] + final_meta_buffers = [ + name for name, b in model.named_buffers() if b.device.type == "meta" + ] + + if final_meta_params or final_meta_buffers: + error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n" + error_msg += "The model would produce GARBAGE output.\n\n" + + if final_meta_params: + error_msg += f"Meta parameters ({len(final_meta_params)}):\n" + for name in final_meta_params[:20]: + error_msg += f" - {name}\n" + + if final_meta_buffers: + error_msg += f"\nMeta buffers ({len(final_meta_buffers)}):\n" + for name in final_meta_buffers[:20]: + error_msg += f" - {name}\n" + + raise RuntimeError(error_msg) + + +def _create_vllm_to_hf_mapping( + model: torch.nn.Module, + ipc_handles: dict, + debug: bool = False, +) -> dict: + """ + Create mapping from HuggingFace parameter names to vLLM tensor names. + + Handles fused layers: + - qkv_proj (vLLM) = q_proj + k_proj + v_proj (HF) + - gate_up_proj (vLLM) = gate_proj + up_proj (HF) + + Uses actual tensor shapes from HF model to determine slice sizes, + rather than calculating from config (which can be wrong for some models). + """ + hf_state_dict = model.state_dict() + print("Here is the HF state dict so that we can get a better view ") + hf_params = set(hf_state_dict.keys()) + vllm_params = set(ipc_handles.keys()) + + # Get model config for fallback dimension calculations + model_config = model.config + hidden_size = getattr(model_config, "hidden_size", 4096) + num_attention_heads = getattr(model_config, "num_attention_heads", 32) + num_key_value_heads = getattr( + model_config, "num_key_value_heads", num_attention_heads + ) + intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4) + + # Try to get head_dim from config (some models like Qwen3 have this) + head_dim = getattr(model_config, "head_dim", None) + if head_dim is None: + head_dim = hidden_size // num_attention_heads + + # Determine QKV sizes from ACTUAL HF model tensor shapes + # Look for a q_proj weight in the model to get the actual size + q_size = None + k_size = None + v_size = None + + for name, param in hf_state_dict.items(): + if "q_proj.weight" in name and q_size is None: + q_size = param.shape[0] # Output dimension + elif "k_proj.weight" in name and k_size is None: + k_size = param.shape[0] + elif "v_proj.weight" in name and v_size is None: + v_size = param.shape[0] + if q_size and k_size and v_size: + break + + # Fallback to calculated values if not found + if q_size is None: + q_size = num_attention_heads * head_dim + if k_size is None: + k_size = num_key_value_heads * head_dim + if v_size is None: + v_size = num_key_value_heads * head_dim + + # Also get gate/up sizes from actual HF model + gate_size = None + up_size = None + + for name, param in hf_state_dict.items(): + if "gate_proj.weight" in name and gate_size is None: + gate_size = param.shape[0] + elif "up_proj.weight" in name and up_size is None: + up_size = param.shape[0] + if gate_size and up_size: + break + + # Fallback + if gate_size is None: + gate_size = intermediate_size + if up_size is None: + up_size = intermediate_size + + # Always print sizes for debugging weight sharing issues + print( + f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, " + f"kv_heads={num_key_value_heads}, head_dim={head_dim}" + ) + print(f"[Mapping] QKV sizes from HF model: q={q_size}, k={k_size}, v={v_size}") + print(f"[Mapping] Gate/Up sizes from HF model: gate={gate_size}, up={up_size}") + + mapping = {} + + def find_vllm_name(hf_name: str) -> Optional[str]: + if hf_name in vllm_params: + return hf_name + if not hf_name.startswith("model."): + candidate = f"model.{hf_name}" + if candidate in vllm_params: + return candidate + if hf_name.startswith("model."): + candidate = hf_name[6:] + if candidate in vllm_params: + return candidate + return None + + def find_fused_source(hf_name: str, fused_suffix: str) -> Optional[str]: + for unfused in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]: + if unfused in hf_name: + fused_name = hf_name.replace(unfused, fused_suffix) + found = find_vllm_name(fused_name) + if found: + return found + return None + + for hf_name in hf_params: + # Try direct match first + vllm_name = find_vllm_name(hf_name) + if vllm_name: + mapping[hf_name] = vllm_name + continue + + # Check for QKV fusion + if any(x in hf_name for x in ["q_proj", "k_proj", "v_proj"]): + fused_name = find_fused_source(hf_name, "qkv_proj") + if fused_name: + if "q_proj" in hf_name: + start, end = 0, q_size + elif "k_proj" in hf_name: + start, end = q_size, q_size + k_size + else: + start, end = q_size + k_size, q_size + k_size + v_size + + mapping[hf_name] = { + "source": fused_name, + "slice": (start, end), + "dim": 0, + "type": "qkv_fusion", + } + continue + + # Check for Gate/Up fusion + if any(x in hf_name for x in ["gate_proj", "up_proj"]): + fused_name = find_fused_source(hf_name, "gate_up_proj") + if fused_name: + if "gate_proj" in hf_name: + start, end = 0, gate_size + else: + start, end = gate_size, gate_size + up_size + + mapping[hf_name] = { + "source": fused_name, + "slice": (start, end), + "dim": 0, + "type": "gate_up_fusion", + } + continue + + if debug: + direct = sum(1 for v in mapping.values() if isinstance(v, str)) + fused = sum(1 for v in mapping.values() if isinstance(v, dict)) + print( + f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)" + ) + + return mapping diff --git a/example_trainer/requirements.txt b/example_trainer/requirements.txt index 403d558c..0313a941 100644 --- a/example_trainer/requirements.txt +++ b/example_trainer/requirements.txt @@ -3,3 +3,6 @@ torch transformers datasets accelerate +peft +requests +wandb diff --git a/example_trainer/run.py b/example_trainer/run.py new file mode 100644 index 00000000..ff60d0ba --- /dev/null +++ b/example_trainer/run.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +""" +Unified GRPO trainer with integrated vLLM server (shared_vllm mode). + +Combines vLLM server startup and trainer into a single command: + python example_trainer/run.py --model Qwen/Qwen3-4B-Instruct --training-steps 20 + +This script: +1. Starts vLLM server with shared weights enabled +2. Waits for vLLM to be ready and bridge config to be created +3. Starts the GRPO trainer in shared_vllm mode +4. Handles cleanup on exit + +For other modes (legacy, LoRA), use grpo.py instead. +""" + +import atexit +import os +import signal +import subprocess +import sys +import time +from pathlib import Path + +import requests + +from .cli import create_unified_parser +from .config import TrainingConfig +from .trainers import train_shared_vllm + + +def wait_for_vllm(port: int, timeout: int = 300) -> bool: + """Wait for vLLM server to be ready.""" + print(f"[Run] Waiting for vLLM server on port {port}...") + start = time.time() + + while time.time() - start < timeout: + try: + response = requests.get(f"http://localhost:{port}/health", timeout=5) + if response.status_code == 200: + print(f"[Run] ✓ vLLM server is ready (took {time.time() - start:.1f}s)") + return True + except requests.exceptions.ConnectionError: + pass + except Exception as e: + print(f"[Run] Health check error: {e}") + + time.sleep(2) + + print(f"[Run] ✗ vLLM server failed to start within {timeout}s") + return False + + +def wait_for_bridge_config(config_path: str, timeout: int = 60) -> bool: + """Wait for vLLM bridge config to be created.""" + print(f"[Run] Waiting for bridge config at {config_path}...") + start = time.time() + + while time.time() - start < timeout: + if os.path.exists(config_path): + try: + import json + + with open(config_path, "r") as f: + config = json.load(f) + if config.get("ipc_handles") and len(config["ipc_handles"]) > 0: + print( + f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles" + ) + return True + except Exception: + pass + time.sleep(1) + + print(f"[Run] ✗ Bridge config not created within {timeout}s") + return False + + +def main(): + # Parse args using shared CLI module + parser = create_unified_parser() + args = parser.parse_args() + + # Create log directory + log_dir = getattr(args, "log_dir", "./logs") + os.makedirs(log_dir, exist_ok=True) + + # Bridge config path + bridge_config_path = "./vllm_bridge_config.json" + + # Clean up old bridge config + if os.path.exists(bridge_config_path): + os.remove(bridge_config_path) + print("[Run] Removed old bridge config") + + # === Print Configuration === + print("\n" + "=" * 60) + print("STARTING UNIFIED GRPO TRAINER (shared_vllm mode)") + print("=" * 60) + print(f"Model: {args.model_name}") + print(f"vLLM port: {args.vllm_port}") + print(f"GPU memory utilization: {args.gpu_memory_utilization}") + print(f"Training steps: {args.training_steps}") + print(f"Optimizer: {args.optimizer}") + print(f"GRPO: kl_coef={args.kl_coef}, clip_eps={args.clip_eps}") + print("=" * 60 + "\n") + + # Get the path to vllm_api_server.py + script_dir = Path(__file__).parent + vllm_server_script = script_dir / "vllm_api_server.py" + + if not vllm_server_script.exists(): + print(f"[Run] ✗ vLLM server script not found at {vllm_server_script}") + sys.exit(1) + + # Extract device index from args.device + device_index = "0" + if ":" in args.device: + device_index = args.device.split(":")[1] + + # Build vLLM environment + vllm_env = os.environ.copy() + vllm_env["VLLM_ENABLE_SHARED_WEIGHTS"] = "1" + vllm_env["VLLM_BRIDGE_CONFIG_PATH"] = bridge_config_path + vllm_env["CUDA_VISIBLE_DEVICES"] = device_index + vllm_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + vllm_env["VLLM_USE_V1"] = "0" # v0 engine required for shared weights patches + vllm_env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" # Required for CUDA + + # Build vLLM command + vllm_cmd = [ + sys.executable, + "-u", + str(vllm_server_script), + "--model", + args.model_name, + "--port", + str(args.vllm_port), + "--dtype", + args.dtype, + "--gpu-memory-utilization", + str(args.gpu_memory_utilization), + "--max-model-len", + str(args.max_model_len), + "--enforce-eager", # Required for shared weights + ] + + vllm_log_path = os.path.join(log_dir, "vllm.log") + print(f"[Run] Starting vLLM server (log: {vllm_log_path})...") + + vllm_log = open(vllm_log_path, "w") + vllm_process = subprocess.Popen( + vllm_cmd, + env=vllm_env, + stdout=vllm_log, + stderr=subprocess.STDOUT, + ) + + # Register cleanup + def cleanup(): + print("\n[Run] Cleaning up...") + if vllm_process.poll() is None: + print("[Run] Terminating vLLM server...") + vllm_process.terminate() + try: + vllm_process.wait(timeout=10) + except subprocess.TimeoutExpired: + vllm_process.kill() + vllm_log.close() + print("[Run] Cleanup complete.") + + atexit.register(cleanup) + signal.signal(signal.SIGINT, lambda s, f: sys.exit(0)) + signal.signal(signal.SIGTERM, lambda s, f: sys.exit(0)) + + # Wait for vLLM to be ready + if not wait_for_vllm(args.vllm_port, timeout=500): + print("[Run] ✗ vLLM server failed to start. Check logs at:", vllm_log_path) + sys.exit(1) + + # Wait for bridge config + if not wait_for_bridge_config(bridge_config_path, timeout=60): + print("[Run] ✗ Bridge config not created. Check vLLM logs.") + sys.exit(1) + + # === Start Trainer === + print("\n[Run] Starting GRPO trainer...") + + # Build config - override some fields for shared_vllm mode + config = TrainingConfig( + model_name=args.model_name, + lr=args.lr, + training_steps=args.training_steps, + batch_size=args.batch_size, + seq_len=args.seq_len, + gradient_accumulation_steps=args.gradient_accumulation_steps, + optimizer=args.optimizer, + device="cuda:0", # Always 0 since we set CUDA_VISIBLE_DEVICES + save_path=args.save_path, + checkpoint_interval=args.checkpoint_interval, + # GRPO hyperparameters + kl_coef=args.kl_coef, + clip_eps=args.clip_eps, + use_reference_logprobs=not getattr(args, "no_reference_logprobs", False), + # vLLM settings + vllm_port=args.vllm_port, + vllm_gpu_memory_utilization=args.gpu_memory_utilization, + vllm_config_path=bridge_config_path, + # Mode settings + weight_bridge_mode="shared_vllm", # Always shared_vllm for run.py + atropos_url=args.atropos_url, + # Logging + use_wandb=args.use_wandb, + wandb_project=args.wandb_project, + benchmark=True, # Always show timing info for run.py + debug_loading=getattr(args, "debug_loading", False), + ) + + try: + train_shared_vllm(config) + print("\n[Run] ✓ Training completed successfully!") + except Exception as e: + print(f"\n[Run] ✗ Training failed: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/example_trainer/run_gsm8k_lora_matrix.sh b/example_trainer/run_gsm8k_lora_matrix.sh new file mode 100755 index 00000000..3f110c13 --- /dev/null +++ b/example_trainer/run_gsm8k_lora_matrix.sh @@ -0,0 +1,497 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Runs three GSM8K test trainings with separate infra/ports: +# 1) shared_vllm +# 2) lora_only (+ layer filtering support) +# 3) lora_restart (+ layer filtering support) +# +# Usage: +# chmod +x example_trainer/run_gsm8k_lora_matrix.sh +# ./example_trainer/run_gsm8k_lora_matrix.sh +# +# Optional environment overrides: +# MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B" +# TRAINING_STEPS=10 +# LORA_LAYER_INDICES="20-31" +# WANDB_PROJECT="gsm8k-grpo-smoke" +# WANDB_GROUP="gsm8k-$(date +%Y%m%d-%H%M%S)" +# START_API_PORT=8002 +# START_VLLM_PORT=9001 +# PYTHON_BIN=python3 +# OUTPUT_BASE_DIR="$PWD" # logs/saves base (defaults to launch directory) +# SHARED_GPU=0 +# LORA_ONLY_TRAINER_GPU=1 +# LORA_ONLY_VLLM_GPU=2 +# LORA_RESTART_TRAINER_GPU=3 +# LORA_RESTART_VLLM_GPU=4 +# DRY_RUN=1 # print commands only, do not execute +# PARALLEL=1 # run all three modes concurrently +# MODE=all # one of: all, shared_vllm, lora_only, lora_restart + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SCRIPT_PATH="${SCRIPT_DIR}/$(basename "${BASH_SOURCE[0]}")" +LAUNCH_DIR="$PWD" +cd "$ROOT_DIR" + +PYTHON_BIN="${PYTHON_BIN:-python3}" +MODEL_NAME="${MODEL_NAME:-NousResearch/Hermes-3-Llama-3.1-8B}" +TRAINING_STEPS="${TRAINING_STEPS:-10}" +BATCH_SIZE="${BATCH_SIZE:-4}" +GRAD_ACCUM="${GRAD_ACCUM:-4}" +LR="${LR:-1e-5}" +KL_COEF="${KL_COEF:-0.1}" +CLIP_EPS="${CLIP_EPS:-0.2}" +GPU_MEMORY_UTILIZATION="${GPU_MEMORY_UTILIZATION:-0.45}" +MAX_MODEL_LEN="${MAX_MODEL_LEN:-4096}" +DTYPE="${DTYPE:-bfloat16}" +LORA_R="${LORA_R:-16}" +LORA_ALPHA="${LORA_ALPHA:-32}" +LORA_DROPOUT="${LORA_DROPOUT:-0.05}" +LORA_TARGET_MODULES="${LORA_TARGET_MODULES:-q_proj v_proj}" +LORA_LAYER_INDICES="${LORA_LAYER_INDICES:-}" +WANDB_PROJECT="${WANDB_PROJECT:-gsm8k-grpo-smoke}" +WANDB_GROUP="${WANDB_GROUP:-gsm8k-$(date +%Y%m%d-%H%M%S)}" +START_API_PORT="${START_API_PORT:-8002}" +START_VLLM_PORT="${START_VLLM_PORT:-9001}" +OUTPUT_BASE_DIR="${OUTPUT_BASE_DIR:-$LAUNCH_DIR}" + +# GPU pinning (one process per GPU preference) +SHARED_GPU="${SHARED_GPU:-0}" +LORA_ONLY_TRAINER_GPU="${LORA_ONLY_TRAINER_GPU:-1}" +LORA_ONLY_VLLM_GPU="${LORA_ONLY_VLLM_GPU:-2}" +LORA_RESTART_TRAINER_GPU="${LORA_RESTART_TRAINER_GPU:-3}" +LORA_RESTART_VLLM_GPU="${LORA_RESTART_VLLM_GPU:-4}" +DRY_RUN="${DRY_RUN:-0}" +ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}" +ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-16}" +ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-8}" +ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}" +PARALLEL="${PARALLEL:-0}" +MODE="${MODE:-all}" + +SHARED_API_PORT="$START_API_PORT" +SHARED_VLLM_PORT="$START_VLLM_PORT" +LORA_ONLY_API_PORT="$((START_API_PORT + 1))" +LORA_ONLY_VLLM_PORT="$((START_VLLM_PORT + 1))" +LORA_RESTART_API_PORT="$((START_API_PORT + 2))" +LORA_RESTART_VLLM_PORT="$((START_VLLM_PORT + 2))" + +run_pids=() +run_ports=() + +log() { + echo "[$(date '+%H:%M:%S')] $*" +} + +kill_port() { + local port="$1" + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] skip port cleanup for :${port}" + return 0 + fi + if lsof -i ":${port}" -sTCP:LISTEN >/dev/null 2>&1; then + lsof -ti ":${port}" | xargs -r kill -9 || true + fi +} + +wait_for_http() { + local url="$1" + local timeout="${2:-180}" + local name="${3:-endpoint}" + local start + start="$(date +%s)" + while true; do + if curl -fsS "$url" >/dev/null 2>&1; then + log "Ready: ${name} (${url})" + return 0 + fi + if (( "$(date +%s)" - start > timeout )); then + log "Timeout waiting for ${name}: ${url}" + return 1 + fi + sleep 2 + done +} + +start_process() { + local name="$1" + local logfile="$2" + shift 2 + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] start ${name} (log: ${logfile})" + printf ' ' + printf '%q ' "$@" + printf '\n' + return 0 + fi + log "Starting ${name} (log: ${logfile})" + "$@" >"$logfile" 2>&1 & + local pid=$! + run_pids+=("$pid") + log "${name} PID=${pid}" +} + +cleanup_run() { + log "Cleaning up run processes..." + if (( ${#run_pids[@]} > 0 )); then + for pid in "${run_pids[@]}"; do + kill "$pid" >/dev/null 2>&1 || true + done + sleep 1 + for pid in "${run_pids[@]}"; do + kill -9 "$pid" >/dev/null 2>&1 || true + done + fi + if (( ${#run_ports[@]} > 0 )); then + for port in "${run_ports[@]}"; do + kill_port "$port" + done + fi + run_pids=() + run_ports=() +} + +add_lora_layer_flag() { + if [[ -n "$LORA_LAYER_INDICES" ]]; then + echo "--lora-layer-indices" "$LORA_LAYER_INDICES" + fi +} + +common_trainer_flags() { + echo \ + --model-name "$MODEL_NAME" \ + --training-steps "$TRAINING_STEPS" \ + --batch-size "$BATCH_SIZE" \ + --gradient-accumulation-steps "$GRAD_ACCUM" \ + --lr "$LR" \ + --kl-coef "$KL_COEF" \ + --clip-eps "$CLIP_EPS" \ + --use-wandb \ + --wandb-project "$WANDB_PROJECT" \ + --wandb-group "$WANDB_GROUP" +} + +start_gsm8k_env() { + local api_port="$1" + local vllm_port="$2" + local env_wandb_name="$3" + local logfile="$4" + start_process "gsm8k_env" "$logfile" \ + "$PYTHON_BIN" environments/gsm8k_server.py serve \ + --env.group_size 4 \ + --env.batch_size "$ENV_BATCH_SIZE" \ + --env.total_steps "$ENV_TOTAL_STEPS" \ + --env.steps_per_eval "$ENV_STEPS_PER_EVAL" \ + --env.max_num_workers_per_node "$ENV_MAX_WORKERS_PER_NODE" \ + --env.max_token_length "$MAX_MODEL_LEN" \ + --env.rollout_server_url "http://localhost:${api_port}" \ + --env.use_wandb true \ + --env.wandb_name "$env_wandb_name" \ + --openai.api_key "dummy" \ + --openai.base_url "http://localhost:${vllm_port}/v1" \ + --openai.model_name "$MODEL_NAME" \ + --openai.server_type vllm +} + +start_gsm8k_env_shared() { + local vllm_port="$1" + local logfile="$2" + local api_port="$SHARED_API_PORT" + start_gsm8k_env "$api_port" "$vllm_port" "gsm8k-shared-vllm-env" "$logfile" +} + +run_shared_vllm() { + log "========== RUN: shared_vllm ==========" + local api_port="$SHARED_API_PORT" + local vllm_port="$SHARED_VLLM_PORT" + local mode_dir="${OUTPUT_BASE_DIR}/logs/gsm8k_shared_vllm" + local save_dir="${OUTPUT_BASE_DIR}/saves/gsm8k_shared_vllm" + local bridge_dir="${mode_dir}/bridge" + mkdir -p "$mode_dir" + mkdir -p "$save_dir" + mkdir -p "$bridge_dir" + + run_ports+=("$api_port" "$vllm_port") + kill_port "$api_port" + kill_port "$vllm_port" + + start_process "run_api" "$mode_dir/run_api.log" run-api --port "$api_port" + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] wait for http://localhost:${api_port}/info" + else + wait_for_http "http://localhost:${api_port}/info" 60 "run-api" + fi + + start_process "vllm_shared" "$mode_dir/vllm.log" \ + env CUDA_VISIBLE_DEVICES="$SHARED_GPU" VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR="$bridge_dir" \ + "$PYTHON_BIN" -m example_trainer.vllm_api_server \ + --model "$MODEL_NAME" \ + --port "$vllm_port" \ + --gpu-memory-utilization "$GPU_MEMORY_UTILIZATION" \ + --max-model-len "$MAX_MODEL_LEN" \ + --dtype "$DTYPE" \ + --enforce-eager + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] wait for http://localhost:${vllm_port}/health" + else + wait_for_http "http://localhost:${vllm_port}/health" 300 "shared vLLM" + fi + + start_gsm8k_env_shared "$vllm_port" "$mode_dir/env.log" + + log "Starting trainer: shared_vllm" + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] trainer command (shared_vllm):" + printf ' ' + printf '%q ' env CUDA_VISIBLE_DEVICES="$SHARED_GPU" "$PYTHON_BIN" -m example_trainer.grpo \ + $(common_trainer_flags) \ + --weight-bridge-mode shared_vllm \ + --device cuda:0 \ + --save-path "$save_dir" \ + --vllm-port "$vllm_port" \ + --vllm-config-path "${bridge_dir}/vllm_bridge_config.json" \ + --atropos-url "http://localhost:${api_port}" + printf '\n' + log "[DRY RUN] trainer log path: $mode_dir/trainer.log" + else + env CUDA_VISIBLE_DEVICES="$SHARED_GPU" "$PYTHON_BIN" -m example_trainer.grpo \ + $(common_trainer_flags) \ + --weight-bridge-mode shared_vllm \ + --device cuda:0 \ + --save-path "$save_dir" \ + --vllm-port "$vllm_port" \ + --vllm-config-path "${bridge_dir}/vllm_bridge_config.json" \ + --atropos-url "http://localhost:${api_port}" | tee "$mode_dir/trainer.log" + fi + + cleanup_run +} + +run_lora_only() { + log "========== RUN: lora_only ==========" + local api_port="$LORA_ONLY_API_PORT" + local vllm_port="$LORA_ONLY_VLLM_PORT" + local mode_dir="${OUTPUT_BASE_DIR}/logs/gsm8k_lora_only" + local save_dir="${OUTPUT_BASE_DIR}/saves/gsm8k_lora_only" + mkdir -p "$mode_dir" + mkdir -p "$save_dir" + + run_ports+=("$api_port" "$vllm_port") + kill_port "$api_port" + kill_port "$vllm_port" + + start_process "run_api" "$mode_dir/run_api.log" run-api --port "$api_port" + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] wait for http://localhost:${api_port}/info" + else + wait_for_http "http://localhost:${api_port}/info" 60 "run-api" + fi + + start_process "vllm_lora_only" "$mode_dir/vllm.log" \ + env CUDA_VISIBLE_DEVICES="$LORA_ONLY_VLLM_GPU" \ + "$PYTHON_BIN" -m example_trainer.vllm_api_server \ + --model "$MODEL_NAME" \ + --port "$vllm_port" \ + --gpu-memory-utilization "$GPU_MEMORY_UTILIZATION" \ + --max-model-len "$MAX_MODEL_LEN" \ + --dtype "$DTYPE" \ + --enable-lora \ + --enforce-eager + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] wait for http://localhost:${vllm_port}/health" + else + wait_for_http "http://localhost:${vllm_port}/health" 300 "lora_only vLLM" + fi + + start_gsm8k_env "$api_port" "$vllm_port" "gsm8k-lora-only-env" "$mode_dir/env.log" + + log "Starting trainer: lora_only" + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] trainer command (lora_only):" + printf ' ' + printf '%q ' env CUDA_VISIBLE_DEVICES="$LORA_ONLY_TRAINER_GPU" "$PYTHON_BIN" -m example_trainer.grpo \ + $(common_trainer_flags) \ + --weight-bridge-mode lora_only \ + --device cuda:0 \ + --save-path "$save_dir" \ + --vllm-port "$vllm_port" \ + --atropos-url "http://localhost:${api_port}" \ + --lora-r "$LORA_R" \ + --lora-alpha "$LORA_ALPHA" \ + --lora-dropout "$LORA_DROPOUT" \ + --lora-target-modules $LORA_TARGET_MODULES \ + $(add_lora_layer_flag) + printf '\n' + log "[DRY RUN] trainer log path: $mode_dir/trainer.log" + else + env CUDA_VISIBLE_DEVICES="$LORA_ONLY_TRAINER_GPU" "$PYTHON_BIN" -m example_trainer.grpo \ + $(common_trainer_flags) \ + --weight-bridge-mode lora_only \ + --device cuda:0 \ + --save-path "$save_dir" \ + --vllm-port "$vllm_port" \ + --atropos-url "http://localhost:${api_port}" \ + --lora-r "$LORA_R" \ + --lora-alpha "$LORA_ALPHA" \ + --lora-dropout "$LORA_DROPOUT" \ + --lora-target-modules $LORA_TARGET_MODULES \ + $(add_lora_layer_flag) | tee "$mode_dir/trainer.log" + fi + + cleanup_run +} + +run_lora_restart() { + log "========== RUN: lora_restart ==========" + local api_port="$LORA_RESTART_API_PORT" + local vllm_port="$LORA_RESTART_VLLM_PORT" + local mode_dir="${OUTPUT_BASE_DIR}/logs/gsm8k_lora_restart" + local save_dir="${OUTPUT_BASE_DIR}/saves/gsm8k_lora_restart" + mkdir -p "$mode_dir" + mkdir -p "$save_dir" + + run_ports+=("$api_port" "$vllm_port") + kill_port "$api_port" + kill_port "$vllm_port" + + start_process "run_api" "$mode_dir/run_api.log" run-api --port "$api_port" + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] wait for http://localhost:${api_port}/info" + else + wait_for_http "http://localhost:${api_port}/info" 60 "run-api" + fi + + log "Starting trainer: lora_restart (it launches its own vLLM)" + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] trainer command (lora_restart):" + printf ' ' + printf '%q ' env CUDA_VISIBLE_DEVICES="$LORA_RESTART_TRAINER_GPU" "$PYTHON_BIN" -m example_trainer.grpo \ + $(common_trainer_flags) \ + --weight-bridge-mode lora_restart \ + --device cuda:0 \ + --save-path "$save_dir" \ + --vllm-port "$vllm_port" \ + --vllm-gpu "$LORA_RESTART_VLLM_GPU" \ + --vllm-restart-interval 3 \ + --atropos-url "http://localhost:${api_port}" \ + --lora-r "$LORA_R" \ + --lora-alpha "$LORA_ALPHA" \ + --lora-dropout "$LORA_DROPOUT" \ + --lora-target-modules $LORA_TARGET_MODULES \ + $(add_lora_layer_flag) + printf '\n' + log "[DRY RUN] then wait for http://localhost:${vllm_port}/health" + log "[DRY RUN] then start GSM8K env pointed at http://localhost:${vllm_port}/v1 and rollout server http://localhost:${api_port}" + log "[DRY RUN] trainer log path: $mode_dir/trainer.log" + else + env CUDA_VISIBLE_DEVICES="$LORA_RESTART_TRAINER_GPU" "$PYTHON_BIN" -m example_trainer.grpo \ + $(common_trainer_flags) \ + --weight-bridge-mode lora_restart \ + --device cuda:0 \ + --save-path "$save_dir" \ + --vllm-port "$vllm_port" \ + --vllm-gpu "$LORA_RESTART_VLLM_GPU" \ + --vllm-restart-interval 3 \ + --atropos-url "http://localhost:${api_port}" \ + --lora-r "$LORA_R" \ + --lora-alpha "$LORA_ALPHA" \ + --lora-dropout "$LORA_DROPOUT" \ + --lora-target-modules $LORA_TARGET_MODULES \ + $(add_lora_layer_flag) >"$mode_dir/trainer.log" 2>&1 & + trainer_pid=$! + run_pids+=("$trainer_pid") + + wait_for_http "http://localhost:${vllm_port}/health" 420 "lora_restart vLLM" + start_gsm8k_env "$api_port" "$vllm_port" "gsm8k-lora-restart-env" "$mode_dir/env.log" + + wait "$trainer_pid" + cat "$mode_dir/trainer.log" + fi + + cleanup_run +} + +trap cleanup_run EXIT INT TERM + +log "Model: $MODEL_NAME" +log "W&B project/group: $WANDB_PROJECT / $WANDB_GROUP" +log "Dry run mode: $DRY_RUN" +log "Output base directory (logs + saves): $OUTPUT_BASE_DIR" +log "Port plan:" +log " shared_vllm: run-api=${SHARED_API_PORT}, vllm=${SHARED_VLLM_PORT}" +log " lora_only: run-api=${LORA_ONLY_API_PORT}, vllm=${LORA_ONLY_VLLM_PORT}" +log " lora_restart: run-api=${LORA_RESTART_API_PORT}, vllm=${LORA_RESTART_VLLM_PORT}" +log "GPU plan:" +log " shared_vllm: trainer+vllm on GPU ${SHARED_GPU} (required for shared weights)" +log " lora_only: trainer GPU ${LORA_ONLY_TRAINER_GPU}, vllm GPU ${LORA_ONLY_VLLM_GPU}" +log " lora_restart: trainer GPU ${LORA_RESTART_TRAINER_GPU}, vllm GPU ${LORA_RESTART_VLLM_GPU}" +if [[ -n "$LORA_LAYER_INDICES" ]]; then + log "LoRA layer indices: $LORA_LAYER_INDICES" +else + log "LoRA layer indices: all matching layers" +fi +log "Mode selector: $MODE" +log "Parallel mode: $PARALLEL" + +if [[ "$MODE" != "all" ]]; then + case "$MODE" in + shared_vllm) run_shared_vllm ;; + lora_only) run_lora_only ;; + lora_restart) run_lora_restart ;; + *) + log "Invalid MODE='$MODE' (expected: all|shared_vllm|lora_only|lora_restart)" + exit 2 + ;; + esac + log "Mode '$MODE' completed." + exit 0 +fi + +if [[ "$PARALLEL" == "1" ]]; then + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] parallel launcher commands:" + for m in shared_vllm lora_only lora_restart; do + local_log="${OUTPUT_BASE_DIR}/logs/gsm8k_${m}/orchestrator.log" + printf ' ' + printf '%q ' env MODE="$m" PARALLEL=0 "$SCRIPT_PATH" + printf '> %q 2>&1 &\n' "$local_log" + done + log "[DRY RUN] parent waits for all child mode runners." + else + log "Launching all modes in parallel..." + parallel_pids=() + parallel_modes=(shared_vllm lora_only lora_restart) + for m in "${parallel_modes[@]}"; do + mode_log_dir="${OUTPUT_BASE_DIR}/logs/gsm8k_${m}" + mkdir -p "$mode_log_dir" + mode_orch_log="${mode_log_dir}/orchestrator.log" + log "Starting mode runner: ${m} (log: ${mode_orch_log})" + env MODE="$m" PARALLEL=0 "$SCRIPT_PATH" >"$mode_orch_log" 2>&1 & + parallel_pids+=("$!") + done + + fail_count=0 + for i in "${!parallel_pids[@]}"; do + pid="${parallel_pids[$i]}" + mode="${parallel_modes[$i]}" + if wait "$pid"; then + log "Mode '${mode}' finished successfully." + else + log "Mode '${mode}' failed. See ${OUTPUT_BASE_DIR}/logs/gsm8k_${mode}/orchestrator.log" + fail_count=$((fail_count + 1)) + fi + done + if (( fail_count > 0 )); then + log "Parallel run finished with ${fail_count} failed mode(s)." + exit 1 + fi + fi +else + run_shared_vllm + run_lora_only + run_lora_restart +fi + +log "All GSM8K mode runs completed." diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py new file mode 100644 index 00000000..ce80adc9 --- /dev/null +++ b/example_trainer/trainers.py @@ -0,0 +1,989 @@ +""" +Training mode implementations for GRPO trainer. + +Contains the four main training modes: +- train_legacy: Checkpoint-based training with vLLM restarts +- train_shared_vllm: Single-copy mode with CUDA IPC +- train_lora: LoRA adapter training with HTTP hot-swap +- train_lora_restart: LoRA training with vLLM restarts (FAST mode) +""" + +import os +import subprocess +import sys +import time +from typing import Iterable, Optional + +import requests +import torch +from torch.optim import AdamW + +from .api import check_atropos_api, register_trainer + + +def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: + """ + Create optimizer based on config.optimizer setting. + + Options: + - 'adamw': Standard AdamW (full precision, ~32GB GPU for 8B model) + - 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes) + - 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies) + """ + return create_optimizer_for_params(model.parameters(), config) + + +def create_optimizer_for_params( + params: Iterable[torch.nn.Parameter], config +) -> torch.optim.Optimizer: + """Create optimizer for a specific parameter iterable.""" + params = list(params) + + if config.optimizer == "adamw_8bit": + try: + import bitsandbytes as bnb + + optimizer = bnb.optim.AdamW8bit(params, lr=config.lr) + print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)") + return optimizer + except ImportError: + print("[Setup] WARNING: bitsandbytes not installed, falling back to AdamW") + print("[Setup] Install with: pip install bitsandbytes") + + if config.optimizer == "adafactor": + try: + from transformers.optimization import Adafactor + + optimizer = Adafactor( + params, + lr=config.lr, + scale_parameter=False, + relative_step=False, + ) + print("[Setup] Using Adafactor (no momentum, saves ~24GB)") + return optimizer + except ImportError: + print("[Setup] WARNING: transformers Adafactor not available, using AdamW") + + # Default: standard AdamW + optimizer = AdamW(params, lr=config.lr) + print("[Setup] Using standard AdamW (requires ~32GB for optimizer states)") + return optimizer + + +from .checkpointing import save_checkpoint, save_lora_checkpoint # noqa: E402 +from .config import TrainingConfig # noqa: E402 +from .data import get_data # noqa: E402 +from .model import PEFT_AVAILABLE, load_model_and_tokenizer # noqa: E402 +from .training import ( # noqa: E402 + finalize_training, + log_metrics, + run_training_step, + setup_wandb, +) +from .vllm_manager import ( # noqa: E402 + check_vllm_health, + check_vllm_process_health, + launch_vllm_server, + set_vllm_process, + terminate_vllm_process, +) + + +def train_legacy(config: TrainingConfig): + """ + Legacy GRPO training with periodic vLLM restarts. + + This mode: + 1. Trains model on trainer GPU + 2. Saves checkpoints periodically + 3. Restarts vLLM to load new weights + + Use for: + - Simple setup + - When trainer and vLLM on different GPUs + """ + training_start_time = time.time() + + # === Setup === + use_wandb = setup_wandb(config) + model, tokenizer = load_model_and_tokenizer(config) + optimizer = create_optimizer(model, config) + + print("\n" + "=" * 60) + print("LEGACY MODE (checkpoint + vLLM restart)") + print("=" * 60) + print(f"Training for {config.training_steps} steps on {config.device}") + print(f"vLLM restart interval: every {config.vllm_restart_interval} steps") + print(f"Save path: {config.save_path}") + print("=" * 60 + "\n") + + os.makedirs(config.save_path, exist_ok=True) + + # Check Atropos API + if not check_atropos_api(url=config.atropos_url, timeout=30): + raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}") + register_trainer(config) + + # Launch initial vLLM server + vllm_proc = launch_vllm_server(config, config.model_name) + set_vllm_process(vllm_proc) + + # === Benchmark tracking === + benchmark_stats = { + "step_times": [], + "sync_times": [], + "data_fetch_times": [], + "gpu_memories": [], + } + + # === Training Loop === + batches = [] + for step in range(config.training_steps): + print(f"\nStep {step+1}/{config.training_steps}") + + # Fetch data (with inference logprobs for proper GRPO) + data_fetch_start = time.time() + if len(batches) == 0: + batches, _ = get_data( + config.batch_size, + config.seq_len, + config.atropos_url, + extract_inference_logprobs=True, + ) + batch_data = batches.pop(0) + token_batches, label_batches, advantage_batches, temperature_batches = ( + batch_data[:4] + ) + inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None + data_fetch_time = time.time() - data_fetch_start + benchmark_stats["data_fetch_times"].append(data_fetch_time) + + # Check if we should sync (save checkpoint + restart vLLM) + should_sync = ( + step + 1 + ) % config.vllm_restart_interval == 0 or step == config.training_steps - 1 + if should_sync: + terminate_vllm_process() + + # Training step (with proper GRPO using inference logprobs) + step_start = time.time() + metrics = run_training_step( + model, + optimizer, + token_batches, + label_batches, + advantage_batches, + temperature_batches, + config, + inference_logprob_batches=inference_logprob_batches, + ) + step_time = time.time() - step_start + benchmark_stats["step_times"].append(step_time) + + # GPU memory tracking + gpu_mem_gb = ( + torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + gpu_mem_reserved_gb = ( + torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + ) + benchmark_stats["gpu_memories"].append(gpu_mem_gb) + + # Sync (checkpoint + restart) + sync_time = 0 + if should_sync: + sync_start = time.time() + checkpoint_path = save_checkpoint( + model, tokenizer, config.save_path, step + 1 + ) + torch.cuda.empty_cache() + vllm_proc = launch_vllm_server(config, checkpoint_path) + set_vllm_process(vllm_proc) + sync_time = time.time() - sync_start + benchmark_stats["sync_times"].append(sync_time) + + # Update metrics + metrics.update( + { + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + } + ) + + log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) + check_vllm_process_health() + + # === Cleanup === + save_checkpoint( + model, tokenizer, config.save_path, config.training_steps, is_final=True + ) + finalize_training( + use_wandb, + training_start_time, + "legacy", + config.training_steps, + benchmark_stats, + config.benchmark, + ) + + +def train_shared_vllm(config: TrainingConfig): + """ + GRPO training with shared vLLM weights (single-copy mode). + + This mode: + 1. Attaches to vLLM's weight tensors via CUDA IPC + 2. optimizer.step() modifies vLLM's weights in-place + 3. vLLM immediately uses updated weights (no restart!) + + Requirements: + - vLLM running with VLLM_ENABLE_SHARED_WEIGHTS=1 + - Trainer on same GPU(s) as vLLM + """ + training_start_time = time.time() + + # === Setup === + use_wandb = setup_wandb(config) + + print("\n" + "=" * 60) + print("SINGLE-COPY MODE (CUDA IPC)") + print(">>> Trainer uses vLLM's tensors directly!") + print("=" * 60) + print(f"Model: {config.model_name}") + print(f"Save path: {config.save_path}") + print("=" * 60 + "\n") + + # Attach to vLLM's shared tensors + print("[1/2] Attaching to vLLM's shared tensors...") + model, tokenizer = load_model_and_tokenizer(config, single_copy=True) + + if model is None: + raise RuntimeError( + "Single-copy mode failed. Make sure:\n" + "1. vLLM is running with VLLM_ENABLE_SHARED_WEIGHTS=1\n" + "2. Trainer is on the SAME GPUs as vLLM\n" + "3. vllm_bridge_config.json exists with IPC handles" + ) + + optimizer = create_optimizer(model, config) + + # === Real-time weight sharing verification === + print("\n[Weight Sharing Verification]") + + os.makedirs(config.save_path, exist_ok=True) + + # Check Atropos API + print(f"\n[Setup] Connecting to Atropos API at {config.atropos_url}...") + if not check_atropos_api(url=config.atropos_url, timeout=30): + raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}") + register_trainer(config) + + # === Benchmark tracking === + benchmark_stats = { + "step_times": [], + "sync_times": [], + "data_fetch_times": [], + "gpu_memories": [], + } + + # === Training Loop === + batches = [] + for step in range(config.training_steps): + print(f"\nStep {step+1}/{config.training_steps}") + + # Fetch data (with inference logprobs for proper GRPO loss) + data_fetch_start = time.time() + if len(batches) == 0: + batches, _ = get_data( + config.batch_size, + config.seq_len, + config.atropos_url, + extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs + ) + batch_data = batches.pop(0) + token_batches, label_batches, advantage_batches, temperature_batches = ( + batch_data[:4] + ) + inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None + data_fetch_time = time.time() - data_fetch_start + benchmark_stats["data_fetch_times"].append(data_fetch_time) + + # Training step with proper GRPO (importance sampling + KL penalty) + step_start = time.time() + metrics = run_training_step( + model, + optimizer, + token_batches, + label_batches, + advantage_batches, + temperature_batches, + config, + inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation + ) + step_time = time.time() - step_start + benchmark_stats["step_times"].append(step_time) + + # GPU memory tracking + gpu_mem_gb = ( + torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + gpu_mem_reserved_gb = ( + torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + ) + benchmark_stats["gpu_memories"].append(gpu_mem_gb) + + # In single-copy mode, weights are updated in-place (no sync needed!) + sync_time = 0.0 + print(f" [SINGLE-COPY] Weights updated in-place - step {step+1}") + benchmark_stats["sync_times"].append(sync_time) + + # Update metrics + metrics.update( + { + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + } + ) + + log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) + + # Periodic checkpoint (for recovery, not for vLLM sync) + if ( + config.checkpoint_interval > 0 + and (step + 1) % config.checkpoint_interval == 0 + ): + save_checkpoint(model, tokenizer, config.save_path, step + 1) + + # === Cleanup === + save_checkpoint( + model, tokenizer, config.save_path, config.training_steps, is_final=True + ) + finalize_training( + use_wandb, + training_start_time, + "shared_vllm", + config.training_steps, + benchmark_stats, + config.benchmark, + ) + + +def train_lora(config: TrainingConfig): + """ + GRPO training with LoRA adapters. + + This mode: + 1. Freezes base model, trains only LoRA adapter weights + 2. Saves lightweight adapter checkpoints + 3. Hot-swaps adapters in vLLM via API + + Benefits: + - Much faster training (fewer parameters) + - Smaller checkpoints + - Adapters can be hot-swapped without restart + + Requirements: + - External vLLM server running with --enable-lora + """ + if not PEFT_AVAILABLE: + raise RuntimeError( + "PEFT library required for LoRA mode. Install with: pip install peft" + ) + + training_start_time = time.time() + + # === Setup === + use_wandb = setup_wandb(config) + + print("\n" + "=" * 60) + print("LORA MODE (adapter-only training)") + print("=" * 60) + print(f"Base model: {config.model_name}") + print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}") + print(f"Save path: {config.save_path}") + print(f"vLLM port: {config.vllm_port}") + print("=" * 60 + "\n") + + # Check external vLLM server + print("[1/3] Checking external vLLM server...") + if not check_vllm_health(config.vllm_port): + print(f"\nERROR: vLLM server not running on port {config.vllm_port}") + print("\nLoRA mode requires an external vLLM server. Start it first:") + print( + f" python example_trainer/vllm_api_server.py --model {config.model_name} " + f"--port {config.vllm_port} --enable-lora --enforce-eager" + ) + raise RuntimeError(f"External vLLM server required on port {config.vllm_port}") + print(f"vLLM server healthy on port {config.vllm_port}") + + # Load model with LoRA adapters + print("[2/3] Loading model with LoRA adapters...") + model, tokenizer = load_model_and_tokenizer(config) + + # Only optimize LoRA parameters + trainable_params = [p for p in model.parameters() if p.requires_grad] + optimizer = create_optimizer_for_params(trainable_params, config) + + print(f"[3/3] Starting training for {config.training_steps} steps") + print("-" * 60) + + os.makedirs(config.save_path, exist_ok=True) + + # Check Atropos API + if not check_atropos_api(url=config.atropos_url, timeout=30): + raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}") + register_trainer(config) + + # === Benchmark tracking === + benchmark_stats = { + "step_times": [], + "sync_times": [], + "data_fetch_times": [], + "gpu_memories": [], + } + + # === Training Loop === + batches = [] + for step in range(config.training_steps): + print(f"\nStep {step+1}/{config.training_steps}") + + # Fetch data (with inference logprobs for proper GRPO) + data_fetch_start = time.time() + if len(batches) == 0: + batches, _ = get_data( + config.batch_size, + config.seq_len, + config.atropos_url, + extract_inference_logprobs=True, + ) + batch_data = batches.pop(0) + token_batches, label_batches, advantage_batches, temperature_batches = ( + batch_data[:4] + ) + inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None + data_fetch_time = time.time() - data_fetch_start + benchmark_stats["data_fetch_times"].append(data_fetch_time) + + # Training step with proper GRPO + step_start = time.time() + metrics = run_training_step( + model, + optimizer, + token_batches, + label_batches, + advantage_batches, + temperature_batches, + config, + inference_logprob_batches=inference_logprob_batches, + ) + step_time = time.time() - step_start + benchmark_stats["step_times"].append(step_time) + + # GPU memory tracking + gpu_mem_gb = ( + torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + gpu_mem_reserved_gb = ( + torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + ) + benchmark_stats["gpu_memories"].append(gpu_mem_gb) + + # Periodic adapter save + hot-swap + sync_time = 0 + should_sync = (step + 1) % config.vllm_restart_interval == 0 + if should_sync: + sync_start = time.time() + adapter_path = save_lora_checkpoint(model, config.save_path, step + 1) + _hotswap_lora_adapter(config.vllm_port, adapter_path, f"step_{step + 1}") + sync_time = time.time() - sync_start + benchmark_stats["sync_times"].append(sync_time) + + # Update metrics + metrics.update( + { + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + } + ) + + log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) + + # === Cleanup === + final_sync_start = time.time() + final_adapter_path = save_lora_checkpoint( + model, config.save_path, config.training_steps, is_final=True + ) + _hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final") + final_sync_time = time.time() - final_sync_start + benchmark_stats["sync_times"].append(final_sync_time) + + finalize_training( + use_wandb, + training_start_time, + "lora_only", + config.training_steps, + benchmark_stats, + config.benchmark, + ) + + # Save tokenizer + tokenizer_path = os.path.join(config.save_path, "tokenizer") + tokenizer.save_pretrained(tokenizer_path) + print(f"Tokenizer saved to {tokenizer_path}") + + +def _hotswap_lora_adapter( + port: int, + adapter_path: str, + adapter_name: Optional[str] = None, +) -> bool: + """ + Request vLLM to hot-swap to a new LoRA adapter. + + Tries: + 1. Native vLLM endpoint: /v1/load_lora_adapter + 2. Custom endpoint: /lora/load + """ + base_url = f"http://localhost:{port}" + name = adapter_name or os.path.basename(adapter_path) + + # Try native vLLM endpoint first + try: + response = requests.post( + f"{base_url}/v1/load_lora_adapter", + json={"lora_name": name, "lora_path": adapter_path}, + timeout=30, + ) + if response.status_code == 200: + print(f" [LORA] ✓ Hot-swapped adapter: {name}") + return True + except Exception: + pass + + # Try custom endpoint + try: + response = requests.post( + f"{base_url}/lora/load", + json={"adapter_path": adapter_path, "adapter_name": name}, + timeout=30, + ) + if response.status_code == 200: + print(f" [LORA] ✓ Hot-swapped adapter via custom API: {name}") + return True + else: + print(f" [LORA] ✗ Hot-swap failed: {response.text}") + return False + except Exception as e: + print(f" [LORA] ✗ Hot-swap request failed: {e}") + return False + + +def train_lora_restart(config: TrainingConfig): + """ + GRPO training with LoRA adapters using vLLM restarts (FAST mode). + + This mode: + 1. Freezes base model, trains only LoRA adapter weights + 2. Runs vLLM WITHOUT --enforce-eager (keeps some CUDA optimizations) + 3. Restarts vLLM every N steps with the new adapter pre-loaded + + Performance comparison (Qwen3-4B @ 8k context): + - lora_only (--enforce-eager): ~13 TPS (SLOW - CUDA graphs disabled) + - lora_restart (no --enforce-eager): ~108 TPS (8x FASTER) + - base model (no LoRA): ~172 TPS (baseline) + + The restart overhead (~45s) is much less than the 8x inference slowdown. + + Requirements: + - No external vLLM needed - this mode manages vLLM internally + - Requires PEFT library for LoRA + """ + if not PEFT_AVAILABLE: + raise RuntimeError( + "PEFT library required for LoRA mode. Install with: pip install peft" + ) + + training_start_time = time.time() + + # === Setup === + use_wandb = setup_wandb(config) + + print("\n" + "=" * 60) + print("LORA RESTART MODE (fast inference with CUDA graphs)") + print("=" * 60) + print(f"Base model: {config.model_name}") + print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}") + print(f"Save path: {config.save_path}") + print(f"vLLM port: {config.vllm_port}") + print(f"Restart interval: every {config.vllm_restart_interval} steps") + print("=" * 60) + print("NOTE: This mode restarts vLLM without --enforce-eager for faster inference.") + print(" Expected: ~108 TPS (vs ~13 TPS with --enforce-eager = 8x speedup)") + print("=" * 60 + "\n") + + # Load model with LoRA adapters for training + print("[1/4] Loading model with LoRA adapters...") + model, tokenizer = load_model_and_tokenizer(config) + + # Only optimize LoRA parameters + trainable_params = [p for p in model.parameters() if p.requires_grad] + optimizer = create_optimizer_for_params(trainable_params, config) + + os.makedirs(config.save_path, exist_ok=True) + + # Save initial adapter + print("[2/4] Saving initial LoRA adapter...") + initial_adapter_path = save_lora_checkpoint(model, config.save_path, 0) + current_adapter_path = initial_adapter_path + + # Launch vLLM with the initial adapter + print("[3/4] Launching vLLM with CUDA graphs (no --enforce-eager)...") + vllm_proc = _launch_vllm_with_lora(config, current_adapter_path) + if vllm_proc is None: + raise RuntimeError("Failed to launch vLLM") + + print(f"[4/4] Starting training for {config.training_steps} steps") + print("-" * 60) + + # Check Atropos API + if not check_atropos_api(url=config.atropos_url, timeout=30): + _terminate_vllm(vllm_proc, config.vllm_port) + raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}") + register_trainer(config) + + # === Benchmark tracking === + benchmark_stats = { + "step_times": [], + "sync_times": [], + "data_fetch_times": [], + "gpu_memories": [], + "restart_times": [], + } + + # === Training Loop === + batches = [] + for step in range(config.training_steps): + print(f"\nStep {step+1}/{config.training_steps}") + + # Fetch data (with inference logprobs for proper GRPO) + data_fetch_start = time.time() + if len(batches) == 0: + batches, _ = get_data( + config.batch_size, + config.seq_len, + config.atropos_url, + extract_inference_logprobs=True, + ) + batch_data = batches.pop(0) + token_batches, label_batches, advantage_batches, temperature_batches = ( + batch_data[:4] + ) + inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None + data_fetch_time = time.time() - data_fetch_start + benchmark_stats["data_fetch_times"].append(data_fetch_time) + + # Training step with proper GRPO + step_start = time.time() + metrics = run_training_step( + model, + optimizer, + token_batches, + label_batches, + advantage_batches, + temperature_batches, + config, + inference_logprob_batches=inference_logprob_batches, + ) + step_time = time.time() - step_start + benchmark_stats["step_times"].append(step_time) + + # GPU memory tracking + gpu_mem_gb = ( + torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + gpu_mem_reserved_gb = ( + torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + ) + benchmark_stats["gpu_memories"].append(gpu_mem_gb) + + # Periodic adapter save + vLLM restart + sync_time = 0 + should_sync = (step + 1) % config.vllm_restart_interval == 0 + if should_sync and (step + 1) < config.training_steps: # Don't restart on last step + sync_start = time.time() + + # Save new adapter + current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1) + + # Restart vLLM with new adapter + print(" [RESTART] Restarting vLLM with new adapter...") + _terminate_vllm(vllm_proc, config.vllm_port) + vllm_proc = _launch_vllm_with_lora(config, current_adapter_path) + if vllm_proc is None: + raise RuntimeError("Failed to restart vLLM") + + sync_time = time.time() - sync_start + benchmark_stats["sync_times"].append(sync_time) + benchmark_stats["restart_times"].append(sync_time) + print(f" [RESTART] vLLM restarted in {sync_time:.1f}s") + + # Update metrics + metrics.update( + { + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + } + ) + + log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) + + # === Cleanup === + print("\nSaving final adapter...") + final_sync_start = time.time() + final_adapter_path = save_lora_checkpoint( + model, config.save_path, config.training_steps, is_final=True + ) + final_sync_time = time.time() - final_sync_start + benchmark_stats["sync_times"].append(final_sync_time) + + # Terminate vLLM + _terminate_vllm(vllm_proc, config.vllm_port) + + finalize_training( + use_wandb, + training_start_time, + "lora_restart", + config.training_steps, + benchmark_stats, + config.benchmark, + ) + + # Save tokenizer + tokenizer_path = os.path.join(config.save_path, "tokenizer") + tokenizer.save_pretrained(tokenizer_path) + print(f"Tokenizer saved to {tokenizer_path}") + print(f"Final adapter saved to {final_adapter_path}") + + +# Global counter for vLLM restarts (for unique log files) +_vllm_restart_counter = 0 + + +def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]: + """ + Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference). + + Unlike lora_only mode, this does NOT use --enforce-eager, so we get + ~108 TPS instead of ~13 TPS (8x faster). + """ + global _vllm_restart_counter + from .vllm_manager import kill_process_on_port, wait_for_vllm_ready + + # Kill any existing process on the port + print(f" Cleaning up port {config.vllm_port}...") + kill_process_on_port(config.vllm_port) + + # Clear CUDA cache before starting new vLLM + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Wait for port and GPU memory to be fully released + time.sleep(5) + + # Find the vllm_api_server.py script + script_dir = os.path.dirname(os.path.abspath(__file__)) + server_script = os.path.join(script_dir, "vllm_api_server.py") + + # Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS) + cmd = [ + sys.executable, server_script, + "--model", config.model_name, + "--port", str(config.vllm_port), + "--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization), + "--max-model-len", str(config.max_model_len), + "--enable-lora", + "--max-lora-rank", str(max(config.lora_r * 2, 32)), + # Note: NOT adding --enforce-eager - this gives us ~8x faster inference! + # Without --enforce-eager, vLLM can use more optimizations. + ] + + # Set environment for GPU selection + env = os.environ.copy() + if config.vllm_gpu is not None: + env["CUDA_VISIBLE_DEVICES"] = str(config.vllm_gpu) + print(f" GPU: {config.vllm_gpu} (via CUDA_VISIBLE_DEVICES)") + else: + print(" GPU: Same as trainer (inherited CUDA_VISIBLE_DEVICES)") + + print(f" Launching: {' '.join(cmd)}") + print(f" Adapter: {adapter_path}") + + # Log vLLM output to file for debugging (unique file per restart) + vllm_log_path = os.path.join(config.save_path, f"vllm_restart_{_vllm_restart_counter}.log") + _vllm_restart_counter += 1 + print(f" vLLM log: {vllm_log_path}") + + try: + vllm_log_file = open(vllm_log_path, "w") + # Start in new session so we can kill entire process group later + proc = subprocess.Popen( + cmd, env=env, stdout=vllm_log_file, stderr=subprocess.STDOUT, + start_new_session=True # Creates new process group for easy cleanup + ) + print(f" vLLM PID: {proc.pid} (process group: {os.getpgid(proc.pid)})") + print(" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)...") + + # Wait for server to be ready (longer timeout for CUDA graph compilation) + if not wait_for_vllm_ready(config.vllm_port, timeout=300): + print(" ERROR: vLLM failed to start after 300s") + print(f" Check log: {vllm_log_path}") + # Print last 30 lines of the log + try: + with open(vllm_log_path, 'r') as f: + lines = f.readlines() + print(" Last 30 lines of vLLM log:") + for line in lines[-30:]: + print(f" {line.rstrip()}") + except Exception as e: + print(f" Could not read log: {e}") + proc.terminate() + return None + + # Load the LoRA adapter + print(" Loading LoRA adapter...") + try: + resp = requests.post( + f"http://localhost:{config.vllm_port}/lora/load", + json={"adapter_path": adapter_path, "adapter_name": "training_adapter"}, + timeout=60, + ) + if resp.status_code == 200: + print(" ✓ Adapter loaded successfully") + else: + print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}") + except Exception as e: + print(f" WARNING: Could not load adapter: {e}") + # Continue anyway - base model inference still works + + return proc + + except Exception as e: + print(f" ERROR: {e}") + return None + + +def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: + """Terminate a vLLM process and release GPU resources.""" + import signal + import subprocess as sp + + print(f" Terminating vLLM on port {port}...") + + # Get current GPU device + gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0] + + # Phase 1: Kill the process group if we have a handle (kills all children too) + main_pid = None + if proc is not None: + main_pid = proc.pid + print(f" Killing process group (PID: {main_pid})...") + try: + # Kill entire process group - this gets all child processes + os.killpg(os.getpgid(main_pid), signal.SIGKILL) + except (ProcessLookupError, PermissionError): + pass + try: + proc.kill() + proc.wait(timeout=5) + except Exception as e: + print(f" Warning: {e}") + + # Phase 2: Kill by port (catches anything still running) + from .vllm_manager import kill_process_on_port + kill_process_on_port(port) + time.sleep(2) + + # Phase 3: Aggressively kill ALL vLLM-related processes + print(" Killing all vLLM-related processes...") + kill_commands = [ + f"fuser -k {port}/tcp", + "pkill -9 -f 'vllm.*EngineCore'", + "pkill -9 -f 'vllm_api_server'", + "pkill -9 -f 'from vllm'", + "pkill -9 -f 'multiprocessing.spawn'", + "pkill -9 -f 'ray::IDLE'", # Ray workers if any + ] + for cmd in kill_commands: + try: + sp.run(cmd, shell=True, capture_output=True, timeout=5) + except Exception: + pass + + # Phase 4: Use nvidia-smi to find and kill GPU processes (nuclear option) + print(f" Checking for zombie GPU processes on GPU {gpu_id}...") + try: + result = sp.run( + f"nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader,nounits -i {gpu_id}", + shell=True, capture_output=True, text=True, timeout=10 + ) + if result.stdout.strip(): + print(f" Found GPU processes:\n{result.stdout}") + for line in result.stdout.strip().split('\n'): + if line.strip(): + parts = line.split(',') + if len(parts) >= 1: + pid = parts[0].strip() + # Don't kill the current Python process (trainer) + if pid and pid != str(os.getpid()) and pid != str(main_pid): + print(f" Killing zombie GPU process: {pid}") + try: + sp.run(f"kill -9 {pid}", shell=True, timeout=5) + except Exception: + pass + except Exception as e: + print(f" Warning: nvidia-smi check failed: {e}") + + # Phase 5: Wait for GPU memory release - CRITICAL + # The CUDA driver needs time to actually free memory after process death + print(" Waiting for GPU memory release...") + for i in range(12): # 60 seconds total (longer wait) + time.sleep(5) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + free_mem = torch.cuda.mem_get_info()[0] / 1e9 + total_mem = torch.cuda.mem_get_info()[1] / 1e9 + print(f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)") + # If we have enough memory (>50% free), break early + if free_mem > total_mem * 0.5: + print(f" ✓ Sufficient memory available ({free_mem:.1f} GB)") + break + + # Final cleanup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + free_mem = torch.cuda.mem_get_info()[0] / 1e9 + total_mem = torch.cuda.mem_get_info()[1] / 1e9 + print(f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)") + + if free_mem < total_mem * 0.3: + print(" WARNING: Low GPU memory! May fail to restart vLLM.") + print(" Consider reducing --vllm-gpu-memory-utilization") + + print(" vLLM terminated") + + diff --git a/example_trainer/training.py b/example_trainer/training.py new file mode 100644 index 00000000..f44cfe84 --- /dev/null +++ b/example_trainer/training.py @@ -0,0 +1,626 @@ +""" +Training utilities for GRPO trainer. + +Contains loss computation, training step logic, and metric logging. + +Includes logprob alignment tracking to verify that training logprobs match +inference logprobs at initialization (validates shared_vllm mode is working). +""" + +import random +import string +import time +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import wandb + +from .config import TrainingConfig + +# Global storage for logprob alignment stats +_logprob_alignment_stats: Dict[str, float] = {} + + +def setup_wandb(config: TrainingConfig) -> bool: + """ + Initialize Weights & Biases logging if enabled. + + Args: + config: Training configuration + + Returns: + True if wandb is active, False otherwise + """ + if not config.use_wandb: + return False + + if not config.wandb_project: + print("Warning: wandb_project not set, disabling wandb.") + return False + + # Generate random group name if not provided + if not config.wandb_group: + config.wandb_group = "".join( + random.choices(string.ascii_letters + string.digits, k=8) + ) + + try: + wandb.init( + project=config.wandb_project, + group=config.wandb_group, + config=config.dict(), + ) + print( + f"Wandb logging enabled. Run: {wandb.run.name} " + f"(Project: {config.wandb_project})" + ) + return True + except Exception as e: + print(f"Error initializing wandb: {e}. Disabling wandb.") + return False + + +def compute_grpo_loss( + model: torch.nn.Module, + tokens: torch.Tensor, + labels: torch.Tensor, + advantages: torch.Tensor, + temperatures: torch.Tensor, + gradient_accumulation_steps: int, + inference_logprobs: Optional[torch.Tensor] = None, + kl_coef: float = 0.1, + clip_eps: float = 0.2, + use_reference_logprobs: bool = True, +) -> Tuple[torch.Tensor, dict]: + """ + Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch. + + This implements proper GRPO/PPO with: + - Importance sampling ratio: policy(a|s) / policy_old(a|s) + - PPO-style clipping to prevent large updates + - KL penalty to prevent reward hacking/policy collapse + + The loss encourages the model to: + - Increase probability for tokens with positive advantages + - Decrease probability for tokens with negative advantages + - Stay close to the reference policy (inference-time policy) + + Args: + model: The model to compute loss for + tokens: Input token IDs [batch, seq_len] + labels: Target labels [batch, seq_len], -100 for masked positions + advantages: Advantage values [batch, 1] + temperatures: Temperature values [batch, 1, 1] + gradient_accumulation_steps: Number of accumulation steps (for scaling) + inference_logprobs: Logprobs from inference (π_old), aligned with labels [batch, seq_len] + kl_coef: KL penalty coefficient (beta). Higher = more conservative updates + clip_eps: PPO clipping epsilon. Clips ratio to [1-eps, 1+eps] + use_reference_logprobs: If True, use inference_logprobs as reference policy + + Returns: + Tuple of (loss tensor, metrics dict) + """ + # Forward pass + outputs = model(tokens) + logits = outputs.logits + + # Temperature scaling for training otherwise likely ratio is off + t = temperatures.to(logits.device, logits.dtype) + t = torch.where(t <= 0, torch.ones_like(t), t) + scaled_logits = logits / t + + # Log probabilities per token (current policy π) + logp_per_token = -F.cross_entropy( + scaled_logits.view(-1, scaled_logits.size(-1)), + labels.view(-1), + reduction="none", + ignore_index=-100, + ).view(labels.shape) + + # Masking based on labels != -100 + mask = (labels != -100).float() + mask_sum = mask.sum(dim=-1).clamp_min(1e-8) + + # Expand advantages to match token shape [batch, 1] -> [batch, seq_len] + adv_expanded = advantages.expand_as(logp_per_token).to(logp_per_token.device) + + # Track logprobs for alignment verification + inference_logprobs_flat = None + logprob_diff_mean = 0.0 + logprob_diff_abs_mean = 0.0 + logprob_diff_max = 0.0 + + # === GRPO/PPO Loss Computation === + if use_reference_logprobs and inference_logprobs is not None: + # Move inference logprobs to correct device/dtype + ref_logprobs = inference_logprobs.to( + logp_per_token.device, logp_per_token.dtype + ) + + # NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated + with torch.no_grad(): + # Only look at generated positions (where mask == 1) + ref_at_generated = (ref_logprobs * mask).sum() / mask.sum() + train_at_generated = (logp_per_token * mask).sum() / mask.sum() + + # Extract logprobs at generated positions for alignment tracking + inference_logprobs_flat = ref_logprobs[mask.bool()].detach() + training_at_mask = logp_per_token[mask.bool()].detach() + + # Token-level difference: THE key metric for alignment verification + # If weights are truly shared, this should be ~0 at step start + token_diff = training_at_mask - inference_logprobs_flat + logprob_diff_mean = token_diff.mean().item() + logprob_diff_abs_mean = token_diff.abs().mean().item() + logprob_diff_max = token_diff.abs().max().item() + + # Check if ref logprobs are negative (as they should be for generated tokens) + # If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used + if ref_at_generated > 0.5: + print( + f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)" + ) + print( + " [WARNING] This suggests inference_logprobs alignment is wrong" + ) + elif abs(ref_at_generated - train_at_generated) > 2.0: + print( + f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}" + ) + + # Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old) + log_ratio = logp_per_token - ref_logprobs + ratio = torch.exp(log_ratio) + + # PPO-style clipping + clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) + + # Surrogate objectives + surr1 = ratio * adv_expanded + surr2 = clipped_ratio * adv_expanded + + # Pessimistic bound: min for positive advantages, max for negative + # This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0 + # -max(ratio * A, clipped_ratio * A) when A < 0 + policy_loss_per_token = -torch.where( + adv_expanded >= 0, + torch.min(surr1, surr2), + torch.max(surr1, surr2), + ) + + # Average over tokens, then over batch + policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean() + + # KL penalty: encourage staying close to reference policy + # Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4): + # This estimator is guaranteed to be non-negative (unlike squared log-ratio). + if kl_coef > 0: + # Schulman's unbiased KL estimator: (π_ref/π) - log(π_ref/π) - 1 + # = exp(-log_ratio) + log_ratio - 1 + kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0 + kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean() + total_loss = ( + policy_loss + kl_coef * kl_penalty + ) / gradient_accumulation_steps + else: + kl_penalty = torch.tensor(0.0, device=logp_per_token.device) + total_loss = policy_loss / gradient_accumulation_steps + + # Compute metrics for logging + with torch.no_grad(): + # Fraction of tokens where ratio was clipped + clipped_fraction = ( + (ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps) + ).float() + clipped_fraction = (clipped_fraction * mask).sum() / mask.sum() + + # Mean ratio and KL for monitoring (using Schulman's estimator) + mean_ratio = (ratio * mask).sum() / mask.sum() + # Schulman KL: exp(-log_ratio) + log_ratio - 1 + schulman_kl = torch.exp(-log_ratio) + log_ratio - 1.0 + mean_kl = (schulman_kl * mask).sum() / mask.sum() + + # For backward compatibility: collect training logprobs + raw_logp_per_token = -F.cross_entropy( + outputs.logits.view(-1, outputs.logits.size(-1)), + labels.view(-1), + reduction="none", + ignore_index=-100, + ).view(labels.shape) + training_logprobs_flat = raw_logp_per_token[mask.bool()].detach() + else: + # Fail loudly + raise ValueError( + "GRPO requires inference_logprobs for importance sampling!\n" + "\n" + "This error means the environment isn't providing logprobs. To fix:\n" + " 1. Use --openai.server_type vllm (not 'openai')\n" + " 2. Ensure vLLM is returning logprobs in /generate response\n" + " 3. Check that gsm8k_server is configured correctly\n" + "\n" + "Without inference logprobs, training will cause reward hacking.\n" + "If you REALLY want vanilla REINFORCE (not recommended), set use_reference_logprobs=False" + ) + + # === Compute Additional Metrics === + with torch.no_grad(): + pos = (advantages > 0).float() + neg = (advantages <= 0).float() + mask_float = mask.to(logp_per_token.dtype) + avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum + pos_logp = (logp_per_token * pos).mean().item() + neg_logp = (logp_per_token * neg).mean().item() + + # Interpretable metric: advantage-weighted average logprob + interpretable_loss = (avg_logp * advantages.squeeze()).mean().item() + + metrics = { + "pos_logp": pos_logp, + "neg_logp": neg_logp, + "avg_logp": avg_logp, + "pos_count": pos.sum().item(), + "neg_count": neg.sum().item(), + "training_logprobs": training_logprobs_flat, + "inference_logprobs": inference_logprobs_flat, + "interpretable_loss": interpretable_loss, + # GRPO-specific metrics + "kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty, + "mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio, + "mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl, + "clipped_fraction": ( + clipped_fraction.item() + if torch.is_tensor(clipped_fraction) + else clipped_fraction + ), + # Token-level alignment metrics (key for verifying weight sharing) + "logprob_diff_mean": logprob_diff_mean, + "logprob_diff_abs_mean": logprob_diff_abs_mean, + "logprob_diff_max": logprob_diff_max, + } + + return total_loss, metrics + + +def run_training_step( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + token_batches: List[torch.Tensor], + label_batches: List[torch.Tensor], + advantage_batches: List[torch.Tensor], + temperature_batches: List[torch.Tensor], + config: TrainingConfig, + inference_logprob_batches: Optional[List[torch.Tensor]] = None, +) -> dict: + """ + Run a single training step with gradient accumulation. + + Performs: + 1. Forward pass through all micro-batches with proper GRPO loss + 2. Backward pass with gradient accumulation + 3. Gradient clipping + 4. Optimizer step + + Args: + model: The model to train + optimizer: The optimizer + token_batches: List of token tensors (micro-batches) + label_batches: List of label tensors + advantage_batches: List of advantage tensors + temperature_batches: List of temperature tensors + config: Training configuration (includes kl_coef, clip_eps, use_reference_logprobs) + inference_logprob_batches: Batched logprobs from inference (π_old), aligned with labels + + Returns: + Dict of training metrics for this step + """ + total_loss = 0.0 + total_pos_logp = 0.0 + total_neg_logp = 0.0 + total_pos = 0.0 + total_neg = 0.0 + total_kl_penalty = 0.0 + total_mean_ratio = 0.0 + total_mean_kl = 0.0 + total_clipped_fraction = 0.0 + total_logprob_diff_mean = 0.0 + total_logprob_diff_abs_mean = 0.0 + total_logprob_diff_max = 0.0 + grad_norm = 0.0 + all_training_logprobs: List[torch.Tensor] = [] + all_inference_logprobs: List[torch.Tensor] = [] + + # Get GRPO hyperparameters from config + kl_coef = getattr(config, "kl_coef", 0.1) + clip_eps = getattr(config, "clip_eps", 0.2) + use_reference_logprobs = getattr(config, "use_reference_logprobs", True) + + # Accumulate gradients over micro-batches + num_batches = len(token_batches) if token_batches else 1 + + for batch_idx, (tokens, labels, advantages, temperatures) in enumerate( + zip(token_batches, label_batches, advantage_batches, temperature_batches) + ): + tokens = tokens.to(config.device) + labels = labels.to(config.device) + advantages = advantages.to(config.device) + + # Get corresponding inference logprobs batch if available + inf_logprobs = None + if inference_logprob_batches is not None and batch_idx < len( + inference_logprob_batches + ): + inf_logprobs = inference_logprob_batches[batch_idx] + + loss, metrics = compute_grpo_loss( + model, + tokens, + labels, + advantages, + temperatures, + config.gradient_accumulation_steps, + inference_logprobs=inf_logprobs, + kl_coef=kl_coef, + clip_eps=clip_eps, + use_reference_logprobs=use_reference_logprobs, + ) + + loss.backward() + total_loss += loss.item() + total_pos_logp += metrics["pos_logp"] + total_neg_logp += metrics["neg_logp"] + total_pos += metrics["pos_count"] + total_neg += metrics["neg_count"] + + # Accumulate GRPO-specific metrics + total_kl_penalty += metrics.get("kl_penalty", 0.0) + total_mean_ratio += metrics.get("mean_ratio", 1.0) + total_mean_kl += metrics.get("mean_kl", 0.0) + total_clipped_fraction += metrics.get("clipped_fraction", 0.0) + + # Accumulate token-level alignment metrics + total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0) + total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0) + total_logprob_diff_max = max( + total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0) + ) + + # Collect logprobs for alignment monitoring + if "training_logprobs" in metrics and metrics["training_logprobs"] is not None: + all_training_logprobs.append(metrics["training_logprobs"]) + if ( + "inference_logprobs" in metrics + and metrics["inference_logprobs"] is not None + ): + all_inference_logprobs.append(metrics["inference_logprobs"]) + + # Gradient clipping and optimizer step + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + optimizer.zero_grad() + + # Help prevent memory fragmentation + torch.cuda.empty_cache() + + # Normalize metrics by batch count + if total_pos > 0: + total_pos_logp /= num_batches + if total_neg > 0: + total_neg_logp /= num_batches + + result = { + "loss": total_loss, + "grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm, + "pos_logp": total_pos_logp, + "neg_logp": total_neg_logp, + "pos_count": total_pos, + "neg_count": total_neg, + # GRPO-specific metrics (averaged over batches) + "kl_penalty": total_kl_penalty / num_batches, + "mean_ratio": total_mean_ratio / num_batches, + "mean_kl": total_mean_kl / num_batches, + "clipped_fraction": total_clipped_fraction / num_batches, + } + + # Compute logprob alignment stats for monitoring + # This proves weight sharing is working: inference & training logprobs should converge + if all_training_logprobs: + train_flat = torch.cat(all_training_logprobs) + if train_flat.numel() > 0: + _logprob_alignment_stats["logprobs/training_mean"] = ( + train_flat.mean().item() + ) + _logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item() + + if all_inference_logprobs: + inf_flat = torch.cat(all_inference_logprobs) + if inf_flat.numel() > 0: + _logprob_alignment_stats["logprobs/inference_mean"] = inf_flat.mean().item() + _logprob_alignment_stats["logprobs/inference_std"] = inf_flat.std().item() + + # Token-level alignment metrics - THE key metric for verifying weight sharing + # diff_abs_mean close to 0 = weights are truly shared + _logprob_alignment_stats["alignment/diff_mean"] = ( + total_logprob_diff_mean / num_batches + ) + _logprob_alignment_stats["alignment/diff_abs_mean"] = ( + total_logprob_diff_abs_mean / num_batches + ) + _logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max + + return result + + +def log_metrics( + metrics: dict, + step: int, + use_wandb: bool, + extra_metrics: Optional[dict] = None, + benchmark: bool = False, +) -> None: + """ + Log training metrics to console and optionally wandb. + + Args: + metrics: Dict of metrics from training step + step: Current step number + use_wandb: Whether to log to wandb + extra_metrics: Optional additional metrics to log + benchmark: Whether to show timing/benchmark info + """ + # Build timing string (only if benchmark enabled) + timing_str = "" + if benchmark: + if "step_time" in metrics: + timing_str += f", Step time: {metrics['step_time']:.2f}s" + if "sync_time" in metrics and metrics["sync_time"] > 0: + timing_str += f", Sync time: {metrics['sync_time']:.2f}s" + if "data_fetch_time" in metrics: + timing_str += f", Data fetch: {metrics['data_fetch_time']:.2f}s" + if "gpu_memory_gb" in metrics: + timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB" + + # Primary metrics line: Loss and grad norm + loss_str = ( + f"{metrics['loss']:.6f}" + if abs(metrics["loss"]) < 0.01 + else f"{metrics['loss']:.4f}" + ) + print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}") + + # GRPO metrics line: KL, ratio, clipping + kl_penalty = metrics.get("kl_penalty", 0) + mean_ratio = metrics.get("mean_ratio", 1.0) + mean_kl = metrics.get("mean_kl", 0) + clipped_frac = metrics.get("clipped_fraction", 0) + + if kl_penalty > 0 or mean_kl > 0: + print( + f" GRPO: KL={mean_kl:.4f}, ratio={mean_ratio:.3f}, " + f"clipped={clipped_frac*100:.1f}%" + ) + + # Advantage distribution + if "pos_count" in metrics or "neg_count" in metrics: + pos_count = metrics.get("pos_count", 0) + neg_count = metrics.get("neg_count", 0) + pos_logp = metrics.get("pos_logp", 0) + neg_logp = metrics.get("neg_logp", 0) + print( + f" Advantages: +{int(pos_count)} / -{int(neg_count)}, " + f"LogP: pos={pos_logp:.3f}, neg={neg_logp:.3f}" + ) + + if use_wandb: + log_dict = { + "train/loss": metrics["loss"], + "train/grad_norm": metrics["grad_norm"], + "train/pos_logp": metrics.get("pos_logp", 0), + "train/neg_logp": metrics.get("neg_logp", 0), + # GRPO-specific metrics + "grpo/kl_penalty": kl_penalty, + "grpo/mean_ratio": mean_ratio, + "grpo/mean_kl": mean_kl, + "grpo/clipped_fraction": clipped_frac, + } + # Add timing metrics if present + for key in [ + "step_time", + "sync_time", + "data_fetch_time", + "gpu_memory_gb", + "gpu_memory_reserved_gb", + ]: + if key in metrics: + log_dict[f"train/{key}"] = metrics[key] + + # Add logprob alignment stats + if _logprob_alignment_stats: + log_dict.update(_logprob_alignment_stats) + + if extra_metrics: + log_dict.update(extra_metrics) + wandb.log(log_dict, step=step) + + +def finalize_training( + use_wandb: bool, + training_start_time: Optional[float] = None, + mode: str = "unknown", + total_steps: int = 0, + benchmark_stats: Optional[dict] = None, + benchmark: bool = False, +) -> None: + """ + Clean up after training and log benchmark summary. + + Args: + use_wandb: Whether wandb is enabled + training_start_time: Start time of training + mode: Training mode name + total_steps: Total steps completed + benchmark_stats: Dict with lists of per-step metrics + benchmark: Whether to print benchmark summary to console + """ + print("\nTraining finished.") + + if benchmark_stats is None: + benchmark_stats = {} + + if training_start_time is not None: + total_time = time.time() - training_start_time + peak_gpu_mem_gb = ( + torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + + # Calculate averages from collected stats + step_times = benchmark_stats.get("step_times", []) + sync_times = benchmark_stats.get("sync_times", []) + data_fetch_times = benchmark_stats.get("data_fetch_times", []) + gpu_memories = benchmark_stats.get("gpu_memories", []) + + avg_step_time = sum(step_times) / len(step_times) if step_times else 0 + total_step_time = sum(step_times) + avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0 + total_sync_time = sum(sync_times) + avg_data_fetch = ( + sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0 + ) + total_data_fetch = sum(data_fetch_times) + avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0 + + if benchmark: + print(f"\n{'='*70}") + print(f"BENCHMARK SUMMARY ({mode})") + print(f"{'='*70}") + print( + f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)" + ) + print(f" Total steps: {total_steps}") + print(" ") + print(" TIMING BREAKDOWN:") + print(f" Avg step time: {avg_step_time:.2f}s") + print(f" Total step time: {total_step_time:.2f}s") + print( + f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)" + ) + print(f" Total sync time: {total_sync_time:.2f}s") + print(f" Avg data fetch time: {avg_data_fetch:.2f}s") + print(f" Total data fetch time: {total_data_fetch:.2f}s") + print(" ") + print(" MEMORY:") + print(f" Peak GPU memory: {peak_gpu_mem_gb:.2f} GB") + print(f" Avg GPU memory: {avg_gpu_mem:.2f} GB") + print(f"{'='*70}\n") + + if use_wandb: + wandb.summary["benchmark/total_time_seconds"] = total_time + wandb.summary["benchmark/total_time_minutes"] = total_time / 60 + wandb.summary["benchmark/mode"] = mode + wandb.summary["benchmark/total_steps"] = total_steps + wandb.summary["benchmark/avg_step_time_seconds"] = avg_step_time + wandb.summary["benchmark/peak_gpu_memory_gb"] = peak_gpu_mem_gb + wandb.summary["benchmark/avg_gpu_memory_gb"] = avg_gpu_mem + wandb.finish() + elif use_wandb: + wandb.finish() diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 920c71db..2846f14f 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -1,45 +1,246 @@ -# Based on https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +#!/usr/bin/env python3 """ -NOTE: This API server is used only for demonstrating usage of AsyncEngine -and simple performance benchmarks. It is not intended for production use. -For production use, we recommend using our OpenAI compatible server. -We are also not going to accept PRs modifying this file, please -change `vllm/entrypoints/openai/api_server.py` instead. +Custom vLLM API server with CUDA IPC shared memory support. + +This server extends the standard vLLM API with: +- Single-copy mode: Exports CUDA IPC handles so trainer can share vLLM's tensors +- LoRA hot-swap without server restart +- Bridge endpoints for coordination + +ARCHITECTURE (Single-Copy Mode): + When VLLM_ENABLE_SHARED_WEIGHTS=1: + 1. vLLM's GPUModelRunner is patched BEFORE loading + 2. Patched runner exports CUDA IPC handles to vllm_bridge_config.json + 3. Trainer reads IPC handles and attaches to the SAME tensors + 4. optimizer.step() updates weights in-place - vLLM sees changes immediately! + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ SINGLE GPU (True Shared Memory) │ + │ ┌─────────────────────────────────────────────────────────────────┐ │ + │ │ Model Weights (ONE copy!) │ │ + │ │ (accessible via CUDA IPC handles) │ │ + │ └─────────────────────────────────────────────────────────────────┘ │ + │ ▲ ▲ │ + │ │ Reads (inference) │ Writes (train) │ + │ ┌────────┴────────┐ ┌───────────┴───────────┐ │ + │ │ vLLM Worker │ │ Trainer Process │ │ + │ │ │ │ (attached via IPC) │ │ + │ └─────────────────┘ └───────────────────────┘ │ + └─────────────────────────────────────────────────────────────────────────┘ + +CRITICAL: Patches must be applied BEFORE importing vLLM! """ +# ============================================================================= +# STEP 0: Standard library imports ONLY (no vLLM yet!) +# ============================================================================= import asyncio import json +import multiprocessing +import os import ssl +import threading from argparse import Namespace from collections.abc import AsyncGenerator -from typing import Any +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, List, Optional -import vllm.envs as envs -from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse, Response, StreamingResponse -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.launcher import serve_http -from vllm.entrypoints.utils import with_cancellation -from vllm.logger import init_logger -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.usage.usage_lib import UsageContext -from vllm.utils import random_uuid -from vllm.v1.engine.async_llm import AsyncLLMEngine +# Default to v0 engine to avoid CUDA fork issues with v1 engine +# Users can override with VLLM_USE_V1=1 if needed +os.environ.setdefault("VLLM_USE_V1", "0") +# Set spawn method for multiprocessing (required for CUDA) +os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") try: - from vllm.utils.argparse_utils import FlexibleArgumentParser - from vllm.utils.system_utils import set_ulimit + multiprocessing.set_start_method("spawn", force=True) +except RuntimeError: + pass # Already set + +# ============================================================================= +# STEP 1: Apply patches BEFORE any vLLM imports! +# ============================================================================= + + +def _apply_patches_early() -> bool: + """ + Apply vLLM patches if shared weights are enabled. + + This MUST be called before any vLLM imports! + Returns True if patches were applied. + """ + enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1" + num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1")) + + if not enable_shared and num_inference_nodes < 0: + print("[vLLM Server] Shared weights not enabled, skipping patches") + return False + + print("[vLLM Server] VLLM_ENABLE_SHARED_WEIGHTS=1, applying patches...") + + try: + # Try relative import first (when run as module) + from .vllm_patching import apply_patches + except ImportError: + # Fall back to absolute import (when run as script) + try: + import sys + from pathlib import Path + + # Add parent directory to path so we can import vllm_patching + script_dir = Path(__file__).parent + if str(script_dir) not in sys.path: + sys.path.insert(0, str(script_dir)) + from vllm_patching import apply_patches + except ImportError as e: + print(f"[vLLM Server] Could not import vllm_patching: {e}") + print("[vLLM Server] Shared memory weight updates will not be available") + return False + + try: + success = apply_patches() + if success: + print("[vLLM Server] ✓ vLLM patches applied successfully!") + else: + print("[vLLM Server] ✗ Failed to apply patches") + return success + except Exception as e: + print(f"[vLLM Server] Error applying patches: {e}") + import traceback + + traceback.print_exc() + return False + + +# Apply patches NOW, before any vLLM imports below! +PATCHES_APPLIED = _apply_patches_early() + + +# ============================================================================= +# STEP 2: Now safe to import vLLM (patches are already in place) +# ============================================================================= + +import torch # noqa: E402 +import vllm.envs as envs # noqa: E402 +from fastapi import FastAPI, HTTPException, Request # noqa: E402 +from fastapi.responses import JSONResponse, Response, StreamingResponse # noqa: E402 +from pydantic import BaseModel # noqa: E402 +from vllm.engine.arg_utils import AsyncEngineArgs # noqa: E402 +from vllm.entrypoints.launcher import serve_http # noqa: E402 +from vllm.entrypoints.utils import with_cancellation # noqa: E402 +from vllm.logger import init_logger # noqa: E402 +from vllm.sampling_params import RequestOutputKind, SamplingParams # noqa: E402 +from vllm.usage.usage_lib import UsageContext # noqa: E402 +from vllm.utils import random_uuid # noqa: E402 +from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402 + +# Handle vLLM version differences - FlexibleArgumentParser was removed/renamed +try: + from vllm.utils import FlexibleArgumentParser except ImportError: - from vllm.utils import FlexibleArgumentParser, set_ulimit -from vllm.outputs import RequestOutput # noqa: F401 -from vllm.version import __version__ as VLLM_VERSION + # Create a compatible ArgumentParser that handles 'deprecated' kwarg + # (Python 3.10 doesn't support 'deprecated' in BooleanOptionalAction) + import argparse + + class FlexibleArgumentParser(argparse.ArgumentParser): + """ArgumentParser that strips unsupported kwargs for Python < 3.13.""" + + def add_argument(self, *args, **kwargs): + # Remove 'deprecated' kwarg if present (not supported before Python 3.13) + kwargs.pop("deprecated", None) + return super().add_argument(*args, **kwargs) + + +# set_ulimit might not exist in all vLLM versions +try: + from vllm.utils import set_ulimit +except ImportError: + + def set_ulimit() -> None: + """No-op fallback for set_ulimit.""" + pass + + +from vllm.outputs import RequestOutput # noqa: F401, E402 +from vllm.version import __version__ as VLLM_VERSION # noqa: E402 + +# Try to import LoRARequest for adapter support +try: + from vllm.lora.request import LoRARequest # noqa: E402 + + LORA_AVAILABLE = True +except ImportError: + LORA_AVAILABLE = False + LoRARequest = None # type: ignore logger = init_logger("vllm.entrypoints.api_server") app = FastAPI() -engine = None +engine: Optional[AsyncLLM] = None + + +@dataclass +class BridgeState: + """State for shared memory and LoRA.""" + + update_count: int = 0 + last_update_time: float = 0.0 + lock: threading.Lock = field(default_factory=threading.Lock) + + # LoRA state + active_lora_path: Optional[str] = None + active_lora_name: Optional[str] = None + active_lora_id: int = 0 # vLLM requires unique integer ID per adapter + lora_load_count: int = 0 + + +bridge_state = BridgeState() + + +def _get_lora_request() -> Optional["LoRARequest"]: + """Get the current LoRA request if an adapter is active.""" + if not LORA_AVAILABLE: + return None + if bridge_state.active_lora_path is None: + return None + + return LoRARequest( + lora_name=bridge_state.active_lora_name or "default_adapter", + lora_int_id=bridge_state.active_lora_id, + lora_path=bridge_state.active_lora_path, + ) + + +# ============================================================================= +# Pydantic Models for API +# ============================================================================= + + +class BridgeInfoResponse(BaseModel): + enabled: bool + update_count: int + last_update_time: float + model_name: str + device: str + + +class LoraLoadRequest(BaseModel): + adapter_path: str + adapter_name: Optional[str] = None + + +class LoraStatusResponse(BaseModel): + lora_available: bool + active_adapter_path: Optional[str] + active_adapter_name: Optional[str] + active_adapter_id: Optional[int] + load_count: int + available_adapters: List[str] + + +# ============================================================================= +# Health Endpoints +# ============================================================================= @app.get("/health") @@ -50,31 +251,40 @@ async def health() -> Response: @app.get("/health_generate") async def health_generate() -> Response: - """ - Check the health of the inference server by sending a special request to generate one token. - """ - assert engine is not None + """Health check that verifies model can generate.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + sampling_params = SamplingParams() request_id = random_uuid() - results_generator = engine.generate( - {"prompt_token_ids": [0]}, sampling_params, request_id - ) + try: - async for request_output in results_generator: - final_output = request_output # type: RequestOutput # noqa: F841 + results_generator = engine.generate( + {"prompt_token_ids": [0]}, sampling_params, request_id + ) + async for _ in results_generator: + pass + return Response(status_code=200) except asyncio.CancelledError: return Response(status_code=499) - return Response(status_code=200) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Generation Endpoints +# ============================================================================= @app.post("/generate") async def generate(request: Request) -> Response: - """Generate completion for the request. + """ + Generate completion for the request. - The request should be a JSON object with the following fields: - - prompt: the prompt to use for the generation. - - stream: whether to stream the results or not. - - other fields: the sampling parameters (See `SamplingParams` for details). + The request should be a JSON object with: + - prompt: the prompt to use for generation + - stream: whether to stream results + - other fields: sampling parameters """ request_dict = await request.json() return await _generate(request_dict, raw_request=request) @@ -82,16 +292,23 @@ async def generate(request: Request) -> Response: @with_cancellation async def _generate(request_dict: dict, raw_request: Request) -> Response: + """Internal generate handler.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - assert engine is not None - results_generator = engine.generate(prompt, sampling_params, request_id) + # Get active LoRA adapter if any + lora_request = _get_lora_request() + + results_generator = engine.generate( + prompt, sampling_params, request_id, lora_request=lora_request + ) - # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: prompt = request_output.prompt @@ -103,11 +320,10 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: if stream: return StreamingResponse(stream_results()) - # Non-streaming case final_output = None try: async for request_output in results_generator: - final_output = request_output # type: RequestOutput + final_output = request_output except asyncio.CancelledError: return Response(status_code=499) @@ -115,10 +331,11 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: prompt = final_output.prompt or engine.tokenizer.decode( final_output.prompt_token_ids ) - assert prompt is not None + text_outputs = [output.text for output in final_output.outputs] finish_reasons = [output.finish_reason for output in final_output.outputs] ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons} + if sampling_params.logprobs is not None: output_logprobs = [ [ @@ -127,63 +344,424 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: ] for x in final_output.outputs ] - prompt_token_ids = final_output.prompt_token_ids - output_token_ids = [x.token_ids for x in final_output.outputs] ret["logprobs"] = output_logprobs - ret["prompt_token_ids"] = prompt_token_ids - ret["token_ids"] = output_token_ids + ret["prompt_token_ids"] = final_output.prompt_token_ids + ret["token_ids"] = [x.token_ids for x in final_output.outputs] + return JSONResponse(ret) -def build_app(args: Namespace) -> FastAPI: - global app # noqa: F824 +# ============================================================================= +# Bridge Endpoints (Weight Synchronization) +# ============================================================================= + +@app.get("/bridge/info") +async def bridge_info() -> JSONResponse: + """Get bridge status and configuration.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + model_name = ( + str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown" + ) + + return JSONResponse( + { + "enabled": PATCHES_APPLIED, + "shared_weights": PATCHES_APPLIED, + "update_count": bridge_state.update_count, + "last_update_time": bridge_state.last_update_time, + "model_name": model_name, + "device": "cuda" if torch.cuda.is_available() else "cpu", + } + ) + + +@app.get("/bridge/state_dict_info") +async def bridge_state_dict_info() -> JSONResponse: + """Get model parameter information.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + # Basic model info + try: + model_config = engine.model_config + return JSONResponse( + { + "model": str(model_config.model), + "dtype": str(model_config.dtype), + "shared_weights_enabled": PATCHES_APPLIED, + } + ) + except Exception as e: + return JSONResponse({"error": str(e)}) + + +# ============================================================================= +# Pause/Resume Endpoints +# ============================================================================= + + +@app.post("/bridge/pause") +async def bridge_pause() -> JSONResponse: + """Pause generation to allow weight updates.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + try: + # vLLM v1 supports pause/resume + if hasattr(engine, "_pause_cond"): + async with engine._pause_cond: + engine._paused = True + logger.info("Engine paused") + return JSONResponse({"status": "paused"}) + else: + return JSONResponse({"status": "not_supported"}) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/bridge/resume") +async def bridge_resume() -> JSONResponse: + """Resume generation after weight updates.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + try: + if hasattr(engine, "_pause_cond"): + async with engine._pause_cond: + engine._paused = False + engine._pause_cond.notify_all() + logger.info("Engine resumed") + return JSONResponse({"status": "resumed"}) + else: + return JSONResponse({"status": "not_supported"}) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/bridge/is_paused") +async def bridge_is_paused() -> JSONResponse: + """Check if engine is paused.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + paused = getattr(engine, "_paused", False) + return JSONResponse({"paused": paused}) + + +# ============================================================================= +# Sleep/Wake Endpoints (GPU memory management) +# ============================================================================= + + +@app.post("/bridge/sleep") +async def bridge_sleep() -> JSONResponse: + """Put engine to sleep to free GPU memory.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + try: + await engine.sleep() + logger.info("Engine sleeping") + return JSONResponse({"status": "sleeping"}) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/bridge/wake_up") +async def bridge_wake_up() -> JSONResponse: + """Wake engine and reload model.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + try: + await engine.wake_up() + logger.info("Engine woken up") + return JSONResponse({"status": "awake"}) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/bridge/is_sleeping") +async def bridge_is_sleeping() -> JSONResponse: + """Check if engine is sleeping.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + sleeping = await engine.is_sleeping() + return JSONResponse({"sleeping": sleeping}) + + +# ============================================================================= +# Debug Endpoints +# ============================================================================= + + +@app.get("/bridge/debug") +async def bridge_debug() -> JSONResponse: + """Debug endpoint to inspect engine state.""" + debug_info = { + "engine_type": type(engine).__name__ if engine else None, + "vllm_version": VLLM_VERSION, + "patches_applied": PATCHES_APPLIED, + "shared_weights_env": os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0"), + "num_inference_nodes": os.environ.get("NUM_INFERENCE_NODES", "unset"), + "logdir": os.environ.get("LOGDIR", "unset"), + } + + if engine is not None: + try: + debug_info["model_config"] = { + "model": str(engine.model_config.model), + "dtype": str(engine.model_config.dtype), + } + except Exception: + pass + + return JSONResponse(debug_info) + + +@app.get("/bridge/list_endpoints") +async def list_endpoints() -> JSONResponse: + """List all available endpoints.""" + endpoints = [] + for route in app.routes: + if hasattr(route, "path") and hasattr(route, "methods"): + endpoints.append( + { + "path": route.path, + "methods": list(route.methods), + } + ) + return JSONResponse({"endpoints": endpoints}) + + +# ============================================================================= +# LoRA Endpoints +# ============================================================================= + + +@app.get("/lora/status") +async def lora_status() -> LoraStatusResponse: + """Get LoRA adapter status.""" + log_dir = os.environ.get("LOGDIR", ".") + available = [] + + if os.path.exists(log_dir): + for item in os.listdir(log_dir): + item_path = os.path.join(log_dir, item) + if os.path.isdir(item_path) and os.path.exists( + os.path.join(item_path, "adapter_config.json") + ): + available.append(item) + + return LoraStatusResponse( + lora_available=LORA_AVAILABLE, + active_adapter_path=bridge_state.active_lora_path, + active_adapter_name=bridge_state.active_lora_name, + active_adapter_id=( + bridge_state.active_lora_id if bridge_state.active_lora_path else None + ), + load_count=bridge_state.lora_load_count, + available_adapters=available, + ) + + +@app.post("/lora/load") +async def lora_load(request: LoraLoadRequest) -> JSONResponse: + """Load a LoRA adapter.""" + if not os.path.exists(request.adapter_path): + raise HTTPException( + status_code=404, detail=f"Adapter not found: {request.adapter_path}" + ) + + # Read adapter config to validate and log details + adapter_config_path = os.path.join(request.adapter_path, "adapter_config.json") + adapter_info = {} + + if os.path.exists(adapter_config_path): + try: + with open(adapter_config_path, "r") as f: + adapter_config = json.load(f) + adapter_info = { + "r": adapter_config.get("r"), + "lora_alpha": adapter_config.get("lora_alpha"), + "target_modules": adapter_config.get("target_modules"), + "base_model": adapter_config.get("base_model_name_or_path"), + } + logger.info(f"LoRA adapter config: {adapter_info}") + except Exception as e: + logger.warning(f"Could not read adapter_config.json: {e}") + else: + logger.warning(f"No adapter_config.json found at {adapter_config_path}") + + with bridge_state.lock: + bridge_state.active_lora_path = request.adapter_path + bridge_state.active_lora_name = ( + request.adapter_name or f"adapter_{bridge_state.lora_load_count}" + ) + bridge_state.active_lora_id = ( + bridge_state.lora_load_count + 1 + ) # vLLM needs unique int ID + bridge_state.lora_load_count += 1 + + logger.info( + f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})" + ) + + return JSONResponse( + { + "status": "ok", + "adapter_path": request.adapter_path, + "adapter_name": bridge_state.active_lora_name, + "adapter_id": bridge_state.active_lora_id, + "load_count": bridge_state.lora_load_count, + "adapter_config": adapter_info, + } + ) + + +@app.post("/lora/unload") +async def lora_unload() -> JSONResponse: + """Unload current LoRA adapter.""" + with bridge_state.lock: + prev_path = bridge_state.active_lora_path + prev_name = bridge_state.active_lora_name + bridge_state.active_lora_path = None + bridge_state.active_lora_name = None + bridge_state.active_lora_id = 0 + + logger.info(f"LoRA adapter unloaded: {prev_path} ({prev_name})") + return JSONResponse( + { + "status": "ok", + "previous_adapter": prev_path, + "previous_name": prev_name, + } + ) + + +# ============================================================================= +# Server Setup +# ============================================================================= + + +def build_app(args: Namespace) -> FastAPI: + """Build the FastAPI application.""" app.root_path = args.root_path return app -async def init_app( - args: Namespace, - llm_engine: AsyncLLMEngine | None = None, -) -> FastAPI: +async def init_app(args: Namespace, llm_engine: AsyncLLM | None = None) -> FastAPI: + """Initialize the application and vLLM engine.""" app = build_app(args) global engine - engine_args = AsyncEngineArgs.from_cli_args(args) engine = ( llm_engine if llm_engine is not None - else AsyncLLMEngine.from_engine_args( + else AsyncLLM.from_engine_args( engine_args, usage_context=UsageContext.API_SERVER ) ) app.state.engine_client = engine + + # Export basic state dict info for trainers (the patched runner exports detailed info) + _export_state_dict_info(args) + return app +def _export_state_dict_info(args: Namespace) -> None: + """Export basic model info to JSON for trainer (backup if patches don't run).""" + # Allow explicit config path via env var, otherwise use LOGDIR + config_path = os.environ.get("VLLM_BRIDGE_CONFIG_PATH") + if config_path: + json_path = Path(config_path) + json_path.parent.mkdir(parents=True, exist_ok=True) + else: + log_dir = os.environ.get("LOGDIR", ".") + Path(log_dir).mkdir(parents=True, exist_ok=True) + json_path = Path(log_dir) / "vllm_bridge_config.json" + + # Only write basic info if the file doesn't exist or is empty + # The patched runner will write complete info with param_mappings + try: + if json_path.exists(): + with open(json_path, "r") as f: + existing = json.load(f) + if ( + existing.get("param_mappings") + and len(existing["param_mappings"]) > 0 + ): + logger.info("Config already has param_mappings, not overwriting") + return + + info = { + "model": getattr(args, "model", "unknown"), + "dtype": getattr(args, "dtype", "auto"), + "tp_degree": getattr(args, "tensor_parallel_size", 1), + "dp_shard_degree": 1, + "param_mappings": {}, + "shared_weights_enabled": PATCHES_APPLIED, + } + + with open(json_path, "w") as f: + json.dump(info, f, indent=2) + + logger.info(f"Exported basic state dict info to {json_path}") + except Exception as e: + logger.warning(f"Failed to export state dict info: {e}") + + async def run_server( - args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any + args: Namespace, llm_engine: AsyncLLM | None = None, **uvicorn_kwargs: Any ) -> None: + """Run the vLLM API server.""" logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + if PATCHES_APPLIED: + logger.info("=" * 60) + logger.info("SHARED MEMORY MODE ENABLED") + logger.info("Weight updates from trainer will be reflected immediately!") + logger.info("=" * 60) + set_ulimit() app = await init_app(args, llm_engine) - assert engine is not None + + if engine is None: + raise RuntimeError("No engine initialized") + + # Log available endpoints + logger.info("=" * 60) + logger.info("Streamlined vLLM Server - Training-Focused API") + logger.info("Available endpoints:") + logger.info(" POST /generate - Generate with logprobs (primary endpoint)") + logger.info(" GET /health - Health check") + logger.info(" GET /bridge/info - Bridge status") + logger.info(" POST /bridge/pause - Pause generation") + logger.info(" POST /bridge/resume - Resume generation") + logger.info(" GET /lora/status - LoRA adapter status") + logger.info(" POST /lora/load - Load LoRA adapter") + logger.info(" POST /lora/unload - Unload LoRA adapter") + logger.info("=" * 60) shutdown_task = await serve_http( app, sock=None, - enable_ssl_refresh=args.enable_ssl_refresh, + enable_ssl_refresh=getattr(args, "enable_ssl_refresh", False), host=args.host, port=args.port, - log_level=args.log_level, + log_level=getattr(args, "log_level", "info"), timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, + ssl_keyfile=getattr(args, "ssl_keyfile", None), + ssl_certfile=getattr(args, "ssl_certfile", None), + ssl_ca_certs=getattr(args, "ssl_ca_certs", None), + ssl_cert_reqs=getattr(args, "ssl_cert_reqs", ssl.CERT_NONE), **uvicorn_kwargs, ) @@ -193,32 +771,17 @@ async def run_server( if __name__ == "__main__": parser = FlexibleArgumentParser() parser.add_argument("--host", type=str, default=None) - parser.add_argument("--port", type=parser.check_port, default=8000) + parser.add_argument("--port", type=int, default=9001) parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) - parser.add_argument( - "--ssl-ca-certs", type=str, default=None, help="The CA certificates file" - ) - parser.add_argument( - "--enable-ssl-refresh", - action="store_true", - default=False, - help="Refresh SSL Context when SSL certificate files change", - ) - parser.add_argument( - "--ssl-cert-reqs", - type=int, - default=int(ssl.CERT_NONE), - help="Whether client certificate is required (see stdlib ssl module's)", - ) - parser.add_argument( - "--root-path", - type=str, - default=None, - help="FastAPI root_path when app is behind a path based routing proxy", - ) - parser.add_argument("--log-level", type=str, default="debug") - parser = AsyncEngineArgs.add_cli_args(parser) - args = parser.parse_args() + parser.add_argument("--ssl-ca-certs", type=str, default=None) + parser.add_argument("--enable-ssl-refresh", action="store_true", default=False) + parser.add_argument("--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE)) + parser.add_argument("--root-path", type=str, default=None) + parser.add_argument("--log-level", type=str, default="info") + # Add vLLM engine args + parser = AsyncEngineArgs.add_cli_args(parser) + + args = parser.parse_args() asyncio.run(run_server(args)) diff --git a/example_trainer/vllm_manager.py b/example_trainer/vllm_manager.py new file mode 100644 index 00000000..7f9f20fe --- /dev/null +++ b/example_trainer/vllm_manager.py @@ -0,0 +1,315 @@ +""" +vLLM process management for GRPO trainer. + +Handles launching, monitoring, and terminating vLLM server processes +for legacy mode training. +""" + +import atexit +import os +import signal +import socket +import subprocess +import time +from typing import Optional + +import requests + +from .config import TrainingConfig + +# Global variable to keep track of the vLLM process +_vllm_process: Optional[subprocess.Popen] = None + + +def is_port_in_use(port: int) -> bool: + """Check if a port is already in use.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + +def kill_process_on_port(port: int, timeout: float = 5.0) -> bool: + """ + Kill any process using the specified port. + + Returns True if no process was running or if it was successfully killed. + """ + if not is_port_in_use(port): + return True + + print(f" Port {port} is in use, attempting to kill existing process...") + + try: + # Try to find and kill the process using lsof (Linux/Mac) + result = subprocess.run( + ["lsof", "-t", "-i", f":{port}"], capture_output=True, text=True, timeout=5 + ) + if result.stdout.strip(): + pids = result.stdout.strip().split("\n") + print(f" Killing {len(pids)} processes on port {port}...") + for pid in pids: + try: + os.kill(int(pid), signal.SIGTERM) + except (ProcessLookupError, ValueError): + pass + + # Wait for port to be free + start = time.time() + while time.time() - start < timeout: + if not is_port_in_use(port): + print(f" Port {port} is now free") + return True + time.sleep(0.5) + + # Force kill if still running + killed_count = 0 + for pid in pids: + try: + os.kill(int(pid), signal.SIGKILL) + killed_count += 1 + except (ProcessLookupError, ValueError): + pass + if killed_count > 0: + print(f" Force killed {killed_count} stubborn processes") + + time.sleep(1) + return not is_port_in_use(port) + except FileNotFoundError: + # lsof not available, try fuser (Linux) + try: + subprocess.run(["fuser", "-k", f"{port}/tcp"], timeout=5) + time.sleep(1) + return not is_port_in_use(port) + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + except subprocess.TimeoutExpired: + pass + + print(f" WARNING: Could not kill process on port {port}") + return False + + +def cleanup_vllm(): + """Cleanup function to terminate vLLM on exit.""" + global _vllm_process + if _vllm_process: + print("\nTerminating vLLM process...") + _vllm_process.terminate() + try: + _vllm_process.wait(timeout=5) + print("vLLM process terminated.") + except subprocess.TimeoutExpired: + print("vLLM process did not terminate gracefully, killing.") + _vllm_process.kill() + _vllm_process.wait() + print("vLLM process killed.") + _vllm_process = None + + +# Register cleanup on module load +atexit.register(cleanup_vllm) + + +def launch_vllm_server( + config: TrainingConfig, + model_path: str, +) -> Optional[subprocess.Popen]: + """ + Launch a vLLM server process using our custom vllm_api_server.py. + + Uses the custom server instead of standard vLLM because: + - Streamlined API: Only /generate endpoint (provides logprobs) + - Weight bridge support: /bridge/* endpoints for shared memory mode + - LoRA hot-swap: /lora/* endpoints for adapter loading/unloading + + Args: + config: Training configuration + model_path: Path to model checkpoint + + Returns: + Popen process object, or None if launch failed + """ + global _vllm_process + + # Check if port is in use and try to kill existing process + if is_port_in_use(config.vllm_port): + print(f" WARNING: Port {config.vllm_port} is already in use!") + if not kill_process_on_port(config.vllm_port): + print( + f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process." + ) + print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN") + print(f" Or: pkill -f 'vllm.*{config.vllm_port}'") + return None + print(f" Successfully freed port {config.vllm_port}") + + # Use our custom vllm_api_server.py + script_dir = os.path.dirname(os.path.abspath(__file__)) + custom_server_path = os.path.join(script_dir, "vllm_api_server.py") + + vllm_command = [ + "python", + custom_server_path, + "--model", + model_path, + "--port", + str(config.vllm_port), + "--gpu-memory-utilization", + str(config.vllm_gpu_memory_utilization), + ] + + # Add served-model-name if using checkpoint path + if model_path != config.model_name: + vllm_command.extend(["--served-model-name", config.model_name]) + + print(f" Launching vLLM: {' '.join(vllm_command)}") + + try: + proc = subprocess.Popen(vllm_command) + print(f" vLLM launched with PID: {proc.pid}") + + # Check for immediate startup errors + try: + proc.communicate(timeout=2) + if proc.returncode is not None and proc.returncode != 0: + print(" WARNING: vLLM failed to start") + return None + except subprocess.TimeoutExpired: + print(" vLLM process started (check logs for details)") + + _vllm_process = proc + return proc + + except FileNotFoundError: + print(" ERROR: vLLM not found. Is it installed?") + return None + except Exception as e: + print(f" ERROR launching vLLM: {e}") + return None + + +def terminate_vllm_process() -> None: + """Terminate the running vLLM process if any.""" + global _vllm_process + + if _vllm_process is None: + return + + print(" Terminating vLLM process...") + _vllm_process.terminate() + try: + _vllm_process.wait(timeout=5) + except subprocess.TimeoutExpired: + print(" vLLM did not terminate gracefully, killing...") + _vllm_process.kill() + _vllm_process.wait() + _vllm_process = None + + +def check_vllm_process_health() -> None: + """Check if vLLM process terminated unexpectedly.""" + global _vllm_process + + if _vllm_process is not None and _vllm_process.poll() is not None: + print( + f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})" + ) + _vllm_process = None + + +def get_vllm_process() -> Optional[subprocess.Popen]: + """Get the current vLLM process.""" + return _vllm_process + + +def set_vllm_process(proc: Optional[subprocess.Popen]) -> None: + """Set the vLLM process (for external management).""" + global _vllm_process + _vllm_process = proc + + +def check_vllm_health(port: int) -> bool: + """ + Check if vLLM server is healthy and responding. + + Args: + port: Port the vLLM server is running on + + Returns: + True if server is healthy + """ + try: + response = requests.get(f"http://localhost:{port}/health", timeout=5) + return response.status_code == 200 + except Exception: + return False + + +def wait_for_vllm_ready(port: int, timeout: float = 120.0) -> bool: + """ + Wait for vLLM server to be ready. + + Args: + port: Port the vLLM server is running on + timeout: Maximum time to wait in seconds + + Returns: + True if server is ready, False if timeout + """ + print(f" Waiting for vLLM to be ready (port {port})...") + start_time = time.time() + + while time.time() - start_time < timeout: + if check_vllm_health(port): + print(" vLLM is ready!") + return True + time.sleep(2) + + print(f" WARNING: vLLM not ready after {timeout}s") + return False + + +def hotswap_lora_adapter( + adapter_name: str, + adapter_path: str, + port: int, +) -> bool: + """ + Hot-swap a LoRA adapter on a running vLLM server. + + Uses the vLLM /v1/load_lora_adapter endpoint to load a new adapter + without restarting the server. + + Args: + adapter_name: Name to identify the adapter + adapter_path: Path to the adapter checkpoint + port: vLLM server port + + Returns: + True if hot-swap succeeded + """ + try: + # Use vLLM's native LoRA loading endpoint + response = requests.post( + f"http://localhost:{port}/v1/load_lora_adapter", + json={ + "lora_name": adapter_name, + "lora_path": adapter_path, + }, + timeout=30, + ) + + if response.status_code == 200: + print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})") + return True + else: + print( + f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}" + ) + return False + + except requests.exceptions.ConnectionError: + print(f" [LORA] ✗ Cannot connect to vLLM at port {port}") + return False + except Exception as e: + print(f" [LORA] ✗ Error during hot-swap: {e}") + return False diff --git a/example_trainer/vllm_patching/__init__.py b/example_trainer/vllm_patching/__init__.py new file mode 100644 index 00000000..4a5bb2f4 --- /dev/null +++ b/example_trainer/vllm_patching/__init__.py @@ -0,0 +1,37 @@ +""" +vLLM Patching Module - Enables CUDA IPC shared memory for single-copy training. + +This module patches vLLM's GPUModelRunner to: +1. Call share_memory_() on model weights after loading +2. Export CUDA IPC handles to vllm_bridge_config.json +3. Enable the trainer to attach to vLLM's tensors directly + +The result: ONE copy of model weights in GPU memory, shared between +vLLM (inference) and the trainer (gradient updates). + +Usage: + # Set environment BEFORE importing + import os + os.environ["VLLM_ENABLE_SHARED_WEIGHTS"] = "1" + + # Import and apply patches BEFORE importing vllm + from example_trainer.vllm_patching import apply_patches + apply_patches() + + # Then import vllm normally + from vllm import AsyncLLM +""" + +from .patched_gpu_runner import ( + PatchedGPUModelRunner, + apply_patches, + get_patched_runner, + is_patched, +) + +__all__ = [ + "PatchedGPUModelRunner", + "apply_patches", + "get_patched_runner", + "is_patched", +] diff --git a/example_trainer/vllm_patching/patched_gpu_runner.py b/example_trainer/vllm_patching/patched_gpu_runner.py new file mode 100644 index 00000000..cd3768da --- /dev/null +++ b/example_trainer/vllm_patching/patched_gpu_runner.py @@ -0,0 +1,452 @@ +""" +Patched GPU Model Runner - Enables CUDA IPC for single-copy training. + +This patches vLLM's GPUModelRunner to: +1. Call share_memory_() on model weights after loading +2. Export CUDA IPC handles to vllm_bridge_config.json + +The key insight is that CUDA IPC handles allow the trainer process to +attach to the EXACT SAME GPU memory that vLLM uses. This means: +- ONE copy of model weights in GPU memory +- Trainer's optimizer.step() updates vLLM's weights directly +- No synchronization needed - vLLM immediately sees new weights + +CRITICAL: This module must be imported and apply_patches() called BEFORE +any vLLM imports. The patches MUST happen before vLLM caches module references. +""" + +from __future__ import annotations + +import os +import shutil +import sys + +# Flag to track if patches have been applied +_PATCHES_APPLIED = False +_PATCHED_RUNNER_CLASS = None + + +def _patch_lora_triton_for_blackwell() -> bool: + """ + Patch vLLM's LoRA Triton kernels to disable GDC (Grid Dependency Control). + + GDC is a Blackwell-specific feature that causes Triton compilation to fail + on B200 GPUs. This patches the kernel_utils.py to disable GDC. + + Returns True if patch was applied successfully. + """ + try: + import vllm + + vllm_path = vllm.__path__[0] + kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py" + + # Check if file exists + if not os.path.exists(kernel_utils_path): + print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch") + return False + + with open(kernel_utils_path, "r") as f: + content = f.read() + + # Check if already patched + if "PATCHED FOR B200" in content: + print("[vLLM Patch] LoRA GDC already patched for B200") + return True + + modified = False + + # Patch USE_GDC = True -> False + if "USE_GDC = True" in content: + content = content.replace( + "USE_GDC = True", + "USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure", + ) + modified = True + + # Patch USE_GDC: tl.constexpr = True -> False + if "USE_GDC: tl.constexpr = True" in content: + content = content.replace( + "USE_GDC: tl.constexpr = True", + "USE_GDC: tl.constexpr = False # PATCHED FOR B200", + ) + modified = True + + # Patch the gdc_wait call itself + if "tl.extra.cuda.gdc_wait()" in content: + content = content.replace( + "tl.extra.cuda.gdc_wait()", + "pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled", + ) + modified = True + + if modified: + with open(kernel_utils_path, "w") as f: + f.write(content) + print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}") + + # Clear Triton cache to force recompilation + triton_cache = os.path.expanduser("~/.triton/cache") + if os.path.exists(triton_cache): + try: + shutil.rmtree(triton_cache) + print("[vLLM Patch] ✓ Cleared Triton cache") + except Exception as e: + print(f"[vLLM Patch] Warning: Could not clear Triton cache: {e}") + + return True + else: + print("[vLLM Patch] No GDC patterns found to patch") + return False + + except Exception as e: + print(f"[vLLM Patch] Warning: Could not patch LoRA GDC: {e}") + return False + + +def apply_patches() -> bool: + """ + Apply patches to vLLM's GPUModelRunner in ALL locations. + + This must be called BEFORE importing vLLM's engine classes. + Safe to call multiple times (idempotent). + + Also patches LoRA Triton kernels to disable GDC for B200 compatibility. + + Returns True if patches were applied successfully. + + Usage: + # CRITICAL: Import and call BEFORE any vLLM imports! + import os + os.environ["VLLM_ENABLE_SHARED_WEIGHTS"] = "1" + + from example_trainer.vllm_patching import apply_patches + apply_patches() + + # Now import vLLM + from vllm import AsyncLLM # Uses patched runner + """ + global _PATCHES_APPLIED, _PATCHED_RUNNER_CLASS + + if _PATCHES_APPLIED: + return True + + # First, patch LoRA Triton for B200 compatibility + _patch_lora_triton_for_blackwell() + + try: + # Import the source module and get original class + import vllm.v1.worker.gpu_model_runner as gpu_model_runner_module + from vllm.v1.worker.gpu_model_runner import GPUModelRunner as OriginalRunner + + # Create the patched class + PatchedRunner = _create_patched_runner(OriginalRunner) + _PATCHED_RUNNER_CLASS = PatchedRunner + + # ================================================================= + # PATCH 1: Replace in source module + # ================================================================= + gpu_model_runner_module.GPUModelRunner = PatchedRunner + print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_model_runner.GPUModelRunner") + + # ================================================================= + # PATCH 2: Replace in gpu_worker module (main usage location) + # ================================================================= + try: + import vllm.v1.worker.gpu_worker as gpu_worker_module + + gpu_worker_module.GPUModelRunner = PatchedRunner + print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_worker.GPUModelRunner") + except ImportError: + pass + + # ================================================================= + # PATCH 3: Update sys.modules entry for source module + # ================================================================= + # This ensures new imports get the patched version + if "vllm.v1.worker.gpu_model_runner" in sys.modules: + sys.modules["vllm.v1.worker.gpu_model_runner"].GPUModelRunner = ( + PatchedRunner + ) + + # ================================================================= + # PATCH 4: Patch GPUWorker if already imported + # ================================================================= + try: + if "vllm.v1.worker.gpu_worker" in sys.modules: + worker_module = sys.modules["vllm.v1.worker.gpu_worker"] + if hasattr(worker_module, "GPUWorker"): + # Update any class-level references + worker_module.GPUModelRunner = PatchedRunner + except Exception: + pass + + _PATCHES_APPLIED = True + print("[vLLM Patch] ✓ GPUModelRunner patched for shared memory updates") + return True + + except ImportError as e: + print(f"[vLLM Patch] Warning: Could not apply patches: {e}") + print("[vLLM Patch] This may be due to vLLM version incompatibility") + print("[vLLM Patch] Shared memory updates will not be available") + return False + except Exception as e: + print(f"[vLLM Patch] Error applying patches: {e}") + import traceback + + traceback.print_exc() + return False + + +def _create_patched_runner(BaseRunner: type) -> type: + """ + Create a patched GPUModelRunner class. + + Returns a new class that inherits from the original and adds + CUDA IPC export functionality for single-copy training. + """ + + class PatchedGPUModelRunner(BaseRunner): + """ + Patched GPUModelRunner that enables CUDA IPC for single-copy training. + + After loading the model, this: + 1. Calls share_memory_() on all parameters + 2. Exports CUDA IPC handles to vllm_bridge_config.json + + The trainer reads these IPC handles and attaches to the SAME + GPU memory, so optimizer.step() updates weights that vLLM + immediately sees for inference. + """ + + _shared_memory_setup_done = False + + def load_model(self, *args, **kwargs) -> None: + """Load model and set up shared memory + update daemon.""" + print("[vLLM Patch] PatchedGPUModelRunner.load_model() called!") + + # Call original load_model + super().load_model(*args, **kwargs) + + print("[vLLM Patch] Model loaded, checking shared weights setup...") + + # Check if shared memory updates are enabled + enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1" + num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1")) + + print( + f"[vLLM Patch] VLLM_ENABLE_SHARED_WEIGHTS={enable_shared}, NUM_INFERENCE_NODES={num_inference_nodes}" + ) + + if not enable_shared and num_inference_nodes < 0: + print( + "[vLLM Patch] Shared weights disabled (set VLLM_ENABLE_SHARED_WEIGHTS=1 to enable)" + ) + return + + if self._shared_memory_setup_done: + print("[vLLM Patch] Shared memory already set up, skipping") + return + + print("[vLLM Patch] Setting up shared memory weight updates...", flush=True) + + try: + self._setup_shared_memory() + PatchedGPUModelRunner._shared_memory_setup_done = True + print("[vLLM Patch] ✓ Shared memory setup complete!", flush=True) + print( + "[vLLM Patch] ✓ IPC handles exported - trainer can now attach!", + flush=True, + ) + except Exception as e: + print(f"[vLLM Patch] ERROR in _setup_shared_memory: {e}", flush=True) + import traceback + + traceback.print_exc() + return + + def _setup_shared_memory(self) -> None: + """Move model tensors to shared memory and export param info.""" + import json + from pathlib import Path + + print("[vLLM Patch] _setup_shared_memory() starting...") + + # Get state dict + state_dict = self.model.state_dict() + print(f"[vLLM Patch] Model has {len(state_dict)} parameters") + + # Make entire model shareable via share_memory_() on each tensor + shared_count = 0 + for key, val in state_dict.items(): + try: + if val.is_cuda: + val.share_memory_() + shared_count += 1 + except Exception as e: + print(f"[vLLM Patch] Warning: Could not share {key}: {e}") + + print(f"[vLLM Patch] Called share_memory_() on {shared_count} CUDA tensors") + + # Also try calling share_memory() on the model itself + try: + self.model.share_memory() + print("[vLLM Patch] Called model.share_memory()") + except Exception as e: + print(f"[vLLM Patch] Note: model.share_memory() not available: {e}") + + # Export parameter info to JSON for trainer + # Allow explicit config path via env var, otherwise use LOGDIR + config_path = os.environ.get("VLLM_BRIDGE_CONFIG_PATH") + if config_path: + json_path = Path(config_path) + json_path.parent.mkdir(parents=True, exist_ok=True) + else: + log_dir = os.environ.get("LOGDIR", ".") + Path(log_dir).mkdir(parents=True, exist_ok=True) + json_path = Path(log_dir) / "vllm_bridge_config.json" + + param_mappings = {} + param_names = [] + ipc_handles = {} + + for name, tensor in state_dict.items(): + param_mappings[name] = { + "vllm_name": name, + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "device": str(tensor.device), + } + param_names.append(name) + + # Export CUDA IPC handles for true single-copy mode + if tensor.is_cuda: + try: + import base64 + + storage = tensor.untyped_storage() + share_data = storage._share_cuda_() + + # share_data is a tuple of 8 items - we need ALL of them: + # [0] = device index (int) + # [1] = cudaIpcMemHandle_t (bytes) + # [2] = storage size (int) + # [3] = storage offset in original (int) + # [4] = ref counter handle (bytes - filename) + # [5] = ref counter offset (int) + # [6] = event handle (bytes) + # [7] = event sync required (bool) + + ipc_handles[name] = { + "device_index": share_data[0], + "ipc_handle_b64": base64.b64encode(share_data[1]).decode( + "ascii" + ), + "storage_size": share_data[2], + "storage_offset_orig": share_data[3], + "ref_counter_handle_b64": base64.b64encode( + share_data[4] + ).decode("ascii"), + "ref_counter_offset": share_data[5], + "event_handle_b64": base64.b64encode(share_data[6]).decode( + "ascii" + ), + "event_sync_required": share_data[7], + # Tensor metadata for reconstruction + "tensor_storage_offset": tensor.storage_offset(), + "shape": list(tensor.shape), + "stride": list(tensor.stride()), + "dtype": str(tensor.dtype), + } + except Exception as e: + print( + f"[vLLM Patch] Could not get IPC handle for {name}: {e}", + flush=True, + ) + import traceback + + traceback.print_exc() + + print( + f"[vLLM Patch] Exported {len(ipc_handles)} IPC handles for single-copy mode", + flush=True, + ) + + # Get model info + model_name = "unknown" + tp_degree = 1 + try: + model_name = str(self.model_config.model) + tp_degree = self.parallel_config.tensor_parallel_size + except Exception as e: + print(f"[vLLM Patch] Warning: Could not get model config: {e}") + + import base64 + + # Convert bytes to base64 for JSON serialization + def serialize_ipc_handles(handles): + result = {} + for k, v in handles.items(): + if isinstance(v, bytes): + result[k] = {"_bytes_b64_": base64.b64encode(v).decode("ascii")} + elif isinstance(v, dict): + result[k] = serialize_ipc_handles(v) + else: + result[k] = v + return result + + serialized_ipc_handles = ( + serialize_ipc_handles(ipc_handles) if ipc_handles else {} + ) + + info = { + "model": model_name, + "tp_degree": tp_degree, + "dp_shard_degree": 1, + "param_mappings": param_mappings, + "param_names": sorted(param_names), + "ipc_handles": serialized_ipc_handles, + "shared_weights_enabled": True, + "num_params": len(param_names), + "single_copy_enabled": len(ipc_handles) > 0, + } + + try: + with open(json_path, "w") as f: + json.dump(info, f, indent=2) + print( + f"[vLLM Patch] ✓ Exported {len(param_mappings)} params to {json_path}" + ) + except Exception as e: + print(f"[vLLM Patch] ERROR: Failed to export params: {e}") + import traceback + + traceback.print_exc() + + # Set proper class name + PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner" + PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner" + + return PatchedGPUModelRunner + + +def get_patched_runner() -> type | None: + """Get the patched runner class if patches have been applied.""" + return _PATCHED_RUNNER_CLASS + + +def is_patched() -> bool: + """Check if patches have been applied.""" + return _PATCHES_APPLIED + + +# Placeholder class for type checking +class PatchedGPUModelRunner: + """ + Placeholder class for type checking. + + The actual patched class is created dynamically by _create_patched_runner() + to properly inherit from vLLM's GPUModelRunner. + """ + + pass