atropos/example_trainer/vllm_weight_bridge.py
Jai Suphavadeeprasit b0d35be8a4 IPC updates
2026-02-13 11:26:25 -05:00

890 lines
33 KiB
Python

"""
vLLM Weight Bridge - Integration between trainer and vLLM inference.
This module provides two modes for coordinating weight updates:
LOCAL MODE (num_inference_nodes=0):
- Trainer and vLLM run as separate processes on the same machine
- Communication via HTTP to vLLM's /bridge/* endpoints
- No NCCL process groups needed
- Simpler setup, suitable for single-machine training
DISTRIBUTED MODE (num_inference_nodes>0):
- Trainer and vLLM join the same NCCL process group
- Direct tensor sharing via shared GPU memory
- Lower latency, but requires coordinated setup
Architecture (Local Mode):
┌─────────────────┐ ┌─────────────────┐
│ Trainer Process │ HTTP │ vLLM Process │
│ (training) │────────▶│ (inference) │
└─────────────────┘ └─────────────────┘
Architecture (Distributed Mode):
┌─────────────────────────────────────────┐
│ Shared GPU Memory (NCCL) │
│ Model weights owned by vLLM process │
└─────────────────────────────────────────┘
▲ ▲
│ forward pass │ optimizer.step()
┌───────┴───────┐ ┌───────┴───────┐
│ vLLM Process │ │Trainer Process│
└───────────────┘ └───────────────┘
"""
from __future__ import annotations
import json
import os
import socket
import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
# =============================================================================
# Process Group Initialization Helpers
# =============================================================================
def init_process_group(
backend: Optional[str] = None,
init_method: Optional[str] = None,
timeout: Optional[timedelta] = None,
world_size: int = -1,
rank: int = -1,
store: Optional[Any] = None,
group_name: str = "",
pg_options: Optional[Any] = None,
) -> dist.ProcessGroup:
"""
Initialize a custom process group for weight synchronization.
This is based on torch.distributed internals but allows creating a named
group that coexists with the default process group (used by vLLM internally).
Args:
backend: "nccl" for GPU, "gloo" for CPU
init_method: Rendezvous URL (e.g., "tcp://host:port" or "env://")
timeout: How long to wait for other ranks
world_size: Total number of processes in the group
rank: This process's rank in the group
store: Optional torch.distributed Store object
group_name: Name for this process group (must match across all ranks)
pg_options: Backend-specific options
Returns:
A ProcessGroup object for collective operations
"""
from torch.distributed.distributed_c10d import (
_new_process_group_helper,
_world,
Backend,
default_pg_timeout,
PrefixStore,
rendezvous,
)
assert (store is None) or (
init_method is None
), "Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
# Rendezvous with other processes
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
# Use a PrefixStore to avoid key collisions with other groups
store = PrefixStore(group_name, store)
# PyTorch 2.6+ renamed pg_options to backend_options
pg_options_param_name = (
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
)
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg
def broadcast_object_list(
object_list: List[Any],
src: int,
group: dist.ProcessGroup,
device: Optional[torch.device] = None,
) -> None:
"""
Broadcast a list of picklable objects from src rank to all other ranks.
This is a simplified version of torch.distributed.broadcast_object_list
that works correctly with custom process groups.
Args:
object_list: List of objects to broadcast (modified in-place on receivers)
src: Source rank that has the data
group: Process group to use
device: Device for intermediate tensors
"""
current_device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Broadcast object sizes first
object_sizes_tensor = torch.empty(
len(object_list), dtype=torch.long, device=current_device
)
dist.broadcast(object_sizes_tensor, src=src, group=group)
# Broadcast serialized objects
object_tensor = torch.empty(
torch.sum(object_sizes_tensor).item(),
dtype=torch.uint8,
device=current_device,
)
dist.broadcast(object_tensor, src=src, group=group)
# Deserialize on receiving ranks
offset = 0
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8)
offset += obj_size
object_list[i] = dist._tensor_to_object(obj_view, obj_size, group)
# =============================================================================
# Environment and URL Helpers
# =============================================================================
def get_inference_urls(num_inference_nodes: int) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[List[str]]]:
"""
Get rendezvous URLs for connecting to inference nodes.
In SLURM environments, parses SLURM_JOB_NODELIST to find inference servers.
For local testing, returns localhost URLs.
Args:
num_inference_nodes: Number of inference nodes (from config)
Returns:
Tuple of (master_server, master_gloo_server, master_inference_server, nodelist)
All None if inference nodes not configured.
"""
if num_inference_nodes > 0:
# Multi-node SLURM environment
nodelist_raw = os.popen(
f'scontrol show hostnames {os.environ.get("SLURM_JOB_NODELIST", "")}'
).read()
nodelist = [n for n in nodelist_raw.split("\n") if n]
if not nodelist:
return None, None, None, None
master_server = f"{nodelist[0]}:26756"
master_gloo_server = f"{nodelist[0]}:26757"
# Inference nodes are the last N nodes
inference_nodes = nodelist[-num_inference_nodes:]
master_inference_server = f"{inference_nodes[0]}:26758"
return master_server, master_gloo_server, master_inference_server, inference_nodes
elif num_inference_nodes == 0:
# Single-node local mode
return "localhost:26756", "localhost:26757", "localhost:26758", ["localhost"]
else:
return None, None, None, None
def get_local_hostname() -> Optional[List[str]]:
"""Get the local hostname(s) from /etc/hosts for rank determination."""
my_ip = socket.gethostbyname(socket.gethostname())
my_hostname = socket.gethostname()
try:
with open("/etc/hosts", "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
parts = line.split()
if len(parts) >= 2 and (parts[0] == my_ip or my_hostname in parts):
ip = parts[0]
if ip.startswith("127."):
continue
return parts
except FileNotFoundError:
pass
return [my_ip, my_hostname]
# =============================================================================
# Tensor Mapping and Permutation Helpers
# =============================================================================
def permute_for_rotary(w: torch.Tensor, n_heads: int) -> torch.Tensor:
"""
Permute weight tensor for sliced rotary embeddings.
vLLM and some model implementations use different layouts for Q/K projections.
This converts between them.
"""
dim1, dim2 = w.shape[0], w.shape[1]
return (
w.view(n_heads, dim1 // n_heads // 2, 2, dim2)
.transpose(1, 2)
.reshape(dim1, dim2)
)
def permute_for_rotary_1d(w: torch.Tensor, n_heads: int) -> torch.Tensor:
"""Permute 1D tensor (bias) for sliced rotary embeddings."""
dim1 = w.shape[0]
return w.view(n_heads, dim1 // n_heads // 2, 2).transpose(1, 2).reshape(dim1)
def get_name_conversions(param_mappings: Dict[str, Any]) -> Dict[str, List[str]]:
"""
Build a mapping from vLLM parameter names to trainer parameter names.
vLLM may split or combine parameters differently than HuggingFace models.
This helps translate between naming conventions.
"""
name_conversions = defaultdict(list)
for name, info in param_mappings.items():
vllm_name = info.get("vllm_name", name)
name_conversions[vllm_name].append(name)
return dict(name_conversions)
# =============================================================================
# Bridge Configuration
# =============================================================================
@dataclass
class BridgeConfig:
"""Configuration for the vLLM weight bridge."""
# Process group settings
trainer_rank: int = 0
world_size: int = 1
init_method: str = "env://"
num_inference_nodes: int = 0
# Model settings
model_name: str = ""
device: str = "cuda"
# Synchronization settings
timeout_seconds: float = 300.0
log_dir: Optional[str] = None
# vLLM server URL for HTTP-based sync (local mode)
vllm_api_url: str = "http://localhost:9001"
# CUDA IPC mode: share GPU memory directly with vLLM (same GPU only!)
use_cuda_ipc: bool = False
# Derived from environment
num_gpus_per_node: int = field(default_factory=lambda: torch.cuda.device_count())
@property
def is_local_mode(self) -> bool:
"""
Local mode: single machine, no NCCL process groups needed.
Communication happens via HTTP to vLLM server.
"""
return self.num_inference_nodes == 0
@classmethod
def from_training_config(cls, config: Any) -> "BridgeConfig":
"""Create BridgeConfig from a TrainingConfig object."""
return cls(
trainer_rank=config.trainer_rank,
world_size=config.world_size,
init_method=config.init_method,
num_inference_nodes=config.num_inference_nodes,
model_name=config.model_name,
device=config.device,
log_dir=os.environ.get("LOGDIR"),
vllm_api_url=f"http://localhost:{getattr(config, 'vllm_port', 9001)}",
use_cuda_ipc=getattr(config, 'use_cuda_ipc', False),
)
# =============================================================================
# Main Bridge Class
# =============================================================================
class VLLMWeightBridge:
"""
Bridge for sharing model weights between trainer and vLLM inference server.
This class handles:
1. Joining the distributed process group with vLLM workers
2. Attaching to vLLM's model weight tensors
3. Providing a model interface for the trainer to optimize
4. Synchronizing updates so vLLM sees changes immediately
Usage:
bridge = VLLMWeightBridge(config)
bridge.initialize()
model = bridge.get_trainable_model()
optimizer = AdamW(model.parameters(), lr=1e-5)
for batch in data:
loss = compute_loss(model, batch)
loss.backward()
optimizer.step()
bridge.notify_update() # vLLM now uses new weights
"""
def __init__(self, config: BridgeConfig):
self.config = config
self.device = torch.device(config.device)
# Process groups (initialized in initialize())
self.nccl_group: Optional[dist.ProcessGroup] = None
self.gloo_group: Optional[dist.ProcessGroup] = None
# Parameter mappings (loaded from vLLM's JSON)
self.param_mappings: Dict[str, Any] = {}
self.name_conversions: Dict[str, List[str]] = {}
# Shared tensors (attached in attach_to_vllm_weights())
self.shared_state_dict: Dict[str, torch.Tensor] = {}
# Model for training (created in get_trainable_model())
self._model: Optional[nn.Module] = None
# Synchronization state
self._update_count: int = 0
self._initialized: bool = False
def initialize(self) -> None:
"""
Initialize the bridge: join process groups and load parameter mappings.
In local mode (num_inference_nodes=0), skips NCCL setup and uses HTTP.
In distributed mode, creates NCCL/Gloo process groups.
This must be called before any other methods.
"""
if self._initialized:
return
print(f"[Bridge] Initializing weight bridge for rank {self.config.trainer_rank}")
if self.config.is_local_mode:
self._initialize_local_mode()
else:
self._initialize_distributed_mode()
self._initialized = True
def _initialize_local_mode(self) -> None:
"""
Initialize for local single-machine mode.
In local mode:
- No NCCL process groups (trainer and vLLM are separate processes)
- Communication via HTTP to vLLM's bridge endpoints
- Trainer loads its own model copy, OR uses CUDA IPC for true shared memory
"""
if self.config.use_cuda_ipc:
print("[Bridge] Using CUDA IPC MODE (true shared GPU memory)")
else:
print("[Bridge] Using LOCAL MODE (HTTP-based sync, no NCCL)")
print(f"[Bridge] vLLM API URL: {self.config.vllm_api_url}")
# Verify vLLM server is reachable
try:
import requests
response = requests.get(f"{self.config.vllm_api_url}/health", timeout=5)
if response.status_code == 200:
print("[Bridge] vLLM server is reachable")
else:
print(f"[Bridge] Warning: vLLM health check returned {response.status_code}")
except Exception as e:
print(f"[Bridge] Warning: Could not reach vLLM server: {e}")
print("[Bridge] Training will continue, but vLLM sync may not work")
# For CUDA IPC mode, request vLLM to export IPC handles
if self.config.use_cuda_ipc:
self._request_cuda_ipc_export()
self._load_cuda_ipc_handles()
# Load parameter mappings if available (optional in local mode)
try:
self._load_param_mappings()
except RuntimeError:
print("[Bridge] Parameter mapping file not found (optional in local mode)")
self.param_mappings = {}
def _request_cuda_ipc_export(self) -> None:
"""Request vLLM to export CUDA IPC handles."""
import requests
print("[Bridge] Requesting CUDA IPC handles from vLLM...")
try:
response = requests.post(
f"{self.config.vllm_api_url}/bridge/export_cuda_ipc",
timeout=60
)
if response.status_code == 200:
result = response.json()
print(f"[Bridge] vLLM exported {result.get('num_parameters', 0)} IPC handles")
else:
raise RuntimeError(f"Failed to export IPC handles: {response.status_code}")
except Exception as e:
raise RuntimeError(f"Could not request CUDA IPC export: {e}")
def _load_cuda_ipc_handles(self) -> None:
"""
Load CUDA IPC handles from file and reconstruct shared tensors.
This is the key to TRUE shared memory - the tensors we create here
point to the SAME GPU memory that vLLM is using!
"""
import base64
import pickle
log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".")
ipc_path = Path(log_dir) / "cuda_ipc_handles.json"
# Wait for file to be created
wait_time = 0
while not ipc_path.exists() and wait_time < self.config.timeout_seconds:
print(f"[Bridge] Waiting for {ipc_path}...")
time.sleep(1)
wait_time += 1
if not ipc_path.exists():
raise RuntimeError(f"CUDA IPC handles file not found: {ipc_path}")
with open(ipc_path, "r") as f:
data = json.load(f)
handles_data = data.get("handles", {})
print(f"[Bridge] Reconstructing {len(handles_data)} shared tensors from IPC handles...")
reconstructed = 0
for name, info in handles_data.items():
try:
# Decode the IPC handle
handle_bytes = base64.b64decode(info["ipc_handle"])
handle = pickle.loads(handle_bytes)
# Reconstruct the storage from the IPC handle
# This does NOT allocate new memory - it maps to existing memory!
device = torch.device(f"cuda:{info['device_index']}")
# Get dtype
dtype_str = info["dtype"]
dtype = getattr(torch, dtype_str.replace("torch.", ""))
# Reconstruct tensor from IPC handle
# The storage is shared with vLLM's process
storage = torch.cuda.Storage._new_shared_cuda(*handle)
# Create tensor view of the shared storage
tensor = torch.tensor([], dtype=dtype, device=device)
tensor.set_(
storage,
info["storage_offset"],
info["shape"],
info["stride"]
)
# Store in shared_state_dict
self.shared_state_dict[name] = tensor
reconstructed += 1
except Exception as e:
print(f"[Bridge] Warning: Could not reconstruct {name}: {e}")
continue
print(f"[Bridge] Successfully reconstructed {reconstructed} shared tensors")
print(f"[Bridge] Memory savings: ~{reconstructed * 4 / 1024:.1f} GB (no model copy needed!)")
def _initialize_distributed_mode(self) -> None:
"""
Initialize for distributed multi-node mode.
Creates NCCL and Gloo process groups for direct tensor sharing.
"""
print("[Bridge] Using DISTRIBUTED MODE (NCCL tensor sharing)")
# Get rendezvous URLs
master_addr, master_gloo_addr, master_inference_addr, nodelist = get_inference_urls(
self.config.num_inference_nodes
)
if master_addr is None:
raise RuntimeError(
"Could not determine inference server URLs. "
"Set NUM_INFERENCE_NODES environment variable or check SLURM_JOB_NODELIST."
)
print(f"[Bridge] Master address: {master_addr}")
print(f"[Bridge] Inference nodes: {nodelist}")
# Load parameter mappings from vLLM
self._load_param_mappings()
# Calculate total group size (trainers + inference workers)
num_training_gpus = self._get_num_training_gpus()
# In distributed mode, each inference node contributes num_gpus_per_node workers
num_inference_gpus = self.config.num_inference_nodes * self.config.num_gpus_per_node
total_group_size = num_training_gpus + num_inference_gpus
trainer_rank_in_group = self.config.trainer_rank
print(f"[Bridge] Training GPUs: {num_training_gpus}, Inference GPUs: {num_inference_gpus}")
print(f"[Bridge] Total group size: {total_group_size}, Trainer rank: {trainer_rank_in_group}")
# Initialize NCCL group for tensor transfers
self.nccl_group = init_process_group(
backend="nccl",
init_method=f"tcp://{master_addr}",
world_size=total_group_size,
rank=trainer_rank_in_group,
group_name="weight_update_group",
timeout=timedelta(seconds=self.config.timeout_seconds),
)
print("[Bridge] NCCL process group initialized")
# Initialize Gloo group for metadata/coordination
self.gloo_group = init_process_group(
backend="gloo",
init_method=f"tcp://{master_gloo_addr}",
world_size=total_group_size,
rank=trainer_rank_in_group,
group_name="gloo_group",
timeout=timedelta(seconds=self.config.timeout_seconds),
)
print("[Bridge] Gloo process group initialized")
def _load_param_mappings(self) -> None:
"""Load parameter name mappings from vLLM's exported JSON."""
log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".")
json_path = Path(log_dir) / "vllm_bridge_config.json"
# Wait for vLLM to write the mapping file
wait_time = 0
while not json_path.exists() and wait_time < self.config.timeout_seconds:
print(f"[Bridge] Waiting for {json_path} to be created...")
time.sleep(1)
wait_time += 1
if not json_path.exists():
raise RuntimeError(
f"Parameter mapping file not found at {json_path}. "
"Make sure vLLM is running and has exported its parameter mappings."
)
# Small delay to ensure file is fully written
time.sleep(1)
with open(json_path, "r") as f:
data = json.load(f)
self.param_mappings = data.get("param_mappings", {})
self.name_conversions = get_name_conversions(self.param_mappings)
print(f"[Bridge] Loaded mappings for {len(self.param_mappings)} parameters")
def _get_num_training_gpus(self) -> int:
"""Get number of training GPUs from param mappings or config."""
if self.param_mappings:
# Try to get from vLLM's exported info
return self.param_mappings.get("dp_shard_degree", 1) * self.param_mappings.get("tp_degree", 1)
return self.config.world_size
def attach_to_vllm_weights(self, vllm_state_dict: Dict[str, torch.Tensor]) -> None:
"""
Attach to vLLM's weight tensors.
After this call, self.shared_state_dict contains references to the
actual tensors that vLLM uses for inference. Modifying these tensors
will immediately affect vLLM's outputs.
Args:
vllm_state_dict: vLLM's model state_dict (actual tensors, not copies)
"""
self.shared_state_dict = vllm_state_dict
print(f"[Bridge] Attached to {len(vllm_state_dict)} vLLM weight tensors")
# Log tensor info for debugging
for name, tensor in list(vllm_state_dict.items())[:5]:
print(f"[Bridge] {name}: {tensor.shape}, {tensor.dtype}, {tensor.device}")
if len(vllm_state_dict) > 5:
print(f"[Bridge] ... and {len(vllm_state_dict) - 5} more")
def get_trainable_model(self) -> nn.Module:
"""
Get a model whose parameters point to vLLM's shared tensors.
In CUDA IPC mode: shared_state_dict is populated from IPC handles during init.
In other modes: must call attach_to_vllm_weights() first.
This creates a HuggingFace model structure but replaces all parameters
with references to the shared tensors. When the optimizer updates these
parameters, it modifies vLLM's weights directly.
Returns:
An nn.Module ready for training with shared weights
"""
if self._model is not None:
return self._model
if not self.shared_state_dict:
if self.config.use_cuda_ipc:
raise RuntimeError(
"CUDA IPC mode enabled but no shared tensors found. "
"Check that vLLM exported IPC handles correctly."
)
else:
raise RuntimeError(
"Must call attach_to_vllm_weights() before get_trainable_model()"
)
print(f"[Bridge] Creating trainable model for {self.config.model_name}")
if self.config.use_cuda_ipc:
print("[Bridge] Using CUDA IPC shared tensors (NO NEW GPU MEMORY!)")
# Load model config (not weights)
model_config = AutoConfig.from_pretrained(self.config.model_name)
# Create model with empty weights (meta device = no memory)
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(model_config)
# Replace each parameter with the shared tensor
self._replace_parameters_with_shared(model)
# Move model structure to device (parameters already on device via IPC)
model.to(self.device)
self._model = model
total_params = sum(p.numel() for p in model.parameters())
print(f"[Bridge] Trainable model ready with {total_params:,} parameters")
if self.config.use_cuda_ipc:
# Verify memory savings
param_memory_gb = total_params * 2 / 1e9 # bfloat16 = 2 bytes
print(f"[Bridge] CUDA IPC memory savings: ~{param_memory_gb:.1f} GB (shared with vLLM)")
return model
def _replace_parameters_with_shared(self, model: nn.Module) -> None:
"""
Replace model parameters with references to shared vLLM tensors.
This is the key operation that makes weight sharing work. After this,
model.parameters() returns tensors that ARE vLLM's weights.
"""
replaced_count = 0
missing_params = []
for name, param in model.named_parameters():
# Convert HuggingFace param name to vLLM param name
vllm_name = self._hf_to_vllm_name(name)
if vllm_name in self.shared_state_dict:
shared_tensor = self.shared_state_dict[vllm_name]
# Create a new Parameter that wraps the shared tensor
# The key is that we're not copying - we're referencing the same storage
new_param = nn.Parameter(shared_tensor, requires_grad=True)
# Replace the parameter in the model
self._set_parameter(model, name, new_param)
replaced_count += 1
else:
missing_params.append(name)
print(f"[Bridge] Replaced {replaced_count} parameters with shared tensors")
if missing_params:
print(f"[Bridge] Warning: {len(missing_params)} parameters not found in shared state:")
for p in missing_params[:5]:
print(f"[Bridge] {p}")
def _hf_to_vllm_name(self, hf_name: str) -> str:
"""
Convert a HuggingFace parameter name to vLLM's naming convention.
vLLM may merge QKV projections, use different layer naming, etc.
This handles the translation.
"""
# Check if we have an explicit mapping
for vllm_name, hf_names in self.name_conversions.items():
if hf_name in hf_names:
return vllm_name
# Common transformations
# vLLM often uses: model.layers.N.self_attn.qkv_proj
# HF uses: model.layers.N.self_attn.q_proj, k_proj, v_proj
# For now, try the name as-is
return hf_name
def _set_parameter(self, model: nn.Module, name: str, new_param: nn.Parameter) -> None:
"""Set a parameter by dotted name path."""
parts = name.split(".")
module = model
for part in parts[:-1]:
module = getattr(module, part)
setattr(module, parts[-1], new_param)
def broadcast_weights_to_inference(self) -> None:
"""
Broadcast updated weights from trainer to inference workers.
Call this after optimizer.step() to push the new weights to all
vLLM inference processes. They will use the updated weights for
subsequent requests.
"""
if not self._initialized:
raise RuntimeError("Bridge not initialized. Call initialize() first.")
param_names = sorted(self.param_mappings.keys())
with torch.no_grad():
for idx, param_name in enumerate(param_names):
# Signal which parameter we're broadcasting
idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device)
dist.broadcast(idx_tensor, src=0, group=self.nccl_group)
# Get the tensor for this parameter
vllm_name = self.param_mappings[param_name].get("vllm_name", param_name)
if vllm_name not in self.shared_state_dict:
continue
tensor = self.shared_state_dict[vllm_name]
local_shape = self.param_mappings[param_name].get("local_shape", list(tensor.shape))
# Gather from all training ranks, then broadcast to inference
# (This handles FSDP/TP sharding if present)
dist.all_gather(
[torch.zeros(local_shape, dtype=tensor.dtype, device=self.device)
for _ in range(dist.get_world_size(self.nccl_group))],
tensor,
group=self.nccl_group,
)
self._update_count += 1
print(f"[Bridge] Broadcast update #{self._update_count} complete")
def notify_update(self) -> None:
"""
Notify inference workers that weights have been updated.
In local mode: sends HTTP request to vLLM's /bridge/notify_update endpoint
In distributed mode: broadcasts update counter via Gloo
"""
self._update_count += 1
if self.config.is_local_mode:
self._notify_update_http()
elif self.gloo_group is not None:
self._notify_update_distributed()
def _notify_update_http(self) -> None:
"""Notify vLLM via HTTP (local mode)."""
try:
import requests
response = requests.post(
f"{self.config.vllm_api_url}/bridge/notify_update",
json={
"update_count": self._update_count,
"trainer_rank": self.config.trainer_rank,
"timestamp": time.time(),
},
timeout=5,
)
if response.status_code != 200:
print(f"[Bridge] Warning: notify_update returned {response.status_code}")
except Exception as e:
# Don't fail training if vLLM notification fails
print(f"[Bridge] Warning: Could not notify vLLM: {e}")
def _notify_update_distributed(self) -> None:
"""Notify via Gloo broadcast (distributed mode)."""
update_tensor = torch.tensor([self._update_count], dtype=torch.long)
dist.broadcast(update_tensor, src=0, group=self.gloo_group)
def barrier(self) -> None:
"""Wait for all processes in the group to reach this point."""
if self.nccl_group is not None:
dist.barrier(group=self.nccl_group)
def cleanup(self) -> None:
"""Clean up process groups and resources."""
if self.nccl_group is not None:
dist.destroy_process_group(self.nccl_group)
self.nccl_group = None
if self.gloo_group is not None:
dist.destroy_process_group(self.gloo_group)
self.gloo_group = None
self._initialized = False
print("[Bridge] Cleaned up")
# =============================================================================
# Convenience Functions
# =============================================================================
def create_bridge_from_training_config(config: Any) -> VLLMWeightBridge:
"""
Create and initialize a VLLMWeightBridge from a TrainingConfig.
Args:
config: TrainingConfig object with bridge settings
Returns:
Initialized VLLMWeightBridge ready for use
"""
bridge_config = BridgeConfig.from_training_config(config)
bridge = VLLMWeightBridge(bridge_config)
bridge.initialize()
return bridge