diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index f9e874ac..3f1b07b0 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -484,38 +484,34 @@ def _attach_to_vllm_shared_tensors( ipc_handles = deserialize_ipc_handles(ipc_handles_raw) print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...") - print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!") + print("[Setup] TRUE SINGLE-COPY MODE - sharing vLLM's GPU memory!") - # Load model config (not weights) to get architecture - from transformers import AutoConfig - model_config = AutoConfig.from_pretrained(config.model_name) - - # Create empty model on meta device (no memory allocation) - with torch.device('meta'): - model = AutoModelForCausalLM.from_config( - model_config, - torch_dtype=torch.bfloat16, - ) - - # Get parameter names from the empty model - param_names = list(model.state_dict().keys()) - print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True) - - # Initialize CUDA before IPC operations - # Get the device indices we'll be using + # Get device from IPC handles device_indices = set() for name, info in ipc_handles.items(): if "device_index" in info: device_indices.add(info["device_index"]) - print(f"[Setup] IPC handles span devices: {sorted(device_indices)}", flush=True) + device = f"cuda:{list(device_indices)[0]}" if device_indices else "cuda:0" + print(f"[Setup] Target device: {device}", flush=True) - # Initialize CUDA context on each device - for dev_idx in sorted(device_indices): - print(f"[Setup] Initializing CUDA on device {dev_idx}...", flush=True) - torch.cuda.set_device(dev_idx) - torch.cuda.synchronize(dev_idx) - print(f"[Setup] ✓ Device {dev_idx} ready", flush=True) + # Initialize CUDA context + torch.cuda.set_device(int(device.split(':')[1])) + torch.cuda.synchronize() + print(f"[Setup] ✓ CUDA initialized", flush=True) + + # APPROACH: Load model skeleton on CPU, then swap in IPC tensors + # This ensures buffers (like rotary_emb) are properly initialized + print(f"[Setup] Loading model structure...", flush=True) + model = AutoModelForCausalLM.from_pretrained( + config.model_name, + torch_dtype=torch.bfloat16, + device_map=device, # Load directly to GPU + low_cpu_mem_usage=True, + ) + + param_names = list(model.state_dict().keys()) + print(f"[Setup] Model has {len(param_names)} parameters", flush=True) # Map vLLM tensor names to HuggingFace model parameter names hf_state_dict = {} @@ -603,13 +599,26 @@ def _attach_to_vllm_shared_tensors( continue if attached_count == 0: - print("[Setup] Could not attach any tensors, falling back to regular loading") - return None + print("[Setup] Could not attach any tensors - IPC failed") + print("[Setup] Model loaded normally (not sharing memory with vLLM)") + return model # Return the normally-loaded model - print(f"[Setup] ✓ Attached {attached_count} tensors to vLLM's shared memory") + print(f"[Setup] ✓ Swapped {attached_count} tensors to share vLLM's memory") - # Load state dict into model - model.load_state_dict(hf_state_dict, strict=False, assign=True) + # Now swap the model's parameters with the IPC tensors + # This makes the model use vLLM's memory directly + swap_count = 0 + for name, param in model.named_parameters(): + if name in hf_state_dict: + ipc_tensor = hf_state_dict[name] + # Verify shapes match + if param.shape == ipc_tensor.shape: + param.data = ipc_tensor + swap_count += 1 + else: + print(f"[Setup] Shape mismatch for {name}: model={param.shape}, ipc={ipc_tensor.shape}") + + print(f"[Setup] ✓ {swap_count} parameters now share vLLM's GPU memory!") return model