diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index 8dd0c5fd..d0cc0dac 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -16,6 +16,122 @@ import torch from torch.optim import AdamW from .api import check_atropos_api, register_trainer + + +class CPUOffloadAdamW(torch.optim.Optimizer): + """ + AdamW with optimizer states offloaded to CPU. + + Full precision (no quantization), but states stay on CPU RAM instead of GPU. + Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states. + """ + def __init__(self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + + def _init_state(self, p): + """Lazily initialize state on CPU.""" + state = self.state[p] + if len(state) == 0: + state['step'] = 0 + # Store on CPU in FP32 + state['exp_avg'] = torch.zeros_like(p, device='cpu', dtype=torch.float32) + state['exp_avg_sq'] = torch.zeros_like(p, device='cpu', dtype=torch.float32) + return state + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + state = self._init_state(p) + + state['step'] += 1 + + # Move states to GPU for computation + exp_avg = state['exp_avg'].to(p.device) + exp_avg_sq = state['exp_avg_sq'].to(p.device) + + # AdamW update + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Bias correction + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] / bias_correction1 + + # Update weights + denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(group['eps']) + p.addcdiv_(exp_avg, denom, value=-step_size) + + # Weight decay + if group['weight_decay'] != 0: + p.add_(p, alpha=-group['lr'] * group['weight_decay']) + + # Move states back to CPU (non-blocking for better perf) + state['exp_avg'].copy_(exp_avg.cpu()) + state['exp_avg_sq'].copy_(exp_avg_sq.cpu()) + + return loss + + +def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: + """ + Create optimizer based on config.optimizer setting. + + Options: + - 'adamw': Standard AdamW (full precision, ~32GB GPU for 4B model) + - 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes) + - 'adamw_cpu': AdamW with CPU offload (~0GB GPU, slower but full precision) + - 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies) + """ + if config.optimizer == "adamw_8bit": + try: + import bitsandbytes as bnb + optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr) + print(f"[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)") + return optimizer + except ImportError: + print("[Setup] WARNING: bitsandbytes not installed, falling back to AdamW") + print("[Setup] Install with: pip install bitsandbytes") + + if config.optimizer == "adamw_cpu": + optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr) + print(f"[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)") + print(f"[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization") + return optimizer + + if config.optimizer == "adafactor": + try: + from transformers.optimization import Adafactor + optimizer = Adafactor( + model.parameters(), + lr=config.lr, + scale_parameter=False, + relative_step=False, + ) + print(f"[Setup] Using Adafactor (no momentum, saves ~24GB)") + return optimizer + except ImportError: + print("[Setup] WARNING: transformers Adafactor not available, using AdamW") + + # Default: standard AdamW + optimizer = AdamW(model.parameters(), lr=config.lr) + print(f"[Setup] Using standard AdamW (requires ~32GB for optimizer states)") + return optimizer + + from .checkpointing import save_checkpoint, save_lora_checkpoint from .config import TrainingConfig from .data import get_data @@ -53,14 +169,7 @@ def train_legacy(config: TrainingConfig): # === Setup === use_wandb = setup_wandb(config) model, tokenizer = load_model_and_tokenizer(config) - - # Use 8-bit Adam to save ~16GB of optimizer state memory - try: - import bitsandbytes as bnb - optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr) - print("[Setup] Using 8-bit AdamW (saves ~16GB memory)") - except ImportError: - optimizer = AdamW(model.parameters(), lr=config.lr) + optimizer = create_optimizer(model, config) print(f"\n{'='*60}") print("LEGACY MODE (checkpoint + vLLM restart)") @@ -190,15 +299,7 @@ def train_shared_vllm(config: TrainingConfig): "3. vllm_bridge_config.json exists with IPC handles" ) - # Use 8-bit Adam to save ~16GB of optimizer state memory - try: - import bitsandbytes as bnb - optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr) - print("[Setup] Using 8-bit AdamW (saves ~16GB memory)") - except ImportError: - print("[Setup] bitsandbytes not installed, using standard AdamW") - print("[Setup] TIP: Install with 'pip install bitsandbytes' to save ~16GB memory") - optimizer = AdamW(model.parameters(), lr=config.lr) + optimizer = create_optimizer(model, config) # === Real-time weight sharing verification === print("\n[Weight Sharing Verification]")