cleanup 3

This commit is contained in:
Jai Suphavadeeprasit 2026-02-13 12:39:37 -05:00
parent 39d307b440
commit 43cc71e070
4 changed files with 4 additions and 93 deletions

View file

@ -20,85 +20,13 @@ from torch.optim import AdamW
from .api import check_atropos_api, register_trainer
class CPUOffloadAdamW(torch.optim.Optimizer):
"""
AdamW with optimizer states offloaded to CPU.
Full precision (no quantization), but states stay on CPU RAM instead of GPU.
Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states.
"""
def __init__(
self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
def _init_state(self, p):
"""Lazily initialize state on CPU."""
state = self.state[p]
if len(state) == 0:
state["step"] = 0
# Store on CPU in FP32
state["exp_avg"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
state["exp_avg_sq"] = torch.zeros_like(p, device="cpu", dtype=torch.float32)
return state
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
beta1, beta2 = group["betas"]
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self._init_state(p)
state["step"] += 1
# Move states to GPU for computation
exp_avg = state["exp_avg"].to(p.device)
exp_avg_sq = state["exp_avg_sq"].to(p.device)
# AdamW update
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Bias correction
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] / bias_correction1
# Update weights
denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group["eps"])
p.addcdiv_(exp_avg, denom, value=-step_size)
# Weight decay
if group["weight_decay"] != 0:
p.add_(p, alpha=-group["lr"] * group["weight_decay"])
# Move states back to CPU (non-blocking for better perf)
state["exp_avg"].copy_(exp_avg.cpu())
state["exp_avg_sq"].copy_(exp_avg_sq.cpu())
return loss
def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
"""
Create optimizer based on config.optimizer setting.
Options:
- 'adamw': Standard AdamW (full precision, ~32GB GPU for 4B model)
- 'adamw': Standard AdamW (full precision, ~32GB GPU for 8B model)
- 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes)
- 'adamw_cpu': AdamW with CPU offload (~0GB GPU, slower but full precision)
- 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies)
"""
if config.optimizer == "adamw_8bit":
@ -112,16 +40,6 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
print("[Setup] WARNING: bitsandbytes not installed, falling back to AdamW")
print("[Setup] Install with: pip install bitsandbytes")
if config.optimizer == "adamw_cpu":
optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr)
print(
"[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)"
)
print(
"[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization"
)
return optimizer
if config.optimizer == "adafactor":
try:
from transformers.optimization import Adafactor