mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
ditching lora nccl
This commit is contained in:
parent
28bf3d9d60
commit
9ba6c0e7bb
7 changed files with 10 additions and 1296 deletions
|
|
@ -5,7 +5,6 @@ Contains the four main training modes:
|
|||
- train_legacy: Checkpoint-based training with vLLM restarts
|
||||
- train_shared_vllm: Single-copy mode with CUDA IPC
|
||||
- train_lora: LoRA adapter training with HTTP hot-swap
|
||||
- train_lora_nccl: LoRA adapter training with NCCL direct transfer (torchtitan-style)
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -659,286 +658,3 @@ def _hotswap_lora_adapter(
|
|||
return False
|
||||
|
||||
|
||||
def train_lora_nccl(config: TrainingConfig):
|
||||
"""
|
||||
GRPO training with LoRA adapters using NCCL direct weight transfer.
|
||||
|
||||
This mode (inspired by torchtitan):
|
||||
1. Freezes base model, trains only LoRA adapter weights
|
||||
2. Uses NCCL to broadcast weights directly to vLLM (zero disk I/O)
|
||||
3. Weight updates are immediate - no HTTP API calls
|
||||
|
||||
Benefits over train_lora():
|
||||
- Much faster weight sync (NCCL vs HTTP+disk)
|
||||
- Lower latency for on-policy training
|
||||
- No checkpoint files during training
|
||||
|
||||
Requirements:
|
||||
- External vLLM server running with NCCL receiver enabled
|
||||
- Trainer and vLLM must be in the same NCCL process group
|
||||
"""
|
||||
if not PEFT_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"PEFT library required for LoRA mode. Install with: pip install peft"
|
||||
)
|
||||
|
||||
training_start_time = time.time()
|
||||
|
||||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("LORA NCCL MODE (torchtitan-style direct weight transfer)")
|
||||
print("=" * 60)
|
||||
print(f"Base model: {config.model_name}")
|
||||
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print(f"vLLM port: {config.vllm_port}")
|
||||
print(f"NCCL init: {config.nccl_init_method}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Check external vLLM server
|
||||
print("[1/5] Checking external vLLM server...")
|
||||
if not check_vllm_health(config.vllm_port):
|
||||
print(f"\nERROR: vLLM server not running on port {config.vllm_port}")
|
||||
print("\nLoRA NCCL mode requires an external vLLM server. Start it first:")
|
||||
print(
|
||||
f" python example_trainer/vllm_api_server.py "
|
||||
f"--model {config.model_name} --port {config.vllm_port} --enable-lora --enforce-eager"
|
||||
)
|
||||
raise RuntimeError(f"External vLLM server required on port {config.vllm_port}")
|
||||
print(f"vLLM server healthy on port {config.vllm_port}")
|
||||
|
||||
# Load model with LoRA adapters
|
||||
print("[2/5] Loading model with LoRA adapters...")
|
||||
model, tokenizer = load_model_and_tokenizer(config)
|
||||
|
||||
# Only optimize LoRA parameters
|
||||
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = AdamW(trainable_params, lr=config.lr)
|
||||
|
||||
# Import NCCL bridge components
|
||||
from .nccl_weight_bridge import (
|
||||
NCCLBridgeConfig,
|
||||
NCCLWeightBridge,
|
||||
create_trainer_param_to_vllm_mapping,
|
||||
export_bridge_config,
|
||||
get_lora_params,
|
||||
)
|
||||
|
||||
# Pre-register params to get metadata for vLLM
|
||||
lora_params = get_lora_params(model)
|
||||
param_names = sorted(lora_params.keys())
|
||||
param_shapes = {name: list(p.shape) for name, p in lora_params.items()}
|
||||
param_dtypes = {name: str(p.dtype) for name, p in lora_params.items()}
|
||||
|
||||
param_metadata = {
|
||||
"param_names": param_names,
|
||||
"param_shapes": param_shapes,
|
||||
"param_dtypes": param_dtypes,
|
||||
"num_params": len(param_names),
|
||||
}
|
||||
|
||||
param_mappings = create_trainer_param_to_vllm_mapping(
|
||||
param_names,
|
||||
model_name=config.model_name
|
||||
)
|
||||
|
||||
# Tell vLLM to start its NCCL receiver FIRST (it will join as rank 1)
|
||||
print("[3/5] Starting NCCL receiver on vLLM server...")
|
||||
vllm_base_url = f"http://localhost:{config.vllm_port}"
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{vllm_base_url}/nccl/start_receiver",
|
||||
json={
|
||||
"init_method": config.nccl_init_method,
|
||||
"world_size": config.nccl_world_size,
|
||||
"param_metadata": param_metadata,
|
||||
"param_mappings": param_mappings,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
resp_data = response.json()
|
||||
if response.status_code != 200 or resp_data.get("status") == "error":
|
||||
raise RuntimeError(f"Failed to start NCCL receiver on vLLM: {resp_data}")
|
||||
print(f" vLLM NCCL receiver started: {resp_data}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Failed to contact vLLM server: {e}")
|
||||
|
||||
# Wait for vLLM to be in "connecting" state
|
||||
import time as time_module
|
||||
print(" Waiting for vLLM NCCL receiver to initialize...")
|
||||
for i in range(10):
|
||||
time_module.sleep(1)
|
||||
try:
|
||||
status_resp = requests.get(f"{vllm_base_url}/nccl/status", timeout=5)
|
||||
status = status_resp.json()
|
||||
print(f" vLLM NCCL status: {status.get('status', 'unknown')}")
|
||||
if status.get("status") == "error":
|
||||
raise RuntimeError(f"vLLM NCCL setup failed: {status.get('error')}")
|
||||
if status.get("status") in ["connecting", "connected"]:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f" Status check error: {e}")
|
||||
|
||||
# Now setup trainer's NCCL bridge (joins as rank 0)
|
||||
print("[4/5] Setting up trainer NCCL weight bridge...")
|
||||
nccl_config = NCCLBridgeConfig(
|
||||
rank=0, # Trainer is always rank 0
|
||||
world_size=config.nccl_world_size,
|
||||
init_method=config.nccl_init_method,
|
||||
)
|
||||
|
||||
bridge = NCCLWeightBridge(nccl_config)
|
||||
if not bridge.setup():
|
||||
# Try to stop vLLM receiver on failure
|
||||
try:
|
||||
requests.post(f"{vllm_base_url}/nccl/stop_receiver", timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
raise RuntimeError("Failed to setup NCCL bridge")
|
||||
|
||||
# Register parameters with the bridge (we already have the metadata)
|
||||
bridge.param_names = param_names
|
||||
bridge.param_shapes = {name: tuple(shape) for name, shape in param_shapes.items()}
|
||||
bridge.param_dtypes = param_dtypes
|
||||
|
||||
# Export config for debugging/recovery
|
||||
bridge_config_path = os.path.join(config.save_path, "nccl_bridge_config.json")
|
||||
os.makedirs(config.save_path, exist_ok=True)
|
||||
export_bridge_config(
|
||||
bridge_config_path,
|
||||
param_metadata,
|
||||
param_mappings,
|
||||
config.nccl_init_method,
|
||||
config.nccl_world_size,
|
||||
)
|
||||
|
||||
print(f"[5/5] Starting training for {config.training_steps} steps")
|
||||
print("-" * 60)
|
||||
|
||||
# Check Atropos API
|
||||
if not check_atropos_api(url=config.atropos_url, timeout=30):
|
||||
raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}")
|
||||
register_trainer(config)
|
||||
|
||||
# === Benchmark tracking ===
|
||||
benchmark_stats = {
|
||||
"step_times": [],
|
||||
"sync_times": [],
|
||||
"data_fetch_times": [],
|
||||
"gpu_memories": [],
|
||||
}
|
||||
|
||||
# Send initial weights to vLLM
|
||||
print("Sending initial LoRA weights to vLLM...")
|
||||
initial_sync_time = bridge.send_lora_weights(model, step=0)
|
||||
print(f" Initial sync completed in {initial_sync_time:.3f}s")
|
||||
|
||||
# === Training Loop ===
|
||||
batches = []
|
||||
for step in range(config.training_steps):
|
||||
print(f"\nStep {step+1}/{config.training_steps}")
|
||||
|
||||
# Fetch data (with inference logprobs for proper GRPO)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(
|
||||
config.batch_size,
|
||||
config.seq_len,
|
||||
config.atropos_url,
|
||||
extract_inference_logprobs=True,
|
||||
)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
# Training step with proper GRPO
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model,
|
||||
optimizer,
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = (
|
||||
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
gpu_mem_reserved_gb = (
|
||||
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||
|
||||
# NCCL weight sync (every step for on-policy, or periodic)
|
||||
sync_time = 0
|
||||
should_sync = (
|
||||
config.nccl_sync_every_step or
|
||||
(step + 1) % config.vllm_restart_interval == 0
|
||||
)
|
||||
if should_sync:
|
||||
sync_start = time.time()
|
||||
bridge.send_lora_weights(model, step=step + 1)
|
||||
sync_time = time.time() - sync_start
|
||||
benchmark_stats["sync_times"].append(sync_time)
|
||||
print(f" [NCCL] Weights synced in {sync_time:.3f}s")
|
||||
|
||||
# Update metrics
|
||||
metrics.update(
|
||||
{
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
}
|
||||
)
|
||||
|
||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||
|
||||
# Periodic checkpoint (for recovery only, not for vLLM sync)
|
||||
if (
|
||||
config.checkpoint_interval > 0
|
||||
and (step + 1) % config.checkpoint_interval == 0
|
||||
):
|
||||
save_lora_checkpoint(model, config.save_path, step + 1)
|
||||
|
||||
# === Cleanup ===
|
||||
# Final sync
|
||||
print("\nSending final weights...")
|
||||
final_sync_time = bridge.send_lora_weights(model, step=config.training_steps)
|
||||
benchmark_stats["sync_times"].append(final_sync_time)
|
||||
|
||||
# Save final checkpoint
|
||||
final_adapter_path = save_lora_checkpoint(
|
||||
model, config.save_path, config.training_steps, is_final=True
|
||||
)
|
||||
|
||||
# Cleanup bridge
|
||||
bridge.cleanup()
|
||||
|
||||
finalize_training(
|
||||
use_wandb,
|
||||
training_start_time,
|
||||
"lora_nccl",
|
||||
config.training_steps,
|
||||
benchmark_stats,
|
||||
config.benchmark,
|
||||
)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer_path = os.path.join(config.save_path, "tokenizer")
|
||||
tokenizer.save_pretrained(tokenizer_path)
|
||||
print(f"Tokenizer saved to {tokenizer_path}")
|
||||
print(f"Final adapter saved to {final_adapter_path}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue