atropos/example_trainer/test_multi_model.py
Jai Suphavadeeprasit c00e7461aa testing 5
2026-03-02 11:18:52 -05:00

614 lines
20 KiB
Python

#!/usr/bin/env python3
"""
Multi-model test suite for shared_vllm trainer.
Tests the trainer against diverse models to verify robustness.
Supports both parallel (different GPUs) and sequential execution.
With --auto-env, each model gets its own isolated stack:
- run-api (port 8002 + offset)
- gsm8k environment (with model-specific tokenizer)
- vLLM server (port 9001 + offset)
- trainer
Usage:
# RECOMMENDED: Fully automated parallel test (each model gets isolated stack)
python -m example_trainer.test_multi_model \
--models qwen3-4b hermes-8b nemotron-14b devstral-24b \
--parallel \
--gpus 0 1 2 3 \
--auto-env
# Sequential test on one GPU
python -m example_trainer.test_multi_model \
--models qwen3-4b hermes-8b \
--sequential \
--gpu 0 \
--auto-env
# Manual mode (you must start run-api and gsm8k_server yourself)
# First start: run-api --port 8002 &
# Then start gsm8k for your model
python -m example_trainer.test_multi_model \
--models qwen3-4b \
--sequential \
--gpu 0 \
--atropos-url http://localhost:8002
Port allocation with --auto-env:
Model 0: run-api:8002, vLLM:9001
Model 1: run-api:8003, vLLM:9002
Model 2: run-api:8004, vLLM:9003
...
"""
import argparse
import json
import os
import signal
import subprocess
import sys
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
import threading
@dataclass
class ModelConfig:
"""Configuration for a test model."""
name: str
model_id: str
gpu_memory_utilization: float = 0.5
max_model_len: int = 4096
dtype: str = "bfloat16"
training_steps: int = 10
notes: str = ""
# Define test models
# Memory estimates for B200 (183GB):
# - Model weights (bf16): 2 bytes/param
# - Gradients: ~same as weights
# - 8-bit optimizer: ~1 byte/param
# - KV cache: depends on max_model_len
TEST_MODELS: Dict[str, ModelConfig] = {
"qwen3-4b": ModelConfig(
name="qwen3-4b",
model_id="Qwen/Qwen3-4B-Instruct-2507",
gpu_memory_utilization=0.4, # ~73GB for vLLM
max_model_len=8192, # Plenty of room on B200
notes="Small 4B model, good baseline test (~8GB weights)",
),
"hermes-8b": ModelConfig(
name="hermes-8b",
model_id="NousResearch/Hermes-3-Llama-3.1-8B",
gpu_memory_utilization=0.45, # ~82GB for vLLM
max_model_len=8192, # 8K context fits well
notes="Llama 8B architecture (~16GB weights)",
),
"nemotron-14b": ModelConfig(
name="nemotron-14b",
model_id="nvidia/Nemotron-Cascade-14B-Thinking",
gpu_memory_utilization=0.5, # ~91GB for vLLM
max_model_len=32768, # 32K context for thinking
notes="14B thinking model (~28GB weights), needs room for long CoT",
),
"devstral-24b": ModelConfig(
name="devstral-24b",
model_id="mistralai/Devstral-Small-2-24B-Instruct-2512",
gpu_memory_utilization=0.55, # ~100GB for vLLM
max_model_len=16384, # 16K context (conservative for 24B)
notes="Large 24B Mistral (~48GB weights), largest model",
),
}
def get_test_dir(base_dir: str, model_name: str, timestamp: str) -> Path:
"""Get unique test directory for a model run."""
return Path(base_dir) / f"{model_name}_{timestamp}"
def start_run_api(
port: int,
log_path: Path,
) -> subprocess.Popen:
"""Start a run-api instance on a specific port."""
cmd = [sys.executable, "-m", "atroposlib.cli.run_api", "--port", str(port)]
log_file = open(log_path, "w")
process = subprocess.Popen(
cmd,
stdout=log_file,
stderr=subprocess.STDOUT,
# Don't buffer output
bufsize=1,
)
return process
def wait_for_run_api(port: int, timeout: int = 60) -> bool:
"""Wait for run-api to be ready."""
import requests
start = time.time()
while time.time() - start < timeout:
try:
# run-api uses /status or / endpoint, not /health
resp = requests.get(f"http://localhost:{port}/status", timeout=5)
if resp.status_code == 200:
return True
except:
pass
try:
# Fallback to root endpoint
resp = requests.get(f"http://localhost:{port}/", timeout=5)
if resp.status_code == 200:
return True
except:
pass
time.sleep(2)
return False
def start_gsm8k_env(
model_id: str,
vllm_port: int,
run_api_port: int,
log_path: Path,
atropos_root: Path,
) -> subprocess.Popen:
"""Start a gsm8k environment process for a specific model."""
gsm8k_script = atropos_root / "environments" / "gsm8k_server.py"
cmd = [
sys.executable, "-u", str(gsm8k_script), "serve",
"--env.rollout_server_url", f"http://localhost:{run_api_port}",
"--env.tokenizer_name", model_id,
"--env.use_wandb", "false",
"--env.total_steps", "10000",
"--env.batch_size", "64",
"--env.group_size", "8",
"--openai.model_name", model_id,
"--openai.base_url", f"http://localhost:{vllm_port}/v1",
"--openai.api_key", "x",
"--openai.server_type", "openai",
]
log_file = open(log_path, "w")
process = subprocess.Popen(
cmd,
stdout=log_file,
stderr=subprocess.STDOUT,
cwd=str(atropos_root), # Run from atropos root
)
return process
def run_model_test(
model_config: ModelConfig,
gpu_id: int,
atropos_url: str,
atropos_port: int,
base_dir: str,
timestamp: str,
training_steps: int,
vllm_port_offset: int = 0,
auto_env: bool = False,
) -> Dict:
"""
Run a complete training test for a single model.
Returns dict with test results.
"""
model_name = model_config.name
test_dir = get_test_dir(base_dir, model_name, timestamp).resolve() # Make absolute
test_dir.mkdir(parents=True, exist_ok=True)
# Unique paths for this model (all absolute)
vllm_port = 9001 + vllm_port_offset
bridge_config_path = test_dir / "vllm_bridge_config.json"
checkpoint_dir = test_dir / "checkpoints"
log_dir = test_dir / "logs"
log_dir.mkdir(exist_ok=True)
vllm_log = log_dir / "vllm.log"
trainer_log = log_dir / "trainer.log"
# Each model gets unique ports
run_api_port = 8002 + vllm_port_offset
result = {
"model": model_config.model_id,
"model_name": model_name,
"gpu": gpu_id,
"vllm_port": vllm_port,
"run_api_port": run_api_port,
"test_dir": str(test_dir),
"status": "pending",
"error": None,
"start_time": None,
"end_time": None,
"duration_seconds": None,
"real_time_alignment": None,
"final_gpu_memory": None,
}
print(f"\n{'='*60}")
print(f"[{model_name}] Starting test on GPU {gpu_id}")
print(f"[{model_name}] Model: {model_config.model_id}")
print(f"[{model_name}] vLLM port: {vllm_port}")
print(f"[{model_name}] Test dir: {test_dir}")
print(f"{'='*60}\n")
result["start_time"] = datetime.now().isoformat()
start_time = time.time()
env_process = None
run_api_process = None
trainer_process = None
# Get atropos root directory (used for vLLM and gsm8k scripts)
script_dir = Path(__file__).parent
atropos_root = script_dir.parent.resolve()
try:
# === Start run-api (if auto_env) ===
if auto_env:
run_api_log = log_dir / "run_api.log"
print(f"[{model_name}] Starting run-api on port {run_api_port}...")
run_api_process = start_run_api(run_api_port, run_api_log)
if not wait_for_run_api(run_api_port, timeout=60):
# Check if process died
if run_api_process.poll() is not None:
print(f"[{model_name}] run-api process exited with code {run_api_process.returncode}")
# Print log contents for debugging
if run_api_log.exists():
print(f"[{model_name}] run-api log contents:")
print(run_api_log.read_text()[-2000:]) # Last 2000 chars
raise RuntimeError(f"run-api failed to start on port {run_api_port}")
print(f"[{model_name}] ✓ run-api ready on port {run_api_port}")
# Update atropos_url to use this model's run-api
atropos_url = f"http://localhost:{run_api_port}"
# === Start gsm8k Environment (if auto_env) ===
if auto_env:
env_log = log_dir / "env.log"
print(f"[{model_name}] Starting gsm8k environment (tokenizer: {model_config.model_id})...")
env_process = start_gsm8k_env(
model_config.model_id, vllm_port, run_api_port, env_log, atropos_root
)
time.sleep(10) # Give it time to initialize and connect
print(f"[{model_name}] ✓ gsm8k environment started")
# === Start Unified vLLM + Trainer (run.py) ===
# Using run.py ensures vLLM is a CHILD of the trainer process,
# which is required for CUDA IPC with ptrace_scope=1
run_script = script_dir / "run.py"
run_env = os.environ.copy()
run_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
run_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
run_cmd = [
sys.executable, "-u", str(run_script),
"--model", model_config.model_id,
"--vllm-port", str(vllm_port),
"--gpu-memory-utilization", str(model_config.gpu_memory_utilization),
"--max-model-len", str(model_config.max_model_len),
"--dtype", model_config.dtype,
"--atropos-url", atropos_url,
"--training-steps", str(training_steps),
"--optimizer", "adamw_8bit",
"--save-path", str(checkpoint_dir),
"--checkpoint-interval", "5",
"--log-dir", str(log_dir),
]
print(f"[{model_name}] Starting unified trainer (vLLM + GRPO) for {training_steps} steps...")
with open(trainer_log, "w") as tlog:
trainer_process = subprocess.Popen(
run_cmd,
env=run_env,
stdout=tlog,
stderr=subprocess.STDOUT,
cwd=str(atropos_root), # Run from atropos root
)
trainer_process.wait()
if trainer_process.returncode != 0:
raise RuntimeError(f"Unified trainer exited with code {trainer_process.returncode}")
result["status"] = "success"
print(f"[{model_name}] ✓ Training completed successfully!")
# Parse trainer log for metrics
try:
with open(trainer_log, "r") as f:
log_content = f.read()
# Extract real-time alignment
if "Mean diff:" in log_content:
import re
match = re.search(r"Mean diff: ([\d.]+)", log_content)
if match:
result["real_time_alignment"] = float(match.group(1))
# Extract final GPU memory
if "GPU mem:" in log_content:
matches = re.findall(r"GPU mem: ([\d.]+)GB", log_content)
if matches:
result["final_gpu_memory"] = float(matches[-1])
except Exception as e:
print(f"[{model_name}] Warning: Could not parse log: {e}")
except Exception as e:
result["status"] = "failed"
result["error"] = str(e)
print(f"[{model_name}] ✗ Test failed: {e}")
import traceback
traceback.print_exc()
finally:
# Note: vLLM is managed by run.py and cleaned up automatically
# Cleanup gsm8k environment
if env_process and env_process.poll() is None:
print(f"[{model_name}] Terminating gsm8k environment...")
env_process.terminate()
try:
env_process.wait(timeout=10)
except subprocess.TimeoutExpired:
env_process.kill()
# Cleanup run-api
if run_api_process and run_api_process.poll() is None:
print(f"[{model_name}] Terminating run-api...")
run_api_process.terminate()
try:
run_api_process.wait(timeout=10)
except subprocess.TimeoutExpired:
run_api_process.kill()
result["end_time"] = datetime.now().isoformat()
result["duration_seconds"] = time.time() - start_time
return result
def run_parallel_tests(
models: List[ModelConfig],
gpu_ids: List[int],
atropos_url: str,
atropos_port: int,
base_dir: str,
training_steps: int,
auto_env: bool = False,
) -> List[Dict]:
"""Run tests for multiple models in parallel."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results = []
threads = []
result_lock = threading.Lock()
def run_and_store(model, gpu, port_offset):
result = run_model_test(
model, gpu, atropos_url, atropos_port, base_dir, timestamp,
training_steps, port_offset, auto_env
)
with result_lock:
results.append(result)
# Start threads
for i, (model, gpu) in enumerate(zip(models, gpu_ids)):
t = threading.Thread(target=run_and_store, args=(model, gpu, i))
t.start()
threads.append(t)
time.sleep(5) # Stagger starts slightly
# Wait for all to complete
for t in threads:
t.join()
return results
def run_sequential_tests(
models: List[ModelConfig],
gpu_id: int,
atropos_url: str,
atropos_port: int,
base_dir: str,
training_steps: int,
auto_env: bool = False,
) -> List[Dict]:
"""Run tests for multiple models sequentially on one GPU."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results = []
for i, model in enumerate(models):
result = run_model_test(
model, gpu_id, atropos_url, atropos_port, base_dir, timestamp,
training_steps, port_offset=0, auto_env=auto_env
)
results.append(result)
# Give GPU time to fully release memory
time.sleep(10)
return results
def print_summary(results: List[Dict]):
"""Print summary of test results."""
print("\n" + "="*80)
print("TEST SUMMARY")
print("="*80)
for r in results:
status_icon = "" if r["status"] == "success" else ""
duration = f"{r['duration_seconds']:.1f}s" if r['duration_seconds'] else "N/A"
alignment = f"{r['real_time_alignment']:.4f}" if r['real_time_alignment'] else "N/A"
memory = f"{r['final_gpu_memory']:.1f}GB" if r['final_gpu_memory'] else "N/A"
print(f"\n{status_icon} {r['model_name']}")
print(f" Model: {r['model']}")
print(f" GPU: {r['gpu']}, vLLM port: {r['vllm_port']}, run-api port: {r.get('run_api_port', 'N/A')}")
print(f" Status: {r['status']}")
print(f" Duration: {duration}")
print(f" Real-time alignment: {alignment}")
print(f" GPU memory: {memory}")
if r["error"]:
print(f" Error: {r['error']}")
print(f" Logs: {r['test_dir']}/logs/")
# Summary stats
successes = sum(1 for r in results if r["status"] == "success")
failures = len(results) - successes
print(f"\n{'='*80}")
print(f"TOTAL: {successes} passed, {failures} failed")
print("="*80)
def main():
parser = argparse.ArgumentParser(
description="Multi-model test suite for shared_vllm trainer",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Run all models in parallel (one per GPU)
python -m example_trainer.test_multi_model --parallel
# Run specific models
python -m example_trainer.test_multi_model --models hermes-8b qwen3-4b --parallel
# Run sequentially on GPU 0
python -m example_trainer.test_multi_model --sequential --gpu 0
Available models: """ + ", ".join(TEST_MODELS.keys())
)
parser.add_argument(
"--models",
nargs="+",
choices=list(TEST_MODELS.keys()),
default=["qwen3-4b", "hermes-8b"],
help="Models to test",
)
parser.add_argument(
"--parallel",
action="store_true",
help="Run models in parallel on different GPUs",
)
parser.add_argument(
"--sequential",
action="store_true",
help="Run models sequentially on one GPU",
)
parser.add_argument(
"--gpus",
type=int,
nargs="+",
default=None,
help="GPU IDs to use (for parallel mode)",
)
parser.add_argument(
"--gpu",
type=int,
default=0,
help="GPU ID (for sequential mode)",
)
parser.add_argument(
"--atropos-url",
type=str,
default="http://localhost:8002",
help="Atropos API URL",
)
parser.add_argument(
"--atropos-port",
type=int,
default=8002,
help="Atropos API port (for spawning multiple if needed)",
)
parser.add_argument(
"--training-steps",
type=int,
default=10,
help="Number of training steps per model",
)
parser.add_argument(
"--output-dir",
type=str,
default="./multi_model_tests",
help="Base directory for test outputs",
)
parser.add_argument(
"--auto-env",
action="store_true",
help="Automatically start gsm8k environment for each model (requires run-api to be running)",
)
args = parser.parse_args()
if not args.parallel and not args.sequential:
args.sequential = True # Default to sequential
# Get model configs
models = [TEST_MODELS[name] for name in args.models]
print(f"\n{'#'*60}")
print("# MULTI-MODEL SHARED_VLLM TRAINER TEST SUITE")
print(f"{'#'*60}")
print(f"\nModels to test: {[m.name for m in models]}")
print(f"Mode: {'Parallel' if args.parallel else 'Sequential'}")
print(f"Training steps per model: {args.training_steps}")
print(f"Output directory: {args.output_dir}")
print(f"Atropos URL: {args.atropos_url}")
# Run tests
if args.auto_env:
print(f"Auto-env: Will start gsm8k environment per model")
if args.parallel:
gpus = args.gpus or list(range(len(models)))
if len(gpus) < len(models):
print(f"\nWarning: Not enough GPUs ({len(gpus)}) for models ({len(models)})")
print("Some models will share GPUs")
gpus = gpus * (len(models) // len(gpus) + 1)
print(f"Using GPUs: {gpus[:len(models)]}")
results = run_parallel_tests(
models, gpus[:len(models)],
args.atropos_url, args.atropos_port,
args.output_dir, args.training_steps,
auto_env=args.auto_env
)
else:
print(f"Using GPU: {args.gpu}")
results = run_sequential_tests(
models, args.gpu,
args.atropos_url, args.atropos_port,
args.output_dir, args.training_steps,
auto_env=args.auto_env
)
# Print summary
print_summary(results)
# Save results to JSON
results_file = Path(args.output_dir) / f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
results_file.parent.mkdir(parents=True, exist_ok=True)
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to: {results_file}")
# Exit with error code if any failed
if any(r["status"] != "success" for r in results):
sys.exit(1)
if __name__ == "__main__":
main()