mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
logprob wandb
This commit is contained in:
parent
210726c3d9
commit
947ab19a8e
2 changed files with 74 additions and 6 deletions
|
|
@ -4,20 +4,56 @@ Checkpoint saving utilities for GRPO trainer.
|
|||
Handles saving model checkpoints for different training modes:
|
||||
- Full model checkpoints (legacy and shared_vllm modes)
|
||||
- LoRA adapter checkpoints
|
||||
|
||||
IMPORTANT: For shared_vllm mode, the model parameters are VIEWS into vLLM's
|
||||
fused tensors (qkv_proj, gate_up_proj). This module handles unfusing them
|
||||
back to HuggingFace format for safe checkpoint saving.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _ensure_contiguous_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Create a state dict with contiguous tensors for safe saving.
|
||||
|
||||
This is critical for shared_vllm mode where parameters are views into
|
||||
vLLM's fused tensors. Views may share storage and not be contiguous,
|
||||
which can cause issues when saving.
|
||||
|
||||
Returns:
|
||||
State dict with all tensors made contiguous (copied if necessary)
|
||||
"""
|
||||
state_dict = {}
|
||||
for name, param in model.named_parameters():
|
||||
# Check if tensor is a view (non-contiguous or shares storage)
|
||||
if not param.is_contiguous() or param.storage_offset() != 0:
|
||||
# Make a contiguous copy - this "unfuses" the view
|
||||
state_dict[name] = param.detach().clone().contiguous()
|
||||
else:
|
||||
state_dict[name] = param.detach()
|
||||
|
||||
# Also include buffers
|
||||
for name, buffer in model.named_buffers():
|
||||
if not buffer.is_contiguous() or buffer.storage_offset() != 0:
|
||||
state_dict[name] = buffer.detach().clone().contiguous()
|
||||
else:
|
||||
state_dict[name] = buffer.detach()
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
tokenizer,
|
||||
save_path: str,
|
||||
step: int,
|
||||
is_final: bool = False,
|
||||
safe_mode: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Save full model checkpoint.
|
||||
|
|
@ -28,6 +64,9 @@ def save_checkpoint(
|
|||
save_path: Base directory for checkpoints
|
||||
step: Current training step
|
||||
is_final: Whether this is the final checkpoint
|
||||
safe_mode: If True, ensure all tensors are contiguous before saving.
|
||||
This is important for shared_vllm mode where params are
|
||||
views into fused vLLM tensors.
|
||||
|
||||
Returns:
|
||||
Path where checkpoint was saved
|
||||
|
|
@ -43,7 +82,26 @@ def save_checkpoint(
|
|||
shutil.rmtree(checkpoint_path)
|
||||
os.makedirs(checkpoint_path, exist_ok=True)
|
||||
|
||||
model.save_pretrained(checkpoint_path)
|
||||
if safe_mode:
|
||||
# For shared_vllm mode: ensure views are properly unfused
|
||||
print(" [Checkpoint] Using safe mode - ensuring contiguous tensors...")
|
||||
state_dict = _ensure_contiguous_state_dict(model)
|
||||
|
||||
# Count how many were non-contiguous (views into fused tensors)
|
||||
view_count = sum(
|
||||
1 for name, param in model.named_parameters()
|
||||
if not param.is_contiguous() or param.storage_offset() != 0
|
||||
)
|
||||
if view_count > 0:
|
||||
print(f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)")
|
||||
|
||||
# Save state dict manually, then save config separately
|
||||
torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
|
||||
model.config.save_pretrained(checkpoint_path)
|
||||
else:
|
||||
# Standard save (may have issues with view tensors)
|
||||
model.save_pretrained(checkpoint_path)
|
||||
|
||||
tokenizer.save_pretrained(checkpoint_path)
|
||||
|
||||
print(" Checkpoint saved.")
|
||||
|
|
|
|||
|
|
@ -90,7 +90,8 @@ def train_legacy(config: TrainingConfig):
|
|||
# Fetch data
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches = get_data(config.batch_size, config.seq_len, config.atropos_url)
|
||||
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
|
||||
extract_inference_logprobs=False)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
|
@ -206,26 +207,34 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
|
||||
# === Training Loop ===
|
||||
batches = []
|
||||
inference_logprobs = None
|
||||
for step in range(config.training_steps):
|
||||
print(f"\nStep {step+1}/{config.training_steps}")
|
||||
|
||||
# Fetch data
|
||||
# Fetch data (with inference logprobs for alignment check)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches = get_data(config.batch_size, config.seq_len, config.atropos_url)
|
||||
batches, inference_logprobs = get_data(
|
||||
config.batch_size, config.seq_len, config.atropos_url,
|
||||
extract_inference_logprobs=True, # Enable logprob alignment check
|
||||
)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
# Training step
|
||||
# Training step (with logprob alignment check)
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model, optimizer,
|
||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
config,
|
||||
inference_logprobs=inference_logprobs, # Pass for alignment validation
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# Clear inference logprobs after use (will be refreshed with new data)
|
||||
inference_logprobs = None
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
|
|
@ -335,7 +344,8 @@ def train_lora(config: TrainingConfig):
|
|||
# Fetch data
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches = get_data(config.batch_size, config.seq_len, config.atropos_url)
|
||||
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
|
||||
extract_inference_logprobs=False)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue