diff --git a/environments/eval_results/evaluate_config.yaml b/environments/eval_results/evaluate_config.yaml new file mode 100644 index 00000000..4e039643 --- /dev/null +++ b/environments/eval_results/evaluate_config.yaml @@ -0,0 +1,47 @@ +env: + group_size: 8 + max_num_workers: -1 + max_eval_workers: 16 + max_num_workers_per_node: 24 + steps_per_eval: 20 + max_token_length: 4096 + eval_handling: LIMIT_TRAIN + eval_limit_ratio: 0.1 + inference_weight: 1.0 + batch_size: 64 + max_batches_offpolicy: 3 + tokenizer_name: Qwen/Qwen3-4B-Instruct-2507 + use_wandb: true + rollout_server_url: http://localhost:8000 + total_steps: 120 + wandb_name: math-zero-env + num_rollouts_to_keep: 32 + num_rollouts_per_group_for_logging: 1 + ensure_scores_are_not_same: true + data_path_to_save_groups: null + data_dir_to_save_evals: ./eval_results + min_items_sent_before_logging: 2 + include_messages: false + min_batch_allocation: null + worker_timeout: 600.0 + thinking_mode: false + reasoning_effort: null + max_reasoning_tokens: null + custom_thinking_prompt: null + run_evaluation: true + mask_too_long_completions: true + percent_length_penalty: 0.0 + start_tok_length: 8192 +openai: +- timeout: 1200 + num_max_requests_at_once: 32 + num_requests_for_eval: 32 + model_name: tinker://d43e769f-dfd3-5a83-81d2-21ac97a656ad:train:0/sampler_weights/step_78 + rolling_buffer_length: 1000 + server_type: openai + api_key: tml-N5fCuvsPh08em1BPcHnY6oaq0uvSrXOqUpveAmCLl7Ow9NpnTqoIl10Yr2kpfBUnFAAAA + base_url: https://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/v1 + n_kwarg_is_ignored: false + health_check: true +slurm: false +testing: false diff --git a/example_trainer/README.md b/example_trainer/README.md index b9a360c0..ebb13006 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -17,10 +17,12 @@ example_trainer/ ├── checkpointing.py # Save models & LoRA adapters ├── vllm_manager.py # vLLM process management ├── trainers.py # Training mode implementations +├── nccl_weight_bridge.py # NCCL direct weight transfer (torchtitan-style) ├── vllm_api_server.py # Custom vLLM server (streamlined for training) -├── vllm_patching/ # CUDA IPC patches for weight sharing basically overriding standard vllm for this +├── vllm_patching/ # CUDA IPC patches for weight sharing │ └── patched_gpu_runner.py -└── scripts/ # Helper scripts +└── scripts/ # Helper scripts + ├── run_lora_nccl.sh # LoRA NCCL mode launcher ├── test_lora_mode.sh └── test_single_copy_mode.sh ``` @@ -58,17 +60,25 @@ Data Flow: --- -## Three Training Modes +## Four Training Modes | Mode | Description | Memory | Best For | |------|-------------|--------|----------| | **shared_vllm** | Single-copy via CUDA IPC | 1x model | Same GPU, maximum efficiency | -| **lora_only** | Train adapters, hot-swap | 1x + small adapter | Fast iteration, small checkpoints | +| **lora_only** | Train adapters, HTTP hot-swap | 1x + small adapter | Simple setup, debugging | +| **lora_nccl** | Train adapters, NCCL transfer | 1x + small adapter | Multi-GPU, fastest sync | | **legacy** | Full model, restart vLLM | 2x model | Different GPUs, simple setup | ### Recommendation -**Start with `lora_only`** - it's the easiest to set up and debug. Move to `shared_vllm` for production training when you need maximum efficiency for SINGLE GPU TRAINING RUNS. MULTIPLE GPU TRAINING NOT SUPPORTED . +**Start with `lora_only`** - it's the easiest to set up and debug. + +**Use `lora_nccl`** (torchtitan-style) for production multi-GPU training when you need: +- Fastest weight synchronization (NCCL broadcast vs HTTP + disk I/O) +- True on-policy training (sync after every step) +- Distributed training across nodes + +**Use `shared_vllm`** for single-GPU training when you need maximum efficiency. --- @@ -139,6 +149,89 @@ python -m example_trainer.grpo \ --- +## LoRA NCCL Mode (Torchtitan-Style) + +This mode implements torchtitan-style direct NCCL weight transfer for LoRA training. Instead of HTTP hot-swap (disk I/O), weights are broadcast directly via NCCL from the trainer to vLLM. + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ NCCL Process Group │ +│ ┌─────────────────────┐ ┌─────────────────────────┐ │ +│ │ Trainer (rank 0) │ ──NCCL send──> │ vLLM (rank 1) │ │ +│ │ - Trains LoRA │ │ - Receives weights │ │ +│ │ - broadcast() │ │ - Updates in-place │ │ +│ └─────────────────────┘ └─────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +### Benefits over `lora_only` + +| Aspect | lora_only (HTTP) | lora_nccl (NCCL) | +|--------|------------------|------------------| +| **Sync time** | 100-500ms (disk I/O) | 5-50ms (direct GPU) | +| **Disk usage** | Writes checkpoints | No disk I/O during training | +| **Scalability** | Single vLLM instance | Multiple inference nodes | +| **On-policy** | Periodic sync | True per-step sync | + +### Quick Start + +**Option 1: Use the launch script** +```bash +./example_trainer/scripts/run_lora_nccl.sh +``` + +**Option 2: Manual launch** + +```bash +# Terminal 1: API +run-api --port 8000 + +# Terminal 2: vLLM with LoRA support +CUDA_VISIBLE_DEVICES=1 python example_trainer/vllm_api_server.py \ + --model Qwen/Qwen2.5-3B-Instruct \ + --port 9001 \ + --gpu-memory-utilization 0.45 \ + --enable-lora \ + --max-lora-rank 32 \ + --enforce-eager + +# Terminal 3: Environment +python environments/gsm8k_server.py serve \ + --env.tokenizer_name "Qwen/Qwen2.5-3B-Instruct" \ + --env.rollout_server_url "http://localhost:8000" \ + --openai.base_url "http://localhost:9001/v1" \ + --openai.model_name "Qwen/Qwen2.5-3B-Instruct" \ + --openai.server_type vllm \ + --slurm false + +# Terminal 4: Trainer with NCCL weight bridge +CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \ + --model-name Qwen/Qwen2.5-3B-Instruct \ + --weight-bridge-mode lora_nccl \ + --vllm-port 9001 \ + --atropos-url http://localhost:8000 \ + --lora-r 16 \ + --lora-alpha 32 \ + --training-steps 50 \ + --nccl-init-method "tcp://localhost:29500" \ + --nccl-world-size 2 \ + --nccl-sync-every-step \ + --benchmark +``` + +### NCCL Configuration Options + +| Flag | Default | Description | +|------|---------|-------------| +| `--nccl-init-method` | `tcp://localhost:29500` | NCCL rendezvous URL | +| `--nccl-world-size` | `2` | Total processes (trainer + vLLM instances) | +| `--nccl-sync-every-step` | `True` | Sync after every step (true on-policy) | +| `--no-nccl-sync-every-step` | - | Sync only at `vllm-restart-interval` | + +--- + ## Shared vLLM Mode (Advanced) Single-copy mode shares GPU memory between vLLM and the trainer - zero model duplication! diff --git a/example_trainer/cli.py b/example_trainer/cli.py index 046e3ec3..1ea2dfde 100644 --- a/example_trainer/cli.py +++ b/example_trainer/cli.py @@ -189,9 +189,9 @@ def add_mode_args(parser: argparse.ArgumentParser) -> None: group.add_argument( "--weight-bridge-mode", type=str, - choices=["shared_vllm", "lora_only", "none"], + choices=["shared_vllm", "lora_only", "lora_nccl", "none"], default="none", - help="Weight sync mode: 'shared_vllm', 'lora_only', or 'none' (legacy)", + help="Weight sync mode: 'shared_vllm', 'lora_only', 'lora_nccl', or 'none' (legacy)", ) group.add_argument( "--vllm-config-path", @@ -218,6 +218,35 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None: ) +def add_nccl_args(parser: argparse.ArgumentParser) -> None: + """Add NCCL weight bridge arguments (for lora_nccl mode).""" + group = parser.add_argument_group("NCCL Weight Bridge (lora_nccl mode)") + group.add_argument( + "--nccl-init-method", + type=str, + default="tcp://localhost:29500", + help="NCCL process group init method (tcp://host:port)", + ) + group.add_argument( + "--nccl-world-size", + type=int, + default=2, + help="Total processes in NCCL group (trainer + vLLM instances)", + ) + group.add_argument( + "--nccl-sync-every-step", + action="store_true", + default=True, + help="Sync weights after every step (true on-policy)", + ) + group.add_argument( + "--no-nccl-sync-every-step", + action="store_false", + dest="nccl_sync_every_step", + help="Sync weights only at vllm_restart_interval", + ) + + def add_distributed_args(parser: argparse.ArgumentParser) -> None: """Add distributed training arguments.""" group = parser.add_argument_group("Distributed Training") @@ -279,6 +308,7 @@ def create_full_parser() -> argparse.ArgumentParser: add_wandb_args(parser) add_mode_args(parser) add_lora_args(parser) + add_nccl_args(parser) add_distributed_args(parser) add_debug_args(parser) @@ -367,4 +397,8 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: debug_loading=getattr(args, "debug_loading", False), benchmark=getattr(args, "benchmark", False), atropos_url=getattr(args, "atropos_url", "http://localhost:8000"), + # NCCL settings (for lora_nccl mode) + nccl_init_method=getattr(args, "nccl_init_method", "tcp://localhost:29500"), + nccl_world_size=getattr(args, "nccl_world_size", 2), + nccl_sync_every_step=getattr(args, "nccl_sync_every_step", True), ) diff --git a/example_trainer/config.py b/example_trainer/config.py index 291dad43..42d852d1 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -105,12 +105,13 @@ class TrainingConfig(BaseModel): wandb_group: Optional[str] = Field(None, description="Wandb group name") # === Training Mode Configuration === - weight_bridge_mode: Literal["shared_vllm", "lora_only", "none"] = Field( + weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_nccl", "none"] = Field( "none", description=( "How to synchronize weights with inference server. " "'shared_vllm': attach to vLLM's shared memory tensors and update in-place. " - "'lora_only': keep base model frozen, train/swap LoRA adapters. " + "'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP. " + "'lora_nccl': LoRA training with NCCL direct weight transfer (torchtitan-style). " "'none': legacy mode, restart vLLM with new checkpoint files." ), ) @@ -148,6 +149,30 @@ class TrainingConfig(BaseModel): ), ) + # === NCCL Weight Bridge Configuration (for lora_nccl mode) === + nccl_init_method: str = Field( + "tcp://localhost:29500", + description=( + "NCCL process group init method for lora_nccl mode. " + "Format: tcp://host:port" + ), + ) + nccl_world_size: int = Field( + 2, + description=( + "Total number of processes in the NCCL weight bridge group. " + "Typically 2: trainer (rank 0) + vLLM server (rank 1). " + "For multi-GPU vLLM, this would be 1 + num_vllm_gpus." + ), + ) + nccl_sync_every_step: bool = Field( + True, + description=( + "Whether to sync weights after every training step (true on-policy). " + "If False, syncs every vllm_restart_interval steps." + ), + ) + # === Single-Copy Mode Configuration === single_copy: bool = Field( False, diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index bca45cb8..49798c6f 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 """ +GRPO (Group Relative Policy Optimization) Trainer. -Supports three training modes: +Supports four training modes: - none (legacy): Periodic checkpoint saves + vLLM restarts - shared_vllm: Single-copy mode with CUDA IPC weight sharing -- lora_only: LoRA adapter training +- lora_only: LoRA adapter training with HTTP hot-swap +- lora_nccl: LoRA adapter training with NCCL direct weight transfer (torchtitan-style) Usage: # Legacy mode (manages vLLM internally) @@ -14,13 +16,18 @@ Usage: python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\ --weight-bridge-mode shared_vllm - # LoRA mode (requires external vLLM with --enable-lora) + # LoRA mode with HTTP hot-swap (requires external vLLM with --enable-lora) python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\ --weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32 + + # LoRA mode with NCCL direct transfer (torchtitan-style, fastest) + python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\ + --weight-bridge-mode lora_nccl --lora-r 16 --lora-alpha 32 \\ + --nccl-init-method tcp://localhost:29500 """ from .cli import config_from_args, parse_args -from .trainers import train_legacy, train_lora, train_shared_vllm +from .trainers import train_legacy, train_lora, train_lora_nccl, train_shared_vllm def main(): @@ -43,8 +50,12 @@ def main(): train_shared_vllm(config) elif config.weight_bridge_mode == "lora_only": - # LoRA mode: freeze base model, train adapters only + # LoRA mode: freeze base model, train adapters only (HTTP hot-swap) train_lora(config) + + elif config.weight_bridge_mode == "lora_nccl": + # LoRA NCCL mode: torchtitan-style direct weight transfer + train_lora_nccl(config) else: # Legacy mode: periodic checkpoint saves + vLLM restarts diff --git a/example_trainer/nccl_weight_bridge.py b/example_trainer/nccl_weight_bridge.py new file mode 100644 index 00000000..9b205be4 --- /dev/null +++ b/example_trainer/nccl_weight_bridge.py @@ -0,0 +1,555 @@ +""" +NCCL Weight Bridge for LoRA Training. + +Implements torchtitan-style direct NCCL weight transfer between trainer and vLLM. +This eliminates disk I/O for weight synchronization. + +Architecture: + ┌─────────────────────────────────────────────────────────────────────┐ + │ NCCL Process Group │ + │ ┌─────────────────────┐ ┌─────────────────────────┐ │ + │ │ Trainer (rank 0) │ ──NCCL send──> │ vLLM (rank 1+) │ │ + │ │ - LoRA weights │ │ - Receives weights │ │ + │ │ - broadcast() │ │ - Updates state_dict │ │ + │ └─────────────────────┘ └─────────────────────────┘ │ + └─────────────────────────────────────────────────────────────────────┘ + +Usage (Trainer side): + bridge = NCCLWeightBridge( + rank=0, + world_size=2, + init_method="tcp://localhost:29500" + ) + bridge.setup() + + # After training step + bridge.send_lora_weights(model) + +Usage (vLLM side): + bridge = NCCLWeightBridge( + rank=1, + world_size=2, + init_method="tcp://localhost:29500" + ) + bridge.setup() + + # In background thread + bridge.receive_and_update_weights(vllm_state_dict, param_mappings) +""" + +import json +import os +import threading +import time +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist + + +@dataclass +class NCCLBridgeConfig: + """Configuration for NCCL weight bridge.""" + + # Process group settings + rank: int = 0 + world_size: int = 2 + init_method: str = "tcp://localhost:29500" + timeout_seconds: int = 300 + + # Weight transfer settings + use_gloo_for_metadata: bool = True # Gloo for small metadata, NCCL for tensors + + # LoRA settings + lora_param_patterns: List[str] = field(default_factory=lambda: [ + "lora_A", "lora_B", "lora_a", "lora_b" + ]) + + +def is_lora_param(name: str, patterns: Optional[List[str]] = None) -> bool: + """Check if a parameter name corresponds to a LoRA weight.""" + if patterns is None: + patterns = ["lora_A", "lora_B", "lora_a", "lora_b", + "lora_a_stacked", "lora_b_stacked"] + return any(p in name for p in patterns) + + +def get_lora_params(model: torch.nn.Module) -> Dict[str, torch.Tensor]: + """Extract LoRA parameters from a model.""" + lora_params = {} + for name, param in model.named_parameters(): + if param.requires_grad and is_lora_param(name): + lora_params[name] = param + return lora_params + + +class NCCLWeightBridge: + """ + NCCL-based weight bridge for synchronizing LoRA weights between trainer and vLLM. + + Inspired by torchtitan's sglang_handling.py approach. + """ + + def __init__(self, config: NCCLBridgeConfig): + self.config = config + self.rank = config.rank + self.world_size = config.world_size + + self.nccl_group: Optional[dist.ProcessGroup] = None + self.gloo_group: Optional[dist.ProcessGroup] = None + + self.is_initialized = False + self.update_count = 0 + self.last_update_time = 0.0 + + # Parameter registry (filled during first sync) + self.param_names: List[str] = [] + self.param_shapes: Dict[str, Tuple[int, ...]] = {} + self.param_dtypes: Dict[str, torch.dtype] = {} + + # Receiver state (vLLM side) + self._receiver_thread: Optional[threading.Thread] = None + self._stop_receiver = threading.Event() + self._state_dict_ref: Optional[Dict[str, torch.Tensor]] = None + self._param_mappings: Dict[str, str] = {} + + def setup(self) -> bool: + """ + Initialize NCCL and Gloo process groups. + + Returns: + True if setup successful, False otherwise. + """ + if self.is_initialized: + return True + + try: + # Clean up any existing distributed state + self._cleanup_env_for_new_group() + + timeout = timedelta(seconds=self.config.timeout_seconds) + + # Initialize NCCL group for tensor transfers + print(f"[NCCLBridge] Initializing NCCL group (rank={self.rank}, world={self.world_size})") + self.nccl_group = self._init_process_group( + backend="nccl", + init_method=self.config.init_method, + world_size=self.world_size, + rank=self.rank, + group_name="lora_weight_nccl", + timeout=timeout, + ) + + if self.config.use_gloo_for_metadata: + # Initialize Gloo group for metadata (param names, shapes, etc.) + gloo_port = int(self.config.init_method.split(":")[-1]) + 1 + gloo_init = self.config.init_method.rsplit(":", 1)[0] + f":{gloo_port}" + + print(f"[NCCLBridge] Initializing Gloo group at {gloo_init}") + self.gloo_group = self._init_process_group( + backend="gloo", + init_method=gloo_init, + world_size=self.world_size, + rank=self.rank, + group_name="lora_weight_gloo", + timeout=timeout, + ) + + self.is_initialized = True + print(f"[NCCLBridge] ✓ Initialized successfully (rank {self.rank})") + return True + + except Exception as e: + print(f"[NCCLBridge] ✗ Failed to initialize: {e}") + import traceback + traceback.print_exc() + return False + + def _cleanup_env_for_new_group(self): + """Remove environment variables that interfere with new process groups.""" + # Save and remove torch distributed env vars + env_vars_to_clear = [ + "LOCAL_RANK", "RANK", "WORLD_SIZE", "GROUP_RANK", + "GROUP_WORLD_SIZE", "LOCAL_WORLD_SIZE", "MASTER_ADDR", + "MASTER_PORT" + ] + self._saved_env = {} + for var in env_vars_to_clear: + if var in os.environ: + self._saved_env[var] = os.environ.pop(var) + + def _restore_env(self): + """Restore saved environment variables.""" + for var, value in getattr(self, '_saved_env', {}).items(): + os.environ[var] = value + + def _init_process_group( + self, + backend: str, + init_method: str, + world_size: int, + rank: int, + group_name: str, + timeout: timedelta, + ) -> dist.ProcessGroup: + """ + Initialize a new process group without affecting the global state. + + Based on torchtitan's init_process_group implementation. + """ + from torch.distributed.distributed_c10d import ( + _new_process_group_helper, + _world, + Backend, + default_pg_timeout, + PrefixStore, + rendezvous, + ) + + backend_obj = Backend(backend) + + if timeout is None: + timeout = default_pg_timeout + + # Create rendezvous store + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid key collisions + store = PrefixStore(group_name, store) + + # Handle PyTorch version differences + pg_options_param_name = ( + "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + ) + + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend_obj, + store, + group_name=group_name, + **{pg_options_param_name: None}, + timeout=timeout, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg + + def register_params(self, model: torch.nn.Module) -> Dict[str, Any]: + """ + Register LoRA parameters from the model. + + Returns: + Dictionary with parameter metadata for vLLM side. + """ + lora_params = get_lora_params(model) + + self.param_names = sorted(lora_params.keys()) + self.param_shapes = {name: tuple(p.shape) for name, p in lora_params.items()} + self.param_dtypes = {name: p.dtype for name, p in lora_params.items()} + + metadata = { + "param_names": self.param_names, + "param_shapes": {k: list(v) for k, v in self.param_shapes.items()}, + "param_dtypes": {k: str(v) for k, v in self.param_dtypes.items()}, + "num_params": len(self.param_names), + } + + print(f"[NCCLBridge] Registered {len(self.param_names)} LoRA parameters") + return metadata + + def send_lora_weights( + self, + model: torch.nn.Module, + step: Optional[int] = None, + ) -> float: + """ + Send LoRA weights to vLLM via NCCL broadcast. + + Args: + model: Model with LoRA adapters + step: Optional training step for logging + + Returns: + Time taken for the transfer in seconds. + """ + if not self.is_initialized: + raise RuntimeError("NCCLBridge not initialized. Call setup() first.") + + if self.rank != 0: + raise RuntimeError("send_lora_weights() should only be called from rank 0 (trainer)") + + start_time = time.time() + + # Get LoRA parameters + lora_params = get_lora_params(model) + + if not self.param_names: + self.register_params(model) + + # Send step index first (so receivers know an update is coming) + step_tensor = torch.tensor([step if step is not None else self.update_count], + dtype=torch.long, device="cuda") + dist.broadcast(step_tensor, src=0, group=self.nccl_group) + + # Send each parameter + for name in self.param_names: + param = lora_params[name] + # Ensure contiguous for efficient transfer + param_data = param.detach().contiguous() + dist.broadcast(param_data, src=0, group=self.nccl_group) + + # Send completion signal + done_tensor = torch.tensor([1], dtype=torch.long, device="cuda") + dist.broadcast(done_tensor, src=0, group=self.nccl_group) + + elapsed = time.time() - start_time + self.update_count += 1 + self.last_update_time = time.time() + + if step is not None: + print(f"[NCCLBridge] Sent LoRA weights (step {step}) in {elapsed:.3f}s") + + return elapsed + + def start_receiver( + self, + state_dict: Dict[str, torch.Tensor], + param_mappings: Dict[str, str], + on_update: Optional[Callable[[int], None]] = None, + ): + """ + Start background thread to receive weight updates (vLLM side). + + Args: + state_dict: vLLM's model state dict (will be updated in-place) + param_mappings: Mapping from trainer param names to vLLM param names + on_update: Optional callback called after each update with step number + """ + if self.rank == 0: + raise RuntimeError("start_receiver() should not be called from rank 0 (trainer)") + + self._state_dict_ref = state_dict + self._param_mappings = param_mappings + self._stop_receiver.clear() + + def receiver_loop(): + print(f"[NCCLBridge] Receiver thread started (rank {self.rank})") + device = "cuda" + + while not self._stop_receiver.is_set(): + try: + # Wait for step index + step_tensor = torch.zeros(1, dtype=torch.long, device=device) + dist.broadcast(step_tensor, src=0, group=self.nccl_group) + step = step_tensor.item() + + if step < 0: + # Negative step means shutdown signal + print("[NCCLBridge] Received shutdown signal") + break + + # Receive each parameter + for name in self.param_names: + shape = self.param_shapes[name] + dtype_str = self.param_dtypes[name] + dtype = getattr(torch, str(dtype_str).replace("torch.", "")) + + # Create buffer and receive + buffer = torch.zeros(shape, dtype=dtype, device=device) + dist.broadcast(buffer, src=0, group=self.nccl_group) + + # Map to vLLM param name and update + vllm_name = self._param_mappings.get(name, name) + if vllm_name in self._state_dict_ref: + # Reshape if needed for vLLM stacked format + target = self._state_dict_ref[vllm_name] + if buffer.shape != target.shape: + buffer = self._reshape_for_vllm(buffer, vllm_name, target.shape) + + target.data.copy_(buffer) + + # Wait for completion signal + done_tensor = torch.zeros(1, dtype=torch.long, device=device) + dist.broadcast(done_tensor, src=0, group=self.nccl_group) + + self.update_count += 1 + self.last_update_time = time.time() + + print(f"[NCCLBridge] Received weight update (step {step})") + + if on_update: + on_update(step) + + except Exception as e: + if not self._stop_receiver.is_set(): + print(f"[NCCLBridge] Receiver error: {e}") + import traceback + traceback.print_exc() + break + + print("[NCCLBridge] Receiver thread exiting") + + self._receiver_thread = threading.Thread(target=receiver_loop, daemon=True) + self._receiver_thread.start() + + def stop_receiver(self): + """Stop the receiver thread.""" + self._stop_receiver.set() + + # Send shutdown signal if we're the trainer + if self.rank == 0 and self.is_initialized: + try: + shutdown_tensor = torch.tensor([-1], dtype=torch.long, device="cuda") + dist.broadcast(shutdown_tensor, src=0, group=self.nccl_group) + except Exception: + pass + + if self._receiver_thread and self._receiver_thread.is_alive(): + self._receiver_thread.join(timeout=5.0) + + def _reshape_for_vllm( + self, + tensor: torch.Tensor, + vllm_name: str, + target_shape: Tuple[int, ...], + ) -> torch.Tensor: + """ + Reshape LoRA tensor from PEFT format to vLLM stacked format. + + vLLM expects: + - Attention LoRA: [1, 1, rank, dim] or [1, 1, dim, rank] + - MoE LoRA: [num_experts, rank, dim] + """ + # Check if this is an attention LoRA (needs [1, 1, ...] prefix) + is_attention_lora = any( + proj in vllm_name + for proj in ["qkv_proj", "o_proj", "q_proj", "k_proj", "v_proj"] + ) + + if is_attention_lora and len(tensor.shape) == 2 and len(target_shape) == 4: + # [rank, dim] -> [1, 1, rank, dim] + tensor = tensor.unsqueeze(0).unsqueeze(0) + + if tensor.shape != target_shape: + raise ValueError( + f"Shape mismatch for {vllm_name}: got {tensor.shape}, expected {target_shape}" + ) + + return tensor + + def cleanup(self): + """Clean up process groups and threads.""" + self.stop_receiver() + self._restore_env() + + # Note: We don't destroy the process groups as they may still be in use + # by other parts of the system. They will be cleaned up on process exit. + + self.is_initialized = False + print("[NCCLBridge] Cleaned up") + + +def create_trainer_param_to_vllm_mapping( + trainer_param_names: List[str], + model_name: str = "llama", +) -> Dict[str, str]: + """ + Create mapping from PEFT trainer parameter names to vLLM parameter names. + + PEFT names: base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight + vLLM names: model.layers.0.self_attn.qkv_proj.q_proj.lora_a_stacked + + Args: + trainer_param_names: List of parameter names from the trainer model + model_name: Model architecture name for architecture-specific mappings + + Returns: + Dictionary mapping trainer names to vLLM names + """ + mapping = {} + + for name in trainer_param_names: + if not is_lora_param(name): + continue + + # Remove PEFT prefixes + vllm_name = name + for prefix in ["base_model.model.", "base_model."]: + if vllm_name.startswith(prefix): + vllm_name = vllm_name[len(prefix):] + + # Handle attention projections (qkv fusion) + for proj in ["q_proj", "k_proj", "v_proj"]: + if f".{proj}.lora_" in vllm_name: + # Map to vLLM's fused qkv_proj format + if ".lora_A" in vllm_name or ".lora_a" in vllm_name: + suffix = "lora_a_stacked" + else: + suffix = "lora_b_stacked" + + # self_attn.q_proj.lora_A.weight -> self_attn.qkv_proj.q_proj.lora_a_stacked + vllm_name = vllm_name.replace(f".{proj}.lora_A.weight", f".qkv_proj.{proj}.{suffix}") + vllm_name = vllm_name.replace(f".{proj}.lora_B.weight", f".qkv_proj.{proj}.{suffix}") + vllm_name = vllm_name.replace(f".{proj}.lora_a.weight", f".qkv_proj.{proj}.{suffix}") + vllm_name = vllm_name.replace(f".{proj}.lora_b.weight", f".qkv_proj.{proj}.{suffix}") + break + + # Handle o_proj + if ".o_proj.lora_" in vllm_name: + suffix = "lora_a_stacked" if ("lora_A" in vllm_name or "lora_a" in vllm_name) else "lora_b_stacked" + vllm_name = vllm_name.replace(".o_proj.lora_A.weight", f".o_proj.o_proj.{suffix}") + vllm_name = vllm_name.replace(".o_proj.lora_B.weight", f".o_proj.o_proj.{suffix}") + vllm_name = vllm_name.replace(".o_proj.lora_a.weight", f".o_proj.o_proj.{suffix}") + vllm_name = vllm_name.replace(".o_proj.lora_b.weight", f".o_proj.o_proj.{suffix}") + + # Handle MLP projections + for mlp_proj in ["gate_proj", "up_proj", "down_proj"]: + if f".{mlp_proj}.lora_" in vllm_name: + suffix = "lora_a_stacked" if ("lora_A" in vllm_name or "lora_a" in vllm_name) else "lora_b_stacked" + vllm_name = vllm_name.replace(f".{mlp_proj}.lora_A.weight", f".{mlp_proj}.{suffix}") + vllm_name = vllm_name.replace(f".{mlp_proj}.lora_B.weight", f".{mlp_proj}.{suffix}") + vllm_name = vllm_name.replace(f".{mlp_proj}.lora_a.weight", f".{mlp_proj}.{suffix}") + vllm_name = vllm_name.replace(f".{mlp_proj}.lora_b.weight", f".{mlp_proj}.{suffix}") + break + + mapping[name] = vllm_name + + return mapping + + +def export_bridge_config( + config_path: str, + param_metadata: Dict[str, Any], + param_mappings: Dict[str, str], + nccl_init_method: str, + world_size: int = 2, +): + """ + Export bridge configuration to JSON for vLLM to read. + + Args: + config_path: Path to write the config + param_metadata: Parameter metadata from register_params() + param_mappings: Trainer to vLLM parameter name mappings + nccl_init_method: NCCL init method (e.g., "tcp://localhost:29500") + world_size: Total number of processes in the group + """ + config = { + "nccl_enabled": True, + "nccl_init_method": nccl_init_method, + "world_size": world_size, + "param_metadata": param_metadata, + "param_mappings": param_mappings, + } + + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + print(f"[NCCLBridge] Exported config to {config_path}") diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index 77e7ec2f..5e0e6371 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -1,10 +1,11 @@ """ Training mode implementations for GRPO trainer. -Contains the three main training modes: +Contains the four main training modes: - train_legacy: Checkpoint-based training with vLLM restarts - train_shared_vllm: Single-copy mode with CUDA IPC -- train_lora: LoRA adapter training with hot-swap +- train_lora: LoRA adapter training with HTTP hot-swap +- train_lora_nccl: LoRA adapter training with NCCL direct transfer (torchtitan-style) """ import os @@ -656,3 +657,228 @@ def _hotswap_lora_adapter( except Exception as e: print(f" [LORA] ✗ Hot-swap request failed: {e}") return False + + +def train_lora_nccl(config: TrainingConfig): + """ + GRPO training with LoRA adapters using NCCL direct weight transfer. + + This mode (inspired by torchtitan): + 1. Freezes base model, trains only LoRA adapter weights + 2. Uses NCCL to broadcast weights directly to vLLM (zero disk I/O) + 3. Weight updates are immediate - no HTTP API calls + + Benefits over train_lora(): + - Much faster weight sync (NCCL vs HTTP+disk) + - Lower latency for on-policy training + - No checkpoint files during training + + Requirements: + - External vLLM server running with NCCL receiver enabled + - Trainer and vLLM must be in the same NCCL process group + """ + if not PEFT_AVAILABLE: + raise RuntimeError( + "PEFT library required for LoRA mode. Install with: pip install peft" + ) + + training_start_time = time.time() + + # === Setup === + use_wandb = setup_wandb(config) + + print("\n" + "=" * 60) + print("LORA NCCL MODE (torchtitan-style direct weight transfer)") + print("=" * 60) + print(f"Base model: {config.model_name}") + print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}") + print(f"Save path: {config.save_path}") + print(f"vLLM port: {config.vllm_port}") + print(f"NCCL init: {config.nccl_init_method}") + print("=" * 60 + "\n") + + # Check external vLLM server + print("[1/4] Checking external vLLM server...") + if not check_vllm_health(config.vllm_port): + print(f"\nERROR: vLLM server not running on port {config.vllm_port}") + print("\nLoRA NCCL mode requires an external vLLM server. Start it first:") + print( + f" NCCL_LORA_ENABLED=1 python example_trainer/vllm_api_server.py " + f"--model {config.model_name} --port {config.vllm_port} --enable-lora --enforce-eager" + ) + raise RuntimeError(f"External vLLM server required on port {config.vllm_port}") + print(f"vLLM server healthy on port {config.vllm_port}") + + # Load model with LoRA adapters + print("[2/4] Loading model with LoRA adapters...") + model, tokenizer = load_model_and_tokenizer(config) + + # Only optimize LoRA parameters + trainable_params = [p for p in model.parameters() if p.requires_grad] + optimizer = AdamW(trainable_params, lr=config.lr) + + # Setup NCCL bridge + print("[3/4] Setting up NCCL weight bridge...") + from .nccl_weight_bridge import ( + NCCLBridgeConfig, + NCCLWeightBridge, + create_trainer_param_to_vllm_mapping, + export_bridge_config, + ) + + nccl_config = NCCLBridgeConfig( + rank=0, # Trainer is always rank 0 + world_size=config.nccl_world_size, + init_method=config.nccl_init_method, + ) + + bridge = NCCLWeightBridge(nccl_config) + if not bridge.setup(): + raise RuntimeError("Failed to setup NCCL bridge") + + # Register parameters and create mappings + param_metadata = bridge.register_params(model) + param_mappings = create_trainer_param_to_vllm_mapping( + bridge.param_names, + model_name=config.model_name + ) + + # Export config for vLLM + bridge_config_path = os.path.join(config.save_path, "nccl_bridge_config.json") + os.makedirs(config.save_path, exist_ok=True) + export_bridge_config( + bridge_config_path, + param_metadata, + param_mappings, + config.nccl_init_method, + config.nccl_world_size, + ) + + print(f"[4/4] Starting training for {config.training_steps} steps") + print("-" * 60) + + # Check Atropos API + if not check_atropos_api(url=config.atropos_url, timeout=30): + raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}") + register_trainer(config) + + # === Benchmark tracking === + benchmark_stats = { + "step_times": [], + "sync_times": [], + "data_fetch_times": [], + "gpu_memories": [], + } + + # Send initial weights to vLLM + print("Sending initial LoRA weights to vLLM...") + initial_sync_time = bridge.send_lora_weights(model, step=0) + print(f" Initial sync completed in {initial_sync_time:.3f}s") + + # === Training Loop === + batches = [] + for step in range(config.training_steps): + print(f"\nStep {step+1}/{config.training_steps}") + + # Fetch data (with inference logprobs for proper GRPO) + data_fetch_start = time.time() + if len(batches) == 0: + batches, _ = get_data( + config.batch_size, + config.seq_len, + config.atropos_url, + extract_inference_logprobs=True, + ) + batch_data = batches.pop(0) + token_batches, label_batches, advantage_batches, temperature_batches = ( + batch_data[:4] + ) + inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None + data_fetch_time = time.time() - data_fetch_start + benchmark_stats["data_fetch_times"].append(data_fetch_time) + + # Training step with proper GRPO + step_start = time.time() + metrics = run_training_step( + model, + optimizer, + token_batches, + label_batches, + advantage_batches, + temperature_batches, + config, + inference_logprob_batches=inference_logprob_batches, + ) + step_time = time.time() - step_start + benchmark_stats["step_times"].append(step_time) + + # GPU memory tracking + gpu_mem_gb = ( + torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + gpu_mem_reserved_gb = ( + torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + ) + benchmark_stats["gpu_memories"].append(gpu_mem_gb) + + # NCCL weight sync (every step for on-policy, or periodic) + sync_time = 0 + should_sync = ( + config.nccl_sync_every_step or + (step + 1) % config.vllm_restart_interval == 0 + ) + if should_sync: + sync_start = time.time() + bridge.send_lora_weights(model, step=step + 1) + sync_time = time.time() - sync_start + benchmark_stats["sync_times"].append(sync_time) + print(f" [NCCL] Weights synced in {sync_time:.3f}s") + + # Update metrics + metrics.update( + { + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + } + ) + + log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) + + # Periodic checkpoint (for recovery only, not for vLLM sync) + if ( + config.checkpoint_interval > 0 + and (step + 1) % config.checkpoint_interval == 0 + ): + save_lora_checkpoint(model, config.save_path, step + 1) + + # === Cleanup === + # Final sync + print("\nSending final weights...") + final_sync_time = bridge.send_lora_weights(model, step=config.training_steps) + benchmark_stats["sync_times"].append(final_sync_time) + + # Save final checkpoint + final_adapter_path = save_lora_checkpoint( + model, config.save_path, config.training_steps, is_final=True + ) + + # Cleanup bridge + bridge.cleanup() + + finalize_training( + use_wandb, + training_start_time, + "lora_nccl", + config.training_steps, + benchmark_stats, + config.benchmark, + ) + + # Save tokenizer + tokenizer_path = os.path.join(config.save_path, "tokenizer") + tokenizer.save_pretrained(tokenizer_path) + print(f"Tokenizer saved to {tokenizer_path}") + print(f"Final adapter saved to {final_adapter_path}") diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 2846f14f..4336f2c7 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -643,6 +643,117 @@ async def lora_unload() -> JSONResponse: ) +# ============================================================================= +# NCCL Weight Receiver (for lora_nccl mode) +# ============================================================================= + +nccl_bridge: Optional[Any] = None # Will hold NCCLWeightBridge instance + + +@app.post("/nccl/start_receiver") +async def nccl_start_receiver(request: Request) -> JSONResponse: + """ + Start NCCL weight receiver (for lora_nccl training mode). + + Request JSON: + { + "init_method": "tcp://localhost:29500", + "world_size": 2, + "param_metadata": {...}, + "param_mappings": {...} + } + """ + global nccl_bridge + + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + request_dict = await request.json() + + try: + from .nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge + except ImportError: + try: + from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge + except ImportError: + raise HTTPException( + status_code=500, + detail="NCCL weight bridge module not available" + ) + + # Get vLLM's state dict for in-place updates + # This is tricky because vLLM's model is encapsulated + # For now, we'll need to use the engine's internal access + state_dict = {} # TODO: Get actual vLLM state dict + + config = NCCLBridgeConfig( + rank=1, # vLLM is rank 1 (trainer is rank 0) + world_size=request_dict.get("world_size", 2), + init_method=request_dict.get("init_method", "tcp://localhost:29500"), + ) + + nccl_bridge = NCCLWeightBridge(config) + + if not nccl_bridge.setup(): + raise HTTPException(status_code=500, detail="Failed to setup NCCL bridge") + + # Set param metadata from trainer + nccl_bridge.param_names = request_dict.get("param_metadata", {}).get("param_names", []) + nccl_bridge.param_shapes = { + k: tuple(v) for k, v in + request_dict.get("param_metadata", {}).get("param_shapes", {}).items() + } + nccl_bridge.param_dtypes = request_dict.get("param_metadata", {}).get("param_dtypes", {}) + + param_mappings = request_dict.get("param_mappings", {}) + + # Start receiver thread + nccl_bridge.start_receiver( + state_dict, + param_mappings, + on_update=lambda step: logger.info(f"NCCL weight update received: step {step}") + ) + + return JSONResponse({ + "status": "ok", + "message": "NCCL receiver started", + "rank": 1, + "world_size": config.world_size, + }) + + +@app.post("/nccl/stop_receiver") +async def nccl_stop_receiver() -> JSONResponse: + """Stop NCCL weight receiver.""" + global nccl_bridge + + if nccl_bridge is None: + return JSONResponse({"status": "ok", "message": "No receiver running"}) + + nccl_bridge.stop_receiver() + nccl_bridge.cleanup() + nccl_bridge = None + + return JSONResponse({"status": "ok", "message": "NCCL receiver stopped"}) + + +@app.get("/nccl/status") +async def nccl_status() -> JSONResponse: + """Get NCCL receiver status.""" + if nccl_bridge is None: + return JSONResponse({ + "active": False, + "update_count": 0, + }) + + return JSONResponse({ + "active": nccl_bridge.is_initialized, + "update_count": nccl_bridge.update_count, + "last_update_time": nccl_bridge.last_update_time, + "num_params": len(nccl_bridge.param_names), + }) + + # ============================================================================= # Server Setup # ============================================================================= @@ -748,6 +859,9 @@ async def run_server( logger.info(" GET /lora/status - LoRA adapter status") logger.info(" POST /lora/load - Load LoRA adapter") logger.info(" POST /lora/unload - Unload LoRA adapter") + logger.info(" POST /nccl/start_receiver - Start NCCL weight receiver (lora_nccl mode)") + logger.info(" POST /nccl/stop_receiver - Stop NCCL weight receiver") + logger.info(" GET /nccl/status - NCCL receiver status") logger.info("=" * 60) shutdown_task = await serve_http(