diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 0bce6d39..f9e874ac 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -499,7 +499,23 @@ def _attach_to_vllm_shared_tensors( # Get parameter names from the empty model param_names = list(model.state_dict().keys()) - print(f"[Setup] Model architecture has {len(param_names)} parameters") + print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True) + + # Initialize CUDA before IPC operations + # Get the device indices we'll be using + device_indices = set() + for name, info in ipc_handles.items(): + if "device_index" in info: + device_indices.add(info["device_index"]) + + print(f"[Setup] IPC handles span devices: {sorted(device_indices)}", flush=True) + + # Initialize CUDA context on each device + for dev_idx in sorted(device_indices): + print(f"[Setup] Initializing CUDA on device {dev_idx}...", flush=True) + torch.cuda.set_device(dev_idx) + torch.cuda.synchronize(dev_idx) + print(f"[Setup] ✓ Device {dev_idx} ready", flush=True) # Map vLLM tensor names to HuggingFace model parameter names hf_state_dict = {} @@ -519,6 +535,13 @@ def _attach_to_vllm_shared_tensors( print(f"[Setup] Missing ipc_handle_b64 for {hf_name}") continue + # DEBUG: Only try first tensor to see if IPC works at all + if attached_count == 0: + print(f"[Setup DEBUG] Attempting first tensor: {hf_name}", flush=True) + print(f"[Setup DEBUG] device_index: {ipc_info['device_index']}", flush=True) + print(f"[Setup DEBUG] storage_size: {ipc_info['storage_size']}", flush=True) + print(f"[Setup DEBUG] shape: {ipc_info['shape']}", flush=True) + # Decode all the bytes fields from base64 device_index = ipc_info["device_index"] ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"]) @@ -529,6 +552,10 @@ def _attach_to_vllm_shared_tensors( event_handle = base64.b64decode(ipc_info["event_handle_b64"]) event_sync_required = ipc_info["event_sync_required"] + if attached_count == 0: + print(f"[Setup DEBUG] Decoded IPC handle, len={len(ipc_handle)}", flush=True) + print(f"[Setup DEBUG] About to call _new_shared_cuda...", flush=True) + # Reconstruct the 8-tuple that _new_shared_cuda expects share_tuple = ( device_index, @@ -544,6 +571,9 @@ def _attach_to_vllm_shared_tensors( # Create storage from IPC handle (needs all 8 items) storage = torch.UntypedStorage._new_shared_cuda(*share_tuple) + if attached_count == 0: + print(f"[Setup DEBUG] Storage created! size={storage.size()}", flush=True) + # Reconstruct tensor dtype = getattr(torch, ipc_info["dtype"].replace("torch.", "")) tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}") @@ -554,14 +584,22 @@ def _attach_to_vllm_shared_tensors( stride=ipc_info["stride"], ) + if attached_count == 0: + print(f"[Setup DEBUG] Tensor set! shape={tensor.shape}", flush=True) + # Make tensor require gradients for training tensor.requires_grad_(True) hf_state_dict[hf_name] = tensor attached_count += 1 + if attached_count == 1: + print(f"[Setup DEBUG] ✓ First tensor attached successfully!", flush=True) + except Exception as e: - print(f"[Setup] Failed to attach {hf_name}: {e}") + print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True) + import traceback + traceback.print_exc() continue if attached_count == 0: