diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index fe7119d5..9dc6b23f 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -631,10 +631,23 @@ def _attach_to_vllm_shared_tensors( if buffer_count > 0: print(f"[Setup] ✓ Moved {buffer_count} buffers to {device}") - # Also move any remaining CPU parameters (shouldn't be many) + # Check for unmatched parameters still on CPU + # These indicate a name mapping issue - don't try to move large ones + unmatched_cpu = [] for name, param in model.named_parameters(): if param.device.type == 'cpu': - param.data = param.to(device) + unmatched_cpu.append((name, param.numel())) + + if unmatched_cpu: + total_unmatched = sum(n for _, n in unmatched_cpu) + print(f"[Setup] WARNING: {len(unmatched_cpu)} parameters ({total_unmatched:,} elements) not matched!") + print(f"[Setup] First few unmatched: {[n for n, _ in unmatched_cpu[:5]]}") + print(f"[Setup] This may cause issues - check name mapping between vLLM and HuggingFace") + + # Only move small parameters (< 1M elements) - skip large unmatched ones + for name, param in model.named_parameters(): + if param.device.type == 'cpu' and param.numel() < 1_000_000: + param.data = param.to(device) return model @@ -649,6 +662,14 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic hf_params = set(model.state_dict().keys()) vllm_params = set(ipc_handles.keys()) + print(f"[Setup] HF model has {len(hf_params)} params, vLLM exported {len(vllm_params)} tensors") + + # Debug: show sample names + hf_sample = sorted(list(hf_params))[:3] + vllm_sample = sorted(list(vllm_params))[:3] + print(f"[Setup] Sample HF names: {hf_sample}") + print(f"[Setup] Sample vLLM names: {vllm_sample}") + mapping = {} for hf_name in hf_params: @@ -669,6 +690,20 @@ def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dic vllm_name = hf_name[6:] if vllm_name in vllm_params: mapping[hf_name] = vllm_name + continue + + # For lm_head, check if it's tied to embed_tokens + if hf_name == "lm_head.weight" and "model.embed_tokens.weight" in vllm_params: + mapping[hf_name] = "model.embed_tokens.weight" + continue + + print(f"[Setup] Mapped {len(mapping)} HF params to vLLM tensors") + + # Show what's NOT mapped + unmapped = hf_params - set(mapping.keys()) + if unmapped: + unmapped_sample = sorted(list(unmapped))[:5] + print(f"[Setup] Unmapped HF params ({len(unmapped)} total): {unmapped_sample}...") return mapping