diff --git a/example_trainer/README.md b/example_trainer/README.md index 582588c1..014be539 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -325,7 +325,7 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level The trainer includes automatic patches for NVIDIA B200 (Blackwell architecture) GPUs when using LoRA mode. These patches disable Grid Dependency Control (GDC) in vLLM's Triton kernels, which causes compilation failures on Blackwell GPUs. The patches are applied automatically when: - `VLLM_ENABLE_SHARED_WEIGHTS=1` is set, or -- LoRA mode is used +- `NUM_INFERENCE_NODES` is set (distributed inference path) The patching clears the Triton cache and disables GDC to ensure compatibility. No manual intervention required. @@ -534,7 +534,7 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands ```bash --use-wandb \ --wandb-project "my-grpo-training" \ ---wandb-run-name "hermes-8b-gsm8k" +--wandb-group "hermes-8b-gsm8k" ``` --- @@ -548,6 +548,7 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands | `--model-name` or `--model` | (required) | HuggingFace model ID | | `--weight-bridge-mode` | `none` | `shared_vllm`, `lora_only`, `lora_restart`, or `none` | | `--training-steps` | 10 | Number of training steps | +| `--checkpoint-interval` | 3 | Save checkpoint every N steps (0 = final only) | | `--batch-size` | 2 | Micro-batch size | | `--gradient-accumulation-steps` | 32 | Effective batch = batch × accum | | `--seq-len` | 2048 | Maximum sequence length | @@ -559,6 +560,7 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands | `--kl-coef` | 0.1 | KL penalty strength (higher = more conservative) | | `--clip-eps` | 0.2 | PPO clipping range [1-ε, 1+ε] | | `--lr` | 1e-5 | Learning rate (NOT --learning-rate) | +| `--no-reference-logprobs` | False | Disable GRPO reference logprobs (falls back to REINFORCE-style updates) | ### LoRA Arguments @@ -567,7 +569,7 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands | `--lora-r` | 16 | LoRA rank (dimension of low-rank matrices) | | `--lora-alpha` | 32 | LoRA alpha scaling factor | | `--lora-dropout` | 0.05 | LoRA dropout probability | -| `--lora-target-modules` | None | Module names to apply LoRA (default: `q_proj v_proj`) | +| `--lora-target-modules` | None | Module names to apply LoRA (`None` falls back to `q_proj v_proj`) | ### vLLM Arguments @@ -581,6 +583,39 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands | `--dtype` | `bfloat16` | Model dtype: `bfloat16`, `float16`, or `auto` | | `--vllm-restart-interval` | 3 | Restart vLLM every N steps (legacy/lora_restart) | +### Atropos API Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--atropos-url` | `http://localhost:8000` | URL of the Atropos API server | + +**Note:** Many examples in this README use `http://localhost:8002` because they start `run-api --port 8002`. + +### Weights & Biases Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--use-wandb` | False | Enable W&B logging | +| `--wandb-project` | None | W&B project name | +| `--wandb-group` | None | W&B group name (auto-generated if omitted) | + +### Distributed Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--trainer-rank` | 0 | Trainer rank | +| `--world-size` | 1 | World size | +| `--init-method` | `env://` | Distributed init method | +| `--num-inference-nodes` | 0 | Number of inference nodes | + +### Debug & Benchmark Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--debug-loading` | False | Verbose model loading diagnostics | +| `--benchmark` | False | Print benchmark/timing metrics | +| `--log-dir` | `./logs` | Directory for unified launcher logs | + --- ## Module Documentation diff --git a/example_trainer/api.py b/example_trainer/api.py index f9073cc2..21c4288e 100644 --- a/example_trainer/api.py +++ b/example_trainer/api.py @@ -53,6 +53,11 @@ def register_trainer(config: TrainingConfig): Verifies registration succeeded before returning. """ url = config.atropos_url + save_checkpoint_interval = ( + config.training_steps + if config.checkpoint_interval <= 0 + else config.checkpoint_interval + ) response = requests.post( f"{url}/register", json={ @@ -63,7 +68,7 @@ def register_trainer(config: TrainingConfig): "max_token_len": config.seq_len, "starting_step": 0, "checkpoint_dir": config.save_path, - "save_checkpoint_interval": config.training_steps, + "save_checkpoint_interval": save_checkpoint_interval, "num_steps": config.training_steps, }, timeout=10, diff --git a/example_trainer/data.py b/example_trainer/data.py index 1d17ffdd..bb3ebbdb 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -9,7 +9,6 @@ Also extracts inference logprobs for proper GRPO loss computation: - They are batched and padded to align token-by-token with training labels """ -import json import math import time from typing import List, Optional, Tuple @@ -180,9 +179,8 @@ def pad_data_to_good_offset( temperature_batches = [] inference_logprob_batches = [] - for i in range(len(input_ids) // batch_size): - start = i * batch_size - end = (i + 1) * batch_size + for start in range(0, len(input_ids), batch_size): + end = min(start + batch_size, len(input_ids)) token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0))) label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0))) @@ -294,10 +292,6 @@ def get_data( ) _logged_logprob_warning = True - # Save batch for debugging - with open("temp.json", "w", encoding="utf-8") as f: - json.dump(data, f) - # Process and accumulate batches (now includes batched inference logprobs) ( token_batches, diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index fdc4e5d4..ce80adc9 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -10,8 +10,9 @@ Contains the four main training modes: import os import subprocess +import sys import time -from typing import Optional +from typing import Iterable, Optional import requests import torch @@ -29,11 +30,20 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: - 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes) - 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies) """ + return create_optimizer_for_params(model.parameters(), config) + + +def create_optimizer_for_params( + params: Iterable[torch.nn.Parameter], config +) -> torch.optim.Optimizer: + """Create optimizer for a specific parameter iterable.""" + params = list(params) + if config.optimizer == "adamw_8bit": try: import bitsandbytes as bnb - optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr) + optimizer = bnb.optim.AdamW8bit(params, lr=config.lr) print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)") return optimizer except ImportError: @@ -45,7 +55,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: from transformers.optimization import Adafactor optimizer = Adafactor( - model.parameters(), + params, lr=config.lr, scale_parameter=False, relative_step=False, @@ -56,7 +66,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: print("[Setup] WARNING: transformers Adafactor not available, using AdamW") # Default: standard AdamW - optimizer = AdamW(model.parameters(), lr=config.lr) + optimizer = AdamW(params, lr=config.lr) print("[Setup] Using standard AdamW (requires ~32GB for optimizer states)") return optimizer @@ -419,7 +429,7 @@ def train_lora(config: TrainingConfig): # Only optimize LoRA parameters trainable_params = [p for p in model.parameters() if p.requires_grad] - optimizer = AdamW(trainable_params, lr=config.lr) + optimizer = create_optimizer_for_params(trainable_params, config) print(f"[3/3] Starting training for {config.training_steps} steps") print("-" * 60) @@ -627,7 +637,7 @@ def train_lora_restart(config: TrainingConfig): # Only optimize LoRA parameters trainable_params = [p for p in model.parameters() if p.requires_grad] - optimizer = AdamW(trainable_params, lr=config.lr) + optimizer = create_optimizer_for_params(trainable_params, config) os.makedirs(config.save_path, exist_ok=True) @@ -800,7 +810,7 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona # Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS) cmd = [ - "python", server_script, + sys.executable, server_script, "--model", config.model_name, "--port", str(config.vllm_port), "--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),