logprob wandb

This commit is contained in:
Jai Suphavadeeprasit 2026-01-27 13:25:37 -05:00
parent 210726c3d9
commit 947ab19a8e
2 changed files with 74 additions and 6 deletions

View file

@ -4,20 +4,56 @@ Checkpoint saving utilities for GRPO trainer.
Handles saving model checkpoints for different training modes:
- Full model checkpoints (legacy and shared_vllm modes)
- LoRA adapter checkpoints
IMPORTANT: For shared_vllm mode, the model parameters are VIEWS into vLLM's
fused tensors (qkv_proj, gate_up_proj). This module handles unfusing them
back to HuggingFace format for safe checkpoint saving.
"""
import os
import shutil
from typing import Dict
import torch
def _ensure_contiguous_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
"""
Create a state dict with contiguous tensors for safe saving.
This is critical for shared_vllm mode where parameters are views into
vLLM's fused tensors. Views may share storage and not be contiguous,
which can cause issues when saving.
Returns:
State dict with all tensors made contiguous (copied if necessary)
"""
state_dict = {}
for name, param in model.named_parameters():
# Check if tensor is a view (non-contiguous or shares storage)
if not param.is_contiguous() or param.storage_offset() != 0:
# Make a contiguous copy - this "unfuses" the view
state_dict[name] = param.detach().clone().contiguous()
else:
state_dict[name] = param.detach()
# Also include buffers
for name, buffer in model.named_buffers():
if not buffer.is_contiguous() or buffer.storage_offset() != 0:
state_dict[name] = buffer.detach().clone().contiguous()
else:
state_dict[name] = buffer.detach()
return state_dict
def save_checkpoint(
model: torch.nn.Module,
tokenizer,
save_path: str,
step: int,
is_final: bool = False,
safe_mode: bool = True,
) -> str:
"""
Save full model checkpoint.
@ -28,6 +64,9 @@ def save_checkpoint(
save_path: Base directory for checkpoints
step: Current training step
is_final: Whether this is the final checkpoint
safe_mode: If True, ensure all tensors are contiguous before saving.
This is important for shared_vllm mode where params are
views into fused vLLM tensors.
Returns:
Path where checkpoint was saved
@ -43,7 +82,26 @@ def save_checkpoint(
shutil.rmtree(checkpoint_path)
os.makedirs(checkpoint_path, exist_ok=True)
model.save_pretrained(checkpoint_path)
if safe_mode:
# For shared_vllm mode: ensure views are properly unfused
print(" [Checkpoint] Using safe mode - ensuring contiguous tensors...")
state_dict = _ensure_contiguous_state_dict(model)
# Count how many were non-contiguous (views into fused tensors)
view_count = sum(
1 for name, param in model.named_parameters()
if not param.is_contiguous() or param.storage_offset() != 0
)
if view_count > 0:
print(f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)")
# Save state dict manually, then save config separately
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
model.config.save_pretrained(checkpoint_path)
else:
# Standard save (may have issues with view tensors)
model.save_pretrained(checkpoint_path)
tokenizer.save_pretrained(checkpoint_path)
print(" Checkpoint saved.")