From ff8eaf9e3c71bf3a0f7077f4ec70b6b6c97e6bd5 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 29 Dec 2025 20:11:23 -0500 Subject: [PATCH] param locations update --- example_trainer/grpo.py | 4 + .../vllm_patching/patched_gpu_runner.py | 34 +- .../vllm_patching/weight_updater.py | 298 +++--------------- example_trainer/vllm_weight_bridge.py | 65 ++-- 4 files changed, 117 insertions(+), 284 deletions(-) diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 7904950a..edbc5af0 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -1057,6 +1057,10 @@ def train_shared_vllm(config: TrainingConfig): print("[2/3] Loading model with shared weights...") model, tokenizer = load_model_and_tokenizer(config, bridge=bridge) optimizer = AdamW(model.parameters(), lr=config.lr) + + # For NCCL mode, set param list from trainer's model + if config.use_shared_memory: + bridge.set_param_list_from_model(model) print(f"[3/3] Starting training for {config.training_steps} steps") print("NOTE: vLLM sees weight updates immediately after each step!") diff --git a/example_trainer/vllm_patching/patched_gpu_runner.py b/example_trainer/vllm_patching/patched_gpu_runner.py index bda9b3fe..76fd5dca 100644 --- a/example_trainer/vllm_patching/patched_gpu_runner.py +++ b/example_trainer/vllm_patching/patched_gpu_runner.py @@ -111,19 +111,47 @@ def _create_patched_runner(BaseRunner: type) -> type: traceback.print_exc() def _setup_shared_memory(self) -> None: - """Move model tensors to shared memory.""" + """Move model tensors to shared memory and export param info.""" + import json + from pathlib import Path + # Make entire model shareable self.model.share_memory() # Also share_memory_() on each parameter individually - # (some implementations may need this) state_dict = self.model.state_dict() for key, val in state_dict.items(): if val.is_cuda or val.device.type == 'cuda': - # For CUDA tensors, we need to ensure they're in shared memory val.share_memory_() print(f"[vLLM Patch] Shared {len(state_dict)} tensors in model") + + # Export parameter info to JSON for trainer + log_dir = os.environ.get("LOGDIR", ".") + json_path = Path(log_dir) / "vllm_bridge_config.json" + + param_mappings = {} + for name, tensor in state_dict.items(): + param_mappings[name] = { + "vllm_name": name, + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + } + + info = { + "model": str(self.model_config.model), + "tp_degree": self.parallel_config.tensor_parallel_size, + "dp_shard_degree": 1, + "param_mappings": param_mappings, + "param_names": sorted(state_dict.keys()), + } + + try: + with open(json_path, "w") as f: + json.dump(info, f, indent=2) + print(f"[vLLM Patch] Exported {len(param_mappings)} params to {json_path}") + except Exception as e: + print(f"[vLLM Patch] Warning: Failed to export params: {e}") def _spawn_weight_updater(self) -> None: """Spawn the daemon process for receiving weight updates.""" diff --git a/example_trainer/vllm_patching/weight_updater.py b/example_trainer/vllm_patching/weight_updater.py index d9a67cde..5d39569e 100644 --- a/example_trainer/vllm_patching/weight_updater.py +++ b/example_trainer/vllm_patching/weight_updater.py @@ -8,10 +8,9 @@ copying them directly into vLLM's shared memory tensors. from __future__ import annotations -import json import os import time -from typing import Any, Dict, List, Optional +from typing import Dict import torch import torch.distributed as dist @@ -20,10 +19,6 @@ from .distributed_utils import ( init_process_group, get_inference_urls, get_hostnames, - get_json_data, - get_name_conversions, - permute, - permute_1d, ) @@ -110,28 +105,18 @@ def weight_updater_process( print(f"[Updater] Master: {master_addr}, world_size={world_size}, my_rank={rank}", flush=True) - # Load config from vLLM (optional - may not exist for simple setups) - print("[Updater] Loading bridge config...", flush=True) - json_data = {} - param_name_list = [] - try: - json_data = get_json_data() - param_name_list = sorted(json_data.get("param_mappings", {}).keys()) - print(f"[Updater] Loaded {len(param_name_list)} parameter mappings", flush=True) - except Exception as e: - print(f"[Updater] No config file found (will receive all params): {e}", flush=True) + # Use state_dict keys as parameter list (we already have the model!) + param_name_list = sorted(state_dict.keys()) + print(f"[Updater] Model has {len(param_name_list)} parameters", flush=True) # Use the world_size and rank we already calculated total_group_size = world_size # For single-node mode, trainer is rank 0 - if num_inference_nodes == 0: - num_training_gpus = 1 - else: - num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1) + num_training_ranks = 1 if num_inference_nodes == 0 else 1 print(f"[Updater] Total group size: {total_group_size}", flush=True) - print(f"[Updater] Training ranks: {num_training_gpus}", flush=True) + print(f"[Updater] Training ranks: {num_training_ranks}", flush=True) print(f"[Updater] My rank: {rank}", flush=True) # Initialize process groups @@ -167,17 +152,18 @@ def weight_updater_process( # Get device for tensors my_device = next(iter(state_dict.values())).device - # Write dtype mapping if rank 0 - if rank == num_training_gpus: # First inference rank - _write_dtype_mapping(state_dict, json_data) + # Build param info dict from state_dict + param_info_dict = {} + for name, tensor in state_dict.items(): + param_info_dict[name] = { + "shape": list(tensor.shape), + "dtype": tensor.dtype, + } print("[Updater] Entering update loop...", flush=True) + print(f"[Updater] Waiting for weight updates from trainer (rank 0)...", flush=True) - # Buffers for merged QKV and gate_up projections - qkv_buffer = {} - gate_up_buffer = {} - qkv_bias_buffer = {} - w1w3_buffer = {} + update_count = 0 with torch.no_grad(): while True: @@ -188,76 +174,43 @@ def weight_updater_process( tt_indx = obj_indx.item() - # -1 signals no update this round (heartbeat) + # -1 signals heartbeat (no update) if tt_indx == -1: continue + # -2 signals shutdown + if tt_indx == -2: + print("[Updater] Received shutdown signal", flush=True) + break + # Get parameter info - if tt_indx >= len(param_name_list): - print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True) - continue - - tt_name = param_name_list[tt_indx] - param_info = json_data["param_mappings"].get(tt_name, {}) - vllm_name = param_info.get("vllm_name", tt_name) - local_shape = param_info.get("local_shape", []) - - if vllm_name not in state_dict: + if tt_indx < 0 or tt_indx >= len(param_name_list): if debug: - print(f"[Updater] {vllm_name} not in state_dict, skipping", flush=True) + print(f"[Updater] Invalid index {tt_indx}, skipping", flush=True) continue - target_dtype = state_dict[vllm_name].dtype + param_name = param_name_list[tt_indx] - if debug: - print( - f"[Updater] Receiving {tt_name} -> {vllm_name}, " - f"shape={local_shape}, dtype={target_dtype}", - flush=True, - ) + if param_name not in state_dict: + if debug: + print(f"[Updater] {param_name} not in state_dict, skipping", flush=True) + continue - # Gather tensors from all training ranks - tensor_list = [ - torch.zeros( - local_shape if idx < num_training_gpus else [1], - dtype=target_dtype, - device=my_device, - ) - for idx in range(total_group_size) - ] + target_tensor = state_dict[param_name] + target_shape = list(target_tensor.shape) + target_dtype = target_tensor.dtype - dist.all_gather( - tensor_list, - torch.zeros(1, dtype=target_dtype, device=my_device), - group=nccl_group, - ) + # Receive the tensor from trainer + # Trainer sends via broadcast, we receive + received_tensor = torch.zeros(target_shape, dtype=target_dtype, device=my_device) + dist.broadcast(received_tensor, src=0, group=nccl_group) - # Only keep training tensors - tensor_list = tensor_list[:num_training_gpus] + # Copy to shared memory + state_dict[param_name].data.copy_(received_tensor) - # Merge tensors from different parallel configurations - tensor = _merge_tensors( - tensor_list, - json_data, - param_info, - state_dict[vllm_name], - ) - - # Apply updates (handling merged QKV, gate_up, etc.) - _apply_weight_update( - state_dict, - vllm_name, - tt_name, - tensor, - param_info, - num_q_heads, - num_kv_heads, - qkv_buffer, - gate_up_buffer, - qkv_bias_buffer, - w1w3_buffer, - debug, - ) + update_count += 1 + if debug or (update_count % 50 == 0): + print(f"[Updater] Updated {param_name} (#{update_count})", flush=True) except Exception as e: print(f"[Updater] Error in update loop: {e}", flush=True) @@ -266,172 +219,5 @@ def weight_updater_process( time.sleep(1) -def _write_dtype_mapping( - state_dict: Dict[str, torch.Tensor], - json_data: Dict[str, Any], -) -> None: - """Write dtype mapping file for trainer reference.""" - try: - log_dir = os.environ.get("LOGDIR", ".") - name_conversions = get_name_conversions(json_data.get("param_mappings", {})) - - weight_dtypes = {} - for name in state_dict.keys(): - tt_names = name_conversions.get(name, [name]) - for tt_name in tt_names: - weight_dtypes[tt_name] = str(state_dict[name].dtype).split(".")[-1] - - with open(f"{log_dir}/vllm_dtypes.json", "w") as f: - json.dump(weight_dtypes, f, indent=2) - - print("[Updater] Wrote dtype mapping", flush=True) - except Exception as e: - print(f"[Updater] Failed to write dtype mapping: {e}", flush=True) - - -def _merge_tensors( - tensor_list: List[torch.Tensor], - json_data: Dict[str, Any], - param_info: Dict[str, Any], - target_tensor: torch.Tensor, -) -> torch.Tensor: - """ - Merge tensors from distributed training into single tensor. - - Handles FSDP (data parallel) and TP (tensor parallel) sharding. - """ - dp_shard_degree = json_data.get("dp_shard_degree", 1) - tp_degree = json_data.get("tp_degree", 1) - tp_shard_dim = param_info.get("tp_shard_dim", 0) - - if dp_shard_degree > 1: - # First merge across data parallel dimension - tp_tensors = [] - for i in range(tp_degree): - dp_tensors = tensor_list[i::tp_degree] - tp_tensors.append(torch.cat(dp_tensors, dim=0)) - - # Then merge across tensor parallel dimension if needed - if tp_degree > 1: - if tp_tensors[0].shape == target_tensor.shape: - tensor = tp_tensors[0].contiguous() - else: - tensor = torch.cat(tp_tensors, dim=tp_shard_dim).contiguous() - else: - tensor = tp_tensors[0].contiguous() - else: - # No FSDP, just merge TP shards - tensor = torch.cat(tensor_list, dim=tp_shard_dim).contiguous() - - # Cast to target dtype if needed - if tensor.dtype != target_tensor.dtype: - tensor = tensor.to(target_tensor.dtype) - - return tensor - - -def _apply_weight_update( - state_dict: Dict[str, torch.Tensor], - vllm_name: str, - tt_name: str, - tensor: torch.Tensor, - param_info: Dict[str, Any], - num_q_heads: int, - num_kv_heads: int, - qkv_buffer: Dict[str, torch.Tensor], - gate_up_buffer: Dict[str, torch.Tensor], - qkv_bias_buffer: Dict[str, torch.Tensor], - w1w3_buffer: Dict[str, torch.Tensor], - debug: bool, -) -> None: - """ - Apply weight update to state_dict, handling merged projections. - - vLLM often merges QKV projections and gate/up projections into single - tensors for efficiency. This handles unpacking and merging correctly. - """ - needs_permute = param_info.get("needs_permute", False) - shape = param_info.get("shape", list(tensor.shape)) - - def _debug_diff(name: str, old: torch.Tensor, new: torch.Tensor) -> None: - if debug: - diff = (new.float() - old.float()).abs() - print( - f"[WEIGHT DIFF] {name}: mean={diff.mean().item():.6e}, " - f"std={diff.std().item():.6e}", - flush=True, - ) - - # Handle merged QKV projection weights - if "qkv_proj.weight" in vllm_name: - key_val = "q" if ".wq." in tt_name or "q_proj" in tt_name else \ - "v" if ".wv." in tt_name or "v_proj" in tt_name else "k" - - if key_val == "q" and needs_permute: - tensor = permute(tensor, num_q_heads) - elif key_val == "k" and needs_permute: - tensor = permute(tensor, num_kv_heads) - - qkv_buffer[key_val] = tensor - - if len(qkv_buffer) == 3: - merged = torch.cat([qkv_buffer["q"], qkv_buffer["k"], qkv_buffer["v"]], dim=0) - _debug_diff(vllm_name, state_dict[vllm_name].data, merged) - state_dict[vllm_name].data.copy_(merged.contiguous()) - qkv_buffer.clear() - - # Handle merged gate/up projection weights - elif "gate_up_proj.weight" in vllm_name: - key_val = "w1" if ".w1." in tt_name or "gate_proj" in tt_name else "w3" - gate_up_buffer[key_val] = tensor - - if len(gate_up_buffer) == 2: - merged = torch.cat([gate_up_buffer["w1"], gate_up_buffer["w3"]], dim=0) - _debug_diff(vllm_name, state_dict[vllm_name].data, merged) - state_dict[vllm_name].data.copy_(merged.contiguous()) - gate_up_buffer.clear() - - # Handle merged w1/w3 weights (alternative naming) - elif "w13_weight" in vllm_name: - key_val = "w1" if ".w1" in tt_name else "w3" - w1w3_buffer[key_val] = tensor - - if len(w1w3_buffer) == 2: - merged = torch.cat([w1w3_buffer["w1"], w1w3_buffer["w3"]], dim=1) - _debug_diff(vllm_name, state_dict[vllm_name].data, merged) - state_dict[vllm_name].data.copy_(merged.contiguous()) - w1w3_buffer.clear() - - # Handle merged QKV bias - elif "qkv_proj.bias" in vllm_name: - key_val = "q" if ".wq." in tt_name else "v" if ".wv." in tt_name else "k" - - if key_val == "q" and needs_permute: - tensor = permute_1d(tensor, num_q_heads) - elif key_val == "k" and needs_permute: - tensor = permute_1d(tensor, num_kv_heads) - - qkv_bias_buffer[key_val] = tensor - - if len(qkv_bias_buffer) == 3: - merged = torch.cat([qkv_bias_buffer["q"], qkv_bias_buffer["k"], qkv_bias_buffer["v"]], dim=0) - _debug_diff(vllm_name, state_dict[vllm_name].data, merged) - state_dict[vllm_name].data.copy_(merged.contiguous()) - qkv_bias_buffer.clear() - - # Handle regular weights (possibly needing permutation) - elif needs_permute: - if len(shape) == 2: - tensor = permute(tensor, shape[0]).contiguous() - elif len(shape) == 1: - tensor = permute_1d(tensor, shape[0]).contiguous() - - _debug_diff(vllm_name, state_dict[vllm_name].data, tensor) - state_dict[vllm_name].data.copy_(tensor) - - # Simple weight copy - else: - _debug_diff(vllm_name, state_dict[vllm_name].data, tensor) - state_dict[vllm_name].data.copy_(tensor) - - +# Note: Advanced multi-GPU tensor parallelism support removed for simplicity. +# For single-node mode, we use direct tensor broadcast which is sufficient. diff --git a/example_trainer/vllm_weight_bridge.py b/example_trainer/vllm_weight_bridge.py index e8c494a6..fd3d66eb 100644 --- a/example_trainer/vllm_weight_bridge.py +++ b/example_trainer/vllm_weight_bridge.py @@ -358,26 +358,45 @@ class VLLMWeightBridge: log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".") json_path = Path(log_dir) / "vllm_bridge_config.json" - # Wait for file + # Wait for file (vLLM needs time to load model and export params) wait_time = 0 - while not json_path.exists() and wait_time < self.config.timeout_seconds: + max_wait = min(self.config.timeout_seconds, 120) # Max 2 minutes + while not json_path.exists() and wait_time < max_wait: if wait_time % 10 == 0: - print(f"[Bridge] Waiting for {json_path}...") + print(f"[Bridge] Waiting for {json_path}... ({wait_time}s)") time.sleep(1) wait_time += 1 if not json_path.exists(): - raise RuntimeError(f"Config file not found: {json_path}") + print(f"[Bridge] Warning: Config file not found after {wait_time}s") + print("[Bridge] Will use trainer's model params directly") + self.param_mappings = {} + self.param_name_list = [] + return - time.sleep(0.5) # Wait for file to finish writing + time.sleep(1.0) # Wait for file to finish writing - with open(json_path, "r") as f: - data = json.load(f) + try: + with open(json_path, "r") as f: + data = json.load(f) + + self.param_mappings = data.get("param_mappings", {}) + self.param_name_list = data.get("param_names", sorted(self.param_mappings.keys())) + + print(f"[Bridge] Loaded {len(self.param_name_list)} vLLM parameter names") + except Exception as e: + print(f"[Bridge] Warning: Failed to load config: {e}") + self.param_mappings = {} + self.param_name_list = [] + + def set_param_list_from_model(self, model: nn.Module) -> None: + """ + Set param list from the trainer's model. - self.param_mappings = data.get("param_mappings", {}) - self.param_name_list = sorted(self.param_mappings.keys()) - - print(f"[Bridge] Loaded mappings for {len(self.param_name_list)} parameters") + Call this if vLLM's param names don't match the trainer's. + """ + self.param_name_list = sorted(name for name, _ in model.named_parameters()) + print(f"[Bridge] Using trainer's {len(self.param_name_list)} parameter names") def broadcast_weights(self, model: nn.Module) -> None: """ @@ -400,31 +419,27 @@ class VLLMWeightBridge: start_time = time.time() state_dict = dict(model.named_parameters()) + num_params = 0 with torch.no_grad(): for idx, param_name in enumerate(self.param_name_list): - # Signal which parameter we're broadcasting - idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device) - dist.broadcast(idx_tensor, src=0, group=self.nccl_group) - # Get tensor for this parameter if param_name not in state_dict: continue tensor = state_dict[param_name].data - local_shape = self.param_mappings[param_name].get( - "local_shape", list(tensor.shape) - ) - # All-gather to distribute to all ranks (including inference) - tensor_list = [ - torch.zeros(local_shape, dtype=tensor.dtype, device=self.device) - for _ in range(self._total_group_size) - ] - dist.all_gather(tensor_list, tensor, group=self.nccl_group) + # Step 1: Broadcast parameter index + idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device) + dist.broadcast(idx_tensor, src=0, group=self.nccl_group) + + # Step 2: Broadcast the actual tensor + dist.broadcast(tensor.contiguous(), src=0, group=self.nccl_group) + + num_params += 1 elapsed = time.time() - start_time - print(f"[Bridge] Broadcast update #{self._update_count} ({elapsed:.2f}s)") + print(f"[Bridge] Broadcast {num_params} params, update #{self._update_count} ({elapsed:.2f}s)") def broadcast_single_param( self,