diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index cca23b33..b010e551 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -454,21 +454,46 @@ def _attach_to_vllm_shared_tensors( print("[Setup] Single-copy mode not available (no IPC handles exported)") return None - ipc_handles = bridge_config.get("ipc_handles", {}) - if not ipc_handles: + ipc_handles_raw = bridge_config.get("ipc_handles", {}) + if not ipc_handles_raw: print("[Setup] No IPC handles found in bridge config") return None + # Deserialize base64-encoded bytes back to bytes + import base64 + + def deserialize_ipc_handles(handles): + result = {} + for k, v in handles.items(): + if isinstance(v, dict): + if "_bytes_b64_" in v: + result[k] = base64.b64decode(v["_bytes_b64_"]) + else: + result[k] = deserialize_ipc_handles(v) + else: + result[k] = v + return result + + ipc_handles = deserialize_ipc_handles(ipc_handles_raw) + print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...") print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!") - # Create model architecture (meta device - no memory allocation) + # Load model config (not weights) to get architecture + from transformers import AutoConfig + model_config = AutoConfig.from_pretrained(config.model_name) + + # Create empty model on meta device (no memory allocation) with torch.device('meta'): - model = AutoModelForCausalLM.from_pretrained( - config.model_name, + model = AutoModelForCausalLM.from_config( + model_config, torch_dtype=torch.bfloat16, ) + # Get parameter names from the empty model + param_names = list(model.state_dict().keys()) + print(f"[Setup] Model architecture has {len(param_names)} parameters") + # Map vLLM tensor names to HuggingFace model parameter names hf_state_dict = {} vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles) @@ -482,7 +507,16 @@ def _attach_to_vllm_shared_tensors( try: # Reconstruct tensor from IPC handle - handle_bytes = bytes.fromhex(ipc_info["handle"]) + # Handle can be bytes (deserialized from base64) or hex string + handle = ipc_info["handle"] + if isinstance(handle, bytes): + handle_bytes = handle + elif isinstance(handle, str): + handle_bytes = bytes.fromhex(handle) + else: + print(f"[Setup] Unknown handle type for {hf_name}: {type(handle)}") + continue + storage_size = ipc_info["storage_size"] device_index = ipc_info["device_index"]