readme fixes

This commit is contained in:
Jai Suphavadeeprasit 2026-02-17 13:44:48 -05:00
parent 366ea72384
commit fae3f5b09e
4 changed files with 63 additions and 19 deletions

View file

@ -325,7 +325,7 @@ Only `server_type=vllm` calls the `/generate` endpoint which returns token-level
The trainer includes automatic patches for NVIDIA B200 (Blackwell architecture) GPUs when using LoRA mode. These patches disable Grid Dependency Control (GDC) in vLLM's Triton kernels, which causes compilation failures on Blackwell GPUs. The patches are applied automatically when:
- `VLLM_ENABLE_SHARED_WEIGHTS=1` is set, or
- LoRA mode is used
- `NUM_INFERENCE_NODES` is set (distributed inference path)
The patching clears the Triton cache and disables GDC to ensure compatibility. No manual intervention required.
@ -534,7 +534,7 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
```bash
--use-wandb \
--wandb-project "my-grpo-training" \
--wandb-run-name "hermes-8b-gsm8k"
--wandb-group "hermes-8b-gsm8k"
```
---
@ -548,6 +548,7 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
| `--model-name` or `--model` | (required) | HuggingFace model ID |
| `--weight-bridge-mode` | `none` | `shared_vllm`, `lora_only`, `lora_restart`, or `none` |
| `--training-steps` | 10 | Number of training steps |
| `--checkpoint-interval` | 3 | Save checkpoint every N steps (0 = final only) |
| `--batch-size` | 2 | Micro-batch size |
| `--gradient-accumulation-steps` | 32 | Effective batch = batch × accum |
| `--seq-len` | 2048 | Maximum sequence length |
@ -559,6 +560,7 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
| `--kl-coef` | 0.1 | KL penalty strength (higher = more conservative) |
| `--clip-eps` | 0.2 | PPO clipping range [1-ε, 1+ε] |
| `--lr` | 1e-5 | Learning rate (NOT --learning-rate) |
| `--no-reference-logprobs` | False | Disable GRPO reference logprobs (falls back to REINFORCE-style updates) |
### LoRA Arguments
@ -567,7 +569,7 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
| `--lora-r` | 16 | LoRA rank (dimension of low-rank matrices) |
| `--lora-alpha` | 32 | LoRA alpha scaling factor |
| `--lora-dropout` | 0.05 | LoRA dropout probability |
| `--lora-target-modules` | None | Module names to apply LoRA (default: `q_proj v_proj`) |
| `--lora-target-modules` | None | Module names to apply LoRA (`None` falls back to `q_proj v_proj`) |
### vLLM Arguments
@ -581,6 +583,39 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands
| `--dtype` | `bfloat16` | Model dtype: `bfloat16`, `float16`, or `auto` |
| `--vllm-restart-interval` | 3 | Restart vLLM every N steps (legacy/lora_restart) |
### Atropos API Arguments
| Argument | Default | Description |
|----------|---------|-------------|
| `--atropos-url` | `http://localhost:8000` | URL of the Atropos API server |
**Note:** Many examples in this README use `http://localhost:8002` because they start `run-api --port 8002`.
### Weights & Biases Arguments
| Argument | Default | Description |
|----------|---------|-------------|
| `--use-wandb` | False | Enable W&B logging |
| `--wandb-project` | None | W&B project name |
| `--wandb-group` | None | W&B group name (auto-generated if omitted) |
### Distributed Arguments
| Argument | Default | Description |
|----------|---------|-------------|
| `--trainer-rank` | 0 | Trainer rank |
| `--world-size` | 1 | World size |
| `--init-method` | `env://` | Distributed init method |
| `--num-inference-nodes` | 0 | Number of inference nodes |
### Debug & Benchmark Arguments
| Argument | Default | Description |
|----------|---------|-------------|
| `--debug-loading` | False | Verbose model loading diagnostics |
| `--benchmark` | False | Print benchmark/timing metrics |
| `--log-dir` | `./logs` | Directory for unified launcher logs |
---
## Module Documentation

View file

@ -53,6 +53,11 @@ def register_trainer(config: TrainingConfig):
Verifies registration succeeded before returning.
"""
url = config.atropos_url
save_checkpoint_interval = (
config.training_steps
if config.checkpoint_interval <= 0
else config.checkpoint_interval
)
response = requests.post(
f"{url}/register",
json={
@ -63,7 +68,7 @@ def register_trainer(config: TrainingConfig):
"max_token_len": config.seq_len,
"starting_step": 0,
"checkpoint_dir": config.save_path,
"save_checkpoint_interval": config.training_steps,
"save_checkpoint_interval": save_checkpoint_interval,
"num_steps": config.training_steps,
},
timeout=10,

View file

@ -9,7 +9,6 @@ Also extracts inference logprobs for proper GRPO loss computation:
- They are batched and padded to align token-by-token with training labels
"""
import json
import math
import time
from typing import List, Optional, Tuple
@ -180,9 +179,8 @@ def pad_data_to_good_offset(
temperature_batches = []
inference_logprob_batches = []
for i in range(len(input_ids) // batch_size):
start = i * batch_size
end = (i + 1) * batch_size
for start in range(0, len(input_ids), batch_size):
end = min(start + batch_size, len(input_ids))
token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0)))
label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0)))
@ -294,10 +292,6 @@ def get_data(
)
_logged_logprob_warning = True
# Save batch for debugging
with open("temp.json", "w", encoding="utf-8") as f:
json.dump(data, f)
# Process and accumulate batches (now includes batched inference logprobs)
(
token_batches,

View file

@ -10,8 +10,9 @@ Contains the four main training modes:
import os
import subprocess
import sys
import time
from typing import Optional
from typing import Iterable, Optional
import requests
import torch
@ -29,11 +30,20 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
- 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes)
- 'adafactor': Adafactor without momentum (~8GB GPU, no extra dependencies)
"""
return create_optimizer_for_params(model.parameters(), config)
def create_optimizer_for_params(
params: Iterable[torch.nn.Parameter], config
) -> torch.optim.Optimizer:
"""Create optimizer for a specific parameter iterable."""
params = list(params)
if config.optimizer == "adamw_8bit":
try:
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr)
optimizer = bnb.optim.AdamW8bit(params, lr=config.lr)
print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)")
return optimizer
except ImportError:
@ -45,7 +55,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
from transformers.optimization import Adafactor
optimizer = Adafactor(
model.parameters(),
params,
lr=config.lr,
scale_parameter=False,
relative_step=False,
@ -56,7 +66,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer:
print("[Setup] WARNING: transformers Adafactor not available, using AdamW")
# Default: standard AdamW
optimizer = AdamW(model.parameters(), lr=config.lr)
optimizer = AdamW(params, lr=config.lr)
print("[Setup] Using standard AdamW (requires ~32GB for optimizer states)")
return optimizer
@ -419,7 +429,7 @@ def train_lora(config: TrainingConfig):
# Only optimize LoRA parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(trainable_params, lr=config.lr)
optimizer = create_optimizer_for_params(trainable_params, config)
print(f"[3/3] Starting training for {config.training_steps} steps")
print("-" * 60)
@ -627,7 +637,7 @@ def train_lora_restart(config: TrainingConfig):
# Only optimize LoRA parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(trainable_params, lr=config.lr)
optimizer = create_optimizer_for_params(trainable_params, config)
os.makedirs(config.save_path, exist_ok=True)
@ -800,7 +810,7 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
# Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS)
cmd = [
"python", server_script,
sys.executable, server_script,
"--model", config.model_name,
"--port", str(config.vllm_port),
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),