""" vLLM Weight Bridge - Trainer-side integration for shared memory weight updates. This module coordinates weight updates between the trainer and vLLM inference. ARCHITECTURE: The patched vLLM server (using vllm_patching/) runs a daemon process that: 1. Joins NCCL process groups with the trainer 2. Receives weight updates via all_gather 3. Copies updates into vLLM's shared memory tensors ┌─────────────────────────────────────────────────────────────────────────┐ │ SHARED MEMORY (via share_memory_()) │ │ ┌─────────────────────────────────────────────────────────────────┐ │ │ │ Model Weights │ │ │ │ (accessible from MULTIPLE processes) │ │ │ └─────────────────────────────────────────────────────────────────┘ │ │ ▲ ▲ │ │ │ Reads │ Writes │ │ ┌────────┴────────┐ ┌───────────┴───────────┐ │ │ │ vLLM Worker │ │ weight_updater │ │ │ │ (inference) │ │ daemon process │ │ │ └─────────────────┘ └───────────┬───────────┘ │ │ │ NCCL │ │ ▼ │ │ ┌─────────────────────┐ │ │ │ Trainer Process │ │ │ │ (this bridge) │ │ │ └─────────────────────┘ │ └─────────────────────────────────────────────────────────────────────────┘ MODES: LOCAL MODE (num_inference_nodes=0): - Single machine setup - Trainer and vLLM share the same node - NCCL for weight broadcast to vLLM's daemon DISTRIBUTED MODE (num_inference_nodes>0): - Multi-node setup with dedicated inference nodes - Last N nodes run vLLM inference - NCCL spans across nodes for weight updates """ from __future__ import annotations import json import os import socket import time from collections import defaultdict from dataclasses import dataclass, field from datetime import timedelta from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch import torch.distributed as dist from torch import nn # ============================================================================= # Process Group Initialization # ============================================================================= def init_process_group( backend: Optional[str] = None, init_method: Optional[str] = None, timeout: Optional[timedelta] = None, world_size: int = -1, rank: int = -1, store: Optional[Any] = None, group_name: str = "", pg_options: Optional[Any] = None, ) -> dist.ProcessGroup: """ Initialize a custom process group for weight synchronization. Creates a named group that coexists with vLLM's internal process groups. """ from torch.distributed.distributed_c10d import ( _new_process_group_helper, _world, Backend, default_pg_timeout, PrefixStore, rendezvous, ) assert (store is None) or (init_method is None), \ "Cannot specify both init_method and store." if store is not None: assert world_size > 0, "world_size must be positive if using store" assert rank >= 0, "rank must be non-negative if using store" elif init_method is None: init_method = "env://" backend = Backend(backend) if backend else Backend("undefined") timeout = timeout or default_pg_timeout if store is None: rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout) store = PrefixStore(group_name, store) # Handle PyTorch version differences pg_options_param_name = ( "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" ) pg, _ = _new_process_group_helper( world_size, rank, [], backend, store, group_name=group_name, **{pg_options_param_name: pg_options}, timeout=timeout, ) _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} return pg def get_inference_urls(num_inference_nodes: int = 0) -> Tuple[Optional[str], ...]: """ Get URLs for inference server communication. Returns: Tuple of (master_addr, master_gloo_addr, master_inference_addr, nodelist) """ if num_inference_nodes > 0: slurm_nodelist = os.environ.get("SLURM_JOB_NODELIST") if not slurm_nodelist: return None, None, None, None nodelist = ( os.popen(f'scontrol show hostnames {slurm_nodelist}') .read().strip().split("\n") ) nodelist = [n for n in nodelist if n] master_server = f"{nodelist[0]}:26756" master_gloo_server = f"{nodelist[0]}:26757" inference_nodes = nodelist[-num_inference_nodes:] master_inference_server = f"{inference_nodes[0]}:26758" return master_server, master_gloo_server, master_inference_server, inference_nodes elif num_inference_nodes == 0: return "localhost:26756", "localhost:26757", "localhost:26758", ["localhost"] else: return None, None, None, None # ============================================================================= # Bridge Configuration # ============================================================================= @dataclass class BridgeConfig: """Configuration for the vLLM weight bridge.""" # Process group settings trainer_rank: int = 0 world_size: int = 1 init_method: str = "env://" num_inference_nodes: int = 0 # Model settings model_name: str = "" device: str = "cuda" # Synchronization settings timeout_seconds: float = 300.0 log_dir: Optional[str] = None # vLLM server URL for HTTP-based sync (fallback) vllm_api_url: str = "http://localhost:9001" # Derived from environment num_gpus_per_node: int = field(default_factory=lambda: torch.cuda.device_count()) @property def is_local_mode(self) -> bool: """Local mode: single machine, uses NCCL to daemon on same node.""" return self.num_inference_nodes == 0 @property def uses_nccl(self) -> bool: """Whether NCCL is used for weight synchronization.""" return self.num_inference_nodes >= 0 @classmethod def from_training_config(cls, config: Any) -> "BridgeConfig": """Create BridgeConfig from a TrainingConfig object.""" return cls( trainer_rank=getattr(config, 'trainer_rank', 0), world_size=getattr(config, 'world_size', 1), init_method=getattr(config, 'init_method', 'env://'), num_inference_nodes=getattr(config, 'num_inference_nodes', 0), model_name=config.model_name, device=config.device, log_dir=os.environ.get("LOGDIR"), vllm_api_url=f"http://localhost:{getattr(config, 'vllm_port', 9001)}", ) # ============================================================================= # Weight Bridge Class # ============================================================================= class VLLMWeightBridge: """ Bridge for synchronizing model weights between trainer and vLLM. This class: 1. Initializes NCCL process groups with vLLM's weight updater daemon 2. Broadcasts weight updates after each optimizer.step() 3. Ensures vLLM immediately uses updated weights for inference Usage: bridge = VLLMWeightBridge(config) bridge.initialize() for batch in data: loss = compute_loss(model, batch) loss.backward() optimizer.step() bridge.broadcast_weights(model) # vLLM now uses new weights """ def __init__(self, config: BridgeConfig): self.config = config self.device = torch.device(config.device) # Process groups self.nccl_group: Optional[dist.ProcessGroup] = None self.gloo_group: Optional[dist.ProcessGroup] = None # Parameter mappings (loaded from vLLM's JSON) self.param_mappings: Dict[str, Any] = {} self.param_name_list: List[str] = [] # State self._initialized: bool = False self._update_count: int = 0 # Derived config self._num_training_gpus: int = 0 self._total_group_size: int = 0 def initialize(self) -> None: """ Initialize the bridge: create process groups and load mappings. Must be called before any weight synchronization. """ if self._initialized: return print(f"[Bridge] Initializing weight bridge (rank {self.config.trainer_rank})") if self.config.uses_nccl: self._initialize_nccl_mode() else: self._initialize_http_mode() self._initialized = True def _initialize_nccl_mode(self) -> None: """Initialize NCCL-based weight synchronization.""" print("[Bridge] Using NCCL mode for weight synchronization") # Get rendezvous URLs master_addr, master_gloo_addr, _, nodelist = get_inference_urls( self.config.num_inference_nodes ) if master_addr is None: raise RuntimeError( "Could not determine inference URLs. " "Set NUM_INFERENCE_NODES environment variable." ) print(f"[Bridge] Master address: {master_addr}") print(f"[Bridge] Inference nodes: {nodelist}") # Load parameter mappings from vLLM self._load_param_mappings() # Calculate group sizes self._num_training_gpus = ( self.config.world_size * (1 if self.config.num_inference_nodes == 0 else 8) # Assume 8 GPUs/node ) if self.config.num_inference_nodes == 0: # Single node: some GPUs for training, some for inference num_inference_gpus = 4 # Default: 4 GPUs for inference self._num_training_gpus = torch.cuda.device_count() - num_inference_gpus num_inference_gpus = ( self.config.num_inference_nodes * 8 if self.config.num_inference_nodes > 0 else 4 ) self._total_group_size = self._num_training_gpus + num_inference_gpus print(f"[Bridge] Training GPUs: {self._num_training_gpus}") print(f"[Bridge] Inference GPUs: {num_inference_gpus}") print(f"[Bridge] Total group size: {self._total_group_size}") # Create Gloo group (for coordination) print("[Bridge] Creating Gloo process group...") self.gloo_group = init_process_group( backend="gloo", init_method=f"tcp://{master_addr}", world_size=self._total_group_size, rank=self.config.trainer_rank, group_name="gloo_group", ) print("[Bridge] ✓ Gloo group created") # Create NCCL group (for tensor transfers) print("[Bridge] Creating NCCL process group...") self.nccl_group = init_process_group( backend="nccl", init_method=f"tcp://{master_addr}", world_size=self._total_group_size, rank=self.config.trainer_rank, group_name="weight_update_group", ) print("[Bridge] ✓ NCCL group created") def _initialize_http_mode(self) -> None: """Initialize HTTP-based weight synchronization (fallback).""" print("[Bridge] Using HTTP mode for weight synchronization") print(f"[Bridge] vLLM API URL: {self.config.vllm_api_url}") # Verify vLLM server is reachable try: import requests response = requests.get(f"{self.config.vllm_api_url}/health", timeout=5) if response.status_code == 200: print("[Bridge] ✓ vLLM server is reachable") else: print(f"[Bridge] Warning: vLLM health check returned {response.status_code}") except Exception as e: print(f"[Bridge] Warning: Could not reach vLLM: {e}") def _load_param_mappings(self) -> None: """Load parameter name mappings from vLLM's config file.""" log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".") json_path = Path(log_dir) / "vllm_bridge_config.json" # Wait for file wait_time = 0 while not json_path.exists() and wait_time < self.config.timeout_seconds: if wait_time % 10 == 0: print(f"[Bridge] Waiting for {json_path}...") time.sleep(1) wait_time += 1 if not json_path.exists(): raise RuntimeError(f"Config file not found: {json_path}") time.sleep(0.5) # Wait for file to finish writing with open(json_path, "r") as f: data = json.load(f) self.param_mappings = data.get("param_mappings", {}) self.param_name_list = sorted(self.param_mappings.keys()) print(f"[Bridge] Loaded mappings for {len(self.param_name_list)} parameters") def broadcast_weights(self, model: nn.Module) -> None: """ Broadcast all model weights to vLLM inference workers. Call this after optimizer.step() to push updated weights. Args: model: The model whose weights to broadcast """ if not self._initialized: raise RuntimeError("Bridge not initialized. Call initialize() first.") if self.nccl_group is None: # HTTP mode - just notify self._notify_update_http() return self._update_count += 1 start_time = time.time() state_dict = dict(model.named_parameters()) with torch.no_grad(): for idx, param_name in enumerate(self.param_name_list): # Signal which parameter we're broadcasting idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device) dist.broadcast(idx_tensor, src=0, group=self.nccl_group) # Get tensor for this parameter if param_name not in state_dict: continue tensor = state_dict[param_name].data local_shape = self.param_mappings[param_name].get( "local_shape", list(tensor.shape) ) # All-gather to distribute to all ranks (including inference) tensor_list = [ torch.zeros(local_shape, dtype=tensor.dtype, device=self.device) for _ in range(self._total_group_size) ] dist.all_gather(tensor_list, tensor, group=self.nccl_group) elapsed = time.time() - start_time print(f"[Bridge] Broadcast update #{self._update_count} ({elapsed:.2f}s)") def broadcast_single_param( self, model: nn.Module, param_name: str ) -> None: """ Broadcast a single parameter to vLLM. Useful for incremental updates or debugging. """ if self.nccl_group is None: return if param_name not in self.param_name_list: print(f"[Bridge] Warning: {param_name} not in param list") return idx = self.param_name_list.index(param_name) state_dict = dict(model.named_parameters()) if param_name not in state_dict: return with torch.no_grad(): idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device) dist.broadcast(idx_tensor, src=0, group=self.nccl_group) tensor = state_dict[param_name].data local_shape = self.param_mappings[param_name].get( "local_shape", list(tensor.shape) ) tensor_list = [ torch.zeros(local_shape, dtype=tensor.dtype, device=self.device) for _ in range(self._total_group_size) ] dist.all_gather(tensor_list, tensor, group=self.nccl_group) def notify_update(self) -> None: """ Notify vLLM that weights have been updated. In NCCL mode, this is a no-op (updates are immediate). In HTTP mode, sends a notification to vLLM. """ self._update_count += 1 if self.nccl_group is None: self._notify_update_http() def _notify_update_http(self) -> None: """Notify vLLM via HTTP (fallback mode).""" try: import requests response = requests.post( f"{self.config.vllm_api_url}/bridge/notify_update", json={ "update_count": self._update_count, "trainer_rank": self.config.trainer_rank, "timestamp": time.time(), }, timeout=5, ) if response.status_code != 200: print(f"[Bridge] Warning: notify_update returned {response.status_code}") except Exception as e: print(f"[Bridge] Warning: Could not notify vLLM: {e}") def send_heartbeat(self) -> None: """ Send heartbeat signal to keep inference workers alive. In NCCL mode, sends -1 as the parameter index to signal "no update this round". """ if self.nccl_group is None: return with torch.no_grad(): idx_tensor = torch.tensor([-1], dtype=torch.long, device=self.device) dist.broadcast(idx_tensor, src=0, group=self.nccl_group) def cleanup(self) -> None: """Clean up resources.""" print("[Bridge] Cleaning up...") # Send shutdown signal (optional) if self.nccl_group is not None: try: # Send -2 to signal shutdown (if implemented in updater) with torch.no_grad(): idx_tensor = torch.tensor([-2], dtype=torch.long, device=self.device) dist.broadcast(idx_tensor, src=0, group=self.nccl_group) except Exception: pass self._initialized = False print("[Bridge] Cleanup complete") # ============================================================================= # Factory Function # ============================================================================= def create_bridge_from_training_config(config: Any) -> VLLMWeightBridge: """ Create a VLLMWeightBridge from a TrainingConfig object. Args: config: TrainingConfig with model and distributed settings Returns: Initialized VLLMWeightBridge ready for use """ bridge_config = BridgeConfig.from_training_config(config) bridge = VLLMWeightBridge(bridge_config) bridge.initialize() return bridge def export_param_mappings( model: nn.Module, model_name: str, tp_degree: int = 1, dp_shard_degree: int = 1, log_dir: Optional[str] = None, ) -> None: """ Export parameter mappings to JSON for vLLM to read. Call this from the trainer BEFORE starting vLLM. Args: model: The model being trained model_name: HuggingFace model name tp_degree: Tensor parallel degree dp_shard_degree: Data parallel shard degree (FSDP) log_dir: Directory to write config file """ log_dir = log_dir or os.environ.get("LOGDIR", ".") json_path = Path(log_dir) / "vllm_bridge_config.json" param_mappings = {} for name, param in model.named_parameters(): param_mappings[name] = { "vllm_name": name, # May need transformation for some models "shape": list(param.shape), "local_shape": list(param.shape), # For FSDP, this would be shard shape "dtype": str(param.dtype), "tp_shard_dim": 0, "needs_permute": False, # Set True for rotary embedding weights } config = { "model": model_name, "tp_degree": tp_degree, "dp_shard_degree": dp_shard_degree, "param_mappings": param_mappings, } with open(json_path, "w") as f: json.dump(config, f, indent=2) print(f"[Bridge] Exported param mappings to {json_path}")