diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 9dc6b23f..fea1a666 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -518,87 +518,119 @@ def _attach_to_vllm_shared_tensors( hf_state_dict = {} vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles) + # Helper to create tensor from IPC handle + def _create_ipc_tensor(ipc_info): + device_index = ipc_info["device_index"] + ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"]) + storage_size = ipc_info["storage_size"] + storage_offset_orig = ipc_info["storage_offset_orig"] + ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"]) + ref_counter_offset = ipc_info["ref_counter_offset"] + event_handle = base64.b64decode(ipc_info["event_handle_b64"]) + event_sync_required = ipc_info["event_sync_required"] + + share_tuple = (device_index, ipc_handle, storage_size, storage_offset_orig, + ref_counter_handle, ref_counter_offset, event_handle, event_sync_required) + + storage = torch.UntypedStorage._new_shared_cuda(*share_tuple) + dtype = getattr(torch, ipc_info["dtype"].replace("torch.", "")) + tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}") + tensor.set_(storage, storage_offset=ipc_info["tensor_storage_offset"], + size=ipc_info["shape"], stride=ipc_info["stride"]) + return tensor + + # Cache for fused tensors (avoid recreating from IPC multiple times) + fused_tensor_cache = {} + + # Extract fused mappings + fused_mappings = vllm_to_hf_mapping.pop('_fused_', {}) + attached_count = 0 + + # First, handle direct mappings for hf_name, vllm_name in vllm_to_hf_mapping.items(): if vllm_name not in ipc_handles: continue - - ipc_info = ipc_handles[vllm_name] try: - # Reconstruct tensor from IPC handle - # We need all 8 items from the original _share_cuda_() call + ipc_info = ipc_handles[vllm_name] if "ipc_handle_b64" not in ipc_info: - print(f"[Setup] Missing ipc_handle_b64 for {hf_name}") continue - # DEBUG: Only try first tensor to see if IPC works at all - if attached_count == 0: - print(f"[Setup DEBUG] Attempting first tensor: {hf_name}", flush=True) - print(f"[Setup DEBUG] device_index: {ipc_info['device_index']}", flush=True) - print(f"[Setup DEBUG] storage_size: {ipc_info['storage_size']}", flush=True) - print(f"[Setup DEBUG] shape: {ipc_info['shape']}", flush=True) - - # Decode all the bytes fields from base64 - device_index = ipc_info["device_index"] - ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"]) - storage_size = ipc_info["storage_size"] - storage_offset_orig = ipc_info["storage_offset_orig"] - ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"]) - ref_counter_offset = ipc_info["ref_counter_offset"] - event_handle = base64.b64decode(ipc_info["event_handle_b64"]) - event_sync_required = ipc_info["event_sync_required"] - - if attached_count == 0: - print(f"[Setup DEBUG] Decoded IPC handle, len={len(ipc_handle)}", flush=True) - print(f"[Setup DEBUG] About to call _new_shared_cuda...", flush=True) - - # Reconstruct the 8-tuple that _new_shared_cuda expects - share_tuple = ( - device_index, - ipc_handle, - storage_size, - storage_offset_orig, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required, - ) - - # Create storage from IPC handle (needs all 8 items) - storage = torch.UntypedStorage._new_shared_cuda(*share_tuple) - - if attached_count == 0: - print(f"[Setup DEBUG] Storage created! size={storage.size()}", flush=True) - - # Reconstruct tensor - dtype = getattr(torch, ipc_info["dtype"].replace("torch.", "")) - tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}") - tensor.set_( - storage, - storage_offset=ipc_info["tensor_storage_offset"], - size=ipc_info["shape"], - stride=ipc_info["stride"], - ) - - if attached_count == 0: - print(f"[Setup DEBUG] Tensor set! shape={tensor.shape}", flush=True) - - # Make tensor require gradients for training + tensor = _create_ipc_tensor(ipc_info) tensor.requires_grad_(True) - hf_state_dict[hf_name] = tensor attached_count += 1 if attached_count == 1: - print(f"[Setup DEBUG] ✓ First tensor attached successfully!", flush=True) + print(f"[Setup DEBUG] ✓ First tensor attached: {hf_name}", flush=True) except Exception as e: print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True) - import traceback - traceback.print_exc() continue + print(f"[Setup] Direct attachments: {attached_count}") + + # Now handle fused mappings (QKV, gate_up) + fused_count = 0 + for hf_name, (vllm_name, fuse_type, slice_idx) in fused_mappings.items(): + if vllm_name not in ipc_handles: + continue + + try: + # Get or create the fused tensor + if vllm_name not in fused_tensor_cache: + ipc_info = ipc_handles[vllm_name] + if "ipc_handle_b64" not in ipc_info: + continue + fused_tensor_cache[vllm_name] = _create_ipc_tensor(ipc_info) + + fused_tensor = fused_tensor_cache[vllm_name] + + # Slice the fused tensor + # vLLM fuses along the first dimension (output features) + if fuse_type == 'qkv': + # QKV is fused: [hidden, (q_size + k_size + v_size)] + # For Qwen: q_size = num_heads * head_dim, k_size = v_size = num_kv_heads * head_dim + total_size = fused_tensor.shape[0] + # Assume equal splits for Q, K, V (common case) + # Actually for GQA: Q is larger, K=V are smaller + # We'll get sizes from the HF model's expected shapes + hf_param = dict(model.named_parameters())[hf_name] + expected_size = hf_param.shape[0] + + if slice_idx == 0: # Q + sliced = fused_tensor[:expected_size] + elif slice_idx == 1: # K + # K starts after Q + q_size = dict(model.named_parameters())[hf_name.replace('.k_proj.', '.q_proj.')].shape[0] + sliced = fused_tensor[q_size:q_size + expected_size] + else: # V + q_size = dict(model.named_parameters())[hf_name.replace('.v_proj.', '.q_proj.')].shape[0] + k_size = dict(model.named_parameters())[hf_name.replace('.v_proj.', '.k_proj.')].shape[0] + sliced = fused_tensor[q_size + k_size:q_size + k_size + expected_size] + + elif fuse_type == 'gate_up': + # gate_up is fused: [hidden, gate_size + up_size] + total_size = fused_tensor.shape[0] + half_size = total_size // 2 + if slice_idx == 0: # gate + sliced = fused_tensor[:half_size] + else: # up + sliced = fused_tensor[half_size:] + + # The slice shares memory with the original! + sliced.requires_grad_(True) + hf_state_dict[hf_name] = sliced + fused_count += 1 + + except Exception as e: + print(f"[Setup] Failed to slice {hf_name} from {vllm_name}: {e}", flush=True) + continue + + print(f"[Setup] Fused slices: {fused_count}") + attached_count += fused_count + if attached_count == 0: print("[Setup] Could not attach any tensors - IPC failed") print("[Setup] Model loaded normally (not sharing memory with vLLM)") @@ -656,21 +688,18 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic """ Create mapping from HuggingFace parameter names to vLLM tensor names. - vLLM uses slightly different naming conventions than HuggingFace. - This function creates the bidirectional mapping. + vLLM uses fused layers that need special handling: + - qkv_proj -> q_proj, k_proj, v_proj (need slicing) + - gate_up_proj -> gate_proj, up_proj (need slicing) """ hf_params = set(model.state_dict().keys()) vllm_params = set(ipc_handles.keys()) print(f"[Setup] HF model has {len(hf_params)} params, vLLM exported {len(vllm_params)} tensors") - # Debug: show sample names - hf_sample = sorted(list(hf_params))[:3] - vllm_sample = sorted(list(vllm_params))[:3] - print(f"[Setup] Sample HF names: {hf_sample}") - print(f"[Setup] Sample vLLM names: {vllm_sample}") - mapping = {} + # Track fused tensors that need slicing + fused_mappings = {} # hf_name -> (vllm_name, slice_type, slice_index) for hf_name in hf_params: # Try direct match first @@ -678,18 +707,35 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic mapping[hf_name] = hf_name continue - # Try common transformations - # vLLM often uses 'model.' prefix - vllm_name = f"model.{hf_name}" if not hf_name.startswith("model.") else hf_name - if vllm_name in vllm_params: - mapping[hf_name] = vllm_name - continue - - # Remove 'model.' prefix if present - if hf_name.startswith("model."): - vllm_name = hf_name[6:] + # Check for fused QKV projection + # HF: model.layers.X.self_attn.q_proj.weight -> vLLM: model.layers.X.self_attn.qkv_proj.weight + if '.self_attn.q_proj.' in hf_name: + vllm_name = hf_name.replace('.q_proj.', '.qkv_proj.') if vllm_name in vllm_params: - mapping[hf_name] = vllm_name + fused_mappings[hf_name] = (vllm_name, 'qkv', 0) # Q is first + continue + if '.self_attn.k_proj.' in hf_name: + vllm_name = hf_name.replace('.k_proj.', '.qkv_proj.') + if vllm_name in vllm_params: + fused_mappings[hf_name] = (vllm_name, 'qkv', 1) # K is second + continue + if '.self_attn.v_proj.' in hf_name: + vllm_name = hf_name.replace('.v_proj.', '.qkv_proj.') + if vllm_name in vllm_params: + fused_mappings[hf_name] = (vllm_name, 'qkv', 2) # V is third + continue + + # Check for fused gate_up projection + # HF: model.layers.X.mlp.gate_proj.weight -> vLLM: model.layers.X.mlp.gate_up_proj.weight + if '.mlp.gate_proj.' in hf_name: + vllm_name = hf_name.replace('.gate_proj.', '.gate_up_proj.') + if vllm_name in vllm_params: + fused_mappings[hf_name] = (vllm_name, 'gate_up', 0) # gate is first + continue + if '.mlp.up_proj.' in hf_name: + vllm_name = hf_name.replace('.up_proj.', '.gate_up_proj.') + if vllm_name in vllm_params: + fused_mappings[hf_name] = (vllm_name, 'gate_up', 1) # up is second continue # For lm_head, check if it's tied to embed_tokens @@ -697,13 +743,12 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic mapping[hf_name] = "model.embed_tokens.weight" continue - print(f"[Setup] Mapped {len(mapping)} HF params to vLLM tensors") + total_mapped = len(mapping) + len(fused_mappings) + print(f"[Setup] Direct mappings: {len(mapping)}, Fused mappings: {len(fused_mappings)}") + print(f"[Setup] Total mapped: {total_mapped} / {len(hf_params)}") - # Show what's NOT mapped - unmapped = hf_params - set(mapping.keys()) - if unmapped: - unmapped_sample = sorted(list(unmapped))[:5] - print(f"[Setup] Unmapped HF params ({len(unmapped)} total): {unmapped_sample}...") + # Store fused mappings for later slicing + mapping['_fused_'] = fused_mappings return mapping