""" Distributed utilities for vLLM weight synchronization. Provides process group initialization and communication helpers for coordinating weight updates between trainer and vLLM. """ from __future__ import annotations import json import os import socket import time from collections import defaultdict from datetime import timedelta from typing import Any, Dict, List, Optional, Tuple import torch import torch.distributed as dist def init_process_group( backend: Optional[str] = None, init_method: Optional[str] = None, timeout: Optional[timedelta] = None, world_size: int = -1, rank: int = -1, store: Optional[Any] = None, group_name: str = "", pg_options: Optional[Any] = None, ) -> dist.ProcessGroup: """ Initialize a custom process group for weight synchronization. This creates a named process group that coexists with vLLM's internal process groups, enabling direct tensor communication between trainer and inference processes. Args: backend: "nccl" for GPU, "gloo" for CPU init_method: Rendezvous URL (e.g., "tcp://host:port") timeout: How long to wait for other ranks world_size: Total number of processes rank: This process's rank store: Optional torch.distributed Store group_name: Name for this process group (must match across ranks) pg_options: Backend-specific options Returns: ProcessGroup for collective operations """ from torch.distributed.distributed_c10d import ( _new_process_group_helper, _world, Backend, default_pg_timeout, PrefixStore, rendezvous, ) assert (store is None) or (init_method is None), \ "Cannot specify both init_method and store." if store is not None: assert world_size > 0, "world_size must be positive if using store" assert rank >= 0, "rank must be non-negative if using store" elif init_method is None: init_method = "env://" if backend: backend = Backend(backend) else: backend = Backend("undefined") if timeout is None: timeout = default_pg_timeout # Create store via rendezvous if not provided if store is None: rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout) store = PrefixStore(group_name, store) # Handle PyTorch version differences for pg_options parameter pg_options_param_name = ( "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" ) pg, _ = _new_process_group_helper( world_size, rank, [], backend, store, group_name=group_name, **{pg_options_param_name: pg_options}, timeout=timeout, ) _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} return pg def broadcast_object_list( object_list: List[Any], src: Optional[int] = None, group: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, group_src: Optional[int] = None, ) -> None: """ Broadcast a list of objects from source rank to all other ranks. Modified from torch.distributed.broadcast_object_list to work correctly with custom process groups where rank 0 may not be the default group's rank 0. Args: object_list: List of objects to broadcast (modified in-place on receivers) src: Global source rank (deprecated, use group_src) group: Process group to use device: Device for temporary tensors group_src: Source rank within the group """ global_src = group_src if group_src is not None else src current_device = device # Broadcast object sizes first object_sizes_tensor = torch.empty( len(object_list), dtype=torch.long, device=current_device ) dist.broadcast(object_sizes_tensor, src=global_src, group=group) # Broadcast serialized objects object_tensor = torch.empty( torch.sum(object_sizes_tensor).item(), dtype=torch.uint8, device=current_device, ) dist.broadcast(object_tensor, src=global_src, group=group) # Deserialize objects offset = 0 for i, obj_size in enumerate(object_sizes_tensor): obj_view = object_tensor[offset : offset + obj_size] obj_view = obj_view.type(torch.uint8) offset += obj_size object_list[i] = dist.distributed_c10d._tensor_to_object( obj_view, obj_size, group ) def get_inference_urls(num_inference_nodes: int = 0) -> Tuple[Optional[str], ...]: """ Get URLs for inference server communication. Parses SLURM environment or uses localhost for single-machine setup. Args: num_inference_nodes: Number of dedicated inference nodes. 0 = single machine, trainer and vLLM share the node >0 = multi-node, last N nodes are for inference Returns: Tuple of (master_addr, master_gloo_addr, master_inference_addr, nodelist) Returns (None, None, None, None) if not in a valid setup. """ if num_inference_nodes > 0: # Multi-node SLURM setup slurm_nodelist = os.environ.get("SLURM_JOB_NODELIST") if not slurm_nodelist: return None, None, None, None # Parse SLURM node list nodelist = ( os.popen(f'scontrol show hostnames {slurm_nodelist}') .read() .strip() .split("\n") ) nodelist = [node for node in nodelist if node] # First node is master for process groups master_server = f"{nodelist[0]}:26756" master_gloo_server = f"{nodelist[0]}:26757" # Last N nodes are inference nodes inference_nodes = nodelist[-num_inference_nodes:] master_inference_server = f"{inference_nodes[0]}:26758" return master_server, master_gloo_server, master_inference_server, inference_nodes elif num_inference_nodes == 0: # Single machine setup master_server = "localhost:26756" master_gloo_server = "localhost:26757" master_inference_server = "localhost:26758" nodelist = ["localhost"] return master_server, master_gloo_server, master_inference_server, nodelist else: return None, None, None, None def get_hostnames() -> Optional[List[str]]: """ Get the hostnames for this machine. Parses /etc/hosts to find all hostnames associated with this machine's IP. Returns: List of [ip, hostname1, hostname2, ...] or None if not found. """ my_ip = socket.gethostbyname(socket.gethostname()) my_hostname = socket.gethostname() try: with open("/etc/hosts", "r") as f: for line in f: line = line.strip() if line and not line.startswith("#"): parts = line.split() if len(parts) >= 2 and ((parts[0] == my_ip) or (my_hostname in parts)): ip = parts[0] if ip.startswith("127."): continue return parts except Exception: pass return None def get_json_data(log_dir: Optional[str] = None, timeout: int = 300) -> Dict[str, Any]: """ Load the bridge configuration JSON from vLLM. Waits for the file to be created by vLLM's weight bridge setup. Args: log_dir: Directory containing the JSON file (defaults to LOGDIR env var) timeout: Maximum seconds to wait for file Returns: Parsed JSON data with parameter mappings and configuration. Raises: ValueError: If LOGDIR not set and log_dir not provided FileNotFoundError: If file not found after timeout """ if log_dir is None: log_dir = os.environ.get("LOGDIR") if log_dir is None: raise ValueError("LOGDIR environment variable not set and log_dir not provided") json_path = os.path.join(log_dir, "vllm_bridge_config.json") wait_time = 0 while not os.path.exists(json_path): if wait_time >= timeout: raise FileNotFoundError(f"Config file not found after {timeout}s: {json_path}") if wait_time % 10 == 0: print(f"[Updater] Waiting for {json_path}...", flush=True) time.sleep(1) wait_time += 1 # Wait a moment for file to finish writing time.sleep(0.5) with open(json_path, "r") as f: return json.load(f) def get_name_conversions(param_mappings: Dict[str, Any]) -> Dict[str, List[str]]: """ Build reverse mapping from vLLM names to trainer names. Args: param_mappings: Dict mapping trainer param names to vLLM info Returns: Dict mapping vLLM names to list of trainer names """ name_conversions = defaultdict(list) for name, info in param_mappings.items(): vllm_name = info.get("vllm_name", name) name_conversions[vllm_name].append(name) return name_conversions # Permutation functions for rotary embeddings def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor: """ Permute weight tensor for sliced rotary embeddings. Args: w: Weight tensor of shape [dim1, dim2] n_heads: Number of attention heads Returns: Permuted tensor for rotary embedding compatibility """ dim1 = w.shape[0] dim2 = w.shape[1] return ( w.view(n_heads, dim1 // n_heads // 2, 2, dim2) .transpose(1, 2) .reshape(dim1, dim2) ) def permute_1d(w: torch.Tensor, n_heads: int) -> torch.Tensor: """ Permute 1D weight tensor (bias) for sliced rotary embeddings. Args: w: Weight tensor of shape [dim1] n_heads: Number of attention heads Returns: Permuted tensor """ dim1 = w.shape[0] return w.view(n_heads, dim1 // n_heads // 2, 2).transpose(1, 2).reshape(dim1)