mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
memory enhancements
This commit is contained in:
parent
75c4f5c853
commit
52bd4cb624
1 changed files with 118 additions and 17 deletions
|
|
@ -16,6 +16,122 @@ import torch
|
|||
from torch.optim import AdamW
|
||||
|
||||
from .api import check_atropos_api, register_trainer
|
||||
|
||||
|
||||
class CPUOffloadAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
AdamW with optimizer states offloaded to CPU.
|
||||
|
||||
Full precision (no quantization), but states stay on CPU RAM instead of GPU.
|
||||
Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states.
|
||||
"""
|
||||
def __init__(self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def _init_state(self, p):
|
||||
"""Lazily initialize state on CPU."""
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Store on CPU in FP32
|
||||
state['exp_avg'] = torch.zeros_like(p, device='cpu', dtype=torch.float32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, device='cpu', dtype=torch.float32)
|
||||
return state
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
grad = p.grad
|
||||
state = self._init_state(p)
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# Move states to GPU for computation
|
||||
exp_avg = state['exp_avg'].to(p.device)
|
||||
exp_avg_sq = state['exp_avg_sq'].to(p.device)
|
||||
|
||||
# AdamW update
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
# Bias correction
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
step_size = group['lr'] / bias_correction1
|
||||
|
||||
# Update weights
|
||||
denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(group['eps'])
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
# Weight decay
|
||||
if group['weight_decay'] != 0:
|
||||
p.add_(p, alpha=-group['lr'] * group['weight_decay'])
|
||||
|
||||
# Move states back to CPU (non-blocking for better perf)
|
||||
state['exp_avg'].copy_(exp_avg.cpu())
|
||||
state['exp_avg_sq'].copy_(exp_avg_sq.cpu())
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
||||
"""
|
||||
Create optimizer based on config.optimizer setting.
|
||||
|
||||
Options:
|
||||
- 'adamw': Standard AdamW (full precision, ~32GB GPU for 4B model)
|
||||
- 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes)
|
||||
- 'adamw_cpu': AdamW with CPU offload (~0GB GPU, slower but full precision)
|
||||
- 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies)
|
||||
"""
|
||||
if config.optimizer == "adamw_8bit":
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
|
||||
print(f"[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
|
||||
return optimizer
|
||||
except ImportError:
|
||||
print("[Setup] WARNING: bitsandbytes not installed, falling back to AdamW")
|
||||
print("[Setup] Install with: pip install bitsandbytes")
|
||||
|
||||
if config.optimizer == "adamw_cpu":
|
||||
optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr)
|
||||
print(f"[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)")
|
||||
print(f"[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization")
|
||||
return optimizer
|
||||
|
||||
if config.optimizer == "adafactor":
|
||||
try:
|
||||
from transformers.optimization import Adafactor
|
||||
optimizer = Adafactor(
|
||||
model.parameters(),
|
||||
lr=config.lr,
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
)
|
||||
print(f"[Setup] Using Adafactor (no momentum, saves ~24GB)")
|
||||
return optimizer
|
||||
except ImportError:
|
||||
print("[Setup] WARNING: transformers Adafactor not available, using AdamW")
|
||||
|
||||
# Default: standard AdamW
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr)
|
||||
print(f"[Setup] Using standard AdamW (requires ~32GB for optimizer states)")
|
||||
return optimizer
|
||||
|
||||
|
||||
from .checkpointing import save_checkpoint, save_lora_checkpoint
|
||||
from .config import TrainingConfig
|
||||
from .data import get_data
|
||||
|
|
@ -53,14 +169,7 @@ def train_legacy(config: TrainingConfig):
|
|||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
model, tokenizer = load_model_and_tokenizer(config)
|
||||
|
||||
# Use 8-bit Adam to save ~16GB of optimizer state memory
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
|
||||
print("[Setup] Using 8-bit AdamW (saves ~16GB memory)")
|
||||
except ImportError:
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr)
|
||||
optimizer = create_optimizer(model, config)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("LEGACY MODE (checkpoint + vLLM restart)")
|
||||
|
|
@ -190,15 +299,7 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
"3. vllm_bridge_config.json exists with IPC handles"
|
||||
)
|
||||
|
||||
# Use 8-bit Adam to save ~16GB of optimizer state memory
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
|
||||
print("[Setup] Using 8-bit AdamW (saves ~16GB memory)")
|
||||
except ImportError:
|
||||
print("[Setup] bitsandbytes not installed, using standard AdamW")
|
||||
print("[Setup] TIP: Install with 'pip install bitsandbytes' to save ~16GB memory")
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr)
|
||||
optimizer = create_optimizer(model, config)
|
||||
|
||||
# === Real-time weight sharing verification ===
|
||||
print("\n[Weight Sharing Verification]")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue