mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-26 17:13:09 +00:00
readme fixes
This commit is contained in:
parent
396491ab72
commit
16ac332880
4 changed files with 63 additions and 19 deletions
|
|
@ -10,8 +10,9 @@ Contains the four main training modes:
|
|||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
|
@ -29,11 +30,20 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
|||
- 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes)
|
||||
- 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies)
|
||||
"""
|
||||
return create_optimizer_for_params(model.parameters(), config)
|
||||
|
||||
|
||||
def create_optimizer_for_params(
|
||||
params: Iterable[torch.nn.Parameter], config
|
||||
) -> torch.optim.Optimizer:
|
||||
"""Create optimizer for a specific parameter iterable."""
|
||||
params = list(params)
|
||||
|
||||
if config.optimizer == "adamw_8bit":
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
|
||||
optimizer = bnb.optim.AdamW8bit(params, lr=config.lr)
|
||||
print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
|
||||
return optimizer
|
||||
except ImportError:
|
||||
|
|
@ -45,7 +55,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
|||
from transformers.optimization import Adafactor
|
||||
|
||||
optimizer = Adafactor(
|
||||
model.parameters(),
|
||||
params,
|
||||
lr=config.lr,
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
|
|
@ -56,7 +66,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
|
|||
print("[Setup] WARNING: transformers Adafactor not available, using AdamW")
|
||||
|
||||
# Default: standard AdamW
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr)
|
||||
optimizer = AdamW(params, lr=config.lr)
|
||||
print("[Setup] Using standard AdamW (requires ~32GB for optimizer states)")
|
||||
return optimizer
|
||||
|
||||
|
|
@ -419,7 +429,7 @@ def train_lora(config: TrainingConfig):
|
|||
|
||||
# Only optimize LoRA parameters
|
||||
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = AdamW(trainable_params, lr=config.lr)
|
||||
optimizer = create_optimizer_for_params(trainable_params, config)
|
||||
|
||||
print(f"[3/3] Starting training for {config.training_steps} steps")
|
||||
print("-" * 60)
|
||||
|
|
@ -627,7 +637,7 @@ def train_lora_restart(config: TrainingConfig):
|
|||
|
||||
# Only optimize LoRA parameters
|
||||
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = AdamW(trainable_params, lr=config.lr)
|
||||
optimizer = create_optimizer_for_params(trainable_params, config)
|
||||
|
||||
os.makedirs(config.save_path, exist_ok=True)
|
||||
|
||||
|
|
@ -800,7 +810,7 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
|||
|
||||
# Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS)
|
||||
cmd = [
|
||||
"python", server_script,
|
||||
sys.executable, server_script,
|
||||
"--model", config.model_name,
|
||||
"--port", str(config.vllm_port),
|
||||
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue