mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
changes based on torchtitan
This commit is contained in:
parent
078dd4a333
commit
53b29472b4
7 changed files with 1535 additions and 1977 deletions
37
example_trainer/vllm_patching/__init__.py
Normal file
37
example_trainer/vllm_patching/__init__.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""
|
||||
vLLM Patching Module - Enables shared memory weight updates.
|
||||
|
||||
This module patches vLLM's GPUModelRunner to:
|
||||
1. Call share_memory_() on model weights after loading
|
||||
2. Spawn a daemon process that receives NCCL weight updates from trainers
|
||||
3. Enable real-time weight synchronization without restarting vLLM
|
||||
|
||||
Usage:
|
||||
# Import this BEFORE importing vllm
|
||||
from example_trainer.vllm_patching import apply_patches
|
||||
apply_patches()
|
||||
|
||||
# Then import vllm normally
|
||||
from vllm import AsyncLLM
|
||||
"""
|
||||
|
||||
from .patched_gpu_runner import PatchedGPUModelRunner, apply_patches
|
||||
from .weight_updater import weight_updater_process
|
||||
from .distributed_utils import (
|
||||
init_process_group,
|
||||
broadcast_object_list,
|
||||
get_inference_urls,
|
||||
get_json_data,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PatchedGPUModelRunner",
|
||||
"apply_patches",
|
||||
"weight_updater_process",
|
||||
"init_process_group",
|
||||
"broadcast_object_list",
|
||||
"get_inference_urls",
|
||||
"get_json_data",
|
||||
]
|
||||
|
||||
|
||||
328
example_trainer/vllm_patching/distributed_utils.py
Normal file
328
example_trainer/vllm_patching/distributed_utils.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
"""
|
||||
Distributed utilities for vLLM weight synchronization.
|
||||
|
||||
Provides process group initialization and communication helpers
|
||||
for coordinating weight updates between trainer and vLLM.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def init_process_group(
|
||||
backend: Optional[str] = None,
|
||||
init_method: Optional[str] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
store: Optional[Any] = None,
|
||||
group_name: str = "",
|
||||
pg_options: Optional[Any] = None,
|
||||
) -> dist.ProcessGroup:
|
||||
"""
|
||||
Initialize a custom process group for weight synchronization.
|
||||
|
||||
This creates a named process group that coexists with vLLM's internal
|
||||
process groups, enabling direct tensor communication between trainer
|
||||
and inference processes.
|
||||
|
||||
Args:
|
||||
backend: "nccl" for GPU, "gloo" for CPU
|
||||
init_method: Rendezvous URL (e.g., "tcp://host:port")
|
||||
timeout: How long to wait for other ranks
|
||||
world_size: Total number of processes
|
||||
rank: This process's rank
|
||||
store: Optional torch.distributed Store
|
||||
group_name: Name for this process group (must match across ranks)
|
||||
pg_options: Backend-specific options
|
||||
|
||||
Returns:
|
||||
ProcessGroup for collective operations
|
||||
"""
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_new_process_group_helper,
|
||||
_world,
|
||||
Backend,
|
||||
default_pg_timeout,
|
||||
PrefixStore,
|
||||
rendezvous,
|
||||
)
|
||||
|
||||
assert (store is None) or (init_method is None), \
|
||||
"Cannot specify both init_method and store."
|
||||
|
||||
if store is not None:
|
||||
assert world_size > 0, "world_size must be positive if using store"
|
||||
assert rank >= 0, "rank must be non-negative if using store"
|
||||
elif init_method is None:
|
||||
init_method = "env://"
|
||||
|
||||
if backend:
|
||||
backend = Backend(backend)
|
||||
else:
|
||||
backend = Backend("undefined")
|
||||
|
||||
if timeout is None:
|
||||
timeout = default_pg_timeout
|
||||
|
||||
# Create store via rendezvous if not provided
|
||||
if store is None:
|
||||
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
|
||||
store, rank, world_size = next(rendezvous_iterator)
|
||||
store.set_timeout(timeout)
|
||||
store = PrefixStore(group_name, store)
|
||||
|
||||
# Handle PyTorch version differences for pg_options parameter
|
||||
pg_options_param_name = (
|
||||
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
|
||||
)
|
||||
|
||||
pg, _ = _new_process_group_helper(
|
||||
world_size,
|
||||
rank,
|
||||
[],
|
||||
backend,
|
||||
store,
|
||||
group_name=group_name,
|
||||
**{pg_options_param_name: pg_options},
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||
|
||||
return pg
|
||||
|
||||
|
||||
def broadcast_object_list(
|
||||
object_list: List[Any],
|
||||
src: Optional[int] = None,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
group_src: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Broadcast a list of objects from source rank to all other ranks.
|
||||
|
||||
Modified from torch.distributed.broadcast_object_list to work correctly
|
||||
with custom process groups where rank 0 may not be the default group's rank 0.
|
||||
|
||||
Args:
|
||||
object_list: List of objects to broadcast (modified in-place on receivers)
|
||||
src: Global source rank (deprecated, use group_src)
|
||||
group: Process group to use
|
||||
device: Device for temporary tensors
|
||||
group_src: Source rank within the group
|
||||
"""
|
||||
global_src = group_src if group_src is not None else src
|
||||
current_device = device
|
||||
|
||||
# Broadcast object sizes first
|
||||
object_sizes_tensor = torch.empty(
|
||||
len(object_list), dtype=torch.long, device=current_device
|
||||
)
|
||||
dist.broadcast(object_sizes_tensor, src=global_src, group=group)
|
||||
|
||||
# Broadcast serialized objects
|
||||
object_tensor = torch.empty(
|
||||
torch.sum(object_sizes_tensor).item(),
|
||||
dtype=torch.uint8,
|
||||
device=current_device,
|
||||
)
|
||||
dist.broadcast(object_tensor, src=global_src, group=group)
|
||||
|
||||
# Deserialize objects
|
||||
offset = 0
|
||||
for i, obj_size in enumerate(object_sizes_tensor):
|
||||
obj_view = object_tensor[offset : offset + obj_size]
|
||||
obj_view = obj_view.type(torch.uint8)
|
||||
offset += obj_size
|
||||
object_list[i] = dist.distributed_c10d._tensor_to_object(
|
||||
obj_view, obj_size, group
|
||||
)
|
||||
|
||||
|
||||
def get_inference_urls(num_inference_nodes: int = 0) -> Tuple[Optional[str], ...]:
|
||||
"""
|
||||
Get URLs for inference server communication.
|
||||
|
||||
Parses SLURM environment or uses localhost for single-machine setup.
|
||||
|
||||
Args:
|
||||
num_inference_nodes: Number of dedicated inference nodes.
|
||||
0 = single machine, trainer and vLLM share the node
|
||||
>0 = multi-node, last N nodes are for inference
|
||||
|
||||
Returns:
|
||||
Tuple of (master_addr, master_gloo_addr, master_inference_addr, nodelist)
|
||||
Returns (None, None, None, None) if not in a valid setup.
|
||||
"""
|
||||
if num_inference_nodes > 0:
|
||||
# Multi-node SLURM setup
|
||||
slurm_nodelist = os.environ.get("SLURM_JOB_NODELIST")
|
||||
if not slurm_nodelist:
|
||||
return None, None, None, None
|
||||
|
||||
# Parse SLURM node list
|
||||
nodelist = (
|
||||
os.popen(f'scontrol show hostnames {slurm_nodelist}')
|
||||
.read()
|
||||
.strip()
|
||||
.split("\n")
|
||||
)
|
||||
nodelist = [node for node in nodelist if node]
|
||||
|
||||
# First node is master for process groups
|
||||
master_server = f"{nodelist[0]}:26756"
|
||||
master_gloo_server = f"{nodelist[0]}:26757"
|
||||
|
||||
# Last N nodes are inference nodes
|
||||
inference_nodes = nodelist[-num_inference_nodes:]
|
||||
master_inference_server = f"{inference_nodes[0]}:26758"
|
||||
|
||||
return master_server, master_gloo_server, master_inference_server, inference_nodes
|
||||
|
||||
elif num_inference_nodes == 0:
|
||||
# Single machine setup
|
||||
master_server = "localhost:26756"
|
||||
master_gloo_server = "localhost:26757"
|
||||
master_inference_server = "localhost:26758"
|
||||
nodelist = ["localhost"]
|
||||
|
||||
return master_server, master_gloo_server, master_inference_server, nodelist
|
||||
|
||||
else:
|
||||
return None, None, None, None
|
||||
|
||||
|
||||
def get_hostnames() -> Optional[List[str]]:
|
||||
"""
|
||||
Get the hostnames for this machine.
|
||||
|
||||
Parses /etc/hosts to find all hostnames associated with this machine's IP.
|
||||
|
||||
Returns:
|
||||
List of [ip, hostname1, hostname2, ...] or None if not found.
|
||||
"""
|
||||
my_ip = socket.gethostbyname(socket.gethostname())
|
||||
my_hostname = socket.gethostname()
|
||||
|
||||
try:
|
||||
with open("/etc/hosts", "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
parts = line.split()
|
||||
if len(parts) >= 2 and ((parts[0] == my_ip) or (my_hostname in parts)):
|
||||
ip = parts[0]
|
||||
if ip.startswith("127."):
|
||||
continue
|
||||
return parts
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_json_data(log_dir: Optional[str] = None, timeout: int = 300) -> Dict[str, Any]:
|
||||
"""
|
||||
Load the bridge configuration JSON from vLLM.
|
||||
|
||||
Waits for the file to be created by vLLM's weight bridge setup.
|
||||
|
||||
Args:
|
||||
log_dir: Directory containing the JSON file (defaults to LOGDIR env var)
|
||||
timeout: Maximum seconds to wait for file
|
||||
|
||||
Returns:
|
||||
Parsed JSON data with parameter mappings and configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If LOGDIR not set and log_dir not provided
|
||||
FileNotFoundError: If file not found after timeout
|
||||
"""
|
||||
if log_dir is None:
|
||||
log_dir = os.environ.get("LOGDIR")
|
||||
if log_dir is None:
|
||||
raise ValueError("LOGDIR environment variable not set and log_dir not provided")
|
||||
|
||||
json_path = os.path.join(log_dir, "vllm_bridge_config.json")
|
||||
|
||||
wait_time = 0
|
||||
while not os.path.exists(json_path):
|
||||
if wait_time >= timeout:
|
||||
raise FileNotFoundError(f"Config file not found after {timeout}s: {json_path}")
|
||||
if wait_time % 10 == 0:
|
||||
print(f"[Updater] Waiting for {json_path}...", flush=True)
|
||||
time.sleep(1)
|
||||
wait_time += 1
|
||||
|
||||
# Wait a moment for file to finish writing
|
||||
time.sleep(0.5)
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_name_conversions(param_mappings: Dict[str, Any]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Build reverse mapping from vLLM names to trainer names.
|
||||
|
||||
Args:
|
||||
param_mappings: Dict mapping trainer param names to vLLM info
|
||||
|
||||
Returns:
|
||||
Dict mapping vLLM names to list of trainer names
|
||||
"""
|
||||
name_conversions = defaultdict(list)
|
||||
for name, info in param_mappings.items():
|
||||
vllm_name = info.get("vllm_name", name)
|
||||
name_conversions[vllm_name].append(name)
|
||||
return name_conversions
|
||||
|
||||
|
||||
# Permutation functions for rotary embeddings
|
||||
def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
|
||||
"""
|
||||
Permute weight tensor for sliced rotary embeddings.
|
||||
|
||||
Args:
|
||||
w: Weight tensor of shape [dim1, dim2]
|
||||
n_heads: Number of attention heads
|
||||
|
||||
Returns:
|
||||
Permuted tensor for rotary embedding compatibility
|
||||
"""
|
||||
dim1 = w.shape[0]
|
||||
dim2 = w.shape[1]
|
||||
return (
|
||||
w.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
||||
.transpose(1, 2)
|
||||
.reshape(dim1, dim2)
|
||||
)
|
||||
|
||||
|
||||
def permute_1d(w: torch.Tensor, n_heads: int) -> torch.Tensor:
|
||||
"""
|
||||
Permute 1D weight tensor (bias) for sliced rotary embeddings.
|
||||
|
||||
Args:
|
||||
w: Weight tensor of shape [dim1]
|
||||
n_heads: Number of attention heads
|
||||
|
||||
Returns:
|
||||
Permuted tensor
|
||||
"""
|
||||
dim1 = w.shape[0]
|
||||
return w.view(n_heads, dim1 // n_heads // 2, 2).transpose(1, 2).reshape(dim1)
|
||||
|
||||
|
||||
425
example_trainer/vllm_patching/weight_updater.py
Normal file
425
example_trainer/vllm_patching/weight_updater.py
Normal file
|
|
@ -0,0 +1,425 @@
|
|||
"""
|
||||
Weight Updater Process - Daemon that receives NCCL weight updates.
|
||||
|
||||
This process runs as a daemon spawned by the patched vLLM GPUModelRunner.
|
||||
It joins NCCL process groups with the trainer and receives weight updates,
|
||||
copying them directly into vLLM's shared memory tensors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .distributed_utils import (
|
||||
init_process_group,
|
||||
get_inference_urls,
|
||||
get_hostnames,
|
||||
get_json_data,
|
||||
get_name_conversions,
|
||||
permute,
|
||||
permute_1d,
|
||||
)
|
||||
|
||||
|
||||
def weight_updater_process(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
gpu_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Daemon process that receives weight updates from trainers via NCCL.
|
||||
|
||||
This runs inside a subprocess spawned by PatchedGPUModelRunner. It:
|
||||
1. Joins NCCL/Gloo process groups with the trainer
|
||||
2. Receives weight update broadcasts from rank 0 (trainer)
|
||||
3. Copies updated weights directly into the shared state_dict
|
||||
|
||||
Since state_dict tensors have share_memory_() called on them, the main
|
||||
vLLM process immediately sees the updates for inference.
|
||||
|
||||
Args:
|
||||
state_dict: Model state dict with shared memory tensors
|
||||
num_q_heads: Number of query attention heads (for permutation)
|
||||
num_kv_heads: Number of key/value attention heads
|
||||
tp_rank: Tensor parallel rank of this worker
|
||||
tp_size: Total tensor parallel size
|
||||
gpu_id: GPU device ID for this worker
|
||||
"""
|
||||
# Configuration from environment
|
||||
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", 0))
|
||||
cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",")
|
||||
debug = int(os.environ.get("WEIGHT_UPDATER_DEBUG", 0))
|
||||
|
||||
# Determine world size based on setup
|
||||
if num_inference_nodes > 0:
|
||||
# Multi-node: 8 GPUs per node
|
||||
world_size = num_inference_nodes * 8
|
||||
ranks_per_node = 8
|
||||
else:
|
||||
# Single node: typically 4 inference GPUs
|
||||
world_size = 4
|
||||
ranks_per_node = 4
|
||||
|
||||
# Get network info
|
||||
hostnames = get_hostnames()
|
||||
master_addr, master_gloo_addr, master_inference_addr, urls = get_inference_urls(
|
||||
num_inference_nodes
|
||||
)
|
||||
|
||||
if master_addr is None:
|
||||
print(f"[Updater] Master address not found, exiting", flush=True)
|
||||
return
|
||||
|
||||
# Set CUDA device
|
||||
torch.cuda.set_device(tp_rank)
|
||||
|
||||
print(
|
||||
f"[Updater] Starting on TP rank {tp_rank}/{tp_size}, "
|
||||
f"q_heads={num_q_heads}, kv_heads={num_kv_heads}, gpu_id={gpu_id}",
|
||||
flush=True,
|
||||
)
|
||||
print(f"[Updater] Master: {master_addr}, world_size={world_size}", flush=True)
|
||||
|
||||
# Determine this worker's rank within the inference group
|
||||
rank = -1
|
||||
if num_inference_nodes == 0:
|
||||
# Single node: skip first N GPUs (used by trainer)
|
||||
rank = int(cuda_devices[gpu_id]) - (8 - ranks_per_node)
|
||||
else:
|
||||
# Multi-node: find which inference node we're on
|
||||
for i, url in enumerate(urls):
|
||||
if hostnames and url in hostnames:
|
||||
rank = ranks_per_node * i + int(cuda_devices[gpu_id])
|
||||
break
|
||||
|
||||
if rank < 0:
|
||||
print(f"[Updater] Could not determine rank, exiting", flush=True)
|
||||
return
|
||||
|
||||
# Load config from vLLM
|
||||
print("[Updater] Loading bridge config...", flush=True)
|
||||
try:
|
||||
json_data = get_json_data()
|
||||
except Exception as e:
|
||||
print(f"[Updater] Failed to load config: {e}", flush=True)
|
||||
return
|
||||
|
||||
param_name_list = sorted(json_data.get("param_mappings", {}).keys())
|
||||
num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1)
|
||||
total_group_size = num_training_gpus + world_size
|
||||
|
||||
# Offset rank by training GPUs
|
||||
rank = rank + num_training_gpus
|
||||
|
||||
print(f"[Updater] Total group size: {total_group_size}", flush=True)
|
||||
print(f"[Updater] Training GPUs: {num_training_gpus}", flush=True)
|
||||
print(f"[Updater] My rank: {rank}", flush=True)
|
||||
|
||||
# Initialize process groups
|
||||
print("[Updater] Creating process groups...", flush=True)
|
||||
|
||||
try:
|
||||
# Gloo group for coordination
|
||||
gloo_group = init_process_group(
|
||||
backend="gloo",
|
||||
init_method=f"tcp://{master_addr}",
|
||||
world_size=total_group_size,
|
||||
rank=rank,
|
||||
group_name="gloo_group",
|
||||
)
|
||||
print("[Updater] ✓ Gloo group created", flush=True)
|
||||
|
||||
# NCCL group for tensor transfers
|
||||
nccl_group = init_process_group(
|
||||
backend="nccl",
|
||||
init_method=f"tcp://{master_addr}",
|
||||
world_size=total_group_size,
|
||||
rank=rank,
|
||||
group_name="weight_update_group",
|
||||
)
|
||||
print("[Updater] ✓ NCCL group created", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Updater] Failed to create process groups: {e}", flush=True)
|
||||
return
|
||||
|
||||
# Get device for tensors
|
||||
my_device = next(iter(state_dict.values())).device
|
||||
|
||||
# Write dtype mapping if rank 0
|
||||
if rank == num_training_gpus: # First inference rank
|
||||
_write_dtype_mapping(state_dict, json_data)
|
||||
|
||||
print("[Updater] Entering update loop...", flush=True)
|
||||
|
||||
# Buffers for merged QKV and gate_up projections
|
||||
qkv_buffer = {}
|
||||
gate_up_buffer = {}
|
||||
qkv_bias_buffer = {}
|
||||
w1w3_buffer = {}
|
||||
|
||||
with torch.no_grad():
|
||||
while True:
|
||||
try:
|
||||
# Receive parameter index from trainer (rank 0)
|
||||
obj_indx = torch.zeros(1, dtype=torch.long, device=my_device)
|
||||
dist.broadcast(obj_indx, src=0, group=nccl_group)
|
||||
|
||||
tt_indx = obj_indx.item()
|
||||
|
||||
# -1 signals no update this round (heartbeat)
|
||||
if tt_indx == -1:
|
||||
continue
|
||||
|
||||
# Get parameter info
|
||||
if tt_indx >= len(param_name_list):
|
||||
print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True)
|
||||
continue
|
||||
|
||||
tt_name = param_name_list[tt_indx]
|
||||
param_info = json_data["param_mappings"].get(tt_name, {})
|
||||
vllm_name = param_info.get("vllm_name", tt_name)
|
||||
local_shape = param_info.get("local_shape", [])
|
||||
|
||||
if vllm_name not in state_dict:
|
||||
if debug:
|
||||
print(f"[Updater] {vllm_name} not in state_dict, skipping", flush=True)
|
||||
continue
|
||||
|
||||
target_dtype = state_dict[vllm_name].dtype
|
||||
|
||||
if debug:
|
||||
print(
|
||||
f"[Updater] Receiving {tt_name} -> {vllm_name}, "
|
||||
f"shape={local_shape}, dtype={target_dtype}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Gather tensors from all training ranks
|
||||
tensor_list = [
|
||||
torch.zeros(
|
||||
local_shape if idx < num_training_gpus else [1],
|
||||
dtype=target_dtype,
|
||||
device=my_device,
|
||||
)
|
||||
for idx in range(total_group_size)
|
||||
]
|
||||
|
||||
dist.all_gather(
|
||||
tensor_list,
|
||||
torch.zeros(1, dtype=target_dtype, device=my_device),
|
||||
group=nccl_group,
|
||||
)
|
||||
|
||||
# Only keep training tensors
|
||||
tensor_list = tensor_list[:num_training_gpus]
|
||||
|
||||
# Merge tensors from different parallel configurations
|
||||
tensor = _merge_tensors(
|
||||
tensor_list,
|
||||
json_data,
|
||||
param_info,
|
||||
state_dict[vllm_name],
|
||||
)
|
||||
|
||||
# Apply updates (handling merged QKV, gate_up, etc.)
|
||||
_apply_weight_update(
|
||||
state_dict,
|
||||
vllm_name,
|
||||
tt_name,
|
||||
tensor,
|
||||
param_info,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
qkv_buffer,
|
||||
gate_up_buffer,
|
||||
qkv_bias_buffer,
|
||||
w1w3_buffer,
|
||||
debug,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Updater] Error in update loop: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _write_dtype_mapping(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
json_data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Write dtype mapping file for trainer reference."""
|
||||
try:
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
name_conversions = get_name_conversions(json_data.get("param_mappings", {}))
|
||||
|
||||
weight_dtypes = {}
|
||||
for name in state_dict.keys():
|
||||
tt_names = name_conversions.get(name, [name])
|
||||
for tt_name in tt_names:
|
||||
weight_dtypes[tt_name] = str(state_dict[name].dtype).split(".")[-1]
|
||||
|
||||
with open(f"{log_dir}/vllm_dtypes.json", "w") as f:
|
||||
json.dump(weight_dtypes, f, indent=2)
|
||||
|
||||
print("[Updater] Wrote dtype mapping", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Updater] Failed to write dtype mapping: {e}", flush=True)
|
||||
|
||||
|
||||
def _merge_tensors(
|
||||
tensor_list: List[torch.Tensor],
|
||||
json_data: Dict[str, Any],
|
||||
param_info: Dict[str, Any],
|
||||
target_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Merge tensors from distributed training into single tensor.
|
||||
|
||||
Handles FSDP (data parallel) and TP (tensor parallel) sharding.
|
||||
"""
|
||||
dp_shard_degree = json_data.get("dp_shard_degree", 1)
|
||||
tp_degree = json_data.get("tp_degree", 1)
|
||||
tp_shard_dim = param_info.get("tp_shard_dim", 0)
|
||||
|
||||
if dp_shard_degree > 1:
|
||||
# First merge across data parallel dimension
|
||||
tp_tensors = []
|
||||
for i in range(tp_degree):
|
||||
dp_tensors = tensor_list[i::tp_degree]
|
||||
tp_tensors.append(torch.cat(dp_tensors, dim=0))
|
||||
|
||||
# Then merge across tensor parallel dimension if needed
|
||||
if tp_degree > 1:
|
||||
if tp_tensors[0].shape == target_tensor.shape:
|
||||
tensor = tp_tensors[0].contiguous()
|
||||
else:
|
||||
tensor = torch.cat(tp_tensors, dim=tp_shard_dim).contiguous()
|
||||
else:
|
||||
tensor = tp_tensors[0].contiguous()
|
||||
else:
|
||||
# No FSDP, just merge TP shards
|
||||
tensor = torch.cat(tensor_list, dim=tp_shard_dim).contiguous()
|
||||
|
||||
# Cast to target dtype if needed
|
||||
if tensor.dtype != target_tensor.dtype:
|
||||
tensor = tensor.to(target_tensor.dtype)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def _apply_weight_update(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
vllm_name: str,
|
||||
tt_name: str,
|
||||
tensor: torch.Tensor,
|
||||
param_info: Dict[str, Any],
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
qkv_buffer: Dict[str, torch.Tensor],
|
||||
gate_up_buffer: Dict[str, torch.Tensor],
|
||||
qkv_bias_buffer: Dict[str, torch.Tensor],
|
||||
w1w3_buffer: Dict[str, torch.Tensor],
|
||||
debug: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Apply weight update to state_dict, handling merged projections.
|
||||
|
||||
vLLM often merges QKV projections and gate/up projections into single
|
||||
tensors for efficiency. This handles unpacking and merging correctly.
|
||||
"""
|
||||
needs_permute = param_info.get("needs_permute", False)
|
||||
shape = param_info.get("shape", list(tensor.shape))
|
||||
|
||||
def _debug_diff(name: str, old: torch.Tensor, new: torch.Tensor) -> None:
|
||||
if debug:
|
||||
diff = (new.float() - old.float()).abs()
|
||||
print(
|
||||
f"[WEIGHT DIFF] {name}: mean={diff.mean().item():.6e}, "
|
||||
f"std={diff.std().item():.6e}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Handle merged QKV projection weights
|
||||
if "qkv_proj.weight" in vllm_name:
|
||||
key_val = "q" if ".wq." in tt_name or "q_proj" in tt_name else \
|
||||
"v" if ".wv." in tt_name or "v_proj" in tt_name else "k"
|
||||
|
||||
if key_val == "q" and needs_permute:
|
||||
tensor = permute(tensor, num_q_heads)
|
||||
elif key_val == "k" and needs_permute:
|
||||
tensor = permute(tensor, num_kv_heads)
|
||||
|
||||
qkv_buffer[key_val] = tensor
|
||||
|
||||
if len(qkv_buffer) == 3:
|
||||
merged = torch.cat([qkv_buffer["q"], qkv_buffer["k"], qkv_buffer["v"]], dim=0)
|
||||
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
|
||||
state_dict[vllm_name].data.copy_(merged.contiguous())
|
||||
qkv_buffer.clear()
|
||||
|
||||
# Handle merged gate/up projection weights
|
||||
elif "gate_up_proj.weight" in vllm_name:
|
||||
key_val = "w1" if ".w1." in tt_name or "gate_proj" in tt_name else "w3"
|
||||
gate_up_buffer[key_val] = tensor
|
||||
|
||||
if len(gate_up_buffer) == 2:
|
||||
merged = torch.cat([gate_up_buffer["w1"], gate_up_buffer["w3"]], dim=0)
|
||||
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
|
||||
state_dict[vllm_name].data.copy_(merged.contiguous())
|
||||
gate_up_buffer.clear()
|
||||
|
||||
# Handle merged w1/w3 weights (alternative naming)
|
||||
elif "w13_weight" in vllm_name:
|
||||
key_val = "w1" if ".w1" in tt_name else "w3"
|
||||
w1w3_buffer[key_val] = tensor
|
||||
|
||||
if len(w1w3_buffer) == 2:
|
||||
merged = torch.cat([w1w3_buffer["w1"], w1w3_buffer["w3"]], dim=1)
|
||||
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
|
||||
state_dict[vllm_name].data.copy_(merged.contiguous())
|
||||
w1w3_buffer.clear()
|
||||
|
||||
# Handle merged QKV bias
|
||||
elif "qkv_proj.bias" in vllm_name:
|
||||
key_val = "q" if ".wq." in tt_name else "v" if ".wv." in tt_name else "k"
|
||||
|
||||
if key_val == "q" and needs_permute:
|
||||
tensor = permute_1d(tensor, num_q_heads)
|
||||
elif key_val == "k" and needs_permute:
|
||||
tensor = permute_1d(tensor, num_kv_heads)
|
||||
|
||||
qkv_bias_buffer[key_val] = tensor
|
||||
|
||||
if len(qkv_bias_buffer) == 3:
|
||||
merged = torch.cat([qkv_bias_buffer["q"], qkv_bias_buffer["k"], qkv_bias_buffer["v"]], dim=0)
|
||||
_debug_diff(vllm_name, state_dict[vllm_name].data, merged)
|
||||
state_dict[vllm_name].data.copy_(merged.contiguous())
|
||||
qkv_bias_buffer.clear()
|
||||
|
||||
# Handle regular weights (possibly needing permutation)
|
||||
elif needs_permute:
|
||||
if len(shape) == 2:
|
||||
tensor = permute(tensor, shape[0]).contiguous()
|
||||
elif len(shape) == 1:
|
||||
tensor = permute_1d(tensor, shape[0]).contiguous()
|
||||
|
||||
_debug_diff(vllm_name, state_dict[vllm_name].data, tensor)
|
||||
state_dict[vllm_name].data.copy_(tensor)
|
||||
|
||||
# Simple weight copy
|
||||
else:
|
||||
_debug_diff(vllm_name, state_dict[vllm_name].data, tensor)
|
||||
state_dict[vllm_name].data.copy_(tensor)
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue