mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
33505fe981
commit
11f495a381
19 changed files with 708 additions and 452 deletions
|
|
@ -21,75 +21,78 @@ from .api import check_atropos_api, register_trainer
|
|||
class CPUOffloadAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
AdamW with optimizer states offloaded to CPU.
|
||||
|
||||
|
||||
Full precision (no quantization), but states stay on CPU RAM instead of GPU.
|
||||
Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states.
|
||||
"""
|
||||
def __init__(self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
|
||||
|
||||
def __init__(
|
||||
self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
|
||||
):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
|
||||
def _init_state(self, p):
|
||||
"""Lazily initialize state on CPU."""
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state["step"] = 0
|
||||
# Store on CPU in FP32
|
||||
state['exp_avg'] = torch.zeros_like(p, device='cpu', dtype=torch.float32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, device='cpu', dtype=torch.float32)
|
||||
state["exp_avg"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
|
||||
return state
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
|
||||
for group in self.param_groups:
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
for p in group['params']:
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
|
||||
grad = p.grad
|
||||
state = self._init_state(p)
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Move states to GPU for computation
|
||||
exp_avg = state['exp_avg'].to(p.device)
|
||||
exp_avg_sq = state['exp_avg_sq'].to(p.device)
|
||||
|
||||
exp_avg = state["exp_avg"].to(p.device)
|
||||
exp_avg_sq = state["exp_avg_sq"].to(p.device)
|
||||
|
||||
# AdamW update
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
|
||||
# Bias correction
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
step_size = group['lr'] / bias_correction1
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
step_size = group["lr"] / bias_correction1
|
||||
|
||||
# Update weights
|
||||
denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(group['eps'])
|
||||
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group["eps"])
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
|
||||
# Weight decay
|
||||
if group['weight_decay'] != 0:
|
||||
p.add_(p, alpha=-group['lr'] * group['weight_decay'])
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p.add_(p, alpha=-group["lr"] * group["weight_decay"])
|
||||
|
||||
# Move states back to CPU (non-blocking for better perf)
|
||||
state['exp_avg'].copy_(exp_avg.cpu())
|
||||
state['exp_avg_sq'].copy_(exp_avg_sq.cpu())
|
||||
|
||||
state["exp_avg"].copy_(exp_avg.cpu())
|
||||
state["exp_avg_sq"].copy_(exp_avg_sq.cpu())
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
||||
"""
|
||||
Create optimizer based on config.optimizer setting.
|
||||
|
||||
|
||||
Options:
|
||||
- 'adamw': Standard AdamW (full precision, ~32GB GPU for 4B model)
|
||||
- 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes)
|
||||
|
|
@ -99,22 +102,28 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
|||
if config.optimizer == "adamw_8bit":
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
|
||||
print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
|
||||
return optimizer
|
||||
except ImportError:
|
||||
print("[Setup] WARNING: bitsandbytes not installed, falling back to AdamW")
|
||||
print("[Setup] Install with: pip install bitsandbytes")
|
||||
|
||||
|
||||
if config.optimizer == "adamw_cpu":
|
||||
optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr)
|
||||
print("[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)")
|
||||
print("[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization")
|
||||
print(
|
||||
"[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)"
|
||||
)
|
||||
print(
|
||||
"[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization"
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
if config.optimizer == "adafactor":
|
||||
try:
|
||||
from transformers.optimization import Adafactor
|
||||
|
||||
optimizer = Adafactor(
|
||||
model.parameters(),
|
||||
lr=config.lr,
|
||||
|
|
@ -125,7 +134,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
|||
return optimizer
|
||||
except ImportError:
|
||||
print("[Setup] WARNING: transformers Adafactor not available, using AdamW")
|
||||
|
||||
|
||||
# Default: standard AdamW
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr)
|
||||
print("[Setup] Using standard AdamW (requires ~32GB for optimizer states)")
|
||||
|
|
@ -135,7 +144,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
|||
from .checkpointing import save_checkpoint, save_lora_checkpoint # noqa: E402
|
||||
from .config import TrainingConfig # noqa: E402
|
||||
from .data import get_data # noqa: E402
|
||||
from .model import load_model_and_tokenizer, PEFT_AVAILABLE # noqa: E402
|
||||
from .model import PEFT_AVAILABLE, load_model_and_tokenizer # noqa: E402
|
||||
from .training import ( # noqa: E402
|
||||
finalize_training,
|
||||
log_metrics,
|
||||
|
|
@ -146,8 +155,8 @@ from .vllm_manager import ( # noqa: E402
|
|||
check_vllm_health,
|
||||
check_vllm_process_health,
|
||||
launch_vllm_server,
|
||||
terminate_vllm_process,
|
||||
set_vllm_process,
|
||||
terminate_vllm_process,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -171,13 +180,13 @@ def train_legacy(config: TrainingConfig):
|
|||
model, tokenizer = load_model_and_tokenizer(config)
|
||||
optimizer = create_optimizer(model, config)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("LEGACY MODE (checkpoint + vLLM restart)")
|
||||
print("="*60)
|
||||
print("=" * 60)
|
||||
print(f"Training for {config.training_steps} steps on {config.device}")
|
||||
print(f"vLLM restart interval: every {config.vllm_restart_interval} steps")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
os.makedirs(config.save_path, exist_ok=True)
|
||||
|
||||
|
|
@ -206,24 +215,36 @@ def train_legacy(config: TrainingConfig):
|
|||
# Fetch data (with inference logprobs for proper GRPO)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
|
||||
extract_inference_logprobs=True)
|
||||
batches, _ = get_data(
|
||||
config.batch_size,
|
||||
config.seq_len,
|
||||
config.atropos_url,
|
||||
extract_inference_logprobs=True,
|
||||
)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
# Check if we should sync (save checkpoint + restart vLLM)
|
||||
should_sync = (step + 1) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
|
||||
should_sync = (
|
||||
step + 1
|
||||
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
|
||||
if should_sync:
|
||||
terminate_vllm_process()
|
||||
|
||||
# Training step (with proper GRPO using inference logprobs)
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model, optimizer,
|
||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
model,
|
||||
optimizer,
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
|
|
@ -231,15 +252,21 @@ def train_legacy(config: TrainingConfig):
|
|||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_gb = (
|
||||
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
gpu_mem_reserved_gb = (
|
||||
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||
|
||||
# Sync (checkpoint + restart)
|
||||
sync_time = 0
|
||||
if should_sync:
|
||||
sync_start = time.time()
|
||||
checkpoint_path = save_checkpoint(model, tokenizer, config.save_path, step + 1)
|
||||
checkpoint_path = save_checkpoint(
|
||||
model, tokenizer, config.save_path, step + 1
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
vllm_proc = launch_vllm_server(config, checkpoint_path)
|
||||
set_vllm_process(vllm_proc)
|
||||
|
|
@ -247,20 +274,31 @@ def train_legacy(config: TrainingConfig):
|
|||
benchmark_stats["sync_times"].append(sync_time)
|
||||
|
||||
# Update metrics
|
||||
metrics.update({
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
})
|
||||
metrics.update(
|
||||
{
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
}
|
||||
)
|
||||
|
||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||
check_vllm_process_health()
|
||||
|
||||
# === Cleanup ===
|
||||
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
|
||||
finalize_training(use_wandb, training_start_time, "legacy", config.training_steps, benchmark_stats, config.benchmark)
|
||||
save_checkpoint(
|
||||
model, tokenizer, config.save_path, config.training_steps, is_final=True
|
||||
)
|
||||
finalize_training(
|
||||
use_wandb,
|
||||
training_start_time,
|
||||
"legacy",
|
||||
config.training_steps,
|
||||
benchmark_stats,
|
||||
config.benchmark,
|
||||
)
|
||||
|
||||
|
||||
def train_shared_vllm(config: TrainingConfig):
|
||||
|
|
@ -281,13 +319,13 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("SINGLE-COPY MODE (CUDA IPC)")
|
||||
print(">>> Trainer uses vLLM's tensors directly!")
|
||||
print("="*60)
|
||||
print("=" * 60)
|
||||
print(f"Model: {config.model_name}")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Attach to vLLM's shared tensors
|
||||
print("[1/2] Attaching to vLLM's shared tensors...")
|
||||
|
|
@ -331,11 +369,15 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(
|
||||
config.batch_size, config.seq_len, config.atropos_url,
|
||||
config.batch_size,
|
||||
config.seq_len,
|
||||
config.atropos_url,
|
||||
extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs
|
||||
)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
|
@ -343,8 +385,12 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
# Training step with proper GRPO (importance sampling + KL penalty)
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model, optimizer,
|
||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
model,
|
||||
optimizer,
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation
|
||||
)
|
||||
|
|
@ -352,8 +398,12 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_gb = (
|
||||
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
gpu_mem_reserved_gb = (
|
||||
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||
|
||||
# In single-copy mode, weights are updated in-place (no sync needed!)
|
||||
|
|
@ -362,23 +412,37 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
benchmark_stats["sync_times"].append(sync_time)
|
||||
|
||||
# Update metrics
|
||||
metrics.update({
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
})
|
||||
metrics.update(
|
||||
{
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
}
|
||||
)
|
||||
|
||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||
|
||||
# Periodic checkpoint (for recovery, not for vLLM sync)
|
||||
if config.checkpoint_interval > 0 and (step + 1) % config.checkpoint_interval == 0:
|
||||
if (
|
||||
config.checkpoint_interval > 0
|
||||
and (step + 1) % config.checkpoint_interval == 0
|
||||
):
|
||||
save_checkpoint(model, tokenizer, config.save_path, step + 1)
|
||||
|
||||
# === Cleanup ===
|
||||
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
|
||||
finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps, benchmark_stats, config.benchmark)
|
||||
save_checkpoint(
|
||||
model, tokenizer, config.save_path, config.training_steps, is_final=True
|
||||
)
|
||||
finalize_training(
|
||||
use_wandb,
|
||||
training_start_time,
|
||||
"shared_vllm",
|
||||
config.training_steps,
|
||||
benchmark_stats,
|
||||
config.benchmark,
|
||||
)
|
||||
|
||||
|
||||
def train_lora(config: TrainingConfig):
|
||||
|
|
@ -399,29 +463,33 @@ def train_lora(config: TrainingConfig):
|
|||
- External vLLM server running with --enable-lora
|
||||
"""
|
||||
if not PEFT_AVAILABLE:
|
||||
raise RuntimeError("PEFT library required for LoRA mode. Install with: pip install peft")
|
||||
raise RuntimeError(
|
||||
"PEFT library required for LoRA mode. Install with: pip install peft"
|
||||
)
|
||||
|
||||
training_start_time = time.time()
|
||||
|
||||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("LORA MODE (adapter-only training)")
|
||||
print("="*60)
|
||||
print("=" * 60)
|
||||
print(f"Base model: {config.model_name}")
|
||||
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print(f"vLLM port: {config.vllm_port}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Check external vLLM server
|
||||
print("[1/3] Checking external vLLM server...")
|
||||
if not check_vllm_health(config.vllm_port):
|
||||
print(f"\nERROR: vLLM server not running on port {config.vllm_port}")
|
||||
print("\nLoRA mode requires an external vLLM server. Start it first:")
|
||||
print(f" python example_trainer/vllm_api_server.py --model {config.model_name} "
|
||||
f"--port {config.vllm_port} --enable-lora --enforce-eager")
|
||||
print(
|
||||
f" python example_trainer/vllm_api_server.py --model {config.model_name} "
|
||||
f"--port {config.vllm_port} --enable-lora --enforce-eager"
|
||||
)
|
||||
raise RuntimeError(f"External vLLM server required on port {config.vllm_port}")
|
||||
print(f"vLLM server healthy on port {config.vllm_port}")
|
||||
|
||||
|
|
@ -459,10 +527,16 @@ def train_lora(config: TrainingConfig):
|
|||
# Fetch data (with inference logprobs for proper GRPO)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url,
|
||||
extract_inference_logprobs=True)
|
||||
batches, _ = get_data(
|
||||
config.batch_size,
|
||||
config.seq_len,
|
||||
config.atropos_url,
|
||||
extract_inference_logprobs=True,
|
||||
)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4]
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
|
@ -470,8 +544,12 @@ def train_lora(config: TrainingConfig):
|
|||
# Training step with proper GRPO
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model, optimizer,
|
||||
token_batches, label_batches, advantage_batches, temperature_batches,
|
||||
model,
|
||||
optimizer,
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
|
|
@ -479,8 +557,12 @@ def train_lora(config: TrainingConfig):
|
|||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
gpu_mem_gb = (
|
||||
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
gpu_mem_reserved_gb = (
|
||||
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||
|
||||
# Periodic adapter save + hot-swap
|
||||
|
|
@ -494,24 +576,35 @@ def train_lora(config: TrainingConfig):
|
|||
benchmark_stats["sync_times"].append(sync_time)
|
||||
|
||||
# Update metrics
|
||||
metrics.update({
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
})
|
||||
metrics.update(
|
||||
{
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
}
|
||||
)
|
||||
|
||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||
|
||||
# === Cleanup ===
|
||||
final_sync_start = time.time()
|
||||
final_adapter_path = save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True)
|
||||
final_adapter_path = save_lora_checkpoint(
|
||||
model, config.save_path, config.training_steps, is_final=True
|
||||
)
|
||||
_hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final")
|
||||
final_sync_time = time.time() - final_sync_start
|
||||
benchmark_stats["sync_times"].append(final_sync_time)
|
||||
|
||||
finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps, benchmark_stats, config.benchmark)
|
||||
finalize_training(
|
||||
use_wandb,
|
||||
training_start_time,
|
||||
"lora_only",
|
||||
config.training_steps,
|
||||
benchmark_stats,
|
||||
config.benchmark,
|
||||
)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer_path = os.path.join(config.save_path, "tokenizer")
|
||||
|
|
@ -563,4 +656,3 @@ def _hotswap_lora_adapter(
|
|||
except Exception as e:
|
||||
print(f" [LORA] ✗ Hot-swap request failed: {e}")
|
||||
return False
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue