mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
890 lines
33 KiB
Python
890 lines
33 KiB
Python
"""
|
|
vLLM Weight Bridge - Integration between trainer and vLLM inference.
|
|
|
|
This module provides two modes for coordinating weight updates:
|
|
|
|
LOCAL MODE (num_inference_nodes=0):
|
|
- Trainer and vLLM run as separate processes on the same machine
|
|
- Communication via HTTP to vLLM's /bridge/* endpoints
|
|
- No NCCL process groups needed
|
|
- Simpler setup, suitable for single-machine training
|
|
|
|
DISTRIBUTED MODE (num_inference_nodes>0):
|
|
- Trainer and vLLM join the same NCCL process group
|
|
- Direct tensor sharing via shared GPU memory
|
|
- Lower latency, but requires coordinated setup
|
|
|
|
Architecture (Local Mode):
|
|
┌─────────────────┐ ┌─────────────────┐
|
|
│ Trainer Process │ HTTP │ vLLM Process │
|
|
│ (training) │────────▶│ (inference) │
|
|
└─────────────────┘ └─────────────────┘
|
|
|
|
Architecture (Distributed Mode):
|
|
┌─────────────────────────────────────────┐
|
|
│ Shared GPU Memory (NCCL) │
|
|
│ Model weights owned by vLLM process │
|
|
└─────────────────────────────────────────┘
|
|
▲ ▲
|
|
│ forward pass │ optimizer.step()
|
|
┌───────┴───────┐ ┌───────┴───────┐
|
|
│ vLLM Process │ │Trainer Process│
|
|
└───────────────┘ └───────────────┘
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import socket
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from datetime import timedelta
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import nn
|
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
|
|
# =============================================================================
|
|
# Process Group Initialization Helpers
|
|
# =============================================================================
|
|
|
|
|
|
def init_process_group(
|
|
backend: Optional[str] = None,
|
|
init_method: Optional[str] = None,
|
|
timeout: Optional[timedelta] = None,
|
|
world_size: int = -1,
|
|
rank: int = -1,
|
|
store: Optional[Any] = None,
|
|
group_name: str = "",
|
|
pg_options: Optional[Any] = None,
|
|
) -> dist.ProcessGroup:
|
|
"""
|
|
Initialize a custom process group for weight synchronization.
|
|
|
|
This is based on torch.distributed internals but allows creating a named
|
|
group that coexists with the default process group (used by vLLM internally).
|
|
|
|
Args:
|
|
backend: "nccl" for GPU, "gloo" for CPU
|
|
init_method: Rendezvous URL (e.g., "tcp://host:port" or "env://")
|
|
timeout: How long to wait for other ranks
|
|
world_size: Total number of processes in the group
|
|
rank: This process's rank in the group
|
|
store: Optional torch.distributed Store object
|
|
group_name: Name for this process group (must match across all ranks)
|
|
pg_options: Backend-specific options
|
|
|
|
Returns:
|
|
A ProcessGroup object for collective operations
|
|
"""
|
|
from torch.distributed.distributed_c10d import (
|
|
_new_process_group_helper,
|
|
_world,
|
|
Backend,
|
|
default_pg_timeout,
|
|
PrefixStore,
|
|
rendezvous,
|
|
)
|
|
|
|
assert (store is None) or (
|
|
init_method is None
|
|
), "Cannot specify both init_method and store."
|
|
|
|
if store is not None:
|
|
assert world_size > 0, "world_size must be positive if using store"
|
|
assert rank >= 0, "rank must be non-negative if using store"
|
|
elif init_method is None:
|
|
init_method = "env://"
|
|
|
|
if backend:
|
|
backend = Backend(backend)
|
|
else:
|
|
backend = Backend("undefined")
|
|
|
|
if timeout is None:
|
|
timeout = default_pg_timeout
|
|
|
|
# Rendezvous with other processes
|
|
if store is None:
|
|
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
|
|
store, rank, world_size = next(rendezvous_iterator)
|
|
store.set_timeout(timeout)
|
|
# Use a PrefixStore to avoid key collisions with other groups
|
|
store = PrefixStore(group_name, store)
|
|
|
|
# PyTorch 2.6+ renamed pg_options to backend_options
|
|
pg_options_param_name = (
|
|
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
|
|
)
|
|
|
|
pg, _ = _new_process_group_helper(
|
|
world_size,
|
|
rank,
|
|
[],
|
|
backend,
|
|
store,
|
|
group_name=group_name,
|
|
**{pg_options_param_name: pg_options},
|
|
timeout=timeout,
|
|
)
|
|
|
|
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
|
return pg
|
|
|
|
|
|
def broadcast_object_list(
|
|
object_list: List[Any],
|
|
src: int,
|
|
group: dist.ProcessGroup,
|
|
device: Optional[torch.device] = None,
|
|
) -> None:
|
|
"""
|
|
Broadcast a list of picklable objects from src rank to all other ranks.
|
|
|
|
This is a simplified version of torch.distributed.broadcast_object_list
|
|
that works correctly with custom process groups.
|
|
|
|
Args:
|
|
object_list: List of objects to broadcast (modified in-place on receivers)
|
|
src: Source rank that has the data
|
|
group: Process group to use
|
|
device: Device for intermediate tensors
|
|
"""
|
|
current_device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Broadcast object sizes first
|
|
object_sizes_tensor = torch.empty(
|
|
len(object_list), dtype=torch.long, device=current_device
|
|
)
|
|
dist.broadcast(object_sizes_tensor, src=src, group=group)
|
|
|
|
# Broadcast serialized objects
|
|
object_tensor = torch.empty(
|
|
torch.sum(object_sizes_tensor).item(),
|
|
dtype=torch.uint8,
|
|
device=current_device,
|
|
)
|
|
dist.broadcast(object_tensor, src=src, group=group)
|
|
|
|
# Deserialize on receiving ranks
|
|
offset = 0
|
|
for i, obj_size in enumerate(object_sizes_tensor):
|
|
obj_view = object_tensor[offset : offset + obj_size]
|
|
obj_view = obj_view.type(torch.uint8)
|
|
offset += obj_size
|
|
object_list[i] = dist._tensor_to_object(obj_view, obj_size, group)
|
|
|
|
|
|
# =============================================================================
|
|
# Environment and URL Helpers
|
|
# =============================================================================
|
|
|
|
|
|
def get_inference_urls(num_inference_nodes: int) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[List[str]]]:
|
|
"""
|
|
Get rendezvous URLs for connecting to inference nodes.
|
|
|
|
In SLURM environments, parses SLURM_JOB_NODELIST to find inference servers.
|
|
For local testing, returns localhost URLs.
|
|
|
|
Args:
|
|
num_inference_nodes: Number of inference nodes (from config)
|
|
|
|
Returns:
|
|
Tuple of (master_server, master_gloo_server, master_inference_server, nodelist)
|
|
All None if inference nodes not configured.
|
|
"""
|
|
if num_inference_nodes > 0:
|
|
# Multi-node SLURM environment
|
|
nodelist_raw = os.popen(
|
|
f'scontrol show hostnames {os.environ.get("SLURM_JOB_NODELIST", "")}'
|
|
).read()
|
|
nodelist = [n for n in nodelist_raw.split("\n") if n]
|
|
|
|
if not nodelist:
|
|
return None, None, None, None
|
|
|
|
master_server = f"{nodelist[0]}:26756"
|
|
master_gloo_server = f"{nodelist[0]}:26757"
|
|
# Inference nodes are the last N nodes
|
|
inference_nodes = nodelist[-num_inference_nodes:]
|
|
master_inference_server = f"{inference_nodes[0]}:26758"
|
|
|
|
return master_server, master_gloo_server, master_inference_server, inference_nodes
|
|
|
|
elif num_inference_nodes == 0:
|
|
# Single-node local mode
|
|
return "localhost:26756", "localhost:26757", "localhost:26758", ["localhost"]
|
|
|
|
else:
|
|
return None, None, None, None
|
|
|
|
|
|
def get_local_hostname() -> Optional[List[str]]:
|
|
"""Get the local hostname(s) from /etc/hosts for rank determination."""
|
|
my_ip = socket.gethostbyname(socket.gethostname())
|
|
my_hostname = socket.gethostname()
|
|
|
|
try:
|
|
with open("/etc/hosts", "r") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line and not line.startswith("#"):
|
|
parts = line.split()
|
|
if len(parts) >= 2 and (parts[0] == my_ip or my_hostname in parts):
|
|
ip = parts[0]
|
|
if ip.startswith("127."):
|
|
continue
|
|
return parts
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
return [my_ip, my_hostname]
|
|
|
|
|
|
# =============================================================================
|
|
# Tensor Mapping and Permutation Helpers
|
|
# =============================================================================
|
|
|
|
|
|
def permute_for_rotary(w: torch.Tensor, n_heads: int) -> torch.Tensor:
|
|
"""
|
|
Permute weight tensor for sliced rotary embeddings.
|
|
|
|
vLLM and some model implementations use different layouts for Q/K projections.
|
|
This converts between them.
|
|
"""
|
|
dim1, dim2 = w.shape[0], w.shape[1]
|
|
return (
|
|
w.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
|
.transpose(1, 2)
|
|
.reshape(dim1, dim2)
|
|
)
|
|
|
|
|
|
def permute_for_rotary_1d(w: torch.Tensor, n_heads: int) -> torch.Tensor:
|
|
"""Permute 1D tensor (bias) for sliced rotary embeddings."""
|
|
dim1 = w.shape[0]
|
|
return w.view(n_heads, dim1 // n_heads // 2, 2).transpose(1, 2).reshape(dim1)
|
|
|
|
|
|
def get_name_conversions(param_mappings: Dict[str, Any]) -> Dict[str, List[str]]:
|
|
"""
|
|
Build a mapping from vLLM parameter names to trainer parameter names.
|
|
|
|
vLLM may split or combine parameters differently than HuggingFace models.
|
|
This helps translate between naming conventions.
|
|
"""
|
|
name_conversions = defaultdict(list)
|
|
for name, info in param_mappings.items():
|
|
vllm_name = info.get("vllm_name", name)
|
|
name_conversions[vllm_name].append(name)
|
|
return dict(name_conversions)
|
|
|
|
|
|
# =============================================================================
|
|
# Bridge Configuration
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class BridgeConfig:
|
|
"""Configuration for the vLLM weight bridge."""
|
|
|
|
# Process group settings
|
|
trainer_rank: int = 0
|
|
world_size: int = 1
|
|
init_method: str = "env://"
|
|
num_inference_nodes: int = 0
|
|
|
|
# Model settings
|
|
model_name: str = ""
|
|
device: str = "cuda"
|
|
|
|
# Synchronization settings
|
|
timeout_seconds: float = 300.0
|
|
log_dir: Optional[str] = None
|
|
|
|
# vLLM server URL for HTTP-based sync (local mode)
|
|
vllm_api_url: str = "http://localhost:9001"
|
|
|
|
# CUDA IPC mode: share GPU memory directly with vLLM (same GPU only!)
|
|
use_cuda_ipc: bool = False
|
|
|
|
# Derived from environment
|
|
num_gpus_per_node: int = field(default_factory=lambda: torch.cuda.device_count())
|
|
|
|
@property
|
|
def is_local_mode(self) -> bool:
|
|
"""
|
|
Local mode: single machine, no NCCL process groups needed.
|
|
Communication happens via HTTP to vLLM server.
|
|
"""
|
|
return self.num_inference_nodes == 0
|
|
|
|
@classmethod
|
|
def from_training_config(cls, config: Any) -> "BridgeConfig":
|
|
"""Create BridgeConfig from a TrainingConfig object."""
|
|
return cls(
|
|
trainer_rank=config.trainer_rank,
|
|
world_size=config.world_size,
|
|
init_method=config.init_method,
|
|
num_inference_nodes=config.num_inference_nodes,
|
|
model_name=config.model_name,
|
|
device=config.device,
|
|
log_dir=os.environ.get("LOGDIR"),
|
|
vllm_api_url=f"http://localhost:{getattr(config, 'vllm_port', 9001)}",
|
|
use_cuda_ipc=getattr(config, 'use_cuda_ipc', False),
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Main Bridge Class
|
|
# =============================================================================
|
|
|
|
|
|
class VLLMWeightBridge:
|
|
"""
|
|
Bridge for sharing model weights between trainer and vLLM inference server.
|
|
|
|
This class handles:
|
|
1. Joining the distributed process group with vLLM workers
|
|
2. Attaching to vLLM's model weight tensors
|
|
3. Providing a model interface for the trainer to optimize
|
|
4. Synchronizing updates so vLLM sees changes immediately
|
|
|
|
Usage:
|
|
bridge = VLLMWeightBridge(config)
|
|
bridge.initialize()
|
|
model = bridge.get_trainable_model()
|
|
optimizer = AdamW(model.parameters(), lr=1e-5)
|
|
|
|
for batch in data:
|
|
loss = compute_loss(model, batch)
|
|
loss.backward()
|
|
optimizer.step()
|
|
bridge.notify_update() # vLLM now uses new weights
|
|
"""
|
|
|
|
def __init__(self, config: BridgeConfig):
|
|
self.config = config
|
|
self.device = torch.device(config.device)
|
|
|
|
# Process groups (initialized in initialize())
|
|
self.nccl_group: Optional[dist.ProcessGroup] = None
|
|
self.gloo_group: Optional[dist.ProcessGroup] = None
|
|
|
|
# Parameter mappings (loaded from vLLM's JSON)
|
|
self.param_mappings: Dict[str, Any] = {}
|
|
self.name_conversions: Dict[str, List[str]] = {}
|
|
|
|
# Shared tensors (attached in attach_to_vllm_weights())
|
|
self.shared_state_dict: Dict[str, torch.Tensor] = {}
|
|
|
|
# Model for training (created in get_trainable_model())
|
|
self._model: Optional[nn.Module] = None
|
|
|
|
# Synchronization state
|
|
self._update_count: int = 0
|
|
self._initialized: bool = False
|
|
|
|
def initialize(self) -> None:
|
|
"""
|
|
Initialize the bridge: join process groups and load parameter mappings.
|
|
|
|
In local mode (num_inference_nodes=0), skips NCCL setup and uses HTTP.
|
|
In distributed mode, creates NCCL/Gloo process groups.
|
|
|
|
This must be called before any other methods.
|
|
"""
|
|
if self._initialized:
|
|
return
|
|
|
|
print(f"[Bridge] Initializing weight bridge for rank {self.config.trainer_rank}")
|
|
|
|
if self.config.is_local_mode:
|
|
self._initialize_local_mode()
|
|
else:
|
|
self._initialize_distributed_mode()
|
|
|
|
self._initialized = True
|
|
|
|
def _initialize_local_mode(self) -> None:
|
|
"""
|
|
Initialize for local single-machine mode.
|
|
|
|
In local mode:
|
|
- No NCCL process groups (trainer and vLLM are separate processes)
|
|
- Communication via HTTP to vLLM's bridge endpoints
|
|
- Trainer loads its own model copy, OR uses CUDA IPC for true shared memory
|
|
"""
|
|
if self.config.use_cuda_ipc:
|
|
print("[Bridge] Using CUDA IPC MODE (true shared GPU memory)")
|
|
else:
|
|
print("[Bridge] Using LOCAL MODE (HTTP-based sync, no NCCL)")
|
|
print(f"[Bridge] vLLM API URL: {self.config.vllm_api_url}")
|
|
|
|
# Verify vLLM server is reachable
|
|
try:
|
|
import requests
|
|
response = requests.get(f"{self.config.vllm_api_url}/health", timeout=5)
|
|
if response.status_code == 200:
|
|
print("[Bridge] vLLM server is reachable")
|
|
else:
|
|
print(f"[Bridge] Warning: vLLM health check returned {response.status_code}")
|
|
except Exception as e:
|
|
print(f"[Bridge] Warning: Could not reach vLLM server: {e}")
|
|
print("[Bridge] Training will continue, but vLLM sync may not work")
|
|
|
|
# For CUDA IPC mode, request vLLM to export IPC handles
|
|
if self.config.use_cuda_ipc:
|
|
self._request_cuda_ipc_export()
|
|
self._load_cuda_ipc_handles()
|
|
|
|
# Load parameter mappings if available (optional in local mode)
|
|
try:
|
|
self._load_param_mappings()
|
|
except RuntimeError:
|
|
print("[Bridge] Parameter mapping file not found (optional in local mode)")
|
|
self.param_mappings = {}
|
|
|
|
def _request_cuda_ipc_export(self) -> None:
|
|
"""Request vLLM to export CUDA IPC handles."""
|
|
import requests
|
|
|
|
print("[Bridge] Requesting CUDA IPC handles from vLLM...")
|
|
try:
|
|
response = requests.post(
|
|
f"{self.config.vllm_api_url}/bridge/export_cuda_ipc",
|
|
timeout=60
|
|
)
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
print(f"[Bridge] vLLM exported {result.get('num_parameters', 0)} IPC handles")
|
|
else:
|
|
raise RuntimeError(f"Failed to export IPC handles: {response.status_code}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Could not request CUDA IPC export: {e}")
|
|
|
|
def _load_cuda_ipc_handles(self) -> None:
|
|
"""
|
|
Load CUDA IPC handles from file and reconstruct shared tensors.
|
|
|
|
This is the key to TRUE shared memory - the tensors we create here
|
|
point to the SAME GPU memory that vLLM is using!
|
|
"""
|
|
import base64
|
|
import pickle
|
|
|
|
log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".")
|
|
ipc_path = Path(log_dir) / "cuda_ipc_handles.json"
|
|
|
|
# Wait for file to be created
|
|
wait_time = 0
|
|
while not ipc_path.exists() and wait_time < self.config.timeout_seconds:
|
|
print(f"[Bridge] Waiting for {ipc_path}...")
|
|
time.sleep(1)
|
|
wait_time += 1
|
|
|
|
if not ipc_path.exists():
|
|
raise RuntimeError(f"CUDA IPC handles file not found: {ipc_path}")
|
|
|
|
with open(ipc_path, "r") as f:
|
|
data = json.load(f)
|
|
|
|
handles_data = data.get("handles", {})
|
|
|
|
print(f"[Bridge] Reconstructing {len(handles_data)} shared tensors from IPC handles...")
|
|
|
|
reconstructed = 0
|
|
for name, info in handles_data.items():
|
|
try:
|
|
# Decode the IPC handle
|
|
handle_bytes = base64.b64decode(info["ipc_handle"])
|
|
handle = pickle.loads(handle_bytes)
|
|
|
|
# Reconstruct the storage from the IPC handle
|
|
# This does NOT allocate new memory - it maps to existing memory!
|
|
device = torch.device(f"cuda:{info['device_index']}")
|
|
|
|
# Get dtype
|
|
dtype_str = info["dtype"]
|
|
dtype = getattr(torch, dtype_str.replace("torch.", ""))
|
|
|
|
# Reconstruct tensor from IPC handle
|
|
# The storage is shared with vLLM's process
|
|
storage = torch.cuda.Storage._new_shared_cuda(*handle)
|
|
|
|
# Create tensor view of the shared storage
|
|
tensor = torch.tensor([], dtype=dtype, device=device)
|
|
tensor.set_(
|
|
storage,
|
|
info["storage_offset"],
|
|
info["shape"],
|
|
info["stride"]
|
|
)
|
|
|
|
# Store in shared_state_dict
|
|
self.shared_state_dict[name] = tensor
|
|
reconstructed += 1
|
|
|
|
except Exception as e:
|
|
print(f"[Bridge] Warning: Could not reconstruct {name}: {e}")
|
|
continue
|
|
|
|
print(f"[Bridge] Successfully reconstructed {reconstructed} shared tensors")
|
|
print(f"[Bridge] Memory savings: ~{reconstructed * 4 / 1024:.1f} GB (no model copy needed!)")
|
|
|
|
def _initialize_distributed_mode(self) -> None:
|
|
"""
|
|
Initialize for distributed multi-node mode.
|
|
|
|
Creates NCCL and Gloo process groups for direct tensor sharing.
|
|
"""
|
|
print("[Bridge] Using DISTRIBUTED MODE (NCCL tensor sharing)")
|
|
|
|
# Get rendezvous URLs
|
|
master_addr, master_gloo_addr, master_inference_addr, nodelist = get_inference_urls(
|
|
self.config.num_inference_nodes
|
|
)
|
|
|
|
if master_addr is None:
|
|
raise RuntimeError(
|
|
"Could not determine inference server URLs. "
|
|
"Set NUM_INFERENCE_NODES environment variable or check SLURM_JOB_NODELIST."
|
|
)
|
|
|
|
print(f"[Bridge] Master address: {master_addr}")
|
|
print(f"[Bridge] Inference nodes: {nodelist}")
|
|
|
|
# Load parameter mappings from vLLM
|
|
self._load_param_mappings()
|
|
|
|
# Calculate total group size (trainers + inference workers)
|
|
num_training_gpus = self._get_num_training_gpus()
|
|
# In distributed mode, each inference node contributes num_gpus_per_node workers
|
|
num_inference_gpus = self.config.num_inference_nodes * self.config.num_gpus_per_node
|
|
|
|
total_group_size = num_training_gpus + num_inference_gpus
|
|
trainer_rank_in_group = self.config.trainer_rank
|
|
|
|
print(f"[Bridge] Training GPUs: {num_training_gpus}, Inference GPUs: {num_inference_gpus}")
|
|
print(f"[Bridge] Total group size: {total_group_size}, Trainer rank: {trainer_rank_in_group}")
|
|
|
|
# Initialize NCCL group for tensor transfers
|
|
self.nccl_group = init_process_group(
|
|
backend="nccl",
|
|
init_method=f"tcp://{master_addr}",
|
|
world_size=total_group_size,
|
|
rank=trainer_rank_in_group,
|
|
group_name="weight_update_group",
|
|
timeout=timedelta(seconds=self.config.timeout_seconds),
|
|
)
|
|
print("[Bridge] NCCL process group initialized")
|
|
|
|
# Initialize Gloo group for metadata/coordination
|
|
self.gloo_group = init_process_group(
|
|
backend="gloo",
|
|
init_method=f"tcp://{master_gloo_addr}",
|
|
world_size=total_group_size,
|
|
rank=trainer_rank_in_group,
|
|
group_name="gloo_group",
|
|
timeout=timedelta(seconds=self.config.timeout_seconds),
|
|
)
|
|
print("[Bridge] Gloo process group initialized")
|
|
|
|
def _load_param_mappings(self) -> None:
|
|
"""Load parameter name mappings from vLLM's exported JSON."""
|
|
log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".")
|
|
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
|
|
|
# Wait for vLLM to write the mapping file
|
|
wait_time = 0
|
|
while not json_path.exists() and wait_time < self.config.timeout_seconds:
|
|
print(f"[Bridge] Waiting for {json_path} to be created...")
|
|
time.sleep(1)
|
|
wait_time += 1
|
|
|
|
if not json_path.exists():
|
|
raise RuntimeError(
|
|
f"Parameter mapping file not found at {json_path}. "
|
|
"Make sure vLLM is running and has exported its parameter mappings."
|
|
)
|
|
|
|
# Small delay to ensure file is fully written
|
|
time.sleep(1)
|
|
|
|
with open(json_path, "r") as f:
|
|
data = json.load(f)
|
|
|
|
self.param_mappings = data.get("param_mappings", {})
|
|
self.name_conversions = get_name_conversions(self.param_mappings)
|
|
|
|
print(f"[Bridge] Loaded mappings for {len(self.param_mappings)} parameters")
|
|
|
|
def _get_num_training_gpus(self) -> int:
|
|
"""Get number of training GPUs from param mappings or config."""
|
|
if self.param_mappings:
|
|
# Try to get from vLLM's exported info
|
|
return self.param_mappings.get("dp_shard_degree", 1) * self.param_mappings.get("tp_degree", 1)
|
|
return self.config.world_size
|
|
|
|
def attach_to_vllm_weights(self, vllm_state_dict: Dict[str, torch.Tensor]) -> None:
|
|
"""
|
|
Attach to vLLM's weight tensors.
|
|
|
|
After this call, self.shared_state_dict contains references to the
|
|
actual tensors that vLLM uses for inference. Modifying these tensors
|
|
will immediately affect vLLM's outputs.
|
|
|
|
Args:
|
|
vllm_state_dict: vLLM's model state_dict (actual tensors, not copies)
|
|
"""
|
|
self.shared_state_dict = vllm_state_dict
|
|
print(f"[Bridge] Attached to {len(vllm_state_dict)} vLLM weight tensors")
|
|
|
|
# Log tensor info for debugging
|
|
for name, tensor in list(vllm_state_dict.items())[:5]:
|
|
print(f"[Bridge] {name}: {tensor.shape}, {tensor.dtype}, {tensor.device}")
|
|
if len(vllm_state_dict) > 5:
|
|
print(f"[Bridge] ... and {len(vllm_state_dict) - 5} more")
|
|
|
|
def get_trainable_model(self) -> nn.Module:
|
|
"""
|
|
Get a model whose parameters point to vLLM's shared tensors.
|
|
|
|
In CUDA IPC mode: shared_state_dict is populated from IPC handles during init.
|
|
In other modes: must call attach_to_vllm_weights() first.
|
|
|
|
This creates a HuggingFace model structure but replaces all parameters
|
|
with references to the shared tensors. When the optimizer updates these
|
|
parameters, it modifies vLLM's weights directly.
|
|
|
|
Returns:
|
|
An nn.Module ready for training with shared weights
|
|
"""
|
|
if self._model is not None:
|
|
return self._model
|
|
|
|
if not self.shared_state_dict:
|
|
if self.config.use_cuda_ipc:
|
|
raise RuntimeError(
|
|
"CUDA IPC mode enabled but no shared tensors found. "
|
|
"Check that vLLM exported IPC handles correctly."
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
"Must call attach_to_vllm_weights() before get_trainable_model()"
|
|
)
|
|
|
|
print(f"[Bridge] Creating trainable model for {self.config.model_name}")
|
|
if self.config.use_cuda_ipc:
|
|
print("[Bridge] Using CUDA IPC shared tensors (NO NEW GPU MEMORY!)")
|
|
|
|
# Load model config (not weights)
|
|
model_config = AutoConfig.from_pretrained(self.config.model_name)
|
|
|
|
# Create model with empty weights (meta device = no memory)
|
|
with torch.device("meta"):
|
|
model = AutoModelForCausalLM.from_config(model_config)
|
|
|
|
# Replace each parameter with the shared tensor
|
|
self._replace_parameters_with_shared(model)
|
|
|
|
# Move model structure to device (parameters already on device via IPC)
|
|
model.to(self.device)
|
|
self._model = model
|
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
print(f"[Bridge] Trainable model ready with {total_params:,} parameters")
|
|
|
|
if self.config.use_cuda_ipc:
|
|
# Verify memory savings
|
|
param_memory_gb = total_params * 2 / 1e9 # bfloat16 = 2 bytes
|
|
print(f"[Bridge] CUDA IPC memory savings: ~{param_memory_gb:.1f} GB (shared with vLLM)")
|
|
|
|
return model
|
|
|
|
def _replace_parameters_with_shared(self, model: nn.Module) -> None:
|
|
"""
|
|
Replace model parameters with references to shared vLLM tensors.
|
|
|
|
This is the key operation that makes weight sharing work. After this,
|
|
model.parameters() returns tensors that ARE vLLM's weights.
|
|
"""
|
|
replaced_count = 0
|
|
missing_params = []
|
|
|
|
for name, param in model.named_parameters():
|
|
# Convert HuggingFace param name to vLLM param name
|
|
vllm_name = self._hf_to_vllm_name(name)
|
|
|
|
if vllm_name in self.shared_state_dict:
|
|
shared_tensor = self.shared_state_dict[vllm_name]
|
|
|
|
# Create a new Parameter that wraps the shared tensor
|
|
# The key is that we're not copying - we're referencing the same storage
|
|
new_param = nn.Parameter(shared_tensor, requires_grad=True)
|
|
|
|
# Replace the parameter in the model
|
|
self._set_parameter(model, name, new_param)
|
|
replaced_count += 1
|
|
else:
|
|
missing_params.append(name)
|
|
|
|
print(f"[Bridge] Replaced {replaced_count} parameters with shared tensors")
|
|
if missing_params:
|
|
print(f"[Bridge] Warning: {len(missing_params)} parameters not found in shared state:")
|
|
for p in missing_params[:5]:
|
|
print(f"[Bridge] {p}")
|
|
|
|
def _hf_to_vllm_name(self, hf_name: str) -> str:
|
|
"""
|
|
Convert a HuggingFace parameter name to vLLM's naming convention.
|
|
|
|
vLLM may merge QKV projections, use different layer naming, etc.
|
|
This handles the translation.
|
|
"""
|
|
# Check if we have an explicit mapping
|
|
for vllm_name, hf_names in self.name_conversions.items():
|
|
if hf_name in hf_names:
|
|
return vllm_name
|
|
|
|
# Common transformations
|
|
# vLLM often uses: model.layers.N.self_attn.qkv_proj
|
|
# HF uses: model.layers.N.self_attn.q_proj, k_proj, v_proj
|
|
|
|
# For now, try the name as-is
|
|
return hf_name
|
|
|
|
def _set_parameter(self, model: nn.Module, name: str, new_param: nn.Parameter) -> None:
|
|
"""Set a parameter by dotted name path."""
|
|
parts = name.split(".")
|
|
module = model
|
|
for part in parts[:-1]:
|
|
module = getattr(module, part)
|
|
setattr(module, parts[-1], new_param)
|
|
|
|
def broadcast_weights_to_inference(self) -> None:
|
|
"""
|
|
Broadcast updated weights from trainer to inference workers.
|
|
|
|
Call this after optimizer.step() to push the new weights to all
|
|
vLLM inference processes. They will use the updated weights for
|
|
subsequent requests.
|
|
"""
|
|
if not self._initialized:
|
|
raise RuntimeError("Bridge not initialized. Call initialize() first.")
|
|
|
|
param_names = sorted(self.param_mappings.keys())
|
|
|
|
with torch.no_grad():
|
|
for idx, param_name in enumerate(param_names):
|
|
# Signal which parameter we're broadcasting
|
|
idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device)
|
|
dist.broadcast(idx_tensor, src=0, group=self.nccl_group)
|
|
|
|
# Get the tensor for this parameter
|
|
vllm_name = self.param_mappings[param_name].get("vllm_name", param_name)
|
|
if vllm_name not in self.shared_state_dict:
|
|
continue
|
|
|
|
tensor = self.shared_state_dict[vllm_name]
|
|
local_shape = self.param_mappings[param_name].get("local_shape", list(tensor.shape))
|
|
|
|
# Gather from all training ranks, then broadcast to inference
|
|
# (This handles FSDP/TP sharding if present)
|
|
dist.all_gather(
|
|
[torch.zeros(local_shape, dtype=tensor.dtype, device=self.device)
|
|
for _ in range(dist.get_world_size(self.nccl_group))],
|
|
tensor,
|
|
group=self.nccl_group,
|
|
)
|
|
|
|
self._update_count += 1
|
|
print(f"[Bridge] Broadcast update #{self._update_count} complete")
|
|
|
|
def notify_update(self) -> None:
|
|
"""
|
|
Notify inference workers that weights have been updated.
|
|
|
|
In local mode: sends HTTP request to vLLM's /bridge/notify_update endpoint
|
|
In distributed mode: broadcasts update counter via Gloo
|
|
"""
|
|
self._update_count += 1
|
|
|
|
if self.config.is_local_mode:
|
|
self._notify_update_http()
|
|
elif self.gloo_group is not None:
|
|
self._notify_update_distributed()
|
|
|
|
def _notify_update_http(self) -> None:
|
|
"""Notify vLLM via HTTP (local mode)."""
|
|
try:
|
|
import requests
|
|
response = requests.post(
|
|
f"{self.config.vllm_api_url}/bridge/notify_update",
|
|
json={
|
|
"update_count": self._update_count,
|
|
"trainer_rank": self.config.trainer_rank,
|
|
"timestamp": time.time(),
|
|
},
|
|
timeout=5,
|
|
)
|
|
if response.status_code != 200:
|
|
print(f"[Bridge] Warning: notify_update returned {response.status_code}")
|
|
except Exception as e:
|
|
# Don't fail training if vLLM notification fails
|
|
print(f"[Bridge] Warning: Could not notify vLLM: {e}")
|
|
|
|
def _notify_update_distributed(self) -> None:
|
|
"""Notify via Gloo broadcast (distributed mode)."""
|
|
update_tensor = torch.tensor([self._update_count], dtype=torch.long)
|
|
dist.broadcast(update_tensor, src=0, group=self.gloo_group)
|
|
|
|
def barrier(self) -> None:
|
|
"""Wait for all processes in the group to reach this point."""
|
|
if self.nccl_group is not None:
|
|
dist.barrier(group=self.nccl_group)
|
|
|
|
def cleanup(self) -> None:
|
|
"""Clean up process groups and resources."""
|
|
if self.nccl_group is not None:
|
|
dist.destroy_process_group(self.nccl_group)
|
|
self.nccl_group = None
|
|
|
|
if self.gloo_group is not None:
|
|
dist.destroy_process_group(self.gloo_group)
|
|
self.gloo_group = None
|
|
|
|
self._initialized = False
|
|
print("[Bridge] Cleaned up")
|
|
|
|
|
|
# =============================================================================
|
|
# Convenience Functions
|
|
# =============================================================================
|
|
|
|
|
|
def create_bridge_from_training_config(config: Any) -> VLLMWeightBridge:
|
|
"""
|
|
Create and initialize a VLLMWeightBridge from a TrainingConfig.
|
|
|
|
Args:
|
|
config: TrainingConfig object with bridge settings
|
|
|
|
Returns:
|
|
Initialized VLLMWeightBridge ready for use
|
|
"""
|
|
bridge_config = BridgeConfig.from_training_config(config)
|
|
bridge = VLLMWeightBridge(bridge_config)
|
|
bridge.initialize()
|
|
return bridge
|
|
|