diff --git a/example_trainer/README.md b/example_trainer/README.md index 3f1b9ba7..6836621e 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -330,7 +330,6 @@ The trainer supports multiple optimizer options to trade off between speed, memo | `adamw` | ~32GB (for 8B model) | Fastest | Full FP32 | None | | `adamw_8bit` (default) | ~8GB | Fast | 8-bit quantized | `bitsandbytes` | | `adafactor` | ~8GB | Fast | Full (no momentum) | `transformers` | -| `adamw_cpu` | ~0GB (on CPU) | ~2x slower | Full FP32 | None | **Usage:** ```bash @@ -342,21 +341,16 @@ The trainer supports multiple optimizer options to trade off between speed, memo # Adafactor - no momentum states, good for large models --optimizer adafactor - -# CPU offload - experimental, use when nothing else fits ---optimizer adamw_cpu ``` **Recommendations:** - **8B models on 80GB:** Use `adamw` (fastest) - **14B+ models on 80GB:** Use `adamw_8bit` or `adafactor` - **24B models:** Use `adafactor` with reduced batch size -- **adamw_cpu:** Experimental - not well tested, ~2x slower due to CPU↔GPU transfers **Potential Risks:** - `adamw_8bit`: Quantization may slightly affect convergence in edge cases; generally safe - `adafactor`: No momentum can make training slightly less stable; use with larger batch sizes -- `adamw_cpu`: Significantly slower; only use when you have no other option --- diff --git a/example_trainer/cli.py b/example_trainer/cli.py index 847e4202..a8fd009e 100644 --- a/example_trainer/cli.py +++ b/example_trainer/cli.py @@ -65,10 +65,10 @@ def add_training_args(parser: argparse.ArgumentParser) -> None: group.add_argument( "--optimizer", type=str, - choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"], + choices=["adamw", "adamw_8bit", "adafactor"], default="adamw_8bit", help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), " - "'adamw_cpu' (CPU offload), 'adafactor' (no momentum)", + "'adafactor' (no momentum)", ) group.add_argument( "--device", diff --git a/example_trainer/config.py b/example_trainer/config.py index c43524aa..61e6d802 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -32,11 +32,10 @@ class TrainingConfig(BaseModel): gradient_accumulation_steps: int = Field( 32, description="Number of gradient accumulation steps" ) - optimizer: Literal["adamw", "adamw_8bit", "adamw_cpu", "adafactor"] = Field( + optimizer: Literal["adamw", "adamw_8bit", "adafactor"] = Field( "adamw_8bit", description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), " "'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), " - "'adamw_cpu' (CPU offload, ~0GB GPU, slower), " "'adafactor' (no momentum, ~8GB GPU)", ) diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index 9f04a811..fdc4e5d4 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -20,85 +20,13 @@ from torch.optim import AdamW from .api import check_atropos_api, register_trainer -class CPUOffloadAdamW(torch.optim.Optimizer): - """ - AdamW with optimizer states offloaded to CPU. - - Full precision (no quantization), but states stay on CPU RAM instead of GPU. - Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states. - """ - - def __init__( - self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01 - ): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults) - - def _init_state(self, p): - """Lazily initialize state on CPU.""" - state = self.state[p] - if len(state) == 0: - state["step"] = 0 - # Store on CPU in FP32 - state["exp_avg"] = torch.zeros_like(p, device="cpu", dtype=torch.float32) - state["exp_avg_sq"] = torch.zeros_like(p, device="cpu", dtype=torch.float32) - return state - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - beta1, beta2 = group["betas"] - - for p in group["params"]: - if p.grad is None: - continue - - grad = p.grad - state = self._init_state(p) - - state["step"] += 1 - - # Move states to GPU for computation - exp_avg = state["exp_avg"].to(p.device) - exp_avg_sq = state["exp_avg_sq"].to(p.device) - - # AdamW update - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # Bias correction - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - step_size = group["lr"] / bias_correction1 - - # Update weights - denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group["eps"]) - p.addcdiv_(exp_avg, denom, value=-step_size) - - # Weight decay - if group["weight_decay"] != 0: - p.add_(p, alpha=-group["lr"] * group["weight_decay"]) - - # Move states back to CPU (non-blocking for better perf) - state["exp_avg"].copy_(exp_avg.cpu()) - state["exp_avg_sq"].copy_(exp_avg_sq.cpu()) - - return loss - - def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: """ Create optimizer based on config.optimizer setting. Options: - - 'adamw': Standard AdamW (full precision, ~32GB GPU for 4B model) + - 'adamw': Standard AdamW (full precision, ~32GB GPU for 8B model) - 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes) - - 'adamw_cpu': AdamW with CPU offload (~0GB GPU, slower but full precision) - 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies) """ if config.optimizer == "adamw_8bit": @@ -112,16 +40,6 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: print("[Setup] WARNING: bitsandbytes not installed, falling back to AdamW") print("[Setup] Install with: pip install bitsandbytes") - if config.optimizer == "adamw_cpu": - optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr) - print( - "[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)" - ) - print( - "[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization" - ) - return optimizer - if config.optimizer == "adafactor": try: from transformers.optimization import Adafactor