mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
major refactor
This commit is contained in:
parent
119721ef3d
commit
6833d4d820
13 changed files with 3268 additions and 3423 deletions
90
example_trainer/checkpointing.py
Normal file
90
example_trainer/checkpointing.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""
|
||||
Checkpoint saving utilities for GRPO trainer.
|
||||
|
||||
Handles saving model checkpoints for different training modes:
|
||||
- Full model checkpoints (legacy and shared_vllm modes)
|
||||
- LoRA adapter checkpoints
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
tokenizer,
|
||||
save_path: str,
|
||||
step: int,
|
||||
is_final: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Save full model checkpoint.
|
||||
|
||||
Args:
|
||||
model: Model to save
|
||||
tokenizer: Tokenizer to save
|
||||
save_path: Base directory for checkpoints
|
||||
step: Current training step
|
||||
is_final: Whether this is the final checkpoint
|
||||
|
||||
Returns:
|
||||
Path where checkpoint was saved
|
||||
"""
|
||||
if is_final:
|
||||
checkpoint_path = os.path.join(save_path, "final_model")
|
||||
else:
|
||||
checkpoint_path = os.path.join(save_path, f"step_{step}")
|
||||
|
||||
print(f" Saving checkpoint to {checkpoint_path}...")
|
||||
|
||||
if os.path.exists(checkpoint_path):
|
||||
shutil.rmtree(checkpoint_path)
|
||||
os.makedirs(checkpoint_path, exist_ok=True)
|
||||
|
||||
model.save_pretrained(checkpoint_path)
|
||||
tokenizer.save_pretrained(checkpoint_path)
|
||||
|
||||
print(" Checkpoint saved.")
|
||||
return checkpoint_path
|
||||
|
||||
|
||||
def save_lora_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
save_path: str,
|
||||
step: int,
|
||||
is_final: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Save LoRA adapter checkpoint.
|
||||
|
||||
Only saves the LoRA adapter weights, not the full model.
|
||||
This results in much smaller checkpoint files.
|
||||
|
||||
Args:
|
||||
model: PEFT model with LoRA adapters
|
||||
save_path: Base directory for checkpoints
|
||||
step: Current training step
|
||||
is_final: Whether this is the final checkpoint
|
||||
|
||||
Returns:
|
||||
Path where adapter was saved
|
||||
"""
|
||||
if is_final:
|
||||
adapter_path = os.path.join(save_path, "final_adapter")
|
||||
else:
|
||||
adapter_path = os.path.join(save_path, f"adapter_step_{step}")
|
||||
|
||||
print(f" Saving LoRA adapter to {adapter_path}...")
|
||||
|
||||
if os.path.exists(adapter_path):
|
||||
shutil.rmtree(adapter_path)
|
||||
os.makedirs(adapter_path, exist_ok=True)
|
||||
|
||||
# Save only the adapter weights (much smaller than full model)
|
||||
model.save_pretrained(adapter_path)
|
||||
|
||||
print(" Adapter saved.")
|
||||
return adapter_path
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue