[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-06 06:46:14 +00:00 committed by Jai Suphavadeeprasit
parent d07ab3e3ce
commit 5cfd1929f1
19 changed files with 708 additions and 452 deletions

View file

@ -20,11 +20,11 @@ import torch
def _ensure_contiguous_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
"""
Create a state dict with contiguous tensors for safe saving.
This is critical for shared_vllm mode where parameters are views into
vLLM's fused tensors. Views may share storage and not be contiguous,
which can cause issues when saving.
Returns:
State dict with all tensors made contiguous (copied if necessary)
"""
@ -36,14 +36,14 @@ def _ensure_contiguous_state_dict(model: torch.nn.Module) -> Dict[str, torch.Ten
state_dict[name] = param.detach().clone().contiguous()
else:
state_dict[name] = param.detach()
# Also include buffers
for name, buffer in model.named_buffers():
if not buffer.is_contiguous() or buffer.storage_offset() != 0:
state_dict[name] = buffer.detach().clone().contiguous()
else:
state_dict[name] = buffer.detach()
return state_dict
@ -86,28 +86,32 @@ def save_checkpoint(
# For shared_vllm mode: ensure views are properly unfused
print(" [Checkpoint] Using safe mode - ensuring contiguous tensors...")
state_dict = _ensure_contiguous_state_dict(model)
# Count how many were non-contiguous (views into fused tensors)
view_count = sum(
1 for name, param in model.named_parameters()
1
for name, param in model.named_parameters()
if not param.is_contiguous() or param.storage_offset() != 0
)
if view_count > 0:
print(f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)")
print(
f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)"
)
# Save state dict manually, then save config separately
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
model.config.save_pretrained(checkpoint_path)
# CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory!
del state_dict
import gc
gc.collect()
torch.cuda.empty_cache()
else:
# Standard save (may have issues with view tensors)
model.save_pretrained(checkpoint_path)
tokenizer.save_pretrained(checkpoint_path)
print(" Checkpoint saved.")
@ -151,4 +155,3 @@ def save_lora_checkpoint(
print(" Adapter saved.")
return adapter_path