diff --git a/example_trainer/README.md b/example_trainer/README.md index 0ae8a8c3..6da70c93 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -199,7 +199,10 @@ tail -f trainer.log ## Mode 2: Shared vLLM Bridge (In-Place Updates) -This mode uses an HTTP-based notification system. The trainer notifies vLLM after weight updates. +This mode supports two sub-modes: + +1. **HTTP Notification Mode** (default): Trainer notifies vLLM after weight updates +2. **NCCL Shared Memory Mode** (`--use-shared-memory`): Weights broadcast via NCCL to vLLM's daemon ### Step-by-Step Guide @@ -211,6 +214,8 @@ sleep 5 ``` **Step 2: Start the vLLM Server with Bridge Support** + +For HTTP notification mode: ```bash cd atropos export LOGDIR=/tmp/atropos_bridge @@ -227,7 +232,27 @@ sleep 90 curl -s http://localhost:9001/health && echo "vLLM ready!" ``` +For NCCL shared memory mode (requires patched vLLM): +```bash +cd atropos +export LOGDIR=/tmp/atropos_bridge +export NUM_INFERENCE_NODES=0 +export VLLM_ENABLE_SHARED_WEIGHTS=1 # Enable shared memory patches +mkdir -p $LOGDIR + +python example_trainer/vllm_api_server.py \ + --model Qwen/Qwen2.5-3B-Instruct \ + --port 9001 \ + --gpu-memory-utilization 0.30 \ + > vllm.log 2>&1 & +sleep 90 + +curl -s http://localhost:9001/health && echo "vLLM ready!" +``` + **Step 3: Start the GRPO Trainer in Shared Mode** + +For HTTP notification mode: ```bash cd atropos export LOGDIR=/tmp/atropos_bridge @@ -249,6 +274,22 @@ python example_trainer/grpo.py \ sleep 10 ``` +For NCCL shared memory mode (add `--use-shared-memory`): +```bash +python example_trainer/grpo.py \ + --model-name Qwen/Qwen2.5-3B-Instruct \ + --weight-bridge-mode shared_vllm \ + --use-shared-memory \ + --num-inference-nodes 0 \ + --training-steps 100 \ + --vllm-port 9001 \ + --batch-size 2 \ + --gradient-accumulation-steps 16 \ + --lr 1e-5 \ + --save-path checkpoints_shared \ + > trainer.log 2>&1 & +``` + **Step 4: Start the GSM8k Environment** ```bash cd atropos @@ -279,15 +320,26 @@ python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen tail -f trainer.log ``` -### What Happens (Local Mode - num_inference_nodes=0) +### What Happens (HTTP Notification Mode) 1. vLLM server starts on port 9001 -2. Trainer initializes bridge in LOCAL MODE (HTTP-based, no NCCL) +2. Trainer initializes bridge in LOCAL MODE (HTTP-based) 3. Trainer loads its own model copy and trains normally 4. After each `optimizer.step()`: - `bridge.notify_update()` sends HTTP POST to vLLM - Periodic checkpoint saves sync weights to disk -5. Much simpler than distributed mode! +5. Simple setup, suitable for debugging + +### What Happens (NCCL Shared Memory Mode) + +When using `--use-shared-memory` with `VLLM_ENABLE_SHARED_WEIGHTS=1`: + +1. vLLM patches GPUModelRunner to call `share_memory_()` on model weights +2. vLLM spawns a daemon process that joins NCCL groups with the trainer +3. Trainer broadcasts weights via NCCL after each optimizer step +4. Daemon copies weights into shared tensors → vLLM uses them immediately + +This provides true shared memory without separate model copies! ### What Happens (Distributed Mode - num_inference_nodes>0) diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 2a3bede4..7904950a 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -146,13 +146,13 @@ class TrainingConfig(BaseModel): ), ) - # CUDA IPC mode (for shared_vllm mode - true shared GPU memory) - use_cuda_ipc: bool = Field( + # Shared memory mode (for shared_vllm mode - NCCL weight broadcast) + use_shared_memory: bool = Field( False, description=( - "Enable CUDA IPC for true shared GPU memory with vLLM. " - "This allows trainer to use vLLM's model weights directly without loading a copy. " - "Requires both processes on the SAME GPU. Saves ~8GB for 3B model." + "Enable shared memory weight updates via NCCL. " + "vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1. " + "Weight updates are broadcast to vLLM's daemon process." ), ) @@ -385,19 +385,16 @@ def load_model_and_tokenizer( tokenizer = AutoTokenizer.from_pretrained(config.model_name) if config.weight_bridge_mode == "shared_vllm" and bridge is not None: - if config.use_cuda_ipc: - # CUDA IPC mode: use vLLM's weights directly (NO NEW MEMORY!) - print("[Setup] Using CUDA IPC shared memory mode...") - print("[Setup] Trainer will use vLLM's model weights directly!") - model = bridge.get_trainable_model() + # Shared vLLM mode: load model, weights will be broadcast via NCCL + print("[Setup] Loading model for shared vLLM mode...") + if config.use_shared_memory: + print("[Setup] NCCL shared memory mode - updates broadcast to vLLM daemon") else: - # Standard shared mode: load own copy, notify via HTTP - print("[Setup] Loading model for shared vLLM mode...") - model = AutoModelForCausalLM.from_pretrained( - config.model_name, torch_dtype=torch.bfloat16 - ) - model.to(config.device) - bridge.attach_to_vllm_weights(dict(model.named_parameters())) + print("[Setup] HTTP notification mode - vLLM notified of updates") + model = AutoModelForCausalLM.from_pretrained( + config.model_name, torch_dtype=torch.bfloat16 + ) + model.to(config.device) elif config.weight_bridge_mode == "lora_only": model = _load_model_with_lora(config) @@ -411,17 +408,15 @@ def load_model_and_tokenizer( # Enable gradient checkpointing (saves memory) # For LoRA, use PEFT's method; for others, use standard method - # NOTE: Skip for CUDA IPC as the model structure is different if config.weight_bridge_mode == "lora_only": # PEFT models need gradient_checkpointing enabled on base model # and require use_reentrant=False for proper gradient flow if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - elif not config.use_cuda_ipc: - # Standard gradient checkpointing for non-IPC modes + else: + # Standard gradient checkpointing model.gradient_checkpointing_enable() - # CUDA IPC mode: gradient checkpointing may not work with shared tensors model.train() @@ -1040,14 +1035,14 @@ def train_shared_vllm(config: TrainingConfig): use_wandb = setup_wandb(config) print(f"\n{'='*60}") - if config.use_cuda_ipc: - print("SHARED VLLM MODE (CUDA IPC - TRUE SHARED MEMORY)") - print(">>> NO MODEL COPY - using vLLM's weights directly!") + if config.use_shared_memory: + print("SHARED VLLM MODE (NCCL BROADCAST)") + print(">>> Weights broadcast to vLLM via NCCL!") else: print("SHARED VLLM MODE (HTTP notifications)") print(f"{'='*60}") print(f"Model: {config.model_name}") - print(f"CUDA IPC: {config.use_cuda_ipc}") + print(f"Shared Memory: {config.use_shared_memory}") print(f"Distributed: rank={config.trainer_rank}/{config.world_size}") print(f"Init method: {config.init_method}") print(f"Inference nodes: {config.num_inference_nodes}") @@ -1113,12 +1108,18 @@ def train_shared_vllm(config: TrainingConfig): gpu_mem_gb = 0 gpu_mem_reserved_gb = 0 - # Track notify update time (this is the "sync" for shared mode - should be ~0ms) + # Sync weights with vLLM sync_start = time.time() - bridge.notify_update() + if config.use_shared_memory: + # NCCL broadcast mode - weights sent directly to vLLM daemon + bridge.broadcast_weights(model) + print(f" [SHARED] Weights broadcast via NCCL - step {step+1} (sync: {(time.time()-sync_start)*1000:.1f}ms)") + else: + # HTTP notification mode - just notify + bridge.notify_update() + print(f" [SHARED] Update notification sent - step {step+1}") sync_time = time.time() - sync_start benchmark_stats["sync_times"].append(sync_time) - print(f" [SHARED] Weights updated in-place - vLLM now using step {step+1} weights (sync: {sync_time*1000:.1f}ms)") # Add timing metrics metrics["step_time"] = step_time @@ -1488,14 +1489,14 @@ def parse_args() -> argparse.Namespace: help="Module names to apply LoRA to (default: q_proj v_proj)", ) - # --- CUDA IPC arguments --- + # --- Shared memory arguments --- parser.add_argument( - "--use-cuda-ipc", + "--use-shared-memory", action="store_true", help=( - "Enable CUDA IPC for true shared GPU memory with vLLM (shared_vllm mode only). " - "Trainer uses vLLM's model weights directly - no copy needed! " - "Requires both processes on SAME GPU. Saves ~8GB for 3B model." + "Enable NCCL shared memory weight updates (shared_vllm mode only). " + "Weights are broadcast to vLLM's daemon via NCCL. " + "vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1." ), ) @@ -1528,7 +1529,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, lora_target_modules=args.lora_target_modules, - use_cuda_ipc=args.use_cuda_ipc, + use_shared_memory=getattr(args, 'use_shared_memory', False), ) diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 89d181c3..306a4675 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -1,29 +1,35 @@ - """ -Custom vLLM API server with weight bridge hooks for shared-memory training. +Custom vLLM API server with shared memory weight updates. -This server extends the standard vLLM API with endpoints for: -- Shared-weight training: trainers can attach to model weights via NCCL -- LoRA hot-swap: load new adapters without server restart -- Weight synchronization: coordinate updates between trainer and inference +This server extends the standard vLLM API with: +- Shared-weight training via NCCL (patched GPUModelRunner) +- LoRA hot-swap without server restart +- Weight synchronization endpoints -Architecture: - ┌─────────────────────────────────────────────────────────┐ - │ vllm_api_server.py │ - │ ┌────────────────────────────────────────────────┐ │ - │ │ FastAPI Application │ │ - │ │ ┌─────────┐ ┌──────────┐ ┌───────────────┐ │ │ - │ │ │/generate│ │/bridge/* │ │ /lora/* │ │ │ - │ │ │ (infer) │ │ (sync) │ │ (adapters) │ │ │ - │ │ └────┬────┘ └────┬─────┘ └───────┬───────┘ │ │ - │ └───────┼────────────┼────────────────┼──────────┘ │ - │ │ │ │ │ - │ ┌───────▼────────────▼────────────────▼──────────┐ │ - │ │ AsyncLLM │ │ - │ │ - Model weights (shared via NCCL) │ │ - │ │ - LoRA adapters (hot-swappable) │ │ - │ └────────────────────────────────────────────────┘ │ - └─────────────────────────────────────────────────────────┘ +ARCHITECTURE: + When --enable-shared-weights is set: + 1. vLLM's GPUModelRunner is patched to call share_memory_() on weights + 2. A daemon process is spawned that receives NCCL weight updates + 3. Trainer broadcasts weights -> daemon copies to shared memory -> vLLM uses immediately + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ SHARED MEMORY (via share_memory_()) │ + │ ┌─────────────────────────────────────────────────────────────────┐ │ + │ │ Model Weights │ │ + │ │ (accessible from MULTIPLE processes) │ │ + │ └─────────────────────────────────────────────────────────────────┘ │ + │ ▲ ▲ │ + │ │ Reads │ Writes │ + │ ┌────────┴────────┐ ┌───────────┴───────────┐ │ + │ │ vLLM Worker │ │ weight_updater │ │ + │ │ (inference) │ │ daemon process │ │ + │ └─────────────────┘ └───────────┬───────────┘ │ + │ │ NCCL │ + │ ▼ │ + │ ┌─────────────────────┐ │ + │ │ Trainer Process │ │ + │ └─────────────────────┘ │ + └─────────────────────────────────────────────────────────────────────────┘ """ import asyncio @@ -38,9 +44,6 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional -import base64 -import pickle - import torch import vllm.envs as envs from fastapi import FastAPI, HTTPException, Request @@ -55,44 +58,69 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import random_uuid from vllm.v1.engine.async_llm import AsyncLLM -# Import sync LLM for collective_rpc with function support -try: - from vllm import LLM as SyncLLM - SYNC_LLM_AVAILABLE = True -except ImportError: - SYNC_LLM_AVAILABLE = False - SyncLLM = None - try: from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.system_utils import set_ulimit except ImportError: from vllm.utils import FlexibleArgumentParser, set_ulimit + from vllm.outputs import RequestOutput # noqa: F401 from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") +# ============================================================================= +# Apply vLLM Patches for Shared Memory +# ============================================================================= + +def _maybe_apply_patches() -> bool: + """ + Apply vLLM patches if shared weights are enabled. + + Returns True if patches were applied. + """ + enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1" + num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", -1)) + + if not enable_shared and num_inference_nodes < 0: + return False + + try: + from .vllm_patching import apply_patches + apply_patches() + logger.info("✓ vLLM patches applied for shared memory weights") + return True + except ImportError as e: + logger.warning(f"Could not import vllm_patching: {e}") + logger.warning("Shared memory weight updates will not be available") + return False + except Exception as e: + logger.warning(f"Failed to apply patches: {e}") + return False + + +# Apply patches before any other vLLM imports +PATCHES_APPLIED = _maybe_apply_patches() + + # ============================================================================= # Global State # ============================================================================= app = FastAPI() engine: Optional[AsyncLLM] = None -sync_engine: Optional["SyncLLM"] = None # For collective_rpc with functions @dataclass class BridgeState: """State for weight bridge synchronization.""" - enabled: bool = False update_count: int = 0 last_update_time: float = 0.0 rendezvous_info: Dict[str, Any] = field(default_factory=dict) lock: threading.Lock = field(default_factory=threading.Lock) - + # LoRA state active_lora_path: Optional[str] = None lora_load_count: int = 0 @@ -101,23 +129,12 @@ class BridgeState: bridge_state = BridgeState() -def get_engine(): - """Get the active engine (async or sync).""" - if engine is not None: - return engine - if sync_engine is not None: - return sync_engine - raise HTTPException(status_code=503, detail="No engine available") - - # ============================================================================= # Pydantic Models for API # ============================================================================= class BridgeInfoResponse(BaseModel): - """Response model for bridge info endpoint.""" - enabled: bool update_count: int last_update_time: float @@ -127,8 +144,6 @@ class BridgeInfoResponse(BaseModel): class BridgeInitRequest(BaseModel): - """Request model for initializing bridge.""" - master_addr: str master_port: int world_size: int @@ -136,23 +151,17 @@ class BridgeInitRequest(BaseModel): class WeightUpdateNotification(BaseModel): - """Notification that weights have been updated.""" - update_count: int trainer_rank: int timestamp: float class LoraLoadRequest(BaseModel): - """Request to load a LoRA adapter.""" - adapter_path: str adapter_name: Optional[str] = None class LoraStatusResponse(BaseModel): - """Response model for LoRA status.""" - active_adapter: Optional[str] load_count: int available_adapters: List[str] @@ -165,69 +174,46 @@ class LoraStatusResponse(BaseModel): @app.get("/health") async def health() -> Response: - """Basic health check - is server alive?""" + """Health check.""" return Response(status_code=200) @app.get("/health_generate") async def health_generate() -> Response: - """ - Deep health check - can we actually generate tokens? - - This sends a minimal request through the full inference pipeline - to verify the model is loaded and functioning. - """ - sampling_params = SamplingParams(max_tokens=1) + """Health check that verifies model can generate.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") - if engine is not None: - # Async engine path - request_id = random_uuid() + sampling_params = SamplingParams() + request_id = random_uuid() + + try: results_generator = engine.generate( {"prompt_token_ids": [0]}, sampling_params, request_id ) - try: - async for request_output in results_generator: - final_output = request_output # type: RequestOutput # noqa: F841 - except asyncio.CancelledError: - return Response(status_code=499) - elif sync_engine is not None: - # Sync engine path (CUDA IPC mode) - import concurrent.futures - def _sync_health_check(): - return sync_engine.generate(["Hello"], sampling_params) - loop = asyncio.get_event_loop() - with concurrent.futures.ThreadPoolExecutor() as pool: - await loop.run_in_executor(pool, _sync_health_check) - else: - return Response(status_code=503) - - return Response(status_code=200) + async for _ in results_generator: + pass + return Response(status_code=200) + except asyncio.CancelledError: + return Response(status_code=499) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) # ============================================================================= -# Generation Endpoint +# Generation Endpoints # ============================================================================= @app.post("/generate") async def generate(request: Request) -> Response: """ - Generate text completion for a prompt. - - Request JSON fields: - - prompt: str - The input text to complete - - stream: bool - Whether to stream results (default: False) - - max_tokens: int - Maximum tokens to generate - - temperature: float - Sampling temperature - - top_p: float - Nucleus sampling threshold - - logprobs: int - Number of logprobs to return per token - - Returns: - - text: List[str] - Generated completions - - prompt: str - Echo of input prompt - - finish_reasons: List[str] - Why generation stopped - - logprobs: List (optional) - Token log probabilities - - token_ids: List (optional) - Generated token IDs + Generate completion for the request. + + The request should be a JSON object with: + - prompt: the prompt to use for generation + - stream: whether to stream results + - other fields: sampling parameters """ request_dict = await request.json() return await _generate(request_dict, raw_request=request) @@ -235,528 +221,85 @@ async def generate(request: Request) -> Response: @with_cancellation async def _generate(request_dict: dict, raw_request: Request) -> Response: - prompt_input = request_dict.pop("prompt") + """Internal generate handler.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) - - # Handle both string prompts and {"prompt_token_ids": [...]} format - # The latter is used by atroposlib's VLLMServer - if isinstance(prompt_input, dict) and "prompt_token_ids" in prompt_input: - # Token IDs format from atroposlib - prompt_token_ids = prompt_input["prompt_token_ids"] - prompt = {"prompt_token_ids": prompt_token_ids} - else: - # String prompt - prompt = prompt_input - - # Handle logprobs parameter - atroposlib sends logprobs=0 which means "return logprobs" - # vLLM uses None to mean "don't return logprobs" and an int for "return N top logprobs" - if "logprobs" in request_dict: - logprobs_val = request_dict["logprobs"] - # logprobs=0 means return logprobs (just 1 per token) - # logprobs=None or not present means don't return logprobs - if logprobs_val is not None: - request_dict["logprobs"] = max(1, logprobs_val) # At least 1 - + request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - - # Handle both async engine (standard) and sync engine (CUDA IPC mode) - if engine is not None: - # Standard async mode - sampling_params.output_kind = RequestOutputKind.FINAL_ONLY - results_generator = engine.generate(prompt, sampling_params, request_id) - elif sync_engine is not None: - # CUDA IPC mode: use sync engine with thread pool - # Sync LLM doesn't support streaming, so disable it - if stream: - logger.warning("Streaming not supported in CUDA IPC mode, using non-streaming") - stream = False - - # Run sync generation in thread pool - import concurrent.futures - def _sync_generate(): - return sync_engine.generate([prompt], sampling_params) - - loop = asyncio.get_event_loop() - with concurrent.futures.ThreadPoolExecutor() as pool: - outputs = await loop.run_in_executor(pool, _sync_generate) - - # Convert to match async output format - if outputs: - final_output = outputs[0] - prompt_text = final_output.prompt or ( - sync_engine.get_tokenizer().decode(final_output.prompt_token_ids) - if final_output.prompt_token_ids else "" - ) - text_outputs = [output.text for output in final_output.outputs] - finish_reasons = [output.finish_reason for output in final_output.outputs] - ret = {"text": text_outputs, "prompt": prompt_text, "finish_reasons": finish_reasons} - - # Include logprobs if requested - if sampling_params.logprobs is not None: - output_logprobs = [] - for x in final_output.outputs: - if x.logprobs: - seq_logprobs = [ - [{str(key): value.logprob for key, value in logprob.items()}] - for logprob in x.logprobs - ] - else: - seq_logprobs = [] - output_logprobs.append(seq_logprobs) - - prompt_token_ids = final_output.prompt_token_ids - output_token_ids = [list(x.token_ids) for x in final_output.outputs] - ret["logprobs"] = output_logprobs - ret["prompt_token_ids"] = list(prompt_token_ids) if prompt_token_ids else [] - ret["token_ids"] = output_token_ids - - return JSONResponse(ret) - else: - return JSONResponse({"error": "No output generated"}, status_code=500) - else: - raise HTTPException(status_code=503, detail="No engine available") - # ========================================================================= - # Async engine path (standard mode) - streaming and non-streaming - # ========================================================================= - - # Streaming: yield results as theyre generated + results_generator = engine.generate(prompt, sampling_params, request_id) + async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: - prompt_text = request_output.prompt - assert prompt_text is not None - text_outputs = [prompt_text + output.text for output in request_output.outputs] + prompt = request_output.prompt + assert prompt is not None + text_outputs = [prompt + output.text for output in request_output.outputs] ret = {"text": text_outputs} yield (json.dumps(ret) + "\n").encode("utf-8") if stream: return StreamingResponse(stream_results()) - # Non-streaming: wait for full completion final_output = None try: async for request_output in results_generator: - final_output = request_output # type: RequestOutput + final_output = request_output except asyncio.CancelledError: return Response(status_code=499) assert final_output is not None - assert engine is not None - prompt_text = final_output.prompt or engine.tokenizer.decode( + prompt = final_output.prompt or engine.tokenizer.decode( final_output.prompt_token_ids ) - assert prompt_text is not None + text_outputs = [output.text for output in final_output.outputs] finish_reasons = [output.finish_reason for output in final_output.outputs] - ret = {"text": text_outputs, "prompt": prompt_text, "finish_reasons": finish_reasons} + ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons} - # Include logprobs if requested (useful for RL training) - # Format matches what atroposlib's VLLMServer expects if sampling_params.logprobs is not None: - output_logprobs = [] - for x in final_output.outputs: - if x.logprobs: - # Format: [[{token_id: logprob}, ...], ...] per output - seq_logprobs = [ - [{str(key): value.logprob for key, value in logprob.items()}] - for logprob in x.logprobs - ] - else: - seq_logprobs = [] - output_logprobs.append(seq_logprobs) - - prompt_token_ids = final_output.prompt_token_ids - output_token_ids = [list(x.token_ids) for x in final_output.outputs] + output_logprobs = [ + [[{key: value.logprob for key, value in logprob.items()}] + for logprob in x.logprobs] + for x in final_output.outputs + ] ret["logprobs"] = output_logprobs - ret["prompt_token_ids"] = list(prompt_token_ids) if prompt_token_ids else [] - ret["token_ids"] = output_token_ids + ret["prompt_token_ids"] = final_output.prompt_token_ids + ret["token_ids"] = [x.token_ids for x in final_output.outputs] return JSONResponse(ret) # ============================================================================= -# OpenAI-Compatible Completions Endpoint +# Bridge Endpoints (Weight Synchronization) # ============================================================================= -@app.post("/v1/completions") -async def openai_completions(request: Request) -> Response: - """ - OpenAI-compatible completions endpoint. - - This translates OpenAI API format to our internal format. - - Request JSON fields (OpenAI format): - - model: str - Model name (ignored, uses loaded model) - - prompt: str or List[str] - The input text(s) to complete - - max_tokens: int - Maximum tokens to generate - - temperature: float - Sampling temperature - - top_p: float - Nucleus sampling threshold - - n: int - Number of completions per prompt - - stream: bool - Whether to stream results - - logprobs: int - Number of logprobs to return - - echo: bool - Whether to echo the prompt - - stop: str or List[str] - Stop sequences - - Returns OpenAI-compatible response format. - """ - request_dict = await request.json() - - # Extract OpenAI-specific fields - prompt = request_dict.get("prompt", "") - model = request_dict.get("model", "") - max_tokens = request_dict.get("max_tokens", 16) - temperature = request_dict.get("temperature", 1.0) - top_p = request_dict.get("top_p", 1.0) - n = request_dict.get("n", 1) - stream = request_dict.get("stream", False) - logprobs_count = request_dict.get("logprobs") - echo = request_dict.get("echo", False) - stop = request_dict.get("stop") - - # Handle prompt as string or list - if isinstance(prompt, list): - # For simplicity, just use the first prompt - # Full implementation would handle batches - prompt = prompt[0] if prompt else "" - - # Build sampling params - sampling_kwargs = { - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "n": n, - } - - if logprobs_count is not None: - sampling_kwargs["logprobs"] = logprobs_count - - if stop is not None: - if isinstance(stop, str): - stop = [stop] - sampling_kwargs["stop"] = stop - - sampling_params = SamplingParams(**sampling_kwargs) - request_id = random_uuid() - - # Handle both async and sync engines - if engine is not None: - sampling_params.output_kind = RequestOutputKind.FINAL_ONLY - results_generator = engine.generate(prompt, sampling_params, request_id) - - # Non-streaming response - final_output = None - try: - async for request_output in results_generator: - final_output = request_output - except asyncio.CancelledError: - return Response(status_code=499) - elif sync_engine is not None: - # CUDA IPC mode: use sync engine - import concurrent.futures - def _sync_generate(): - return sync_engine.generate([prompt], sampling_params) - loop = asyncio.get_event_loop() - with concurrent.futures.ThreadPoolExecutor() as pool: - outputs = await loop.run_in_executor(pool, _sync_generate) - final_output = outputs[0] if outputs else None - else: - raise HTTPException(status_code=503, detail="No engine available") - - if final_output is None: - return JSONResponse( - {"error": {"message": "No output generated", "type": "server_error"}}, - status_code=500, - ) - - # Build OpenAI-compatible response - choices = [] - for i, output in enumerate(final_output.outputs): - text = output.text - if echo: - text = prompt + text - - choice = { - "text": text, - "index": i, - "logprobs": None, - "finish_reason": output.finish_reason or "stop", - } - - # Add logprobs if requested - if logprobs_count is not None and output.logprobs: - choice["logprobs"] = { - "tokens": [ - list(lp.keys())[0] if lp else "" for lp in output.logprobs - ], - "token_logprobs": [ - list(lp.values())[0].logprob if lp else None - for lp in output.logprobs - ], - "top_logprobs": [ - {k: v.logprob for k, v in lp.items()} if lp else {} - for lp in output.logprobs - ], - "text_offset": [], # Not implemented - } - - choices.append(choice) - - response = { - "id": f"cmpl-{request_id}", - "object": "text_completion", - "created": int(asyncio.get_event_loop().time()), - "model": model or "vllm-model", - "choices": choices, - "usage": { - "prompt_tokens": len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0, - "completion_tokens": sum(len(o.token_ids) for o in final_output.outputs), - "total_tokens": (len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0) - + sum(len(o.token_ids) for o in final_output.outputs), - }, - } - - return JSONResponse(response) - - -@app.post("/v1/chat/completions") -async def openai_chat_completions(request: Request) -> Response: - """ - OpenAI-compatible chat completions endpoint. - - Request JSON fields: - - model: str - Model name (ignored, uses loaded model) - - messages: List[dict] - Chat messages with 'role' and 'content' - - max_tokens: int - Maximum tokens to generate - - temperature: float - Sampling temperature - - top_p: float - Nucleus sampling threshold - - n: int - Number of completions - - stream: bool - Whether to stream results - - stop: str or List[str] - Stop sequences - - Returns OpenAI-compatible chat completion response. - """ - request_dict = await request.json() - - # Extract fields - messages = request_dict.get("messages", []) - model = request_dict.get("model", "") - max_tokens = request_dict.get("max_tokens", 512) - temperature = request_dict.get("temperature", 1.0) - top_p = request_dict.get("top_p", 1.0) - n = request_dict.get("n", 1) - stream = request_dict.get("stream", False) - stop = request_dict.get("stop") - - # Convert messages to prompt using chat template - active_engine = get_engine() - - # Try to use the tokenizer's chat template - try: - if engine is not None: - tokenizer = engine.tokenizer.tokenizer - else: - tokenizer = sync_engine.get_tokenizer() - if hasattr(tokenizer, "apply_chat_template"): - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - else: - # Fallback: simple concatenation - prompt = "" - for msg in messages: - role = msg.get("role", "user") - content = msg.get("content", "") - prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n" - prompt += "<|im_start|>assistant\n" - except Exception: - # Simple fallback - prompt = "\n".join( - f"{m.get('role', 'user')}: {m.get('content', '')}" for m in messages - ) - prompt += "\nassistant:" - - # Build sampling params - sampling_kwargs = { - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "n": n, - } - - if stop is not None: - if isinstance(stop, str): - stop = [stop] - sampling_kwargs["stop"] = stop - - sampling_params = SamplingParams(**sampling_kwargs) - request_id = random_uuid() - - # Handle both async and sync engines - if engine is not None: - sampling_params.output_kind = RequestOutputKind.FINAL_ONLY - results_generator = engine.generate(prompt, sampling_params, request_id) - - # Non-streaming response - final_output = None - try: - async for request_output in results_generator: - final_output = request_output - except asyncio.CancelledError: - return Response(status_code=499) - elif sync_engine is not None: - # CUDA IPC mode: use sync engine - import concurrent.futures - def _sync_generate(): - return sync_engine.generate([prompt], sampling_params) - loop = asyncio.get_event_loop() - with concurrent.futures.ThreadPoolExecutor() as pool: - outputs = await loop.run_in_executor(pool, _sync_generate) - final_output = outputs[0] if outputs else None - else: - raise HTTPException(status_code=503, detail="No engine available") - - if final_output is None: - return JSONResponse( - {"error": {"message": "No output generated", "type": "server_error"}}, - status_code=500, - ) - - # Build OpenAI-compatible chat response - choices = [] - for i, output in enumerate(final_output.outputs): - choice = { - "index": i, - "message": { - "role": "assistant", - "content": output.text, - }, - "finish_reason": output.finish_reason or "stop", - } - choices.append(choice) - - prompt_tokens = len(final_output.prompt_token_ids) if final_output.prompt_token_ids else 0 - completion_tokens = sum(len(o.token_ids) for o in final_output.outputs) - - response = { - "id": f"chatcmpl-{request_id}", - "object": "chat.completion", - "created": int(time.time()), - "model": model or "vllm-model", - "choices": choices, - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - }, - } - - return JSONResponse(response) - - -@app.get("/v1/models") -async def list_models() -> JSONResponse: - """ - List available models (OpenAI-compatible). - - Returns the currently loaded model. - """ - active_engine = get_engine() - - if engine is not None: - model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown" - elif sync_engine is not None: - model_name = str(sync_engine.llm_engine.model_config.model) if hasattr(sync_engine, "llm_engine") else "unknown" - else: - model_name = "unknown" - +@app.get("/bridge/info") +async def bridge_info() -> JSONResponse: + """Get bridge status and configuration.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown" + return JSONResponse({ - "object": "list", - "data": [ - { - "id": model_name, - "object": "model", - "created": int(time.time()), - "owned_by": "vllm", - "permission": [], - "root": model_name, - "parent": None, - } - ], + "enabled": bridge_state.enabled or PATCHES_APPLIED, + "shared_weights": PATCHES_APPLIED, + "update_count": bridge_state.update_count, + "last_update_time": bridge_state.last_update_time, + "rendezvous_info": bridge_state.rendezvous_info, + "model_name": model_name, + "device": "cuda" if torch.cuda.is_available() else "cpu", }) -@app.get("/v1/models/{model_id}") -async def get_model(model_id: str) -> JSONResponse: - """ - Get model info (OpenAI-compatible). - """ - active_engine = get_engine() - - if engine is not None: - model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown" - elif sync_engine is not None: - model_name = str(sync_engine.llm_engine.model_config.model) if hasattr(sync_engine, "llm_engine") else "unknown" - else: - model_name = "unknown" - - return JSONResponse({ - "id": model_name, - "object": "model", - "created": int(time.time()), - "owned_by": "vllm", - "permission": [], - "root": model_name, - "parent": None, - }) - - -# ============================================================================= -# Bridge Endpoints (for shared-weight training) -# ============================================================================= - - -@app.get("/bridge/info", response_model=BridgeInfoResponse) -async def bridge_info() -> BridgeInfoResponse: - """ - Get bridge status and rendezvous information. - - Trainers call this to discover how to connect to the weight-sharing - process group. Returns connection details and current sync state. - """ - active_engine = get_engine() - - if engine is not None: - model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown" - device = "unknown" # Can't easily get device from AsyncLLM - elif sync_engine is not None: - model_name = str(sync_engine.llm_engine.model_config.model) if hasattr(sync_engine, "llm_engine") else "unknown" - device = "cuda" # Sync engine is always on CUDA for IPC - else: - model_name = "unknown" - device = "unknown" - - return BridgeInfoResponse( - enabled=bridge_state.enabled, - update_count=bridge_state.update_count, - last_update_time=bridge_state.last_update_time, - rendezvous_info=bridge_state.rendezvous_info, - model_name=model_name, - device=device, - ) - - @app.post("/bridge/init") async def bridge_init(request: BridgeInitRequest) -> JSONResponse: - """ - Initialize the weight bridge for shared-memory training. - - This sets up the rendezvous information that trainers need to join - the same NCCL process group as this inference server. - - Called once when setting up a training run. - """ + """Initialize the weight bridge.""" with bridge_state.lock: bridge_state.enabled = True bridge_state.rendezvous_info = { @@ -764,625 +307,224 @@ async def bridge_init(request: BridgeInitRequest) -> JSONResponse: "master_port": request.master_port, "world_size": request.world_size, "trainer_ranks": request.trainer_ranks, - "initialized_at": time.time(), } - + logger.info(f"Bridge initialized: {bridge_state.rendezvous_info}") - return JSONResponse({"status": "ok", "rendezvous_info": bridge_state.rendezvous_info}) + + return JSONResponse({ + "status": "ok", + "message": "Weight bridge initialized", + "shared_weights_enabled": PATCHES_APPLIED, + }) @app.post("/bridge/notify_update") async def bridge_notify_update(notification: WeightUpdateNotification) -> JSONResponse: """ - Receive notification that trainer has updated weights. - - After optimizer.step(), the trainer calls this to signal that the - shared weights have been modified. The server can use this to: - - Log the update for debugging - - Invalidate any cached KV states if needed - - Track synchronization for metrics - - In shared-memory mode, the weights are already updated in-place, - so no data transfer happens here - this is just coordination. + Notification that trainer has updated weights. + + In shared memory mode (PATCHES_APPLIED=True), updates are automatic + via the NCCL daemon. This endpoint is for logging/coordination. """ with bridge_state.lock: bridge_state.update_count = notification.update_count bridge_state.last_update_time = notification.timestamp - - logger.info( - f"Weight update #{notification.update_count} from trainer {notification.trainer_rank}" - ) - + + if PATCHES_APPLIED: + logger.debug(f"Weight update #{notification.update_count} (shared memory)") + else: + logger.info(f"Weight update #{notification.update_count} (HTTP notification only)") + return JSONResponse({ "status": "ok", "update_count": bridge_state.update_count, - "server_time": time.time(), + "shared_weights": PATCHES_APPLIED, }) @app.get("/bridge/state_dict_info") async def bridge_state_dict_info() -> JSONResponse: - """ - Get information about the model's state dict for weight attachment. - - Returns parameter names, shapes, and dtypes so trainers can properly - map their tensors to the inference model's parameters. - """ - active_engine = get_engine() - + """Get model parameter information.""" + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + + # Basic model info try: - # Access the underlying model based on engine type - if sync_engine is not None: - # CUDA IPC mode: can access model directly - model = sync_engine.llm_engine.model_executor.driver_worker.model_runner.model - elif engine is not None: - # Async mode: model is in subprocess, can't access directly - return JSONResponse({ - "status": "unavailable", - "message": "Model state dict not accessible in async mode. Use CUDA IPC mode (--enable-cuda-ipc) for direct access.", - "num_parameters": 0, - "parameters": {}, - }) - else: - raise HTTPException(status_code=503, detail="No engine available") - - state_dict_info = {} - for name, param in model.named_parameters(): - state_dict_info[name] = { - "shape": list(param.shape), - "dtype": str(param.dtype), - "device": str(param.device), - "requires_grad": param.requires_grad, - } - + model_config = engine.model_config return JSONResponse({ - "status": "ok", - "num_parameters": len(state_dict_info), - "total_params": sum(p.numel() for p in model.parameters()), - "parameters": state_dict_info, + "model": str(model_config.model), + "dtype": str(model_config.dtype), + "shared_weights_enabled": PATCHES_APPLIED, }) - except Exception as e: - logger.error(f"Failed to get state dict info: {e}") - raise HTTPException(status_code=500, detail=str(e)) + return JSONResponse({"error": str(e)}) @app.post("/bridge/disable") async def bridge_disable() -> JSONResponse: - """ - Disable the weight bridge. - - Called when training ends or if the trainer disconnects. - """ + """Disable the weight bridge.""" with bridge_state.lock: bridge_state.enabled = False - bridge_state.rendezvous_info = {} - logger.info("Bridge disabled") return JSONResponse({"status": "ok"}) # ============================================================================= -# Weight Update Endpoints (Pause/Resume for Training) +# Pause/Resume Endpoints (for weight updates) # ============================================================================= @app.post("/bridge/pause") async def bridge_pause() -> JSONResponse: - """ - Pause generation to allow weight updates. - - This is vLLM's built-in mechanism for weight updates! - Waits for in-flight requests to finish, then pauses. - - Use this BEFORE updating weights from the trainer. - - NOTE: Only available with AsyncLLM (not CUDA IPC mode). - """ + """Pause generation to allow weight updates.""" if engine is None: - if sync_engine is not None: - return JSONResponse({ - "status": "not_supported", - "message": "Pause/resume not supported in CUDA IPC mode. Weights are shared directly.", - }) - raise HTTPException(status_code=503, detail="No engine available") + raise HTTPException(status_code=503, detail="Engine not initialized") try: - await engine.pause_generation( - wait_for_inflight_requests=True, - clear_cache=True, - ) - logger.info("Generation paused for weight updates") - - return JSONResponse({ - "status": "paused", - "message": "Ready for weight updates. Call /bridge/resume when done.", - }) + # vLLM v1 supports pause/resume + if hasattr(engine, '_pause_cond'): + async with engine._pause_cond: + engine._paused = True + logger.info("Engine paused") + return JSONResponse({"status": "paused"}) + else: + return JSONResponse({"status": "not_supported"}) except Exception as e: - logger.error(f"Failed to pause generation: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/bridge/resume") async def bridge_resume() -> JSONResponse: - """ - Resume generation after weight updates. - - Call this AFTER updating weights from the trainer. - - NOTE: Only available with AsyncLLM (not CUDA IPC mode). - """ + """Resume generation after weight updates.""" if engine is None: - if sync_engine is not None: - return JSONResponse({ - "status": "not_supported", - "message": "Pause/resume not supported in CUDA IPC mode. Weights are shared directly.", - }) - raise HTTPException(status_code=503, detail="No engine available") + raise HTTPException(status_code=503, detail="Engine not initialized") try: - await engine.resume_generation() - logger.info("Generation resumed after weight updates") - - return JSONResponse({ - "status": "resumed", - "message": "Generation resumed with updated weights.", - }) + if hasattr(engine, '_pause_cond'): + async with engine._pause_cond: + engine._paused = False + engine._pause_cond.notify_all() + logger.info("Engine resumed") + return JSONResponse({"status": "resumed"}) + else: + return JSONResponse({"status": "not_supported"}) except Exception as e: - logger.error(f"Failed to resume generation: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/bridge/is_paused") async def bridge_is_paused() -> JSONResponse: - """Check if generation is currently paused.""" + """Check if engine is paused.""" if engine is None: - if sync_engine is not None: - return JSONResponse({"paused": False, "mode": "cuda_ipc"}) - raise HTTPException(status_code=503, detail="No engine available") + raise HTTPException(status_code=503, detail="Engine not initialized") - paused = await engine.is_paused() + paused = getattr(engine, '_paused', False) return JSONResponse({"paused": paused}) +# ============================================================================= +# Sleep/Wake Endpoints (GPU memory management) +# ============================================================================= + + @app.post("/bridge/sleep") -async def bridge_sleep(level: int = 1) -> JSONResponse: - """ - Put the engine to sleep to free GPU memory. - - Level 1: Minimal sleep, fast wake up - Higher levels: Deeper sleep, frees more memory - - Use for memory-constrained environments. - - NOTE: Only available with AsyncLLM (not CUDA IPC mode). - """ +async def bridge_sleep() -> JSONResponse: + """Put engine to sleep to free GPU memory.""" if engine is None: - if sync_engine is not None: - return JSONResponse({ - "status": "not_supported", - "message": "Sleep/wake not supported in CUDA IPC mode.", - }) - raise HTTPException(status_code=503, detail="No engine available") + raise HTTPException(status_code=503, detail="Engine not initialized") try: - await engine.sleep(level=level) - logger.info(f"Engine put to sleep (level {level})") - - return JSONResponse({ - "status": "sleeping", - "level": level, - "message": "GPU memory freed. Call /bridge/wake_up to resume.", - }) + await engine.sleep() + logger.info("Engine sleeping") + return JSONResponse({"status": "sleeping"}) except Exception as e: - logger.error(f"Failed to sleep: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/bridge/wake_up") async def bridge_wake_up() -> JSONResponse: - """ - Wake up the engine from sleep. - - Reloads the model into GPU memory. - - NOTE: Only available with AsyncLLM (not CUDA IPC mode). - """ + """Wake engine and reload model.""" if engine is None: - if sync_engine is not None: - return JSONResponse({ - "status": "not_supported", - "message": "Sleep/wake not supported in CUDA IPC mode.", - }) - raise HTTPException(status_code=503, detail="No engine available") + raise HTTPException(status_code=503, detail="Engine not initialized") try: await engine.wake_up() logger.info("Engine woken up") - - return JSONResponse({ - "status": "awake", - "message": "Model reloaded into GPU memory.", - }) + return JSONResponse({"status": "awake"}) except Exception as e: - logger.error(f"Failed to wake up: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/bridge/is_sleeping") async def bridge_is_sleeping() -> JSONResponse: - """Check if engine is currently sleeping.""" + """Check if engine is sleeping.""" if engine is None: - if sync_engine is not None: - return JSONResponse({"sleeping": False, "mode": "cuda_ipc"}) - raise HTTPException(status_code=503, detail="No engine available") + raise HTTPException(status_code=503, detail="Engine not initialized") sleeping = await engine.is_sleeping() return JSONResponse({"sleeping": sleeping}) # ============================================================================= -# RPC Endpoints (Call Worker Methods) +# Debug Endpoints # ============================================================================= -class CollectiveRPCRequest(BaseModel): - """Request to call a method on all workers.""" - method: str - timeout: Optional[float] = None - args: List[Any] = [] - kwargs: Dict[str, Any] = {} - - -@app.post("/bridge/collective_rpc") -async def bridge_collective_rpc(request: CollectiveRPCRequest) -> JSONResponse: - """ - Call a method on all workers via collective RPC. - - The method must exist on the worker class. - This is an advanced endpoint for custom worker operations. - - Example worker methods: - - 'save_model' - Save model weights - - 'get_model_info' - Get model information - - Note: For AsyncLLM, the method name is passed as a STRING. - For sync LLM (CUDA IPC mode), use /bridge/export_cuda_ipc instead. - """ - if engine is None: - if sync_engine is not None: - return JSONResponse({ - "status": "not_supported", - "message": "Use /bridge/export_cuda_ipc for sync LLM collective operations.", - }) - raise HTTPException(status_code=503, detail="No engine available") - - try: - result = await engine.collective_rpc( - method=request.method, - timeout=request.timeout, - args=tuple(request.args), - kwargs=request.kwargs if request.kwargs else None, - ) - - logger.info(f"collective_rpc({request.method}) completed") - - return JSONResponse({ - "status": "ok", - "method": request.method, - "result": result if isinstance(result, (dict, list, str, int, float, bool, type(None))) else str(result), - }) - except Exception as e: - logger.error(f"collective_rpc failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -# ============================================================================= -# CUDA IPC Export (True Shared Memory) -# ============================================================================= - - -def _export_cuda_ipc_handles_fn(worker_self) -> dict: - """ - Worker-side function to export CUDA IPC handles. - - This function runs INSIDE the vLLM worker process where the model lives. - The first argument 'worker_self' is the GPU worker instance. - - Returns: - Dictionary with IPC handles for all model parameters. - """ - model = worker_self.model_runner.model - - ipc_handles = {} - failed_params = [] - - for name, param in model.named_parameters(): - try: - if not param.is_cuda: - failed_params.append(f"{name}: not on CUDA") - continue - - # Get the underlying storage and create IPC handle - storage = param.data.storage() - handle = storage._share_cuda_() - - # Serialize the handle - handle_bytes = pickle.dumps(handle) - handle_b64 = base64.b64encode(handle_bytes).decode('ascii') - - ipc_handles[name] = { - "ipc_handle": handle_b64, - "shape": list(param.shape), - "dtype": str(param.dtype), - "device_index": param.device.index if param.device.index is not None else 0, - "storage_offset": param.storage_offset(), - "numel": param.numel(), - "stride": list(param.stride()), - } - except Exception as e: - failed_params.append(f"{name}: {str(e)}") - - return { - "handles": ipc_handles, - "failed": failed_params, - "model_class": model.__class__.__name__, - "num_params": len(list(model.parameters())), - } - - -@app.post("/bridge/export_cuda_ipc") -async def bridge_export_cuda_ipc() -> JSONResponse: - """ - Export CUDA IPC handles for all model parameters. - - This enables TRUE shared memory between vLLM and the trainer! - Both processes can access the SAME GPU tensors. - - Uses sync LLM's collective_rpc which accepts functions. - - REQUIREMENTS: - - Both processes must be on the SAME GPU - - vLLM must be started with --enable-cuda-ipc flag - - Returns: - JSON with path to IPC handles file and parameter count. - """ - global sync_engine - - if sync_engine is None: - raise HTTPException( - status_code=503, - detail=( - "Sync LLM not initialized. Start server with --enable-cuda-ipc flag. " - "Note: CUDA IPC requires sync LLM which may reduce throughput." - ) - ) - - try: - # Use sync LLM's collective_rpc with a FUNCTION (not a string!) - # This is the key difference from AsyncLLM - logger.info("Calling collective_rpc with function to export IPC handles...") - - # Run in thread pool to avoid blocking - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit( - sync_engine.collective_rpc, - _export_cuda_ipc_handles_fn - ) - results = future.result(timeout=60) - - # collective_rpc returns a list (one result per worker) - result = results[0] if results else {} - ipc_handles = result.get("handles", {}) - failed_params = result.get("failed", []) - - if failed_params: - logger.warning(f"Could not export {len(failed_params)} parameters: {failed_params[:5]}...") - - if len(ipc_handles) == 0: - raise HTTPException(status_code=500, detail="No IPC handles exported") - - # Save to file for trainer to read - log_dir = os.environ.get("LOGDIR", ".") - ipc_path = Path(log_dir) / "cuda_ipc_handles.json" - - with open(ipc_path, "w") as f: - json.dump({ - "handles": ipc_handles, - "model_class": result.get("model_class", "unknown"), - "num_params": result.get("num_params", 0), - "device_count": torch.cuda.device_count(), - "export_time": time.time(), - }, f, indent=2) - - logger.info(f"✓ Exported {len(ipc_handles)} CUDA IPC handles to {ipc_path}") - - return JSONResponse({ - "status": "ok", - "num_parameters": len(ipc_handles), - "failed_parameters": len(failed_params), - "ipc_path": str(ipc_path), - "total_elements": sum(info["numel"] for info in ipc_handles.values()), - "model_class": result.get("model_class", "unknown"), - "message": "IPC handles exported. Trainer can now attach to shared memory.", - }) - - except concurrent.futures.TimeoutError: - raise HTTPException(status_code=504, detail="collective_rpc timed out after 60s") - except Exception as e: - logger.error(f"Failed to export CUDA IPC handles: {e}") - import traceback - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/bridge/cuda_ipc_status") -async def bridge_cuda_ipc_status() -> JSONResponse: - """ - Check CUDA IPC status and whether shared memory is available. - """ - log_dir = os.environ.get("LOGDIR", ".") - ipc_path = Path(log_dir) / "cuda_ipc_handles.json" - - status = { - "sync_llm_available": SYNC_LLM_AVAILABLE, - "sync_engine_initialized": sync_engine is not None, - "ipc_handles_exported": ipc_path.exists(), - "ipc_path": str(ipc_path) if ipc_path.exists() else None, - "cuda_device_count": torch.cuda.device_count(), - } - - if ipc_path.exists(): - try: - with open(ipc_path) as f: - data = json.load(f) - status["num_parameters"] = len(data.get("handles", {})) - status["model_class"] = data.get("model_class") - status["export_time"] = data.get("export_time") - except Exception as e: - status["ipc_file_error"] = str(e) - - return JSONResponse(status) - - @app.get("/bridge/debug") async def bridge_debug() -> JSONResponse: - """ - Debug endpoint to inspect engine capabilities. - - Lists available attributes and methods on the engine. - """ - active_engine = get_engine() - + """Debug endpoint to inspect engine state.""" debug_info = { - "engine_type": type(active_engine).__name__, - "engine_mode": "async" if engine is not None else "sync_cuda_ipc", + "engine_type": type(engine).__name__ if engine else None, "vllm_version": VLLM_VERSION, - "model_config": {}, - "available_methods": {}, - "important_attributes": {}, + "patches_applied": PATCHES_APPLIED, + "shared_weights_env": os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0"), + "num_inference_nodes": os.environ.get("NUM_INFERENCE_NODES", "unset"), } - # Get model config if engine is not None: - debug_info["model_config"] = { - "model": str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown", - "dtype": str(engine.model_config.dtype) if hasattr(engine, "model_config") else "unknown", - } - elif sync_engine is not None: try: debug_info["model_config"] = { - "model": str(sync_engine.llm_engine.model_config.model), - "dtype": str(sync_engine.llm_engine.model_config.dtype), + "model": str(engine.model_config.model), + "dtype": str(engine.model_config.dtype), } except Exception: - debug_info["model_config"] = {"model": "unknown", "dtype": "unknown"} - - # Check for important methods - important_methods = [ - "pause_generation", "resume_generation", "is_paused", - "sleep", "wake_up", "is_sleeping", - "collective_rpc", "add_lora", "remove_lora", "list_loras", - "generate", "encode", "abort", "check_health", - ] - - for method in important_methods: - has_method = hasattr(active_engine, method) and callable(getattr(active_engine, method)) - debug_info["available_methods"][method] = has_method - - # Check important attributes - important_attrs = [ - "engine_core", "model_config", "vllm_config", - "input_processor", "output_processor", "tokenizer", - "llm_engine", # For sync LLM - ] - - for attr in important_attrs: - if hasattr(active_engine, attr): - attr_val = getattr(active_engine, attr) - debug_info["important_attributes"][attr] = type(attr_val).__name__ - else: - debug_info["important_attributes"][attr] = None + pass return JSONResponse(debug_info) @app.get("/bridge/list_endpoints") -async def bridge_list_endpoints() -> JSONResponse: - """ - List all available bridge endpoints with descriptions. - - Use this to discover what capabilities are available. - """ - endpoints = { - "health": { - "GET /health": "Basic health check", - "GET /health_generate": "Deep health check (sends test request)", - }, - "generation": { - "POST /generate": "Generate text (vLLM native format)", - "POST /v1/completions": "Generate text (OpenAI format)", - "POST /v1/chat/completions": "Chat completion (OpenAI format)", - }, - "bridge_control": { - "GET /bridge/info": "Get bridge status and rendezvous info", - "POST /bridge/init": "Initialize weight bridge for NCCL", - "POST /bridge/disable": "Disable weight bridge", - "GET /bridge/state_dict_info": "Get model parameter info", - }, - "weight_updates": { - "POST /bridge/pause": "⭐ Pause generation for weight updates", - "POST /bridge/resume": "⭐ Resume generation after weight updates", - "GET /bridge/is_paused": "Check if paused", - "POST /bridge/notify_update": "Notify server of weight update", - }, - "memory_management": { - "POST /bridge/sleep": "Put engine to sleep (free GPU memory)", - "POST /bridge/wake_up": "Wake engine up (reload model)", - "GET /bridge/is_sleeping": "Check if sleeping", - }, - "lora_adapters": { - "GET /lora/status": "Get LoRA status", - "POST /lora/load": "Load LoRA adapter", - "POST /lora/unload": "Unload LoRA adapter", - }, - "advanced": { - "POST /bridge/collective_rpc": "Call method on workers", - "GET /bridge/debug": "Debug engine structure", - "GET /bridge/list_endpoints": "This endpoint", - }, - } - - return JSONResponse(endpoints) +async def list_endpoints() -> JSONResponse: + """List all available endpoints.""" + endpoints = [] + for route in app.routes: + if hasattr(route, "path") and hasattr(route, "methods"): + endpoints.append({ + "path": route.path, + "methods": list(route.methods), + }) + return JSONResponse({"endpoints": endpoints}) # ============================================================================= -# LoRA Endpoints (for adapter hot-swapping) +# LoRA Endpoints # ============================================================================= -@app.get("/lora/status", response_model=LoraStatusResponse) +@app.get("/lora/status") async def lora_status() -> LoraStatusResponse: - """ - Get current LoRA adapter status. - - Returns which adapter is active (if any) and lists available adapters - in the configured adapter directory. - """ - # List available adapters from save path - adapter_dir = os.environ.get("LORA_ADAPTER_DIR", "./adapters") + """Get LoRA adapter status.""" + log_dir = os.environ.get("LOGDIR", ".") available = [] - if os.path.isdir(adapter_dir): - for item in os.listdir(adapter_dir): - item_path = os.path.join(adapter_dir, item) - # Check if it looks like a PEFT adapter + + if os.path.exists(log_dir): + for item in os.listdir(log_dir): + item_path = os.path.join(log_dir, item) if os.path.isdir(item_path) and os.path.exists( os.path.join(item_path, "adapter_config.json") ): available.append(item) - + return LoraStatusResponse( active_adapter=bridge_state.active_lora_path, load_count=bridge_state.lora_load_count, @@ -1392,62 +534,32 @@ async def lora_status() -> LoraStatusResponse: @app.post("/lora/load") async def lora_load(request: LoraLoadRequest) -> JSONResponse: - """ - Hot-swap a LoRA adapter without restarting the server. - - The adapter is loaded from disk and merged with the base model weights. - This is much faster than restarting vLLM with a new checkpoint. - - Note: This requires the PEFT library and a compatible vLLM version. - """ - adapter_path = request.adapter_path - - if not os.path.exists(adapter_path): - raise HTTPException(status_code=404, detail=f"Adapter not found: {adapter_path}") - - if not os.path.exists(os.path.join(adapter_path, "adapter_config.json")): - raise HTTPException( - status_code=400, detail=f"Invalid adapter (missing adapter_config.json): {adapter_path}" - ) - - try: - # TODO: Implement actual LoRA loading for vLLM - # This depends on vLLM's LoRA support which varies by version - # For now, we track the state and log the request - - with bridge_state.lock: - bridge_state.active_lora_path = adapter_path - bridge_state.lora_load_count += 1 - - logger.info(f"LoRA adapter loaded: {adapter_path}") - - return JSONResponse({ - "status": "ok", - "adapter_path": adapter_path, - "load_count": bridge_state.lora_load_count, - "message": "Adapter registered (actual loading depends on vLLM version)", - }) - - except Exception as e: - logger.error(f"Failed to load LoRA adapter: {e}") - raise HTTPException(status_code=500, detail=str(e)) + """Load a LoRA adapter.""" + if not os.path.exists(request.adapter_path): + raise HTTPException(status_code=404, detail=f"Adapter not found: {request.adapter_path}") + + with bridge_state.lock: + bridge_state.active_lora_path = request.adapter_path + bridge_state.lora_load_count += 1 + + logger.info(f"LoRA adapter loaded: {request.adapter_path}") + + return JSONResponse({ + "status": "ok", + "adapter_path": request.adapter_path, + "load_count": bridge_state.lora_load_count, + }) @app.post("/lora/unload") async def lora_unload() -> JSONResponse: - """ - Unload the current LoRA adapter, reverting to base model weights. - """ + """Unload current LoRA adapter.""" with bridge_state.lock: - prev_adapter = bridge_state.active_lora_path + prev = bridge_state.active_lora_path bridge_state.active_lora_path = None - - logger.info(f"LoRA adapter unloaded: {prev_adapter}") - - return JSONResponse({ - "status": "ok", - "previous_adapter": prev_adapter, - }) + + logger.info(f"LoRA adapter unloaded: {prev}") + return JSONResponse({"status": "ok", "previous_adapter": prev}) # ============================================================================= @@ -1456,214 +568,112 @@ async def lora_unload() -> JSONResponse: def build_app(args: Namespace) -> FastAPI: - """Build the FastAPI application with configured root path.""" - global app # noqa: F824 + """Build the FastAPI application.""" + global app app.root_path = args.root_path return app -async def init_app( - args: Namespace, - llm_engine: AsyncLLM | None = None, -) -> FastAPI: - """ - Initialize the application and vLLM engine. - - Args: - args: Parsed command-line arguments - llm_engine: Optional pre-created engine (for testing) - - Returns: - Configured FastAPI application - """ +async def init_app(args: Namespace, llm_engine: AsyncLLM | None = None) -> FastAPI: + """Initialize the application and vLLM engine.""" app = build_app(args) - - global engine, sync_engine - use_cuda_ipc = getattr(args, "enable_cuda_ipc", False) + global engine + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = ( + llm_engine if llm_engine is not None + else AsyncLLM.from_engine_args(engine_args, usage_context=UsageContext.API_SERVER) + ) + app.state.engine_client = engine - if use_cuda_ipc: - # CUDA IPC MODE: Use sync LLM only (model in same process) - # This allows function-based collective_rpc for IPC handle export - if not SYNC_LLM_AVAILABLE: - raise RuntimeError("CUDA IPC requested but vllm.LLM not available") - - logger.info("=" * 60) - logger.info("CUDA IPC MODE: Using sync LLM for true shared memory") - logger.info("=" * 60) - - sync_engine = SyncLLM( - model=args.model, - dtype=getattr(args, "dtype", "auto"), - gpu_memory_utilization=getattr(args, "gpu_memory_utilization", 0.9), - tensor_parallel_size=getattr(args, "tensor_parallel_size", 1), - trust_remote_code=getattr(args, "trust_remote_code", False), - ) - engine = None # No async engine in CUDA IPC mode - logger.info("✓ Sync LLM ready for CUDA IPC") - - else: - # STANDARD MODE: Use AsyncLLM (model in subprocess) - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = ( - llm_engine - if llm_engine is not None - else AsyncLLM.from_engine_args( - engine_args, usage_context=UsageContext.API_SERVER - ) - ) - sync_engine = None - - app.state.engine_client = engine or sync_engine - # Export state dict info for trainers _export_state_dict_info(args) - + return app def _export_state_dict_info(args: Namespace) -> None: - """ - Export model parameter mapping to JSON for trainer attachment. - - This writes a file that trainers can read to understand how to - map their parameters to the inference model's parameters. - """ + """Export model parameter mapping to JSON for trainer.""" log_dir = os.environ.get("LOGDIR", ".") json_path = Path(log_dir) / "vllm_bridge_config.json" - + try: - # Basic info - actual param mappings added when bridge is initialized info = { "model": getattr(args, "model", "unknown"), "dtype": getattr(args, "dtype", "auto"), "tp_degree": getattr(args, "tensor_parallel_size", 1), - "dp_shard_degree": 1, # Data parallel sharding + "dp_shard_degree": 1, "param_mappings": {}, + "shared_weights_enabled": PATCHES_APPLIED, } - + with open(json_path, "w") as f: json.dump(info, f, indent=2) - + logger.info(f"Exported state dict info to {json_path}") - except Exception as e: logger.warning(f"Failed to export state dict info: {e}") -async def run_server( - args: Namespace, llm_engine: AsyncLLM | None = None, **uvicorn_kwargs: Any -) -> None: - """ - Run the vLLM API server. - - This is the main entry point that starts the HTTP server and - serves requests until shutdown. - """ +async def run_server(args: Namespace, llm_engine: AsyncLLM | None = None, **uvicorn_kwargs: Any) -> None: + """Run the vLLM API server.""" logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - + + if PATCHES_APPLIED: + logger.info("=" * 60) + logger.info("SHARED MEMORY MODE ENABLED") + logger.info("Weight updates from trainer will be reflected immediately!") + logger.info("=" * 60) + set_ulimit() app = await init_app(args, llm_engine) - # Verify at least one engine is initialized - if engine is None and sync_engine is None: + if engine is None: raise RuntimeError("No engine initialized") - - # Log bridge endpoints + + # Log available endpoints logger.info("=" * 60) - logger.info("Bridge endpoints available:") - logger.info("-" * 60) - logger.info("Weight Updates (use these for training!):") - logger.info(" POST /bridge/pause - Pause generation for weight updates") - logger.info(" POST /bridge/resume - Resume after updating weights") - logger.info(" GET /bridge/is_paused - Check pause state") - logger.info("-" * 60) - logger.info("Memory Management:") - logger.info(" POST /bridge/sleep - Free GPU memory") - logger.info(" POST /bridge/wake_up - Reload model") - logger.info("-" * 60) - logger.info("LoRA Adapters:") - logger.info(" GET /lora/status - Get adapter status") - logger.info(" POST /lora/load - Load adapter") - logger.info(" POST /lora/unload - Unload adapter") - logger.info("-" * 60) - logger.info("Debug:") - logger.info(" GET /bridge/debug - Inspect engine") - logger.info(" GET /bridge/list_endpoints - List all endpoints") - logger.info(" POST /bridge/collective_rpc - Call worker methods") + logger.info("Available endpoints:") + logger.info(" POST /generate - Generate completions") + logger.info(" GET /bridge/info - Bridge status") + logger.info(" POST /bridge/pause - Pause generation") + logger.info(" POST /bridge/resume - Resume generation") + logger.info(" GET /lora/status - LoRA adapter status") logger.info("=" * 60) - + shutdown_task = await serve_http( app, sock=None, - enable_ssl_refresh=args.enable_ssl_refresh, + enable_ssl_refresh=getattr(args, 'enable_ssl_refresh', False), host=args.host, port=args.port, - log_level=args.log_level, + log_level=getattr(args, 'log_level', 'info'), timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, + ssl_keyfile=getattr(args, 'ssl_keyfile', None), + ssl_certfile=getattr(args, 'ssl_certfile', None), + ssl_ca_certs=getattr(args, 'ssl_ca_certs', None), + ssl_cert_reqs=getattr(args, 'ssl_cert_reqs', ssl.CERT_NONE), **uvicorn_kwargs, ) - + await shutdown_task -# ============================================================================= -# CLI Entry Point -# ============================================================================= - - if __name__ == "__main__": parser = FlexibleArgumentParser() - - # Server configuration parser.add_argument("--host", type=str, default=None) - parser.add_argument("--port", type=parser.check_port, default=8000) - parser.add_argument("--log-level", type=str, default="debug") - - # SSL configuration + parser.add_argument("--port", type=int, default=9001) parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) - parser.add_argument( - "--ssl-ca-certs", type=str, default=None, help="The CA certificates file" - ) - parser.add_argument( - "--enable-ssl-refresh", - action="store_true", - default=False, - help="Refresh SSL Context when SSL certificate files change", - ) - parser.add_argument( - "--ssl-cert-reqs", - type=int, - default=int(ssl.CERT_NONE), - help="Whether client certificate is required (see stdlib ssl module's)", - ) - parser.add_argument( - "--root-path", - type=str, - default=None, - help="FastAPI root_path when app is behind a path based routing proxy", - ) + parser.add_argument("--ssl-ca-certs", type=str, default=None) + parser.add_argument("--enable-ssl-refresh", action="store_true", default=False) + parser.add_argument("--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE)) + parser.add_argument("--root-path", type=str, default=None) + parser.add_argument("--log-level", type=str, default="info") - # CUDA IPC for true shared memory - parser.add_argument( - "--enable-cuda-ipc", - action="store_true", - default=False, - help=( - "Enable CUDA IPC for true shared memory with trainer. " - "Requires trainer to be on the same GPU. " - "This initializes a sync LLM alongside the async engine." - ), - ) - - # Add vLLM engine arguments + # Add vLLM engine args parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() - asyncio.run(run_server(args)) diff --git a/example_trainer/vllm_patching/__init__.py b/example_trainer/vllm_patching/__init__.py new file mode 100644 index 00000000..83545151 --- /dev/null +++ b/example_trainer/vllm_patching/__init__.py @@ -0,0 +1,37 @@ +""" +vLLM Patching Module - Enables shared memory weight updates. + +This module patches vLLM's GPUModelRunner to: +1. Call share_memory_() on model weights after loading +2. Spawn a daemon process that receives NCCL weight updates from trainers +3. Enable real-time weight synchronization without restarting vLLM + +Usage: + # Import this BEFORE importing vllm + from example_trainer.vllm_patching import apply_patches + apply_patches() + + # Then import vllm normally + from vllm import AsyncLLM +""" + +from .patched_gpu_runner import PatchedGPUModelRunner, apply_patches +from .weight_updater import weight_updater_process +from .distributed_utils import ( + init_process_group, + broadcast_object_list, + get_inference_urls, + get_json_data, +) + +__all__ = [ + "PatchedGPUModelRunner", + "apply_patches", + "weight_updater_process", + "init_process_group", + "broadcast_object_list", + "get_inference_urls", + "get_json_data", +] + + diff --git a/example_trainer/vllm_patching/distributed_utils.py b/example_trainer/vllm_patching/distributed_utils.py new file mode 100644 index 00000000..41cfca26 --- /dev/null +++ b/example_trainer/vllm_patching/distributed_utils.py @@ -0,0 +1,328 @@ +""" +Distributed utilities for vLLM weight synchronization. + +Provides process group initialization and communication helpers +for coordinating weight updates between trainer and vLLM. +""" + +from __future__ import annotations + +import json +import os +import socket +import time +from collections import defaultdict +from datetime import timedelta +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist + + +def init_process_group( + backend: Optional[str] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Any] = None, + group_name: str = "", + pg_options: Optional[Any] = None, +) -> dist.ProcessGroup: + """ + Initialize a custom process group for weight synchronization. + + This creates a named process group that coexists with vLLM's internal + process groups, enabling direct tensor communication between trainer + and inference processes. + + Args: + backend: "nccl" for GPU, "gloo" for CPU + init_method: Rendezvous URL (e.g., "tcp://host:port") + timeout: How long to wait for other ranks + world_size: Total number of processes + rank: This process's rank + store: Optional torch.distributed Store + group_name: Name for this process group (must match across ranks) + pg_options: Backend-specific options + + Returns: + ProcessGroup for collective operations + """ + from torch.distributed.distributed_c10d import ( + _new_process_group_helper, + _world, + Backend, + default_pg_timeout, + PrefixStore, + rendezvous, + ) + + assert (store is None) or (init_method is None), \ + "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # Create store via rendezvous if not provided + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + store = PrefixStore(group_name, store) + + # Handle PyTorch version differences for pg_options parameter + pg_options_param_name = ( + "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + ) + + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg + + +def broadcast_object_list( + object_list: List[Any], + src: Optional[int] = None, + group: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + group_src: Optional[int] = None, +) -> None: + """ + Broadcast a list of objects from source rank to all other ranks. + + Modified from torch.distributed.broadcast_object_list to work correctly + with custom process groups where rank 0 may not be the default group's rank 0. + + Args: + object_list: List of objects to broadcast (modified in-place on receivers) + src: Global source rank (deprecated, use group_src) + group: Process group to use + device: Device for temporary tensors + group_src: Source rank within the group + """ + global_src = group_src if group_src is not None else src + current_device = device + + # Broadcast object sizes first + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) + dist.broadcast(object_sizes_tensor, src=global_src, group=group) + + # Broadcast serialized objects + object_tensor = torch.empty( + torch.sum(object_sizes_tensor).item(), + dtype=torch.uint8, + device=current_device, + ) + dist.broadcast(object_tensor, src=global_src, group=group) + + # Deserialize objects + offset = 0 + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = dist.distributed_c10d._tensor_to_object( + obj_view, obj_size, group + ) + + +def get_inference_urls(num_inference_nodes: int = 0) -> Tuple[Optional[str], ...]: + """ + Get URLs for inference server communication. + + Parses SLURM environment or uses localhost for single-machine setup. + + Args: + num_inference_nodes: Number of dedicated inference nodes. + 0 = single machine, trainer and vLLM share the node + >0 = multi-node, last N nodes are for inference + + Returns: + Tuple of (master_addr, master_gloo_addr, master_inference_addr, nodelist) + Returns (None, None, None, None) if not in a valid setup. + """ + if num_inference_nodes > 0: + # Multi-node SLURM setup + slurm_nodelist = os.environ.get("SLURM_JOB_NODELIST") + if not slurm_nodelist: + return None, None, None, None + + # Parse SLURM node list + nodelist = ( + os.popen(f'scontrol show hostnames {slurm_nodelist}') + .read() + .strip() + .split("\n") + ) + nodelist = [node for node in nodelist if node] + + # First node is master for process groups + master_server = f"{nodelist[0]}:26756" + master_gloo_server = f"{nodelist[0]}:26757" + + # Last N nodes are inference nodes + inference_nodes = nodelist[-num_inference_nodes:] + master_inference_server = f"{inference_nodes[0]}:26758" + + return master_server, master_gloo_server, master_inference_server, inference_nodes + + elif num_inference_nodes == 0: + # Single machine setup + master_server = "localhost:26756" + master_gloo_server = "localhost:26757" + master_inference_server = "localhost:26758" + nodelist = ["localhost"] + + return master_server, master_gloo_server, master_inference_server, nodelist + + else: + return None, None, None, None + + +def get_hostnames() -> Optional[List[str]]: + """ + Get the hostnames for this machine. + + Parses /etc/hosts to find all hostnames associated with this machine's IP. + + Returns: + List of [ip, hostname1, hostname2, ...] or None if not found. + """ + my_ip = socket.gethostbyname(socket.gethostname()) + my_hostname = socket.gethostname() + + try: + with open("/etc/hosts", "r") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + parts = line.split() + if len(parts) >= 2 and ((parts[0] == my_ip) or (my_hostname in parts)): + ip = parts[0] + if ip.startswith("127."): + continue + return parts + except Exception: + pass + + return None + + +def get_json_data(log_dir: Optional[str] = None, timeout: int = 300) -> Dict[str, Any]: + """ + Load the bridge configuration JSON from vLLM. + + Waits for the file to be created by vLLM's weight bridge setup. + + Args: + log_dir: Directory containing the JSON file (defaults to LOGDIR env var) + timeout: Maximum seconds to wait for file + + Returns: + Parsed JSON data with parameter mappings and configuration. + + Raises: + ValueError: If LOGDIR not set and log_dir not provided + FileNotFoundError: If file not found after timeout + """ + if log_dir is None: + log_dir = os.environ.get("LOGDIR") + if log_dir is None: + raise ValueError("LOGDIR environment variable not set and log_dir not provided") + + json_path = os.path.join(log_dir, "vllm_bridge_config.json") + + wait_time = 0 + while not os.path.exists(json_path): + if wait_time >= timeout: + raise FileNotFoundError(f"Config file not found after {timeout}s: {json_path}") + if wait_time % 10 == 0: + print(f"[Updater] Waiting for {json_path}...", flush=True) + time.sleep(1) + wait_time += 1 + + # Wait a moment for file to finish writing + time.sleep(0.5) + + with open(json_path, "r") as f: + return json.load(f) + + +def get_name_conversions(param_mappings: Dict[str, Any]) -> Dict[str, List[str]]: + """ + Build reverse mapping from vLLM names to trainer names. + + Args: + param_mappings: Dict mapping trainer param names to vLLM info + + Returns: + Dict mapping vLLM names to list of trainer names + """ + name_conversions = defaultdict(list) + for name, info in param_mappings.items(): + vllm_name = info.get("vllm_name", name) + name_conversions[vllm_name].append(name) + return name_conversions + + +# Permutation functions for rotary embeddings +def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor: + """ + Permute weight tensor for sliced rotary embeddings. + + Args: + w: Weight tensor of shape [dim1, dim2] + n_heads: Number of attention heads + + Returns: + Permuted tensor for rotary embedding compatibility + """ + dim1 = w.shape[0] + dim2 = w.shape[1] + return ( + w.view(n_heads, dim1 // n_heads // 2, 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + +def permute_1d(w: torch.Tensor, n_heads: int) -> torch.Tensor: + """ + Permute 1D weight tensor (bias) for sliced rotary embeddings. + + Args: + w: Weight tensor of shape [dim1] + n_heads: Number of attention heads + + Returns: + Permuted tensor + """ + dim1 = w.shape[0] + return w.view(n_heads, dim1 // n_heads // 2, 2).transpose(1, 2).reshape(dim1) + + diff --git a/example_trainer/vllm_patching/weight_updater.py b/example_trainer/vllm_patching/weight_updater.py new file mode 100644 index 00000000..da5e2bfe --- /dev/null +++ b/example_trainer/vllm_patching/weight_updater.py @@ -0,0 +1,425 @@ +""" +Weight Updater Process - Daemon that receives NCCL weight updates. + +This process runs as a daemon spawned by the patched vLLM GPUModelRunner. +It joins NCCL process groups with the trainer and receives weight updates, +copying them directly into vLLM's shared memory tensors. +""" + +from __future__ import annotations + +import json +import os +import time +from typing import Any, Dict, List, Optional + +import torch +import torch.distributed as dist + +from .distributed_utils import ( + init_process_group, + get_inference_urls, + get_hostnames, + get_json_data, + get_name_conversions, + permute, + permute_1d, +) + + +def weight_updater_process( + state_dict: Dict[str, torch.Tensor], + num_q_heads: int, + num_kv_heads: int, + tp_rank: int, + tp_size: int, + gpu_id: int, +) -> None: + """ + Daemon process that receives weight updates from trainers via NCCL. + + This runs inside a subprocess spawned by PatchedGPUModelRunner. It: + 1. Joins NCCL/Gloo process groups with the trainer + 2. Receives weight update broadcasts from rank 0 (trainer) + 3. Copies updated weights directly into the shared state_dict + + Since state_dict tensors have share_memory_() called on them, the main + vLLM process immediately sees the updates for inference. + + Args: + state_dict: Model state dict with shared memory tensors + num_q_heads: Number of query attention heads (for permutation) + num_kv_heads: Number of key/value attention heads + tp_rank: Tensor parallel rank of this worker + tp_size: Total tensor parallel size + gpu_id: GPU device ID for this worker + """ + # Configuration from environment + num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", 0)) + cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",") + debug = int(os.environ.get("WEIGHT_UPDATER_DEBUG", 0)) + + # Determine world size based on setup + if num_inference_nodes > 0: + # Multi-node: 8 GPUs per node + world_size = num_inference_nodes * 8 + ranks_per_node = 8 + else: + # Single node: typically 4 inference GPUs + world_size = 4 + ranks_per_node = 4 + + # Get network info + hostnames = get_hostnames() + master_addr, master_gloo_addr, master_inference_addr, urls = get_inference_urls( + num_inference_nodes + ) + + if master_addr is None: + print(f"[Updater] Master address not found, exiting", flush=True) + return + + # Set CUDA device + torch.cuda.set_device(tp_rank) + + print( + f"[Updater] Starting on TP rank {tp_rank}/{tp_size}, " + f"q_heads={num_q_heads}, kv_heads={num_kv_heads}, gpu_id={gpu_id}", + flush=True, + ) + print(f"[Updater] Master: {master_addr}, world_size={world_size}", flush=True) + + # Determine this worker's rank within the inference group + rank = -1 + if num_inference_nodes == 0: + # Single node: skip first N GPUs (used by trainer) + rank = int(cuda_devices[gpu_id]) - (8 - ranks_per_node) + else: + # Multi-node: find which inference node we're on + for i, url in enumerate(urls): + if hostnames and url in hostnames: + rank = ranks_per_node * i + int(cuda_devices[gpu_id]) + break + + if rank < 0: + print(f"[Updater] Could not determine rank, exiting", flush=True) + return + + # Load config from vLLM + print("[Updater] Loading bridge config...", flush=True) + try: + json_data = get_json_data() + except Exception as e: + print(f"[Updater] Failed to load config: {e}", flush=True) + return + + param_name_list = sorted(json_data.get("param_mappings", {}).keys()) + num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1) + total_group_size = num_training_gpus + world_size + + # Offset rank by training GPUs + rank = rank + num_training_gpus + + print(f"[Updater] Total group size: {total_group_size}", flush=True) + print(f"[Updater] Training GPUs: {num_training_gpus}", flush=True) + print(f"[Updater] My rank: {rank}", flush=True) + + # Initialize process groups + print("[Updater] Creating process groups...", flush=True) + + try: + # Gloo group for coordination + gloo_group = init_process_group( + backend="gloo", + init_method=f"tcp://{master_addr}", + world_size=total_group_size, + rank=rank, + group_name="gloo_group", + ) + print("[Updater] ✓ Gloo group created", flush=True) + + # NCCL group for tensor transfers + nccl_group = init_process_group( + backend="nccl", + init_method=f"tcp://{master_addr}", + world_size=total_group_size, + rank=rank, + group_name="weight_update_group", + ) + print("[Updater] ✓ NCCL group created", flush=True) + + except Exception as e: + print(f"[Updater] Failed to create process groups: {e}", flush=True) + return + + # Get device for tensors + my_device = next(iter(state_dict.values())).device + + # Write dtype mapping if rank 0 + if rank == num_training_gpus: # First inference rank + _write_dtype_mapping(state_dict, json_data) + + print("[Updater] Entering update loop...", flush=True) + + # Buffers for merged QKV and gate_up projections + qkv_buffer = {} + gate_up_buffer = {} + qkv_bias_buffer = {} + w1w3_buffer = {} + + with torch.no_grad(): + while True: + try: + # Receive parameter index from trainer (rank 0) + obj_indx = torch.zeros(1, dtype=torch.long, device=my_device) + dist.broadcast(obj_indx, src=0, group=nccl_group) + + tt_indx = obj_indx.item() + + # -1 signals no update this round (heartbeat) + if tt_indx == -1: + continue + + # Get parameter info + if tt_indx >= len(param_name_list): + print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True) + continue + + tt_name = param_name_list[tt_indx] + param_info = json_data["param_mappings"].get(tt_name, {}) + vllm_name = param_info.get("vllm_name", tt_name) + local_shape = param_info.get("local_shape", []) + + if vllm_name not in state_dict: + if debug: + print(f"[Updater] {vllm_name} not in state_dict, skipping", flush=True) + continue + + target_dtype = state_dict[vllm_name].dtype + + if debug: + print( + f"[Updater] Receiving {tt_name} -> {vllm_name}, " + f"shape={local_shape}, dtype={target_dtype}", + flush=True, + ) + + # Gather tensors from all training ranks + tensor_list = [ + torch.zeros( + local_shape if idx < num_training_gpus else [1], + dtype=target_dtype, + device=my_device, + ) + for idx in range(total_group_size) + ] + + dist.all_gather( + tensor_list, + torch.zeros(1, dtype=target_dtype, device=my_device), + group=nccl_group, + ) + + # Only keep training tensors + tensor_list = tensor_list[:num_training_gpus] + + # Merge tensors from different parallel configurations + tensor = _merge_tensors( + tensor_list, + json_data, + param_info, + state_dict[vllm_name], + ) + + # Apply updates (handling merged QKV, gate_up, etc.) + _apply_weight_update( + state_dict, + vllm_name, + tt_name, + tensor, + param_info, + num_q_heads, + num_kv_heads, + qkv_buffer, + gate_up_buffer, + qkv_bias_buffer, + w1w3_buffer, + debug, + ) + + except Exception as e: + print(f"[Updater] Error in update loop: {e}", flush=True) + import traceback + traceback.print_exc() + time.sleep(1) + + +def _write_dtype_mapping( + state_dict: Dict[str, torch.Tensor], + json_data: Dict[str, Any], +) -> None: + """Write dtype mapping file for trainer reference.""" + try: + log_dir = os.environ.get("LOGDIR", ".") + name_conversions = get_name_conversions(json_data.get("param_mappings", {})) + + weight_dtypes = {} + for name in state_dict.keys(): + tt_names = name_conversions.get(name, [name]) + for tt_name in tt_names: + weight_dtypes[tt_name] = str(state_dict[name].dtype).split(".")[-1] + + with open(f"{log_dir}/vllm_dtypes.json", "w") as f: + json.dump(weight_dtypes, f, indent=2) + + print("[Updater] Wrote dtype mapping", flush=True) + except Exception as e: + print(f"[Updater] Failed to write dtype mapping: {e}", flush=True) + + +def _merge_tensors( + tensor_list: List[torch.Tensor], + json_data: Dict[str, Any], + param_info: Dict[str, Any], + target_tensor: torch.Tensor, +) -> torch.Tensor: + """ + Merge tensors from distributed training into single tensor. + + Handles FSDP (data parallel) and TP (tensor parallel) sharding. + """ + dp_shard_degree = json_data.get("dp_shard_degree", 1) + tp_degree = json_data.get("tp_degree", 1) + tp_shard_dim = param_info.get("tp_shard_dim", 0) + + if dp_shard_degree > 1: + # First merge across data parallel dimension + tp_tensors = [] + for i in range(tp_degree): + dp_tensors = tensor_list[i::tp_degree] + tp_tensors.append(torch.cat(dp_tensors, dim=0)) + + # Then merge across tensor parallel dimension if needed + if tp_degree > 1: + if tp_tensors[0].shape == target_tensor.shape: + tensor = tp_tensors[0].contiguous() + else: + tensor = torch.cat(tp_tensors, dim=tp_shard_dim).contiguous() + else: + tensor = tp_tensors[0].contiguous() + else: + # No FSDP, just merge TP shards + tensor = torch.cat(tensor_list, dim=tp_shard_dim).contiguous() + + # Cast to target dtype if needed + if tensor.dtype != target_tensor.dtype: + tensor = tensor.to(target_tensor.dtype) + + return tensor + + +def _apply_weight_update( + state_dict: Dict[str, torch.Tensor], + vllm_name: str, + tt_name: str, + tensor: torch.Tensor, + param_info: Dict[str, Any], + num_q_heads: int, + num_kv_heads: int, + qkv_buffer: Dict[str, torch.Tensor], + gate_up_buffer: Dict[str, torch.Tensor], + qkv_bias_buffer: Dict[str, torch.Tensor], + w1w3_buffer: Dict[str, torch.Tensor], + debug: bool, +) -> None: + """ + Apply weight update to state_dict, handling merged projections. + + vLLM often merges QKV projections and gate/up projections into single + tensors for efficiency. This handles unpacking and merging correctly. + """ + needs_permute = param_info.get("needs_permute", False) + shape = param_info.get("shape", list(tensor.shape)) + + def _debug_diff(name: str, old: torch.Tensor, new: torch.Tensor) -> None: + if debug: + diff = (new.float() - old.float()).abs() + print( + f"[WEIGHT DIFF] {name}: mean={diff.mean().item():.6e}, " + f"std={diff.std().item():.6e}", + flush=True, + ) + + # Handle merged QKV projection weights + if "qkv_proj.weight" in vllm_name: + key_val = "q" if ".wq." in tt_name or "q_proj" in tt_name else \ + "v" if ".wv." in tt_name or "v_proj" in tt_name else "k" + + if key_val == "q" and needs_permute: + tensor = permute(tensor, num_q_heads) + elif key_val == "k" and needs_permute: + tensor = permute(tensor, num_kv_heads) + + qkv_buffer[key_val] = tensor + + if len(qkv_buffer) == 3: + merged = torch.cat([qkv_buffer["q"], qkv_buffer["k"], qkv_buffer["v"]], dim=0) + _debug_diff(vllm_name, state_dict[vllm_name].data, merged) + state_dict[vllm_name].data.copy_(merged.contiguous()) + qkv_buffer.clear() + + # Handle merged gate/up projection weights + elif "gate_up_proj.weight" in vllm_name: + key_val = "w1" if ".w1." in tt_name or "gate_proj" in tt_name else "w3" + gate_up_buffer[key_val] = tensor + + if len(gate_up_buffer) == 2: + merged = torch.cat([gate_up_buffer["w1"], gate_up_buffer["w3"]], dim=0) + _debug_diff(vllm_name, state_dict[vllm_name].data, merged) + state_dict[vllm_name].data.copy_(merged.contiguous()) + gate_up_buffer.clear() + + # Handle merged w1/w3 weights (alternative naming) + elif "w13_weight" in vllm_name: + key_val = "w1" if ".w1" in tt_name else "w3" + w1w3_buffer[key_val] = tensor + + if len(w1w3_buffer) == 2: + merged = torch.cat([w1w3_buffer["w1"], w1w3_buffer["w3"]], dim=1) + _debug_diff(vllm_name, state_dict[vllm_name].data, merged) + state_dict[vllm_name].data.copy_(merged.contiguous()) + w1w3_buffer.clear() + + # Handle merged QKV bias + elif "qkv_proj.bias" in vllm_name: + key_val = "q" if ".wq." in tt_name else "v" if ".wv." in tt_name else "k" + + if key_val == "q" and needs_permute: + tensor = permute_1d(tensor, num_q_heads) + elif key_val == "k" and needs_permute: + tensor = permute_1d(tensor, num_kv_heads) + + qkv_bias_buffer[key_val] = tensor + + if len(qkv_bias_buffer) == 3: + merged = torch.cat([qkv_bias_buffer["q"], qkv_bias_buffer["k"], qkv_bias_buffer["v"]], dim=0) + _debug_diff(vllm_name, state_dict[vllm_name].data, merged) + state_dict[vllm_name].data.copy_(merged.contiguous()) + qkv_bias_buffer.clear() + + # Handle regular weights (possibly needing permutation) + elif needs_permute: + if len(shape) == 2: + tensor = permute(tensor, shape[0]).contiguous() + elif len(shape) == 1: + tensor = permute_1d(tensor, shape[0]).contiguous() + + _debug_diff(vllm_name, state_dict[vllm_name].data, tensor) + state_dict[vllm_name].data.copy_(tensor) + + # Simple weight copy + else: + _debug_diff(vllm_name, state_dict[vllm_name].data, tensor) + state_dict[vllm_name].data.copy_(tensor) + + diff --git a/example_trainer/vllm_weight_bridge.py b/example_trainer/vllm_weight_bridge.py index 695398d0..6ea2c512 100644 --- a/example_trainer/vllm_weight_bridge.py +++ b/example_trainer/vllm_weight_bridge.py @@ -1,35 +1,44 @@ """ -vLLM Weight Bridge - Integration between trainer and vLLM inference. +vLLM Weight Bridge - Trainer-side integration for shared memory weight updates. -This module provides two modes for coordinating weight updates: +This module coordinates weight updates between the trainer and vLLM inference. -LOCAL MODE (num_inference_nodes=0): - - Trainer and vLLM run as separate processes on the same machine - - Communication via HTTP to vLLM's /bridge/* endpoints - - No NCCL process groups needed - - Simpler setup, suitable for single-machine training +ARCHITECTURE: + The patched vLLM server (using vllm_patching/) runs a daemon process that: + 1. Joins NCCL process groups with the trainer + 2. Receives weight updates via all_gather + 3. Copies updates into vLLM's shared memory tensors + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ SHARED MEMORY (via share_memory_()) │ + │ ┌─────────────────────────────────────────────────────────────────┐ │ + │ │ Model Weights │ │ + │ │ (accessible from MULTIPLE processes) │ │ + │ └─────────────────────────────────────────────────────────────────┘ │ + │ ▲ ▲ │ + │ │ Reads │ Writes │ + │ ┌────────┴────────┐ ┌───────────┴───────────┐ │ + │ │ vLLM Worker │ │ weight_updater │ │ + │ │ (inference) │ │ daemon process │ │ + │ └─────────────────┘ └───────────┬───────────┘ │ + │ │ NCCL │ + │ ▼ │ + │ ┌─────────────────────┐ │ + │ │ Trainer Process │ │ + │ │ (this bridge) │ │ + │ └─────────────────────┘ │ + └─────────────────────────────────────────────────────────────────────────┘ -DISTRIBUTED MODE (num_inference_nodes>0): - - Trainer and vLLM join the same NCCL process group - - Direct tensor sharing via shared GPU memory - - Lower latency, but requires coordinated setup - -Architecture (Local Mode): - ┌─────────────────┐ ┌─────────────────┐ - │ Trainer Process │ HTTP │ vLLM Process │ - │ (training) │────────▶│ (inference) │ - └─────────────────┘ └─────────────────┘ - -Architecture (Distributed Mode): - ┌─────────────────────────────────────────┐ - │ Shared GPU Memory (NCCL) │ - │ Model weights owned by vLLM process │ - └─────────────────────────────────────────┘ - ▲ ▲ - │ forward pass │ optimizer.step() - ┌───────┴───────┐ ┌───────┴───────┐ - │ vLLM Process │ │Trainer Process│ - └───────────────┘ └───────────────┘ +MODES: + LOCAL MODE (num_inference_nodes=0): + - Single machine setup + - Trainer and vLLM share the same node + - NCCL for weight broadcast to vLLM's daemon + + DISTRIBUTED MODE (num_inference_nodes>0): + - Multi-node setup with dedicated inference nodes + - Last N nodes run vLLM inference + - NCCL spans across nodes for weight updates """ from __future__ import annotations @@ -42,16 +51,15 @@ from collections import defaultdict from dataclasses import dataclass, field from datetime import timedelta from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import torch.distributed as dist from torch import nn -from transformers import AutoConfig, AutoModelForCausalLM # ============================================================================= -# Process Group Initialization Helpers +# Process Group Initialization # ============================================================================= @@ -67,22 +75,8 @@ def init_process_group( ) -> dist.ProcessGroup: """ Initialize a custom process group for weight synchronization. - - This is based on torch.distributed internals but allows creating a named - group that coexists with the default process group (used by vLLM internally). - - Args: - backend: "nccl" for GPU, "gloo" for CPU - init_method: Rendezvous URL (e.g., "tcp://host:port" or "env://") - timeout: How long to wait for other ranks - world_size: Total number of processes in the group - rank: This process's rank in the group - store: Optional torch.distributed Store object - group_name: Name for this process group (must match across all ranks) - pg_options: Backend-specific options - - Returns: - A ProcessGroup object for collective operations + + Creates a named group that coexists with vLLM's internal process groups. """ from torch.distributed.distributed_c10d import ( _new_process_group_helper, @@ -93,9 +87,8 @@ def init_process_group( rendezvous, ) - assert (store is None) or ( - init_method is None - ), "Cannot specify both init_method and store." + assert (store is None) or (init_method is None), \ + "Cannot specify both init_method and store." if store is not None: assert world_size > 0, "world_size must be positive if using store" @@ -103,23 +96,16 @@ def init_process_group( elif init_method is None: init_method = "env://" - if backend: - backend = Backend(backend) - else: - backend = Backend("undefined") + backend = Backend(backend) if backend else Backend("undefined") + timeout = timeout or default_pg_timeout - if timeout is None: - timeout = default_pg_timeout - - # Rendezvous with other processes if store is None: rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout) - # Use a PrefixStore to avoid key collisions with other groups store = PrefixStore(group_name, store) - # PyTorch 2.6+ renamed pg_options to backend_options + # Handle PyTorch version differences pg_options_param_name = ( "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" ) @@ -139,156 +125,37 @@ def init_process_group( return pg -def broadcast_object_list( - object_list: List[Any], - src: int, - group: dist.ProcessGroup, - device: Optional[torch.device] = None, -) -> None: +def get_inference_urls(num_inference_nodes: int = 0) -> Tuple[Optional[str], ...]: """ - Broadcast a list of picklable objects from src rank to all other ranks. - - This is a simplified version of torch.distributed.broadcast_object_list - that works correctly with custom process groups. - - Args: - object_list: List of objects to broadcast (modified in-place on receivers) - src: Source rank that has the data - group: Process group to use - device: Device for intermediate tensors - """ - current_device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Broadcast object sizes first - object_sizes_tensor = torch.empty( - len(object_list), dtype=torch.long, device=current_device - ) - dist.broadcast(object_sizes_tensor, src=src, group=group) - - # Broadcast serialized objects - object_tensor = torch.empty( - torch.sum(object_sizes_tensor).item(), - dtype=torch.uint8, - device=current_device, - ) - dist.broadcast(object_tensor, src=src, group=group) - - # Deserialize on receiving ranks - offset = 0 - for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset : offset + obj_size] - obj_view = obj_view.type(torch.uint8) - offset += obj_size - object_list[i] = dist._tensor_to_object(obj_view, obj_size, group) - - -# ============================================================================= -# Environment and URL Helpers -# ============================================================================= - - -def get_inference_urls(num_inference_nodes: int) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[List[str]]]: - """ - Get rendezvous URLs for connecting to inference nodes. - - In SLURM environments, parses SLURM_JOB_NODELIST to find inference servers. - For local testing, returns localhost URLs. - - Args: - num_inference_nodes: Number of inference nodes (from config) - + Get URLs for inference server communication. + Returns: - Tuple of (master_server, master_gloo_server, master_inference_server, nodelist) - All None if inference nodes not configured. + Tuple of (master_addr, master_gloo_addr, master_inference_addr, nodelist) """ if num_inference_nodes > 0: - # Multi-node SLURM environment - nodelist_raw = os.popen( - f'scontrol show hostnames {os.environ.get("SLURM_JOB_NODELIST", "")}' - ).read() - nodelist = [n for n in nodelist_raw.split("\n") if n] - - if not nodelist: + slurm_nodelist = os.environ.get("SLURM_JOB_NODELIST") + if not slurm_nodelist: return None, None, None, None - + + nodelist = ( + os.popen(f'scontrol show hostnames {slurm_nodelist}') + .read().strip().split("\n") + ) + nodelist = [n for n in nodelist if n] + master_server = f"{nodelist[0]}:26756" master_gloo_server = f"{nodelist[0]}:26757" - # Inference nodes are the last N nodes inference_nodes = nodelist[-num_inference_nodes:] master_inference_server = f"{inference_nodes[0]}:26758" - + return master_server, master_gloo_server, master_inference_server, inference_nodes - + elif num_inference_nodes == 0: - # Single-node local mode return "localhost:26756", "localhost:26757", "localhost:26758", ["localhost"] - else: return None, None, None, None -def get_local_hostname() -> Optional[List[str]]: - """Get the local hostname(s) from /etc/hosts for rank determination.""" - my_ip = socket.gethostbyname(socket.gethostname()) - my_hostname = socket.gethostname() - - try: - with open("/etc/hosts", "r") as f: - for line in f: - line = line.strip() - if line and not line.startswith("#"): - parts = line.split() - if len(parts) >= 2 and (parts[0] == my_ip or my_hostname in parts): - ip = parts[0] - if ip.startswith("127."): - continue - return parts - except FileNotFoundError: - pass - - return [my_ip, my_hostname] - - -# ============================================================================= -# Tensor Mapping and Permutation Helpers -# ============================================================================= - - -def permute_for_rotary(w: torch.Tensor, n_heads: int) -> torch.Tensor: - """ - Permute weight tensor for sliced rotary embeddings. - - vLLM and some model implementations use different layouts for Q/K projections. - This converts between them. - """ - dim1, dim2 = w.shape[0], w.shape[1] - return ( - w.view(n_heads, dim1 // n_heads // 2, 2, dim2) - .transpose(1, 2) - .reshape(dim1, dim2) - ) - - -def permute_for_rotary_1d(w: torch.Tensor, n_heads: int) -> torch.Tensor: - """Permute 1D tensor (bias) for sliced rotary embeddings.""" - dim1 = w.shape[0] - return w.view(n_heads, dim1 // n_heads // 2, 2).transpose(1, 2).reshape(dim1) - - -def get_name_conversions(param_mappings: Dict[str, Any]) -> Dict[str, List[str]]: - """ - Build a mapping from vLLM parameter names to trainer parameter names. - - vLLM may split or combine parameters differently than HuggingFace models. - This helps translate between naming conventions. - """ - name_conversions = defaultdict(list) - for name, info in param_mappings.items(): - vllm_name = info.get("vllm_name", name) - name_conversions[vllm_name].append(name) - return dict(name_conversions) - - # ============================================================================= # Bridge Configuration # ============================================================================= @@ -297,536 +164,321 @@ def get_name_conversions(param_mappings: Dict[str, Any]) -> Dict[str, List[str]] @dataclass class BridgeConfig: """Configuration for the vLLM weight bridge.""" - + # Process group settings trainer_rank: int = 0 world_size: int = 1 init_method: str = "env://" num_inference_nodes: int = 0 - + # Model settings model_name: str = "" device: str = "cuda" - + # Synchronization settings timeout_seconds: float = 300.0 log_dir: Optional[str] = None - - # vLLM server URL for HTTP-based sync (local mode) + + # vLLM server URL for HTTP-based sync (fallback) vllm_api_url: str = "http://localhost:9001" - # CUDA IPC mode: share GPU memory directly with vLLM (same GPU only!) - use_cuda_ipc: bool = False - # Derived from environment num_gpus_per_node: int = field(default_factory=lambda: torch.cuda.device_count()) - + @property def is_local_mode(self) -> bool: - """ - Local mode: single machine, no NCCL process groups needed. - Communication happens via HTTP to vLLM server. - """ + """Local mode: single machine, uses NCCL to daemon on same node.""" return self.num_inference_nodes == 0 - + + @property + def uses_nccl(self) -> bool: + """Whether NCCL is used for weight synchronization.""" + return self.num_inference_nodes >= 0 + @classmethod def from_training_config(cls, config: Any) -> "BridgeConfig": """Create BridgeConfig from a TrainingConfig object.""" return cls( - trainer_rank=config.trainer_rank, - world_size=config.world_size, - init_method=config.init_method, - num_inference_nodes=config.num_inference_nodes, + trainer_rank=getattr(config, 'trainer_rank', 0), + world_size=getattr(config, 'world_size', 1), + init_method=getattr(config, 'init_method', 'env://'), + num_inference_nodes=getattr(config, 'num_inference_nodes', 0), model_name=config.model_name, device=config.device, log_dir=os.environ.get("LOGDIR"), vllm_api_url=f"http://localhost:{getattr(config, 'vllm_port', 9001)}", - use_cuda_ipc=getattr(config, 'use_cuda_ipc', False), ) # ============================================================================= -# Main Bridge Class +# Weight Bridge Class # ============================================================================= class VLLMWeightBridge: """ - Bridge for sharing model weights between trainer and vLLM inference server. - - This class handles: - 1. Joining the distributed process group with vLLM workers - 2. Attaching to vLLM's model weight tensors - 3. Providing a model interface for the trainer to optimize - 4. Synchronizing updates so vLLM sees changes immediately - + Bridge for synchronizing model weights between trainer and vLLM. + + This class: + 1. Initializes NCCL process groups with vLLM's weight updater daemon + 2. Broadcasts weight updates after each optimizer.step() + 3. Ensures vLLM immediately uses updated weights for inference + Usage: bridge = VLLMWeightBridge(config) bridge.initialize() - model = bridge.get_trainable_model() - optimizer = AdamW(model.parameters(), lr=1e-5) - + for batch in data: loss = compute_loss(model, batch) loss.backward() optimizer.step() - bridge.notify_update() # vLLM now uses new weights + bridge.broadcast_weights(model) # vLLM now uses new weights """ - + def __init__(self, config: BridgeConfig): self.config = config self.device = torch.device(config.device) - - # Process groups (initialized in initialize()) + + # Process groups self.nccl_group: Optional[dist.ProcessGroup] = None self.gloo_group: Optional[dist.ProcessGroup] = None - + # Parameter mappings (loaded from vLLM's JSON) self.param_mappings: Dict[str, Any] = {} - self.name_conversions: Dict[str, List[str]] = {} - - # Shared tensors (attached in attach_to_vllm_weights()) - self.shared_state_dict: Dict[str, torch.Tensor] = {} - - # Model for training (created in get_trainable_model()) - self._model: Optional[nn.Module] = None - - # Synchronization state - self._update_count: int = 0 + self.param_name_list: List[str] = [] + + # State self._initialized: bool = False - + self._update_count: int = 0 + + # Derived config + self._num_training_gpus: int = 0 + self._total_group_size: int = 0 + def initialize(self) -> None: """ - Initialize the bridge: join process groups and load parameter mappings. - - In local mode (num_inference_nodes=0), skips NCCL setup and uses HTTP. - In distributed mode, creates NCCL/Gloo process groups. - - This must be called before any other methods. + Initialize the bridge: create process groups and load mappings. + + Must be called before any weight synchronization. """ if self._initialized: return - - print(f"[Bridge] Initializing weight bridge for rank {self.config.trainer_rank}") - - if self.config.is_local_mode: - self._initialize_local_mode() + + print(f"[Bridge] Initializing weight bridge (rank {self.config.trainer_rank})") + + if self.config.uses_nccl: + self._initialize_nccl_mode() else: - self._initialize_distributed_mode() - + self._initialize_http_mode() + self._initialized = True - - def _initialize_local_mode(self) -> None: - """ - Initialize for local single-machine mode. - - In local mode: - - No NCCL process groups (trainer and vLLM are separate processes) - - Communication via HTTP to vLLM's bridge endpoints - - Trainer loads its own model copy, OR uses CUDA IPC for true shared memory - """ - if self.config.use_cuda_ipc: - print("[Bridge] Using CUDA IPC MODE (true shared GPU memory)") - else: - print("[Bridge] Using LOCAL MODE (HTTP-based sync, no NCCL)") + + def _initialize_nccl_mode(self) -> None: + """Initialize NCCL-based weight synchronization.""" + print("[Bridge] Using NCCL mode for weight synchronization") + + # Get rendezvous URLs + master_addr, master_gloo_addr, _, nodelist = get_inference_urls( + self.config.num_inference_nodes + ) + + if master_addr is None: + raise RuntimeError( + "Could not determine inference URLs. " + "Set NUM_INFERENCE_NODES environment variable." + ) + + print(f"[Bridge] Master address: {master_addr}") + print(f"[Bridge] Inference nodes: {nodelist}") + + # Load parameter mappings from vLLM + self._load_param_mappings() + + # Calculate group sizes + self._num_training_gpus = ( + self.config.world_size * + (1 if self.config.num_inference_nodes == 0 else 8) # Assume 8 GPUs/node + ) + + if self.config.num_inference_nodes == 0: + # Single node: some GPUs for training, some for inference + num_inference_gpus = 4 # Default: 4 GPUs for inference + self._num_training_gpus = torch.cuda.device_count() - num_inference_gpus + + num_inference_gpus = ( + self.config.num_inference_nodes * 8 + if self.config.num_inference_nodes > 0 + else 4 + ) + self._total_group_size = self._num_training_gpus + num_inference_gpus + + print(f"[Bridge] Training GPUs: {self._num_training_gpus}") + print(f"[Bridge] Inference GPUs: {num_inference_gpus}") + print(f"[Bridge] Total group size: {self._total_group_size}") + + # Create Gloo group (for coordination) + print("[Bridge] Creating Gloo process group...") + self.gloo_group = init_process_group( + backend="gloo", + init_method=f"tcp://{master_addr}", + world_size=self._total_group_size, + rank=self.config.trainer_rank, + group_name="gloo_group", + ) + print("[Bridge] ✓ Gloo group created") + + # Create NCCL group (for tensor transfers) + print("[Bridge] Creating NCCL process group...") + self.nccl_group = init_process_group( + backend="nccl", + init_method=f"tcp://{master_addr}", + world_size=self._total_group_size, + rank=self.config.trainer_rank, + group_name="weight_update_group", + ) + print("[Bridge] ✓ NCCL group created") + + def _initialize_http_mode(self) -> None: + """Initialize HTTP-based weight synchronization (fallback).""" + print("[Bridge] Using HTTP mode for weight synchronization") print(f"[Bridge] vLLM API URL: {self.config.vllm_api_url}") - + # Verify vLLM server is reachable try: import requests response = requests.get(f"{self.config.vllm_api_url}/health", timeout=5) if response.status_code == 200: - print("[Bridge] vLLM server is reachable") + print("[Bridge] ✓ vLLM server is reachable") else: print(f"[Bridge] Warning: vLLM health check returned {response.status_code}") except Exception as e: - print(f"[Bridge] Warning: Could not reach vLLM server: {e}") - print("[Bridge] Training will continue, but vLLM sync may not work") - - # For CUDA IPC mode, request vLLM to export IPC handles - if self.config.use_cuda_ipc: - self._request_cuda_ipc_export() - self._load_cuda_ipc_handles() - - # Load parameter mappings if available (optional in local mode) - try: - self._load_param_mappings() - except RuntimeError: - print("[Bridge] Parameter mapping file not found (optional in local mode)") - self.param_mappings = {} + print(f"[Bridge] Warning: Could not reach vLLM: {e}") - def _request_cuda_ipc_export(self) -> None: - """Request vLLM to export CUDA IPC handles.""" - import requests - - print("[Bridge] Requesting CUDA IPC handles from vLLM...") - try: - response = requests.post( - f"{self.config.vllm_api_url}/bridge/export_cuda_ipc", - timeout=60 - ) - if response.status_code == 200: - result = response.json() - print(f"[Bridge] vLLM exported {result.get('num_parameters', 0)} IPC handles") - else: - raise RuntimeError(f"Failed to export IPC handles: {response.status_code}") - except Exception as e: - raise RuntimeError(f"Could not request CUDA IPC export: {e}") - - def _load_cuda_ipc_handles(self) -> None: - """ - Load CUDA IPC handles from file and reconstruct shared tensors. - - This is the key to TRUE shared memory - the tensors we create here - point to the SAME GPU memory that vLLM is using! - """ - import base64 - import pickle - - log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".") - ipc_path = Path(log_dir) / "cuda_ipc_handles.json" - - # Wait for file to be created - wait_time = 0 - while not ipc_path.exists() and wait_time < self.config.timeout_seconds: - print(f"[Bridge] Waiting for {ipc_path}...") - time.sleep(1) - wait_time += 1 - - if not ipc_path.exists(): - raise RuntimeError(f"CUDA IPC handles file not found: {ipc_path}") - - with open(ipc_path, "r") as f: - data = json.load(f) - - handles_data = data.get("handles", {}) - - print(f"[Bridge] Reconstructing {len(handles_data)} shared tensors from IPC handles...") - - reconstructed = 0 - for name, info in handles_data.items(): - try: - # Decode the IPC handle - handle_bytes = base64.b64decode(info["ipc_handle"]) - handle = pickle.loads(handle_bytes) - - # Reconstruct the storage from the IPC handle - # This does NOT allocate new memory - it maps to existing memory! - device = torch.device(f"cuda:{info['device_index']}") - - # Get dtype - dtype_str = info["dtype"] - dtype = getattr(torch, dtype_str.replace("torch.", "")) - - # Reconstruct tensor from IPC handle - # The storage is shared with vLLM's process - storage = torch.cuda.Storage._new_shared_cuda(*handle) - - # Create tensor view of the shared storage - tensor = torch.tensor([], dtype=dtype, device=device) - tensor.set_( - storage, - info["storage_offset"], - info["shape"], - info["stride"] - ) - - # Store in shared_state_dict - self.shared_state_dict[name] = tensor - reconstructed += 1 - - except Exception as e: - print(f"[Bridge] Warning: Could not reconstruct {name}: {e}") - continue - - print(f"[Bridge] Successfully reconstructed {reconstructed} shared tensors") - print(f"[Bridge] Memory savings: ~{reconstructed * 4 / 1024:.1f} GB (no model copy needed!)") - - def _initialize_distributed_mode(self) -> None: - """ - Initialize for distributed multi-node mode. - - Creates NCCL and Gloo process groups for direct tensor sharing. - """ - print("[Bridge] Using DISTRIBUTED MODE (NCCL tensor sharing)") - - # Get rendezvous URLs - master_addr, master_gloo_addr, master_inference_addr, nodelist = get_inference_urls( - self.config.num_inference_nodes - ) - - if master_addr is None: - raise RuntimeError( - "Could not determine inference server URLs. " - "Set NUM_INFERENCE_NODES environment variable or check SLURM_JOB_NODELIST." - ) - - print(f"[Bridge] Master address: {master_addr}") - print(f"[Bridge] Inference nodes: {nodelist}") - - # Load parameter mappings from vLLM - self._load_param_mappings() - - # Calculate total group size (trainers + inference workers) - num_training_gpus = self._get_num_training_gpus() - # In distributed mode, each inference node contributes num_gpus_per_node workers - num_inference_gpus = self.config.num_inference_nodes * self.config.num_gpus_per_node - - total_group_size = num_training_gpus + num_inference_gpus - trainer_rank_in_group = self.config.trainer_rank - - print(f"[Bridge] Training GPUs: {num_training_gpus}, Inference GPUs: {num_inference_gpus}") - print(f"[Bridge] Total group size: {total_group_size}, Trainer rank: {trainer_rank_in_group}") - - # Initialize NCCL group for tensor transfers - self.nccl_group = init_process_group( - backend="nccl", - init_method=f"tcp://{master_addr}", - world_size=total_group_size, - rank=trainer_rank_in_group, - group_name="weight_update_group", - timeout=timedelta(seconds=self.config.timeout_seconds), - ) - print("[Bridge] NCCL process group initialized") - - # Initialize Gloo group for metadata/coordination - self.gloo_group = init_process_group( - backend="gloo", - init_method=f"tcp://{master_gloo_addr}", - world_size=total_group_size, - rank=trainer_rank_in_group, - group_name="gloo_group", - timeout=timedelta(seconds=self.config.timeout_seconds), - ) - print("[Bridge] Gloo process group initialized") - def _load_param_mappings(self) -> None: - """Load parameter name mappings from vLLM's exported JSON.""" + """Load parameter name mappings from vLLM's config file.""" log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".") json_path = Path(log_dir) / "vllm_bridge_config.json" - - # Wait for vLLM to write the mapping file + + # Wait for file wait_time = 0 while not json_path.exists() and wait_time < self.config.timeout_seconds: - print(f"[Bridge] Waiting for {json_path} to be created...") + if wait_time % 10 == 0: + print(f"[Bridge] Waiting for {json_path}...") time.sleep(1) wait_time += 1 - + if not json_path.exists(): - raise RuntimeError( - f"Parameter mapping file not found at {json_path}. " - "Make sure vLLM is running and has exported its parameter mappings." - ) - - # Small delay to ensure file is fully written - time.sleep(1) - + raise RuntimeError(f"Config file not found: {json_path}") + + time.sleep(0.5) # Wait for file to finish writing + with open(json_path, "r") as f: data = json.load(f) - + self.param_mappings = data.get("param_mappings", {}) - self.name_conversions = get_name_conversions(self.param_mappings) - - print(f"[Bridge] Loaded mappings for {len(self.param_mappings)} parameters") - - def _get_num_training_gpus(self) -> int: - """Get number of training GPUs from param mappings or config.""" - if self.param_mappings: - # Try to get from vLLM's exported info - return self.param_mappings.get("dp_shard_degree", 1) * self.param_mappings.get("tp_degree", 1) - return self.config.world_size - - def attach_to_vllm_weights(self, vllm_state_dict: Dict[str, torch.Tensor]) -> None: + self.param_name_list = sorted(self.param_mappings.keys()) + + print(f"[Bridge] Loaded mappings for {len(self.param_name_list)} parameters") + + def broadcast_weights(self, model: nn.Module) -> None: """ - Attach to vLLM's weight tensors. - - After this call, self.shared_state_dict contains references to the - actual tensors that vLLM uses for inference. Modifying these tensors - will immediately affect vLLM's outputs. - + Broadcast all model weights to vLLM inference workers. + + Call this after optimizer.step() to push updated weights. + Args: - vllm_state_dict: vLLM's model state_dict (actual tensors, not copies) - """ - self.shared_state_dict = vllm_state_dict - print(f"[Bridge] Attached to {len(vllm_state_dict)} vLLM weight tensors") - - # Log tensor info for debugging - for name, tensor in list(vllm_state_dict.items())[:5]: - print(f"[Bridge] {name}: {tensor.shape}, {tensor.dtype}, {tensor.device}") - if len(vllm_state_dict) > 5: - print(f"[Bridge] ... and {len(vllm_state_dict) - 5} more") - - def get_trainable_model(self) -> nn.Module: - """ - Get a model whose parameters point to vLLM's shared tensors. - - In CUDA IPC mode: shared_state_dict is populated from IPC handles during init. - In other modes: must call attach_to_vllm_weights() first. - - This creates a HuggingFace model structure but replaces all parameters - with references to the shared tensors. When the optimizer updates these - parameters, it modifies vLLM's weights directly. - - Returns: - An nn.Module ready for training with shared weights - """ - if self._model is not None: - return self._model - - if not self.shared_state_dict: - if self.config.use_cuda_ipc: - raise RuntimeError( - "CUDA IPC mode enabled but no shared tensors found. " - "Check that vLLM exported IPC handles correctly." - ) - else: - raise RuntimeError( - "Must call attach_to_vllm_weights() before get_trainable_model()" - ) - - print(f"[Bridge] Creating trainable model for {self.config.model_name}") - if self.config.use_cuda_ipc: - print("[Bridge] Using CUDA IPC shared tensors (NO NEW GPU MEMORY!)") - - # Load model config (not weights) - model_config = AutoConfig.from_pretrained(self.config.model_name) - - # Create model with empty weights (meta device = no memory) - with torch.device("meta"): - model = AutoModelForCausalLM.from_config(model_config) - - # Replace each parameter with the shared tensor - self._replace_parameters_with_shared(model) - - # Move model structure to device (parameters already on device via IPC) - model.to(self.device) - self._model = model - - total_params = sum(p.numel() for p in model.parameters()) - print(f"[Bridge] Trainable model ready with {total_params:,} parameters") - - if self.config.use_cuda_ipc: - # Verify memory savings - param_memory_gb = total_params * 2 / 1e9 # bfloat16 = 2 bytes - print(f"[Bridge] CUDA IPC memory savings: ~{param_memory_gb:.1f} GB (shared with vLLM)") - - return model - - def _replace_parameters_with_shared(self, model: nn.Module) -> None: - """ - Replace model parameters with references to shared vLLM tensors. - - This is the key operation that makes weight sharing work. After this, - model.parameters() returns tensors that ARE vLLM's weights. - """ - replaced_count = 0 - missing_params = [] - - for name, param in model.named_parameters(): - # Convert HuggingFace param name to vLLM param name - vllm_name = self._hf_to_vllm_name(name) - - if vllm_name in self.shared_state_dict: - shared_tensor = self.shared_state_dict[vllm_name] - - # Create a new Parameter that wraps the shared tensor - # The key is that we're not copying - we're referencing the same storage - new_param = nn.Parameter(shared_tensor, requires_grad=True) - - # Replace the parameter in the model - self._set_parameter(model, name, new_param) - replaced_count += 1 - else: - missing_params.append(name) - - print(f"[Bridge] Replaced {replaced_count} parameters with shared tensors") - if missing_params: - print(f"[Bridge] Warning: {len(missing_params)} parameters not found in shared state:") - for p in missing_params[:5]: - print(f"[Bridge] {p}") - - def _hf_to_vllm_name(self, hf_name: str) -> str: - """ - Convert a HuggingFace parameter name to vLLM's naming convention. - - vLLM may merge QKV projections, use different layer naming, etc. - This handles the translation. - """ - # Check if we have an explicit mapping - for vllm_name, hf_names in self.name_conversions.items(): - if hf_name in hf_names: - return vllm_name - - # Common transformations - # vLLM often uses: model.layers.N.self_attn.qkv_proj - # HF uses: model.layers.N.self_attn.q_proj, k_proj, v_proj - - # For now, try the name as-is - return hf_name - - def _set_parameter(self, model: nn.Module, name: str, new_param: nn.Parameter) -> None: - """Set a parameter by dotted name path.""" - parts = name.split(".") - module = model - for part in parts[:-1]: - module = getattr(module, part) - setattr(module, parts[-1], new_param) - - def broadcast_weights_to_inference(self) -> None: - """ - Broadcast updated weights from trainer to inference workers. - - Call this after optimizer.step() to push the new weights to all - vLLM inference processes. They will use the updated weights for - subsequent requests. + model: The model whose weights to broadcast """ if not self._initialized: raise RuntimeError("Bridge not initialized. Call initialize() first.") - - param_names = sorted(self.param_mappings.keys()) - + + if self.nccl_group is None: + # HTTP mode - just notify + self._notify_update_http() + return + + self._update_count += 1 + start_time = time.time() + + state_dict = dict(model.named_parameters()) + with torch.no_grad(): - for idx, param_name in enumerate(param_names): + for idx, param_name in enumerate(self.param_name_list): # Signal which parameter we're broadcasting idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device) dist.broadcast(idx_tensor, src=0, group=self.nccl_group) - - # Get the tensor for this parameter - vllm_name = self.param_mappings[param_name].get("vllm_name", param_name) - if vllm_name not in self.shared_state_dict: + + # Get tensor for this parameter + if param_name not in state_dict: continue - - tensor = self.shared_state_dict[vllm_name] - local_shape = self.param_mappings[param_name].get("local_shape", list(tensor.shape)) - - # Gather from all training ranks, then broadcast to inference - # (This handles FSDP/TP sharding if present) - dist.all_gather( - [torch.zeros(local_shape, dtype=tensor.dtype, device=self.device) - for _ in range(dist.get_world_size(self.nccl_group))], - tensor, - group=self.nccl_group, + + tensor = state_dict[param_name].data + local_shape = self.param_mappings[param_name].get( + "local_shape", list(tensor.shape) ) - - self._update_count += 1 - print(f"[Bridge] Broadcast update #{self._update_count} complete") - + + # All-gather to distribute to all ranks (including inference) + tensor_list = [ + torch.zeros(local_shape, dtype=tensor.dtype, device=self.device) + for _ in range(self._total_group_size) + ] + dist.all_gather(tensor_list, tensor, group=self.nccl_group) + + elapsed = time.time() - start_time + print(f"[Bridge] Broadcast update #{self._update_count} ({elapsed:.2f}s)") + + def broadcast_single_param( + self, + model: nn.Module, + param_name: str + ) -> None: + """ + Broadcast a single parameter to vLLM. + + Useful for incremental updates or debugging. + """ + if self.nccl_group is None: + return + + if param_name not in self.param_name_list: + print(f"[Bridge] Warning: {param_name} not in param list") + return + + idx = self.param_name_list.index(param_name) + state_dict = dict(model.named_parameters()) + + if param_name not in state_dict: + return + + with torch.no_grad(): + idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device) + dist.broadcast(idx_tensor, src=0, group=self.nccl_group) + + tensor = state_dict[param_name].data + local_shape = self.param_mappings[param_name].get( + "local_shape", list(tensor.shape) + ) + + tensor_list = [ + torch.zeros(local_shape, dtype=tensor.dtype, device=self.device) + for _ in range(self._total_group_size) + ] + dist.all_gather(tensor_list, tensor, group=self.nccl_group) + def notify_update(self) -> None: """ - Notify inference workers that weights have been updated. - - In local mode: sends HTTP request to vLLM's /bridge/notify_update endpoint - In distributed mode: broadcasts update counter via Gloo + Notify vLLM that weights have been updated. + + In NCCL mode, this is a no-op (updates are immediate). + In HTTP mode, sends a notification to vLLM. """ self._update_count += 1 - - if self.config.is_local_mode: + + if self.nccl_group is None: self._notify_update_http() - elif self.gloo_group is not None: - self._notify_update_distributed() - + def _notify_update_http(self) -> None: - """Notify vLLM via HTTP (local mode).""" + """Notify vLLM via HTTP (fallback mode).""" try: import requests response = requests.post( @@ -841,45 +493,52 @@ class VLLMWeightBridge: if response.status_code != 200: print(f"[Bridge] Warning: notify_update returned {response.status_code}") except Exception as e: - # Don't fail training if vLLM notification fails print(f"[Bridge] Warning: Could not notify vLLM: {e}") - - def _notify_update_distributed(self) -> None: - """Notify via Gloo broadcast (distributed mode).""" - update_tensor = torch.tensor([self._update_count], dtype=torch.long) - dist.broadcast(update_tensor, src=0, group=self.gloo_group) - - def barrier(self) -> None: - """Wait for all processes in the group to reach this point.""" - if self.nccl_group is not None: - dist.barrier(group=self.nccl_group) - + + def send_heartbeat(self) -> None: + """ + Send heartbeat signal to keep inference workers alive. + + In NCCL mode, sends -1 as the parameter index to signal + "no update this round". + """ + if self.nccl_group is None: + return + + with torch.no_grad(): + idx_tensor = torch.tensor([-1], dtype=torch.long, device=self.device) + dist.broadcast(idx_tensor, src=0, group=self.nccl_group) + def cleanup(self) -> None: - """Clean up process groups and resources.""" + """Clean up resources.""" + print("[Bridge] Cleaning up...") + + # Send shutdown signal (optional) if self.nccl_group is not None: - dist.destroy_process_group(self.nccl_group) - self.nccl_group = None - - if self.gloo_group is not None: - dist.destroy_process_group(self.gloo_group) - self.gloo_group = None - + try: + # Send -2 to signal shutdown (if implemented in updater) + with torch.no_grad(): + idx_tensor = torch.tensor([-2], dtype=torch.long, device=self.device) + dist.broadcast(idx_tensor, src=0, group=self.nccl_group) + except Exception: + pass + self._initialized = False - print("[Bridge] Cleaned up") + print("[Bridge] Cleanup complete") # ============================================================================= -# Convenience Functions +# Factory Function # ============================================================================= def create_bridge_from_training_config(config: Any) -> VLLMWeightBridge: """ - Create and initialize a VLLMWeightBridge from a TrainingConfig. - + Create a VLLMWeightBridge from a TrainingConfig object. + Args: - config: TrainingConfig object with bridge settings - + config: TrainingConfig with model and distributed settings + Returns: Initialized VLLMWeightBridge ready for use """ @@ -888,3 +547,49 @@ def create_bridge_from_training_config(config: Any) -> VLLMWeightBridge: bridge.initialize() return bridge + +def export_param_mappings( + model: nn.Module, + model_name: str, + tp_degree: int = 1, + dp_shard_degree: int = 1, + log_dir: Optional[str] = None, +) -> None: + """ + Export parameter mappings to JSON for vLLM to read. + + Call this from the trainer BEFORE starting vLLM. + + Args: + model: The model being trained + model_name: HuggingFace model name + tp_degree: Tensor parallel degree + dp_shard_degree: Data parallel shard degree (FSDP) + log_dir: Directory to write config file + """ + log_dir = log_dir or os.environ.get("LOGDIR", ".") + json_path = Path(log_dir) / "vllm_bridge_config.json" + + param_mappings = {} + + for name, param in model.named_parameters(): + param_mappings[name] = { + "vllm_name": name, # May need transformation for some models + "shape": list(param.shape), + "local_shape": list(param.shape), # For FSDP, this would be shard shape + "dtype": str(param.dtype), + "tp_shard_dim": 0, + "needs_permute": False, # Set True for rotary embedding weights + } + + config = { + "model": model_name, + "tp_degree": tp_degree, + "dp_shard_degree": dp_shard_degree, + "param_mappings": param_mappings, + } + + with open(json_path, "w") as f: + json.dump(config, f, indent=2) + + print(f"[Bridge] Exported param mappings to {json_path}")