diff --git a/example_trainer/vllm_patching/__init__.py b/example_trainer/vllm_patching/__init__.py index 7a4ab38b..e7fb9128 100644 --- a/example_trainer/vllm_patching/__init__.py +++ b/example_trainer/vllm_patching/__init__.py @@ -1,13 +1,20 @@ """ -vLLM Patching Module - Enables shared memory weight updates. +vLLM Patching Module - Enables CUDA IPC shared memory for single-copy training. This module patches vLLM's GPUModelRunner to: 1. Call share_memory_() on model weights after loading -2. Spawn a daemon process that receives NCCL weight updates from trainers -3. Enable real-time weight synchronization without restarting vLLM +2. Export CUDA IPC handles to vllm_bridge_config.json +3. Enable the trainer to attach to vLLM's tensors directly + +The result: ONE copy of model weights in GPU memory, shared between +vLLM (inference) and the trainer (gradient updates). Usage: - # Import this BEFORE importing vllm + # Set environment BEFORE importing + import os + os.environ["VLLM_ENABLE_SHARED_WEIGHTS"] = "1" + + # Import and apply patches BEFORE importing vllm from example_trainer.vllm_patching import apply_patches apply_patches() @@ -21,24 +28,10 @@ from .patched_gpu_runner import ( get_patched_runner, is_patched, ) -from .weight_updater import weight_updater_process -from .distributed_utils import ( - init_process_group, - broadcast_object_list, - get_inference_urls, - get_json_data, -) __all__ = [ "PatchedGPUModelRunner", "apply_patches", "get_patched_runner", "is_patched", - "weight_updater_process", - "init_process_group", - "broadcast_object_list", - "get_inference_urls", - "get_json_data", ] - - diff --git a/example_trainer/vllm_patching/distributed_utils.py b/example_trainer/vllm_patching/distributed_utils.py deleted file mode 100644 index 41cfca26..00000000 --- a/example_trainer/vllm_patching/distributed_utils.py +++ /dev/null @@ -1,328 +0,0 @@ -""" -Distributed utilities for vLLM weight synchronization. - -Provides process group initialization and communication helpers -for coordinating weight updates between trainer and vLLM. -""" - -from __future__ import annotations - -import json -import os -import socket -import time -from collections import defaultdict -from datetime import timedelta -from typing import Any, Dict, List, Optional, Tuple - -import torch -import torch.distributed as dist - - -def init_process_group( - backend: Optional[str] = None, - init_method: Optional[str] = None, - timeout: Optional[timedelta] = None, - world_size: int = -1, - rank: int = -1, - store: Optional[Any] = None, - group_name: str = "", - pg_options: Optional[Any] = None, -) -> dist.ProcessGroup: - """ - Initialize a custom process group for weight synchronization. - - This creates a named process group that coexists with vLLM's internal - process groups, enabling direct tensor communication between trainer - and inference processes. - - Args: - backend: "nccl" for GPU, "gloo" for CPU - init_method: Rendezvous URL (e.g., "tcp://host:port") - timeout: How long to wait for other ranks - world_size: Total number of processes - rank: This process's rank - store: Optional torch.distributed Store - group_name: Name for this process group (must match across ranks) - pg_options: Backend-specific options - - Returns: - ProcessGroup for collective operations - """ - from torch.distributed.distributed_c10d import ( - _new_process_group_helper, - _world, - Backend, - default_pg_timeout, - PrefixStore, - rendezvous, - ) - - assert (store is None) or (init_method is None), \ - "Cannot specify both init_method and store." - - if store is not None: - assert world_size > 0, "world_size must be positive if using store" - assert rank >= 0, "rank must be non-negative if using store" - elif init_method is None: - init_method = "env://" - - if backend: - backend = Backend(backend) - else: - backend = Backend("undefined") - - if timeout is None: - timeout = default_pg_timeout - - # Create store via rendezvous if not provided - if store is None: - rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) - store, rank, world_size = next(rendezvous_iterator) - store.set_timeout(timeout) - store = PrefixStore(group_name, store) - - # Handle PyTorch version differences for pg_options parameter - pg_options_param_name = ( - "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" - ) - - pg, _ = _new_process_group_helper( - world_size, - rank, - [], - backend, - store, - group_name=group_name, - **{pg_options_param_name: pg_options}, - timeout=timeout, - ) - - _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} - - return pg - - -def broadcast_object_list( - object_list: List[Any], - src: Optional[int] = None, - group: Optional[dist.ProcessGroup] = None, - device: Optional[torch.device] = None, - group_src: Optional[int] = None, -) -> None: - """ - Broadcast a list of objects from source rank to all other ranks. - - Modified from torch.distributed.broadcast_object_list to work correctly - with custom process groups where rank 0 may not be the default group's rank 0. - - Args: - object_list: List of objects to broadcast (modified in-place on receivers) - src: Global source rank (deprecated, use group_src) - group: Process group to use - device: Device for temporary tensors - group_src: Source rank within the group - """ - global_src = group_src if group_src is not None else src - current_device = device - - # Broadcast object sizes first - object_sizes_tensor = torch.empty( - len(object_list), dtype=torch.long, device=current_device - ) - dist.broadcast(object_sizes_tensor, src=global_src, group=group) - - # Broadcast serialized objects - object_tensor = torch.empty( - torch.sum(object_sizes_tensor).item(), - dtype=torch.uint8, - device=current_device, - ) - dist.broadcast(object_tensor, src=global_src, group=group) - - # Deserialize objects - offset = 0 - for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset : offset + obj_size] - obj_view = obj_view.type(torch.uint8) - offset += obj_size - object_list[i] = dist.distributed_c10d._tensor_to_object( - obj_view, obj_size, group - ) - - -def get_inference_urls(num_inference_nodes: int = 0) -> Tuple[Optional[str], ...]: - """ - Get URLs for inference server communication. - - Parses SLURM environment or uses localhost for single-machine setup. - - Args: - num_inference_nodes: Number of dedicated inference nodes. - 0 = single machine, trainer and vLLM share the node - >0 = multi-node, last N nodes are for inference - - Returns: - Tuple of (master_addr, master_gloo_addr, master_inference_addr, nodelist) - Returns (None, None, None, None) if not in a valid setup. - """ - if num_inference_nodes > 0: - # Multi-node SLURM setup - slurm_nodelist = os.environ.get("SLURM_JOB_NODELIST") - if not slurm_nodelist: - return None, None, None, None - - # Parse SLURM node list - nodelist = ( - os.popen(f'scontrol show hostnames {slurm_nodelist}') - .read() - .strip() - .split("\n") - ) - nodelist = [node for node in nodelist if node] - - # First node is master for process groups - master_server = f"{nodelist[0]}:26756" - master_gloo_server = f"{nodelist[0]}:26757" - - # Last N nodes are inference nodes - inference_nodes = nodelist[-num_inference_nodes:] - master_inference_server = f"{inference_nodes[0]}:26758" - - return master_server, master_gloo_server, master_inference_server, inference_nodes - - elif num_inference_nodes == 0: - # Single machine setup - master_server = "localhost:26756" - master_gloo_server = "localhost:26757" - master_inference_server = "localhost:26758" - nodelist = ["localhost"] - - return master_server, master_gloo_server, master_inference_server, nodelist - - else: - return None, None, None, None - - -def get_hostnames() -> Optional[List[str]]: - """ - Get the hostnames for this machine. - - Parses /etc/hosts to find all hostnames associated with this machine's IP. - - Returns: - List of [ip, hostname1, hostname2, ...] or None if not found. - """ - my_ip = socket.gethostbyname(socket.gethostname()) - my_hostname = socket.gethostname() - - try: - with open("/etc/hosts", "r") as f: - for line in f: - line = line.strip() - if line and not line.startswith("#"): - parts = line.split() - if len(parts) >= 2 and ((parts[0] == my_ip) or (my_hostname in parts)): - ip = parts[0] - if ip.startswith("127."): - continue - return parts - except Exception: - pass - - return None - - -def get_json_data(log_dir: Optional[str] = None, timeout: int = 300) -> Dict[str, Any]: - """ - Load the bridge configuration JSON from vLLM. - - Waits for the file to be created by vLLM's weight bridge setup. - - Args: - log_dir: Directory containing the JSON file (defaults to LOGDIR env var) - timeout: Maximum seconds to wait for file - - Returns: - Parsed JSON data with parameter mappings and configuration. - - Raises: - ValueError: If LOGDIR not set and log_dir not provided - FileNotFoundError: If file not found after timeout - """ - if log_dir is None: - log_dir = os.environ.get("LOGDIR") - if log_dir is None: - raise ValueError("LOGDIR environment variable not set and log_dir not provided") - - json_path = os.path.join(log_dir, "vllm_bridge_config.json") - - wait_time = 0 - while not os.path.exists(json_path): - if wait_time >= timeout: - raise FileNotFoundError(f"Config file not found after {timeout}s: {json_path}") - if wait_time % 10 == 0: - print(f"[Updater] Waiting for {json_path}...", flush=True) - time.sleep(1) - wait_time += 1 - - # Wait a moment for file to finish writing - time.sleep(0.5) - - with open(json_path, "r") as f: - return json.load(f) - - -def get_name_conversions(param_mappings: Dict[str, Any]) -> Dict[str, List[str]]: - """ - Build reverse mapping from vLLM names to trainer names. - - Args: - param_mappings: Dict mapping trainer param names to vLLM info - - Returns: - Dict mapping vLLM names to list of trainer names - """ - name_conversions = defaultdict(list) - for name, info in param_mappings.items(): - vllm_name = info.get("vllm_name", name) - name_conversions[vllm_name].append(name) - return name_conversions - - -# Permutation functions for rotary embeddings -def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor: - """ - Permute weight tensor for sliced rotary embeddings. - - Args: - w: Weight tensor of shape [dim1, dim2] - n_heads: Number of attention heads - - Returns: - Permuted tensor for rotary embedding compatibility - """ - dim1 = w.shape[0] - dim2 = w.shape[1] - return ( - w.view(n_heads, dim1 // n_heads // 2, 2, dim2) - .transpose(1, 2) - .reshape(dim1, dim2) - ) - - -def permute_1d(w: torch.Tensor, n_heads: int) -> torch.Tensor: - """ - Permute 1D weight tensor (bias) for sliced rotary embeddings. - - Args: - w: Weight tensor of shape [dim1] - n_heads: Number of attention heads - - Returns: - Permuted tensor - """ - dim1 = w.shape[0] - return w.view(n_heads, dim1 // n_heads // 2, 2).transpose(1, 2).reshape(dim1) - - diff --git a/example_trainer/vllm_patching/patched_gpu_runner.py b/example_trainer/vllm_patching/patched_gpu_runner.py index d4f373c3..70b513d7 100644 --- a/example_trainer/vllm_patching/patched_gpu_runner.py +++ b/example_trainer/vllm_patching/patched_gpu_runner.py @@ -1,13 +1,15 @@ """ -Patched GPU Model Runner - Enables shared memory weight updates. +Patched GPU Model Runner - Enables CUDA IPC for single-copy training. This patches vLLM's GPUModelRunner to: 1. Call share_memory_() on model weights after loading -2. Spawn a daemon process that receives NCCL weight updates from trainers +2. Export CUDA IPC handles to vllm_bridge_config.json -The key insight is that share_memory_() makes tensors accessible from -multiple processes. The daemon receives updates via NCCL and copies them -directly into the shared tensors, which vLLM reads for inference. +The key insight is that CUDA IPC handles allow the trainer process to +attach to the EXACT SAME GPU memory that vLLM uses. This means: +- ONE copy of model weights in GPU memory +- Trainer's optimizer.step() updates vLLM's weights directly +- No synchronization needed - vLLM immediately sees new weights CRITICAL: This module must be imported and apply_patches() called BEFORE any vLLM imports. The patches MUST happen before vLLM caches module references. @@ -119,28 +121,24 @@ def _create_patched_runner(BaseRunner: type) -> type: Create a patched GPUModelRunner class. Returns a new class that inherits from the original and adds - shared memory + daemon functionality. + CUDA IPC export functionality for single-copy training. """ import torch - import torch.multiprocessing as mp - from .weight_updater import weight_updater_process class PatchedGPUModelRunner(BaseRunner): """ - Patched GPUModelRunner that enables shared memory weight updates. + Patched GPUModelRunner that enables CUDA IPC for single-copy training. After loading the model, this: - 1. Calls share_memory_() on all parameters to make them accessible - from other processes - 2. Spawns a daemon process that joins NCCL groups with the trainer - and receives weight updates - - The daemon copies updates directly into the shared tensors, so - vLLM immediately sees the new weights for inference. + 1. Calls share_memory_() on all parameters + 2. Exports CUDA IPC handles to vllm_bridge_config.json + + The trainer reads these IPC handles and attaches to the SAME + GPU memory, so optimizer.step() updates weights that vLLM + immediately sees for inference. """ _shared_memory_setup_done = False - weight_updater_process = None def load_model(self, *args, **kwargs) -> None: """Load model and set up shared memory + update daemon.""" @@ -171,27 +169,12 @@ def _create_patched_runner(BaseRunner: type) -> type: self._setup_shared_memory() PatchedGPUModelRunner._shared_memory_setup_done = True print("[vLLM Patch] ✓ Shared memory setup complete!", flush=True) + print("[vLLM Patch] ✓ IPC handles exported - trainer can now attach!", flush=True) except Exception as e: print(f"[vLLM Patch] ERROR in _setup_shared_memory: {e}", flush=True) import traceback traceback.print_exc() return - - # Spawn weight updater daemon (optional - can be skipped for HTTP-only mode) - skip_daemon = os.environ.get("VLLM_SKIP_WEIGHT_DAEMON", "0") == "1" - if skip_daemon: - print("[vLLM Patch] Skipping weight updater daemon (VLLM_SKIP_WEIGHT_DAEMON=1)", flush=True) - return - - try: - print("[vLLM Patch] Spawning weight updater daemon...", flush=True) - self._spawn_weight_updater() - print("[vLLM Patch] ✓ Weight updater daemon spawned!", flush=True) - except Exception as e: - print(f"[vLLM Patch] ERROR spawning weight updater: {e}", flush=True) - import traceback - traceback.print_exc() - print("[vLLM Patch] Continuing without daemon (HTTP-only mode)", flush=True) def _setup_shared_memory(self) -> None: """Move model tensors to shared memory and export param info.""" @@ -326,70 +309,6 @@ def _create_patched_runner(BaseRunner: type) -> type: import traceback traceback.print_exc() - def _spawn_weight_updater(self) -> None: - """Start the weight updater as a background thread. - - Note: We use threading instead of multiprocessing because vLLM's - worker processes are daemons, and daemons cannot spawn child processes. - """ - import threading - - print("[vLLM Patch] _spawn_weight_updater() called", flush=True) - - try: - from vllm.distributed import get_tensor_model_parallel_rank - print("[vLLM Patch] Imported get_tensor_model_parallel_rank", flush=True) - except ImportError as e: - print(f"[vLLM Patch] Could not import get_tensor_model_parallel_rank: {e}", flush=True) - get_tensor_model_parallel_rank = lambda: 0 - - # Get model configuration - state_dict = self.model.state_dict() - print(f"[vLLM Patch] Got state_dict with {len(state_dict)} params", flush=True) - - # Get attention head counts - hf_config = self.model_config.hf_text_config - num_heads = getattr(hf_config, "num_attention_heads", 0) - num_kv_heads = self.model_config.get_total_num_kv_heads() - print(f"[vLLM Patch] num_heads={num_heads}, num_kv_heads={num_kv_heads}", flush=True) - - # Get parallel configuration - tp_rank = get_tensor_model_parallel_rank() - print(f"[vLLM Patch] tp_rank={tp_rank}", flush=True) - - # Get GPU ID - gpu_id = 0 - try: - if hasattr(self, 'device'): - if hasattr(self.device, 'index'): - gpu_id = self.device.index or 0 - elif isinstance(self.device, int): - gpu_id = self.device - except Exception: - gpu_id = tp_rank - - print(f"[vLLM Patch] Starting weight updater thread: tp_rank={tp_rank}, gpu={gpu_id}", flush=True) - - # Start as a daemon thread (threads CAN be started from daemon processes) - self.weight_updater_thread = threading.Thread( - target=weight_updater_process, - args=( - state_dict, - num_heads, - num_kv_heads, - tp_rank, - self.parallel_config.tensor_parallel_size, - gpu_id, - ), - daemon=True, - name=f"WeightUpdater_TP{tp_rank}", - ) - - print("[vLLM Patch] Starting thread...", flush=True) - self.weight_updater_thread.start() - - print(f"[vLLM Patch] ✓ Weight updater thread started (name: {self.weight_updater_thread.name})", flush=True) - # Set proper class name PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner" PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner" diff --git a/example_trainer/vllm_patching/weight_updater.py b/example_trainer/vllm_patching/weight_updater.py deleted file mode 100644 index 40b446a4..00000000 --- a/example_trainer/vllm_patching/weight_updater.py +++ /dev/null @@ -1,239 +0,0 @@ -""" -Weight Updater Process - Daemon that receives NCCL weight updates. - -This process runs as a daemon spawned by the patched vLLM GPUModelRunner. -It joins NCCL process groups with the trainer and receives weight updates, -copying them directly into vLLM's shared memory tensors. -""" - -from __future__ import annotations - -import os -import time -from typing import Dict - -import torch -import torch.distributed as dist - -from .distributed_utils import ( - init_process_group, - get_inference_urls, - get_hostnames, -) - - -def weight_updater_process( - state_dict: Dict[str, torch.Tensor], - num_q_heads: int, - num_kv_heads: int, - tp_rank: int, - tp_size: int, - gpu_id: int, -) -> None: - """ - Daemon process that receives weight updates from trainers via NCCL. - - This runs inside a subprocess spawned by PatchedGPUModelRunner. It: - 1. Joins NCCL/Gloo process groups with the trainer - 2. Receives weight update broadcasts from rank 0 (trainer) - 3. Copies updated weights directly into the shared state_dict - - Since state_dict tensors have share_memory_() called on them, the main - vLLM process immediately sees the updates for inference. - - Args: - state_dict: Model state dict with shared memory tensors - num_q_heads: Number of query attention heads (for permutation) - num_kv_heads: Number of key/value attention heads - tp_rank: Tensor parallel rank of this worker - tp_size: Total tensor parallel size - gpu_id: GPU device ID for this worker - """ - # Configuration from environment - num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", 0)) - debug = int(os.environ.get("WEIGHT_UPDATER_DEBUG", 0)) - - # Get network info - master_addr, master_gloo_addr, master_inference_addr, urls = get_inference_urls( - num_inference_nodes - ) - - if master_addr is None: - print(f"[Updater] Master address not found, exiting", flush=True) - return - - # Set CUDA device - torch.cuda.set_device(tp_rank) - - print( - f"[Updater] Starting on TP rank {tp_rank}/{tp_size}, " - f"q_heads={num_q_heads}, kv_heads={num_kv_heads}, gpu_id={gpu_id}", - flush=True, - ) - - # For single-node mode (num_inference_nodes=0): - # - Trainer is rank 0 - # - Inference daemon is rank 1 (or tp_rank + 1 for multi-GPU) - # Total world size = 1 trainer + 1 inference = 2 - # - # For multi-node mode: - # - More complex, based on SLURM node allocation - - if num_inference_nodes == 0: - # Single node: simple setup - # World = [trainer (rank 0), inference daemon (rank 1)] - num_training_ranks = 1 - num_inference_ranks = 1 - world_size = num_training_ranks + num_inference_ranks - rank = num_training_ranks + tp_rank # Daemon is rank 1 - else: - # Multi-node: 8 GPUs per node - hostnames = get_hostnames() - cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",") - ranks_per_node = 8 - world_size = num_inference_nodes * ranks_per_node - - rank = -1 - for i, url in enumerate(urls or []): - if hostnames and url in hostnames: - rank = ranks_per_node * i + int(cuda_devices[gpu_id]) - break - - if rank < 0: - print(f"[Updater] Could not determine rank for multi-node, exiting", flush=True) - return - - print(f"[Updater] Master: {master_addr}, world_size={world_size}, my_rank={rank}", flush=True) - - # Use state_dict keys as parameter list (we already have the model!) - param_name_list = sorted(state_dict.keys()) - print(f"[Updater] Model has {len(param_name_list)} parameters", flush=True) - - # Use the world_size and rank we already calculated - total_group_size = world_size - - # For single-node mode, trainer is rank 0 - num_training_ranks = 1 if num_inference_nodes == 0 else 1 - - print(f"[Updater] Total group size: {total_group_size}", flush=True) - print(f"[Updater] Training ranks: {num_training_ranks}", flush=True) - print(f"[Updater] My rank: {rank}", flush=True) - - # Initialize process groups - print("[Updater] Creating process groups...", flush=True) - - try: - # Gloo group for coordination - gloo_group = init_process_group( - backend="gloo", - init_method=f"tcp://{master_addr}", - world_size=total_group_size, - rank=rank, - group_name="gloo_group", - ) - print("[Updater] ✓ Gloo group created", flush=True) - - # NCCL group for tensor transfers - nccl_group = init_process_group( - backend="nccl", - init_method=f"tcp://{master_addr}", - world_size=total_group_size, - rank=rank, - group_name="weight_update_group", - ) - print("[Updater] ✓ NCCL group created", flush=True) - - # Barrier synchronization to confirm both sides are ready - print("[Updater] Waiting for trainer to be ready...", flush=True) - dist.barrier(group=gloo_group) - print("[Updater] ✓ Trainer is ready, starting update loop", flush=True) - - except Exception as e: - print(f"[Updater] Failed to create process groups: {e}", flush=True) - import traceback - traceback.print_exc() - return - - # Get device for tensors - my_device = next(iter(state_dict.values())).device - - # Build param info dict from state_dict - param_info_dict = {} - for name, tensor in state_dict.items(): - param_info_dict[name] = { - "shape": list(tensor.shape), - "dtype": tensor.dtype, - } - - print("[Updater] Entering update loop...", flush=True) - print(f"[Updater] Waiting for weight updates from trainer (rank 0)...", flush=True) - - update_count = 0 - - with torch.no_grad(): - while True: - try: - # Receive parameter index from trainer (rank 0) - obj_indx = torch.zeros(1, dtype=torch.long, device=my_device) - dist.broadcast(obj_indx, src=0, group=nccl_group) - - tt_indx = obj_indx.item() - - # -1 signals heartbeat (no update) - if tt_indx == -1: - continue - - # -2 signals shutdown - if tt_indx == -2: - print("[Updater] Received shutdown signal", flush=True) - break - - # Get parameter info - if tt_indx < 0 or tt_indx >= len(param_name_list): - if debug: - print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True) - continue - - param_name = param_name_list[tt_indx] - - if param_name not in state_dict: - if debug: - print(f"[Updater] {param_name} not in state_dict, skipping", flush=True) - continue - - target_tensor = state_dict[param_name] - target_shape = list(target_tensor.shape) - target_dtype = target_tensor.dtype - - # Receive the tensor from trainer - # Trainer sends via broadcast, we receive - received_tensor = torch.zeros(target_shape, dtype=target_dtype, device=my_device) - dist.broadcast(received_tensor, src=0, group=nccl_group) - - # Copy to shared memory - state_dict[param_name].data.copy_(received_tensor) - - update_count += 1 - if debug or (update_count % 50 == 0): - print(f"[Updater] Updated {param_name} (#{update_count})", flush=True) - - except torch.distributed.DistBackendError as e: - # NCCL communication failure - likely trainer crashed - error_str = str(e) - if "Broken pipe" in error_str or "Connection reset" in error_str: - print("[Updater] Trainer disconnected (broken pipe). Exiting.", flush=True) - break - else: - print(f"[Updater] NCCL error: {e}", flush=True) - import traceback - traceback.print_exc() - time.sleep(1) - except Exception as e: - print(f"[Updater] Error in update loop: {e}", flush=True) - import traceback - traceback.print_exc() - time.sleep(1) - - -# Note: Advanced multi-GPU tensor parallelism support removed for simplicity. -# For single-node mode, we use direct tensor broadcast which is sufficient.