manual testing

This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 15:40:24 -05:00
parent da046d3d3b
commit c1bb4f33f0
5 changed files with 329 additions and 766 deletions

View file

@ -70,6 +70,38 @@ def parse_args() -> argparse.Namespace:
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
"'adafactor' (no momentum, ~8GB GPU)",
)
# === GRPO/PPO Hyperparameters ===
parser.add_argument(
"--kl-coef",
type=float,
default=0.1,
help=(
"KL divergence penalty coefficient (beta). "
"Controls policy deviation from reference. "
"Higher = more conservative, prevents reward hacking. "
"0 = disabled (not recommended)."
),
)
parser.add_argument(
"--clip-eps",
type=float,
default=0.2,
help=(
"PPO-style clipping epsilon. "
"Clips importance ratio to [1-eps, 1+eps]. "
"Prevents destabilizing large policy updates."
),
)
parser.add_argument(
"--no-reference-logprobs",
action="store_true",
help=(
"Disable use of inference logprobs as reference policy. "
"Falls back to REINFORCE-style updates (not recommended)."
),
)
parser.add_argument(
"--device",
type=str,
@ -265,6 +297,11 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
device=args.device,
save_path=args.save_path,
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
# GRPO/PPO hyperparameters
kl_coef=getattr(args, "kl_coef", 0.1),
clip_eps=getattr(args, "clip_eps", 0.2),
use_reference_logprobs=not getattr(args, "no_reference_logprobs", False),
# vLLM settings
vllm_restart_interval=args.vllm_restart_interval,
vllm_port=args.vllm_port,
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,

View file

@ -40,6 +40,33 @@ class TrainingConfig(BaseModel):
"'adafactor' (no momentum, ~8GB GPU)"
)
# === GRPO/PPO Hyperparameters ===
kl_coef: float = Field(
0.1,
description=(
"KL divergence penalty coefficient (beta). "
"Controls how much the policy can deviate from the reference (inference-time) policy. "
"Higher values = more conservative updates, prevents reward hacking. "
"Set to 0 to disable KL penalty (not recommended)."
),
)
clip_eps: float = Field(
0.2,
description=(
"PPO-style clipping epsilon. "
"Clips the importance sampling ratio to [1-eps, 1+eps]. "
"Prevents large policy updates that could destabilize training."
),
)
use_reference_logprobs: bool = Field(
True,
description=(
"Whether to use inference logprobs as the reference policy (π_old). "
"When True, implements proper GRPO with importance sampling. "
"When False, falls back to REINFORCE-style updates (not recommended)."
),
)
# === Device & Storage ===
device: str = Field(
"cuda" if torch.cuda.is_available() else "cpu",

View file

@ -4,7 +4,9 @@ Data processing utilities for GRPO trainer.
Handles data retrieval from Atropos API, padding, batching,
and advantage normalization.
Also extracts inference logprobs for alignment validation with training logprobs.
Also extracts inference logprobs for proper GRPO loss computation:
- Inference logprobs serve as π_old (reference policy) for importance sampling
- They are batched and padded to align token-by-token with training labels
"""
import json
@ -23,11 +25,11 @@ def pad_data_to_good_offset(
batch_size: int,
extract_inference_logprobs: bool = True,
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[torch.Tensor],
List[torch.Tensor],
Optional[List[np.ndarray]],
List[torch.Tensor], # token_batches
List[torch.Tensor], # label_batches
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels)
]:
"""
Pad and batch data from the Atropos API.
@ -36,7 +38,7 @@ def pad_data_to_good_offset(
- Pads token sequences to nearest multiple of 64
- Normalizes advantage scores
- Extracts temperature values
- Optionally extracts inference logprobs for alignment validation
- Extracts and pads inference logprobs for proper GRPO loss computation
Args:
data: Raw batch data from Atropos API
@ -44,8 +46,12 @@ def pad_data_to_good_offset(
extract_inference_logprobs: Whether to extract inference logprobs
Returns:
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs)
inference_logprobs is None if extract_inference_logprobs=False or no logprobs in data
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
Note:
inference_logprob_batches are padded with 0.0 at positions where labels == -100.
This allows token-by-token alignment during GRPO loss computation.
"""
max_token_len = max(
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
@ -66,7 +72,8 @@ def pad_data_to_good_offset(
advantages = []
lengths = []
temperatures = []
inference_logprobs_list: List[np.ndarray] = []
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
has_any_logprobs = False
for item in data["batch"]:
# Normalize advantage scores
@ -84,15 +91,16 @@ def pad_data_to_good_offset(
# Process each sample in the item
for i in range(len(item["tokens"])):
seq_len = len(item["tokens"][i])
lengths.append(
math.ceil((len(item["tokens"][i]) - 1) / good_multiple) * good_multiple
math.ceil((seq_len - 1) / good_multiple) * good_multiple
)
# Create labels with padding
# Create labels with padding (-100 for masked positions)
label_item = np.concatenate([
np.array(item["masks"][i]),
np.full(
max(0, token_setup_len - len(item["tokens"][i])),
max(0, token_setup_len - seq_len),
-100,
dtype=np.int32,
),
@ -102,7 +110,7 @@ def pad_data_to_good_offset(
item["tokens"][i] = np.concatenate([
np.array(item["tokens"][i]),
np.zeros(
max(0, token_setup_len - len(item["tokens"][i])),
max(0, token_setup_len - seq_len),
dtype=np.int32,
),
])
@ -111,13 +119,36 @@ def pad_data_to_good_offset(
labels.append(label_item[1:]) # Shift by 1 for causal
advantages.append(item["scores"][i])
# Extract inference logprobs for alignment validation
# These come from vLLM during rollout generation
# Extract and pad inference logprobs to match labels shape
# Inference logprobs are ONLY for generated tokens (where labels != -100)
# We need to create a padded array that aligns position-by-position
if extract_inference_logprobs and "inference_logprobs" in item:
if i < len(item["inference_logprobs"]):
inference_logprobs_list.append(
np.array(item["inference_logprobs"][i], dtype=np.float32)
)
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
has_any_logprobs = True
# Create padded logprobs array matching label_item shape
# Fill with 0.0 (will be masked out during loss computation)
padded_logprobs = np.zeros(token_setup_len, dtype=np.float32)
# The inference logprobs correspond to generated tokens
# Find positions where labels != -100 (generated positions)
mask_arr = np.array(item["masks"][i])
generated_positions = np.where(mask_arr != -100)[0]
# Fill in inference logprobs at generated positions
n_to_fill = min(len(raw_logprobs), len(generated_positions))
if n_to_fill > 0:
padded_logprobs[generated_positions[:n_to_fill]] = raw_logprobs[:n_to_fill]
# Shift by 1 to match causal label shift
inference_logprobs_padded.append(padded_logprobs[1:])
else:
# No logprobs for this sample, use zeros
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
elif extract_inference_logprobs:
# No inference_logprobs in item, use zeros
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
t = 1.0
@ -139,6 +170,7 @@ def pad_data_to_good_offset(
label_batches = []
advantage_batches = []
temperature_batches = []
inference_logprob_batches = []
for i in range(len(input_ids) // batch_size):
start = i * batch_size
@ -158,11 +190,17 @@ def pad_data_to_good_offset(
np.array(temperatures[start:end], dtype=np.float32)
).view(-1, 1, 1)
)
# Batch inference logprobs (same shape as labels)
if extract_inference_logprobs and inference_logprobs_padded:
inference_logprob_batches.append(
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
)
# Return inference logprobs if available
inference_logprobs = inference_logprobs_list if inference_logprobs_list else None
# Return inference logprob batches if we have any real logprobs
final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None
return token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs
return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches
def get_data(
@ -171,8 +209,14 @@ def get_data(
atropos_url: str = "http://localhost:8000",
extract_inference_logprobs: bool = True,
) -> Tuple[
List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]],
Optional[List[np.ndarray]],
List[Tuple[
List[torch.Tensor], # token_batches
List[torch.Tensor], # label_batches
List[torch.Tensor], # advantage_batches
List[torch.Tensor], # temperature_batches
Optional[List[torch.Tensor]], # inference_logprob_batches
]],
None, # Legacy return (no longer used)
]:
"""
Fetch and process training data from the Atropos API.
@ -184,15 +228,15 @@ def get_data(
batch_size: Size of each training batch
seq_len: Maximum sequence length (for reference, not used directly)
atropos_url: URL of the Atropos API server
extract_inference_logprobs: Whether to extract inference logprobs for alignment
extract_inference_logprobs: Whether to extract inference logprobs for GRPO loss
Returns:
Tuple of (batches, all_inference_logprobs)
- batches: List of processed batch tuples
- all_inference_logprobs: List of inference logprob arrays for alignment validation
Tuple of (batches, None)
- batches: List of processed batch tuples, each containing:
(token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
"""
batches = []
all_inference_logprobs: List[np.ndarray] = []
while True:
data = get_batch(url=atropos_url)
@ -202,18 +246,16 @@ def get_data(
with open("temp.json", "w", encoding="utf-8") as f:
json.dump(data, f)
# Process and accumulate batches
token_batches, label_batches, adv_batches, temp_batches, inf_logprobs = \
# Process and accumulate batches (now includes batched inference logprobs)
token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
batches.append((token_batches, label_batches, adv_batches, temp_batches))
if inf_logprobs:
all_inference_logprobs.extend(inf_logprobs)
# Include inference logprob batches in the tuple
batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches))
elif len(batches) > 0:
# Return accumulated batches when no more data
return batches, all_inference_logprobs if all_inference_logprobs else None
return batches, None
else:
# Wait for data
time.sleep(1)

View file

@ -1,647 +0,0 @@
#!/usr/bin/env python3
"""
Multi-model test suite for shared_vllm trainer.
Tests the trainer against diverse models to verify robustness.
Supports both parallel (different GPUs) and sequential execution.
With --auto-env, each model gets its own isolated stack:
- run-api (port 8002 + offset)
- gsm8k environment (with model-specific tokenizer)
- vLLM server (port 9001 + offset)
- trainer
Usage:
# RECOMMENDED: Fully automated parallel test with W&B logging
python -m example_trainer.test_multi_model \
--models qwen3-4b hermes-8b nemotron-14b devstral-24b \
--parallel \
--gpus 0 1 2 3 \
--auto-env \
--use-wandb \
--wandb-project multi-model-test
# Sequential test on one GPU
python -m example_trainer.test_multi_model \
--models qwen3-4b hermes-8b \
--sequential \
--gpu 0 \
--auto-env \
--use-wandb
# Manual mode (you must start run-api and gsm8k_server yourself)
python -m example_trainer.test_multi_model \
--models qwen3-4b \
--sequential \
--gpu 0 \
--atropos-url http://localhost:8002
Port allocation with --auto-env:
Model 0: run-api:8002, vLLM:9001, GPU from --gpus[0]
Model 1: run-api:8003, vLLM:9002, GPU from --gpus[1]
Model 2: run-api:8004, vLLM:9003, GPU from --gpus[2]
...
"""
import argparse
import json
import os
import signal
import subprocess
import sys
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
import threading
@dataclass
class ModelConfig:
"""Configuration for a test model."""
name: str
model_id: str
gpu_memory_utilization: float = 0.5
max_model_len: int = 4096
dtype: str = "bfloat16"
training_steps: int = 10
notes: str = ""
# Define test models
# Memory estimates for B200 (183GB):
# - Model weights (bf16): 2 bytes/param
# - Gradients: ~same as weights
# - 8-bit optimizer: ~1 byte/param
# - KV cache: depends on max_model_len
TEST_MODELS: Dict[str, ModelConfig] = {
"qwen3-4b": ModelConfig(
name="qwen3-4b",
model_id="Qwen/Qwen3-4B-Instruct-2507",
gpu_memory_utilization=0.4, # ~73GB for vLLM
max_model_len=8192, # Plenty of room on B200
notes="Small 4B model, good baseline test (~8GB weights)",
),
"hermes-8b": ModelConfig(
name="hermes-8b",
model_id="NousResearch/Hermes-3-Llama-3.1-8B",
gpu_memory_utilization=0.45, # ~82GB for vLLM
max_model_len=8192, # 8K context fits well
notes="Llama 8B architecture (~16GB weights)",
),
"nemotron-14b": ModelConfig(
name="nemotron-14b",
model_id="nvidia/Nemotron-Cascade-14B-Thinking",
gpu_memory_utilization=0.5, # ~91GB for vLLM
max_model_len=32768, # 32K context for thinking
notes="14B thinking model (~28GB weights), needs room for long CoT",
),
"devstral-24b": ModelConfig(
name="devstral-24b",
model_id="mistralai/Devstral-Small-2-24B-Instruct-2512",
gpu_memory_utilization=0.55, # ~100GB for vLLM
max_model_len=16384, # 16K context (conservative for 24B)
notes="Large 24B Mistral (~48GB weights), largest model",
),
}
def get_test_dir(base_dir: str, model_name: str, timestamp: str) -> Path:
"""Get unique test directory for a model run."""
return Path(base_dir) / f"{model_name}_{timestamp}"
def start_run_api(
port: int,
log_path: Path,
) -> subprocess.Popen:
"""Start a run-api instance on a specific port."""
cmd = [sys.executable, "-m", "atroposlib.cli.run_api", "--port", str(port)]
log_file = open(log_path, "w")
process = subprocess.Popen(
cmd,
stdout=log_file,
stderr=subprocess.STDOUT,
# Don't buffer output
bufsize=1,
)
return process
def wait_for_run_api(port: int, timeout: int = 60) -> bool:
"""Wait for run-api to be ready."""
import requests
start = time.time()
while time.time() - start < timeout:
try:
# run-api uses /status or / endpoint, not /health
resp = requests.get(f"http://localhost:{port}/status", timeout=5)
if resp.status_code == 200:
return True
except:
pass
try:
# Fallback to root endpoint
resp = requests.get(f"http://localhost:{port}/", timeout=5)
if resp.status_code == 200:
return True
except:
pass
time.sleep(2)
return False
def start_gsm8k_env(
model_id: str,
vllm_port: int,
run_api_port: int,
log_path: Path,
atropos_root: Path,
) -> subprocess.Popen:
"""Start a gsm8k environment process for a specific model."""
gsm8k_script = atropos_root / "environments" / "gsm8k_server.py"
cmd = [
sys.executable, "-u", str(gsm8k_script), "serve",
"--env.rollout_server_url", f"http://localhost:{run_api_port}",
"--env.tokenizer_name", model_id,
"--env.use_wandb", "false",
"--env.total_steps", "10000",
"--env.batch_size", "64",
"--env.group_size", "8",
"--openai.model_name", model_id,
"--openai.base_url", f"http://localhost:{vllm_port}/v1",
"--openai.api_key", "x",
"--openai.server_type", "openai",
]
log_file = open(log_path, "w")
process = subprocess.Popen(
cmd,
stdout=log_file,
stderr=subprocess.STDOUT,
cwd=str(atropos_root), # Run from atropos root
)
return process
def run_model_test(
model_config: ModelConfig,
gpu_id: int,
atropos_url: str,
atropos_port: int,
base_dir: str,
timestamp: str,
training_steps: int,
vllm_port_offset: int = 0,
auto_env: bool = False,
use_wandb: bool = False,
wandb_project: str = "multi-model-test",
) -> Dict:
"""
Run a complete training test for a single model.
Returns dict with test results.
"""
model_name = model_config.name
test_dir = get_test_dir(base_dir, model_name, timestamp).resolve() # Make absolute
test_dir.mkdir(parents=True, exist_ok=True)
# Unique paths for this model (all absolute)
vllm_port = 9001 + vllm_port_offset
bridge_config_path = test_dir / "vllm_bridge_config.json"
checkpoint_dir = test_dir / "checkpoints"
log_dir = test_dir / "logs"
log_dir.mkdir(exist_ok=True)
vllm_log = log_dir / "vllm.log"
trainer_log = log_dir / "trainer.log"
# Each model gets unique ports
run_api_port = 8002 + vllm_port_offset
result = {
"model": model_config.model_id,
"model_name": model_name,
"gpu": gpu_id,
"vllm_port": vllm_port,
"run_api_port": run_api_port,
"test_dir": str(test_dir),
"status": "pending",
"error": None,
"start_time": None,
"end_time": None,
"duration_seconds": None,
"real_time_alignment": None,
"final_gpu_memory": None,
}
print(f"\n{'='*60}")
print(f"[{model_name}] Starting test on GPU {gpu_id}")
print(f"[{model_name}] Model: {model_config.model_id}")
print(f"[{model_name}] vLLM port: {vllm_port}")
print(f"[{model_name}] Test dir: {test_dir}")
print(f"{'='*60}\n")
result["start_time"] = datetime.now().isoformat()
start_time = time.time()
env_process = None
run_api_process = None
trainer_process = None
# Get atropos root directory (used for vLLM and gsm8k scripts)
script_dir = Path(__file__).parent
atropos_root = script_dir.parent.resolve()
try:
# === Start run-api (if auto_env) ===
if auto_env:
run_api_log = log_dir / "run_api.log"
print(f"[{model_name}] Starting run-api on port {run_api_port}...")
run_api_process = start_run_api(run_api_port, run_api_log)
if not wait_for_run_api(run_api_port, timeout=60):
# Check if process died
if run_api_process.poll() is not None:
print(f"[{model_name}] run-api process exited with code {run_api_process.returncode}")
# Print log contents for debugging
if run_api_log.exists():
print(f"[{model_name}] run-api log contents:")
print(run_api_log.read_text()[-2000:]) # Last 2000 chars
raise RuntimeError(f"run-api failed to start on port {run_api_port}")
print(f"[{model_name}] ✓ run-api ready on port {run_api_port}")
# Update atropos_url to use this model's run-api
atropos_url = f"http://localhost:{run_api_port}"
# === Start gsm8k Environment (if auto_env) ===
if auto_env:
env_log = log_dir / "env.log"
print(f"[{model_name}] Starting gsm8k environment (tokenizer: {model_config.model_id})...")
env_process = start_gsm8k_env(
model_config.model_id, vllm_port, run_api_port, env_log, atropos_root
)
time.sleep(10) # Give it time to initialize and connect
print(f"[{model_name}] ✓ gsm8k environment started")
# === Start Unified vLLM + Trainer (run.py) ===
# Using run.py ensures vLLM is a CHILD of the trainer process,
# which is required for CUDA IPC with ptrace_scope=1
run_script = script_dir / "run.py"
# Don't use CUDA_VISIBLE_DEVICES - use --device instead
# run.py sets CUDA_VISIBLE_DEVICES internally based on --device
run_env = os.environ.copy()
run_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
run_cmd = [
sys.executable, "-u", str(run_script),
"--model", model_config.model_id,
"--device", f"cuda:{gpu_id}", # This controls GPU selection
"--vllm-port", str(vllm_port),
"--gpu-memory-utilization", str(model_config.gpu_memory_utilization),
"--max-model-len", str(model_config.max_model_len),
"--dtype", model_config.dtype,
"--atropos-url", atropos_url,
"--training-steps", str(training_steps),
"--optimizer", "adamw_8bit",
"--save-path", str(checkpoint_dir),
"--checkpoint-interval", "5",
"--log-dir", str(log_dir),
]
# Add wandb flags if enabled
if use_wandb:
run_cmd.extend(["--use-wandb", "--wandb-project", wandb_project])
print(f"[{model_name}] Starting unified trainer (vLLM + GRPO) for {training_steps} steps...")
with open(trainer_log, "w") as tlog:
trainer_process = subprocess.Popen(
run_cmd,
env=run_env,
stdout=tlog,
stderr=subprocess.STDOUT,
cwd=str(atropos_root), # Run from atropos root
)
trainer_process.wait()
if trainer_process.returncode != 0:
raise RuntimeError(f"Unified trainer exited with code {trainer_process.returncode}")
result["status"] = "success"
print(f"[{model_name}] ✓ Training completed successfully!")
# Parse trainer log for metrics
try:
with open(trainer_log, "r") as f:
log_content = f.read()
# Extract real-time alignment
if "Mean diff:" in log_content:
import re
match = re.search(r"Mean diff: ([\d.]+)", log_content)
if match:
result["real_time_alignment"] = float(match.group(1))
# Extract final GPU memory
if "GPU mem:" in log_content:
matches = re.findall(r"GPU mem: ([\d.]+)GB", log_content)
if matches:
result["final_gpu_memory"] = float(matches[-1])
except Exception as e:
print(f"[{model_name}] Warning: Could not parse log: {e}")
except Exception as e:
result["status"] = "failed"
result["error"] = str(e)
print(f"[{model_name}] ✗ Test failed: {e}")
import traceback
traceback.print_exc()
finally:
# Note: vLLM is managed by run.py and cleaned up automatically
# Cleanup gsm8k environment
if env_process and env_process.poll() is None:
print(f"[{model_name}] Terminating gsm8k environment...")
env_process.terminate()
try:
env_process.wait(timeout=10)
except subprocess.TimeoutExpired:
env_process.kill()
# Cleanup run-api
if run_api_process and run_api_process.poll() is None:
print(f"[{model_name}] Terminating run-api...")
run_api_process.terminate()
try:
run_api_process.wait(timeout=10)
except subprocess.TimeoutExpired:
run_api_process.kill()
result["end_time"] = datetime.now().isoformat()
result["duration_seconds"] = time.time() - start_time
return result
def run_parallel_tests(
models: List[ModelConfig],
gpu_ids: List[int],
atropos_url: str,
atropos_port: int,
base_dir: str,
training_steps: int,
auto_env: bool = False,
use_wandb: bool = False,
wandb_project: str = "multi-model-test",
) -> List[Dict]:
"""Run tests for multiple models in parallel."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results = []
threads = []
result_lock = threading.Lock()
def run_and_store(model, gpu, port_offset):
result = run_model_test(
model, gpu, atropos_url, atropos_port, base_dir, timestamp,
training_steps, port_offset, auto_env, use_wandb, wandb_project
)
with result_lock:
results.append(result)
# Start threads
for i, (model, gpu) in enumerate(zip(models, gpu_ids)):
t = threading.Thread(target=run_and_store, args=(model, gpu, i))
t.start()
threads.append(t)
time.sleep(5) # Stagger starts slightly
# Wait for all to complete
for t in threads:
t.join()
return results
def run_sequential_tests(
models: List[ModelConfig],
gpu_id: int,
atropos_url: str,
atropos_port: int,
base_dir: str,
training_steps: int,
auto_env: bool = False,
use_wandb: bool = False,
wandb_project: str = "multi-model-test",
) -> List[Dict]:
"""Run tests for multiple models sequentially on one GPU."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results = []
for i, model in enumerate(models):
result = run_model_test(
model, gpu_id, atropos_url, atropos_port, base_dir, timestamp,
training_steps, port_offset=0, auto_env=auto_env,
use_wandb=use_wandb, wandb_project=wandb_project
)
results.append(result)
# Give GPU time to fully release memory
time.sleep(10)
return results
def print_summary(results: List[Dict]):
"""Print summary of test results."""
print("\n" + "="*80)
print("TEST SUMMARY")
print("="*80)
for r in results:
status_icon = "" if r["status"] == "success" else ""
duration = f"{r['duration_seconds']:.1f}s" if r['duration_seconds'] else "N/A"
alignment = f"{r['real_time_alignment']:.4f}" if r['real_time_alignment'] else "N/A"
memory = f"{r['final_gpu_memory']:.1f}GB" if r['final_gpu_memory'] else "N/A"
print(f"\n{status_icon} {r['model_name']}")
print(f" Model: {r['model']}")
print(f" GPU: {r['gpu']}, vLLM port: {r['vllm_port']}, run-api port: {r.get('run_api_port', 'N/A')}")
print(f" Status: {r['status']}")
print(f" Duration: {duration}")
print(f" Real-time alignment: {alignment}")
print(f" GPU memory: {memory}")
if r["error"]:
print(f" Error: {r['error']}")
print(f" Logs: {r['test_dir']}/logs/")
# Summary stats
successes = sum(1 for r in results if r["status"] == "success")
failures = len(results) - successes
print(f"\n{'='*80}")
print(f"TOTAL: {successes} passed, {failures} failed")
print("="*80)
def main():
parser = argparse.ArgumentParser(
description="Multi-model test suite for shared_vllm trainer",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Run all models in parallel (one per GPU)
python -m example_trainer.test_multi_model --parallel
# Run specific models
python -m example_trainer.test_multi_model --models hermes-8b qwen3-4b --parallel
# Run sequentially on GPU 0
python -m example_trainer.test_multi_model --sequential --gpu 0
Available models: """ + ", ".join(TEST_MODELS.keys())
)
parser.add_argument(
"--models",
nargs="+",
choices=list(TEST_MODELS.keys()),
default=["qwen3-4b", "hermes-8b"],
help="Models to test",
)
parser.add_argument(
"--parallel",
action="store_true",
help="Run models in parallel on different GPUs",
)
parser.add_argument(
"--sequential",
action="store_true",
help="Run models sequentially on one GPU",
)
parser.add_argument(
"--gpus",
type=int,
nargs="+",
default=None,
help="GPU IDs to use (for parallel mode)",
)
parser.add_argument(
"--gpu",
type=int,
default=0,
help="GPU ID (for sequential mode)",
)
parser.add_argument(
"--atropos-url",
type=str,
default="http://localhost:8002",
help="Atropos API URL",
)
parser.add_argument(
"--atropos-port",
type=int,
default=8002,
help="Atropos API port (for spawning multiple if needed)",
)
parser.add_argument(
"--training-steps",
type=int,
default=10,
help="Number of training steps per model",
)
parser.add_argument(
"--output-dir",
type=str,
default="./multi_model_tests",
help="Base directory for test outputs",
)
parser.add_argument(
"--auto-env",
action="store_true",
help="Automatically start run-api and gsm8k environment for each model",
)
parser.add_argument(
"--use-wandb",
action="store_true",
help="Enable Weights & Biases logging for training runs",
)
parser.add_argument(
"--wandb-project",
type=str,
default="multi-model-test",
help="W&B project name for logging",
)
args = parser.parse_args()
if not args.parallel and not args.sequential:
args.sequential = True # Default to sequential
# Get model configs
models = [TEST_MODELS[name] for name in args.models]
print(f"\n{'#'*60}")
print("# MULTI-MODEL SHARED_VLLM TRAINER TEST SUITE")
print(f"{'#'*60}")
print(f"\nModels to test: {[m.name for m in models]}")
print(f"Mode: {'Parallel' if args.parallel else 'Sequential'}")
print(f"Training steps per model: {args.training_steps}")
print(f"Output directory: {args.output_dir}")
print(f"Atropos URL: {args.atropos_url}")
# Run tests
if args.auto_env:
print(f"Auto-env: Will start gsm8k environment per model")
if args.parallel:
gpus = args.gpus or list(range(len(models)))
if len(gpus) < len(models):
print(f"\nWarning: Not enough GPUs ({len(gpus)}) for models ({len(models)})")
print("Some models will share GPUs")
gpus = gpus * (len(models) // len(gpus) + 1)
print(f"Using GPUs: {gpus[:len(models)]}")
if args.use_wandb:
print(f"W&B logging enabled (project: {args.wandb_project})")
results = run_parallel_tests(
models, gpus[:len(models)],
args.atropos_url, args.atropos_port,
args.output_dir, args.training_steps,
auto_env=args.auto_env,
use_wandb=args.use_wandb,
wandb_project=args.wandb_project,
)
else:
print(f"Using GPU: {args.gpu}")
if args.use_wandb:
print(f"W&B logging enabled (project: {args.wandb_project})")
results = run_sequential_tests(
models, args.gpu,
args.atropos_url, args.atropos_port,
args.output_dir, args.training_steps,
auto_env=args.auto_env,
use_wandb=args.use_wandb,
wandb_project=args.wandb_project,
)
# Print summary
print_summary(results)
# Save results to JSON
results_file = Path(args.output_dir) / f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
results_file.parent.mkdir(parents=True, exist_ok=True)
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to: {results_file}")
# Exit with error code if any failed
if any(r["status"] != "success" for r in results):
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -153,13 +153,22 @@ def compute_grpo_loss(
temperatures: torch.Tensor,
gradient_accumulation_steps: int,
inference_logprobs: Optional[torch.Tensor] = None,
kl_coef: float = 0.1,
clip_eps: float = 0.2,
use_reference_logprobs: bool = True,
) -> Tuple[torch.Tensor, dict]:
"""
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
The GRPO loss encourages the model to:
This implements proper GRPO/PPO with:
- Importance sampling ratio: π(a|s) / π_old(a|s)
- PPO-style clipping to prevent large updates
- KL penalty to prevent reward hacking/policy collapse
The loss encourages the model to:
- Increase probability for tokens with positive advantages
- Decrease probability for tokens with negative advantages
- Stay close to the reference policy (inference-time policy)
Args:
model: The model to compute loss for
@ -168,7 +177,10 @@ def compute_grpo_loss(
advantages: Advantage values [batch, 1]
temperatures: Temperature values [batch, 1, 1]
gradient_accumulation_steps: Number of accumulation steps (for scaling)
inference_logprobs: Optional logprobs from inference for alignment check
inference_logprobs: Logprobs from inference (π_old), aligned with labels [batch, seq_len]
kl_coef: KL penalty coefficient (beta). Higher = more conservative updates
clip_eps: PPO clipping epsilon. Clips ratio to [1-eps, 1+eps]
use_reference_logprobs: If True, use inference_logprobs as reference policy
Returns:
Tuple of (loss tensor, metrics dict)
@ -177,14 +189,14 @@ def compute_grpo_loss(
outputs = model(tokens)
logits = outputs.logits
# Temperature scaling
# Temperature scaling for training
t = temperatures.to(logits.device, logits.dtype)
t = torch.where(t <= 0, torch.ones_like(t), t)
logits = logits / t
scaled_logits = logits / t
# Log probabilities per token
# Log probabilities per token (current policy π)
logp_per_token = -F.cross_entropy(
logits.view(-1, logits.size(-1)),
scaled_logits.view(-1, scaled_logits.size(-1)),
labels.view(-1),
reduction="none",
ignore_index=-100,
@ -192,39 +204,103 @@ def compute_grpo_loss(
# Masking based on labels != -100
mask = (labels != -100).float()
mask_sum = mask.sum(dim=-1).clamp_min(1e-8)
# Compute metrics (no grad needed)
# Expand advantages to match token shape [batch, 1] -> [batch, seq_len]
adv_expanded = advantages.expand_as(logp_per_token).to(logp_per_token.device)
# === GRPO/PPO Loss Computation ===
if use_reference_logprobs and inference_logprobs is not None:
# Move inference logprobs to correct device/dtype
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
# Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old)
log_ratio = logp_per_token - ref_logprobs
ratio = torch.exp(log_ratio)
# PPO-style clipping
clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
# Surrogate objectives
surr1 = ratio * adv_expanded
surr2 = clipped_ratio * adv_expanded
# Pessimistic bound: min for positive advantages, max for negative
# This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0
# -max(ratio * A, clipped_ratio * A) when A < 0
policy_loss_per_token = -torch.where(
adv_expanded >= 0,
torch.min(surr1, surr2),
torch.max(surr1, surr2),
)
# Average over tokens, then over batch
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
# KL penalty: encourage staying close to reference policy
# KL(π || π_ref) ≈ log(π/π_ref) = log_ratio (when π_ref is the reference)
# We use the approximation: KL ≈ (ratio - 1) - log(ratio)
# But simpler: just penalize squared log-ratio which is symmetric
if kl_coef > 0:
# Approximate KL using (log_ratio)^2 / 2 (Taylor expansion)
# Or just use log_ratio directly as a penalty
kl_per_token = log_ratio.pow(2) # Squared for symmetric penalty
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
else:
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
total_loss = policy_loss / gradient_accumulation_steps
# Compute metrics for logging
with torch.no_grad():
# Fraction of tokens where ratio was clipped
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
# Mean ratio and KL for monitoring
mean_ratio = (ratio * mask).sum() / mask.sum()
mean_kl = (log_ratio.pow(2) * mask).sum() / mask.sum()
# For backward compatibility: collect training logprobs
raw_logp_per_token = -F.cross_entropy(
outputs.logits.view(-1, outputs.logits.size(-1)),
labels.view(-1),
reduction="none",
ignore_index=-100,
).view(labels.shape)
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
else:
# Fallback: REINFORCE-style (no reference policy)
# This is what the original code did - NOT recommended!
print(" [WARNING] No reference logprobs - using REINFORCE (may cause reward hacking!)")
# Simple policy gradient: -log(π) * A
policy_loss = ((-logp_per_token * mask * adv_expanded).sum(dim=-1) / mask_sum).mean()
total_loss = policy_loss / gradient_accumulation_steps
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
with torch.no_grad():
clipped_fraction = torch.tensor(0.0)
mean_ratio = torch.tensor(1.0)
mean_kl = torch.tensor(0.0)
raw_logp_per_token = -F.cross_entropy(
outputs.logits.view(-1, outputs.logits.size(-1)),
labels.view(-1),
reduction="none",
ignore_index=-100,
).view(labels.shape)
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
# === Compute Additional Metrics ===
with torch.no_grad():
pos = (advantages > 0).float()
neg = (advantages <= 0).float()
mask_float = mask.to(logp_per_token.dtype)
mask_sum = mask_float.sum(dim=-1).clamp_min(1e-8)
avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum
pos_logp = (logp_per_token * pos).mean().item()
neg_logp = (logp_per_token * neg).mean().item()
# For alignment check: compute logprobs WITHOUT temperature scaling
# This allows fair comparison with inference logprobs (which are at temp=1.0)
raw_logp_per_token = -F.cross_entropy(
outputs.logits.view(-1, outputs.logits.size(-1)), # Use original logits, not temp-scaled
labels.view(-1),
reduction="none",
ignore_index=-100,
).view(labels.shape)
# Collect raw training logprobs for masked positions (generated tokens only)
# Keep as PyTorch tensor (supports bfloat16 natively)
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
# GRPO loss: weighted log probabilities by advantages
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
grpo_loss = (
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
* advantages.to(logp_per_token.device)
).mean() / gradient_accumulation_steps
# Compute a more interpretable loss metric (advantage-weighted logprobs)
with torch.no_grad():
# Interpretable metric: advantage-weighted average logprob
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
metrics = {
@ -233,11 +309,16 @@ def compute_grpo_loss(
"avg_logp": avg_logp,
"pos_count": pos.sum().item(),
"neg_count": neg.sum().item(),
"training_logprobs": training_logprobs_flat, # For alignment check
"interpretable_loss": interpretable_loss, # More meaningful metric
"training_logprobs": training_logprobs_flat,
"interpretable_loss": interpretable_loss,
# GRPO-specific metrics
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
}
return grpo_loss, metrics
return total_loss, metrics
def compute_logprob_alignment(
@ -309,17 +390,16 @@ def run_training_step(
advantage_batches: List[torch.Tensor],
temperature_batches: List[torch.Tensor],
config: TrainingConfig,
inference_logprobs: Optional[List[np.ndarray]] = None,
inference_logprob_batches: Optional[List[torch.Tensor]] = None,
) -> dict:
"""
Run a single training step with gradient accumulation.
Performs:
1. Forward pass through all micro-batches
1. Forward pass through all micro-batches with proper GRPO loss
2. Backward pass with gradient accumulation
3. Gradient clipping
4. Optimizer step
5. (Optional) Logprob alignment check
Args:
model: The model to train
@ -328,8 +408,8 @@ def run_training_step(
label_batches: List of label tensors
advantage_batches: List of advantage tensors
temperature_batches: List of temperature tensors
config: Training configuration
inference_logprobs: Optional logprobs from inference for alignment check
config: Training configuration (includes kl_coef, clip_eps, use_reference_logprobs)
inference_logprob_batches: Batched logprobs from inference (π_old), aligned with labels
Returns:
Dict of training metrics for this step
@ -341,16 +421,32 @@ def run_training_step(
total_neg_logp = 0.0
total_pos = 0.0
total_neg = 0.0
total_kl_penalty = 0.0
total_mean_ratio = 0.0
total_mean_kl = 0.0
total_clipped_fraction = 0.0
grad_norm = 0.0
all_training_logprobs: List[torch.Tensor] = []
# Get GRPO hyperparameters from config
kl_coef = getattr(config, 'kl_coef', 0.1)
clip_eps = getattr(config, 'clip_eps', 0.2)
use_reference_logprobs = getattr(config, 'use_reference_logprobs', True)
# Accumulate gradients over micro-batches
for tokens, labels, advantages, temperatures in zip(
num_batches = len(token_batches) if token_batches else 1
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip(
token_batches, label_batches, advantage_batches, temperature_batches
):
)):
tokens = tokens.to(config.device)
labels = labels.to(config.device)
advantages = advantages.to(config.device)
# Get corresponding inference logprobs batch if available
inf_logprobs = None
if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches):
inf_logprobs = inference_logprob_batches[batch_idx]
loss, metrics = compute_grpo_loss(
model,
@ -359,6 +455,10 @@ def run_training_step(
advantages,
temperatures,
config.gradient_accumulation_steps,
inference_logprobs=inf_logprobs,
kl_coef=kl_coef,
clip_eps=clip_eps,
use_reference_logprobs=use_reference_logprobs,
)
loss.backward()
@ -368,7 +468,13 @@ def run_training_step(
total_pos += metrics["pos_count"]
total_neg += metrics["neg_count"]
# Collect training logprobs for alignment check
# Accumulate GRPO-specific metrics
total_kl_penalty += metrics.get("kl_penalty", 0.0)
total_mean_ratio += metrics.get("mean_ratio", 1.0)
total_mean_kl += metrics.get("mean_kl", 0.0)
total_clipped_fraction += metrics.get("clipped_fraction", 0.0)
# Collect training logprobs for alignment monitoring
if "training_logprobs" in metrics:
all_training_logprobs.append(metrics["training_logprobs"])
@ -380,8 +486,7 @@ def run_training_step(
# Help prevent memory fragmentation
torch.cuda.empty_cache()
# Normalize metrics by count
num_batches = len(token_batches) if token_batches else 1
# Normalize metrics by batch count
if total_pos > 0:
total_pos_logp /= num_batches
if total_neg > 0:
@ -394,18 +499,22 @@ def run_training_step(
"neg_logp": total_neg_logp,
"pos_count": total_pos,
"neg_count": total_neg,
# GRPO-specific metrics (averaged over batches)
"kl_penalty": total_kl_penalty / num_batches,
"mean_ratio": total_mean_ratio / num_batches,
"mean_kl": total_mean_kl / num_batches,
"clipped_fraction": total_clipped_fraction / num_batches,
}
# Compute logprob alignment stats
# NOTE: This comparison is approximate - inference and training logprobs
# come from different batching, so token-by-token alignment isn't possible.
# The real-time test at startup is the reliable alignment check.
if inference_logprobs is not None and all_training_logprobs:
alignment_stats = compute_logprob_alignment(
inference_logprobs, all_training_logprobs, debug=False
)
_logprob_alignment_stats.update(alignment_stats)
result["logprob_alignment"] = alignment_stats
# Compute logprob alignment stats for monitoring
# NOTE: Now that we use proper GRPO, this is less critical
# but still useful for debugging weight sharing issues
if all_training_logprobs:
# Store training logprob stats
train_flat = torch.cat(all_training_logprobs)
if train_flat.numel() > 0:
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
return result
@ -441,19 +550,27 @@ def log_metrics(
if "gpu_memory_gb" in metrics:
timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB"
# Show interpretable loss (advantage-weighted logprobs) if available
interp_loss = metrics.get("interpretable_loss")
if interp_loss is not None:
print(f" AdvWeightedLogP: {interp_loss:.4f}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
else:
loss_str = (
f"{metrics['loss']:.6f}"
if abs(metrics["loss"]) < 0.01
else f"{metrics['loss']:.4f}"
)
print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
# Primary metrics line: Loss and grad norm
loss_str = (
f"{metrics['loss']:.6f}"
if abs(metrics["loss"]) < 0.01
else f"{metrics['loss']:.4f}"
)
print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
# Show GRPO-specific metrics if available
# GRPO metrics line: KL, ratio, clipping
kl_penalty = metrics.get("kl_penalty", 0)
mean_ratio = metrics.get("mean_ratio", 1.0)
mean_kl = metrics.get("mean_kl", 0)
clipped_frac = metrics.get("clipped_fraction", 0)
if kl_penalty > 0 or mean_kl > 0:
print(
f" GRPO: KL={mean_kl:.4f}, ratio={mean_ratio:.3f}, "
f"clipped={clipped_frac*100:.1f}%"
)
# Advantage distribution
if "pos_count" in metrics or "neg_count" in metrics:
pos_count = metrics.get("pos_count", 0)
neg_count = metrics.get("neg_count", 0)
@ -463,24 +580,6 @@ def log_metrics(
f" Advantages: +{int(pos_count)} / -{int(neg_count)}, "
f"LogP: pos={pos_logp:.3f}, neg={neg_logp:.3f}"
)
# Show logprob alignment stats (important for shared_vllm validation!)
if "logprob_alignment" in metrics:
alignment = metrics["logprob_alignment"]
if "logprobs/diff" in alignment:
diff = alignment["logprobs/diff"]
inf_mean = alignment.get("logprobs/inference_mean", 0)
train_mean = alignment.get("logprobs/training_mean", 0)
# NOTE: This comparison has a fundamental timing issue!
# - inference_logprobs: from vLLM at generation time (possibly stale)
# - training_logprobs: from trainer's current forward pass
# After training starts, weights change, making comparison invalid.
#
# NOTE: This diff is just for monitoring, not validation!
# The real-time test at startup is the reliable alignment check.
# This diff will naturally drift as training progresses (expected).
print(f" LogProb Stats: inf_mean={inf_mean:.4f}, train_mean={train_mean:.4f}")
if use_wandb:
log_dict = {
@ -488,6 +587,11 @@ def log_metrics(
"train/grad_norm": metrics["grad_norm"],
"train/pos_logp": metrics.get("pos_logp", 0),
"train/neg_logp": metrics.get("neg_logp", 0),
# GRPO-specific metrics
"grpo/kl_penalty": kl_penalty,
"grpo/mean_ratio": mean_ratio,
"grpo/mean_kl": mean_kl,
"grpo/clipped_fraction": clipped_frac,
}
# Add timing metrics if present
for key in ["step_time", "sync_time", "data_fetch_time",