mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
manual testing
This commit is contained in:
parent
da046d3d3b
commit
c1bb4f33f0
5 changed files with 329 additions and 766 deletions
|
|
@ -70,6 +70,38 @@ def parse_args() -> argparse.Namespace:
|
|||
"'adamw_cpu' (CPU offload, ~0GB GPU, slower), "
|
||||
"'adafactor' (no momentum, ~8GB GPU)",
|
||||
)
|
||||
|
||||
# === GRPO/PPO Hyperparameters ===
|
||||
parser.add_argument(
|
||||
"--kl-coef",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help=(
|
||||
"KL divergence penalty coefficient (beta). "
|
||||
"Controls policy deviation from reference. "
|
||||
"Higher = more conservative, prevents reward hacking. "
|
||||
"0 = disabled (not recommended)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip-eps",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help=(
|
||||
"PPO-style clipping epsilon. "
|
||||
"Clips importance ratio to [1-eps, 1+eps]. "
|
||||
"Prevents destabilizing large policy updates."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-reference-logprobs",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Disable use of inference logprobs as reference policy. "
|
||||
"Falls back to REINFORCE-style updates (not recommended)."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
|
|
@ -265,6 +297,11 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|||
device=args.device,
|
||||
save_path=args.save_path,
|
||||
checkpoint_interval=getattr(args, "checkpoint_interval", 3),
|
||||
# GRPO/PPO hyperparameters
|
||||
kl_coef=getattr(args, "kl_coef", 0.1),
|
||||
clip_eps=getattr(args, "clip_eps", 0.2),
|
||||
use_reference_logprobs=not getattr(args, "no_reference_logprobs", False),
|
||||
# vLLM settings
|
||||
vllm_restart_interval=args.vllm_restart_interval,
|
||||
vllm_port=args.vllm_port,
|
||||
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
|
||||
|
|
|
|||
|
|
@ -40,6 +40,33 @@ class TrainingConfig(BaseModel):
|
|||
"'adafactor' (no momentum, ~8GB GPU)"
|
||||
)
|
||||
|
||||
# === GRPO/PPO Hyperparameters ===
|
||||
kl_coef: float = Field(
|
||||
0.1,
|
||||
description=(
|
||||
"KL divergence penalty coefficient (beta). "
|
||||
"Controls how much the policy can deviate from the reference (inference-time) policy. "
|
||||
"Higher values = more conservative updates, prevents reward hacking. "
|
||||
"Set to 0 to disable KL penalty (not recommended)."
|
||||
),
|
||||
)
|
||||
clip_eps: float = Field(
|
||||
0.2,
|
||||
description=(
|
||||
"PPO-style clipping epsilon. "
|
||||
"Clips the importance sampling ratio to [1-eps, 1+eps]. "
|
||||
"Prevents large policy updates that could destabilize training."
|
||||
),
|
||||
)
|
||||
use_reference_logprobs: bool = Field(
|
||||
True,
|
||||
description=(
|
||||
"Whether to use inference logprobs as the reference policy (π_old). "
|
||||
"When True, implements proper GRPO with importance sampling. "
|
||||
"When False, falls back to REINFORCE-style updates (not recommended)."
|
||||
),
|
||||
)
|
||||
|
||||
# === Device & Storage ===
|
||||
device: str = Field(
|
||||
"cuda" if torch.cuda.is_available() else "cpu",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ Data processing utilities for GRPO trainer.
|
|||
Handles data retrieval from Atropos API, padding, batching,
|
||||
and advantage normalization.
|
||||
|
||||
Also extracts inference logprobs for alignment validation with training logprobs.
|
||||
Also extracts inference logprobs for proper GRPO loss computation:
|
||||
- Inference logprobs serve as π_old (reference policy) for importance sampling
|
||||
- They are batched and padded to align token-by-token with training labels
|
||||
"""
|
||||
|
||||
import json
|
||||
|
|
@ -23,11 +25,11 @@ def pad_data_to_good_offset(
|
|||
batch_size: int,
|
||||
extract_inference_logprobs: bool = True,
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
Optional[List[np.ndarray]],
|
||||
List[torch.Tensor], # token_batches
|
||||
List[torch.Tensor], # label_batches
|
||||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels)
|
||||
]:
|
||||
"""
|
||||
Pad and batch data from the Atropos API.
|
||||
|
|
@ -36,7 +38,7 @@ def pad_data_to_good_offset(
|
|||
- Pads token sequences to nearest multiple of 64
|
||||
- Normalizes advantage scores
|
||||
- Extracts temperature values
|
||||
- Optionally extracts inference logprobs for alignment validation
|
||||
- Extracts and pads inference logprobs for proper GRPO loss computation
|
||||
|
||||
Args:
|
||||
data: Raw batch data from Atropos API
|
||||
|
|
@ -44,8 +46,12 @@ def pad_data_to_good_offset(
|
|||
extract_inference_logprobs: Whether to extract inference logprobs
|
||||
|
||||
Returns:
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs)
|
||||
inference_logprobs is None if extract_inference_logprobs=False or no logprobs in data
|
||||
Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
|
||||
inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data
|
||||
|
||||
Note:
|
||||
inference_logprob_batches are padded with 0.0 at positions where labels == -100.
|
||||
This allows token-by-token alignment during GRPO loss computation.
|
||||
"""
|
||||
max_token_len = max(
|
||||
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
|
||||
|
|
@ -66,7 +72,8 @@ def pad_data_to_good_offset(
|
|||
advantages = []
|
||||
lengths = []
|
||||
temperatures = []
|
||||
inference_logprobs_list: List[np.ndarray] = []
|
||||
inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape
|
||||
has_any_logprobs = False
|
||||
|
||||
for item in data["batch"]:
|
||||
# Normalize advantage scores
|
||||
|
|
@ -84,15 +91,16 @@ def pad_data_to_good_offset(
|
|||
|
||||
# Process each sample in the item
|
||||
for i in range(len(item["tokens"])):
|
||||
seq_len = len(item["tokens"][i])
|
||||
lengths.append(
|
||||
math.ceil((len(item["tokens"][i]) - 1) / good_multiple) * good_multiple
|
||||
math.ceil((seq_len - 1) / good_multiple) * good_multiple
|
||||
)
|
||||
|
||||
# Create labels with padding
|
||||
# Create labels with padding (-100 for masked positions)
|
||||
label_item = np.concatenate([
|
||||
np.array(item["masks"][i]),
|
||||
np.full(
|
||||
max(0, token_setup_len - len(item["tokens"][i])),
|
||||
max(0, token_setup_len - seq_len),
|
||||
-100,
|
||||
dtype=np.int32,
|
||||
),
|
||||
|
|
@ -102,7 +110,7 @@ def pad_data_to_good_offset(
|
|||
item["tokens"][i] = np.concatenate([
|
||||
np.array(item["tokens"][i]),
|
||||
np.zeros(
|
||||
max(0, token_setup_len - len(item["tokens"][i])),
|
||||
max(0, token_setup_len - seq_len),
|
||||
dtype=np.int32,
|
||||
),
|
||||
])
|
||||
|
|
@ -111,13 +119,36 @@ def pad_data_to_good_offset(
|
|||
labels.append(label_item[1:]) # Shift by 1 for causal
|
||||
advantages.append(item["scores"][i])
|
||||
|
||||
# Extract inference logprobs for alignment validation
|
||||
# These come from vLLM during rollout generation
|
||||
# Extract and pad inference logprobs to match labels shape
|
||||
# Inference logprobs are ONLY for generated tokens (where labels != -100)
|
||||
# We need to create a padded array that aligns position-by-position
|
||||
if extract_inference_logprobs and "inference_logprobs" in item:
|
||||
if i < len(item["inference_logprobs"]):
|
||||
inference_logprobs_list.append(
|
||||
np.array(item["inference_logprobs"][i], dtype=np.float32)
|
||||
)
|
||||
raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32)
|
||||
has_any_logprobs = True
|
||||
|
||||
# Create padded logprobs array matching label_item shape
|
||||
# Fill with 0.0 (will be masked out during loss computation)
|
||||
padded_logprobs = np.zeros(token_setup_len, dtype=np.float32)
|
||||
|
||||
# The inference logprobs correspond to generated tokens
|
||||
# Find positions where labels != -100 (generated positions)
|
||||
mask_arr = np.array(item["masks"][i])
|
||||
generated_positions = np.where(mask_arr != -100)[0]
|
||||
|
||||
# Fill in inference logprobs at generated positions
|
||||
n_to_fill = min(len(raw_logprobs), len(generated_positions))
|
||||
if n_to_fill > 0:
|
||||
padded_logprobs[generated_positions[:n_to_fill]] = raw_logprobs[:n_to_fill]
|
||||
|
||||
# Shift by 1 to match causal label shift
|
||||
inference_logprobs_padded.append(padded_logprobs[1:])
|
||||
else:
|
||||
# No logprobs for this sample, use zeros
|
||||
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
|
||||
elif extract_inference_logprobs:
|
||||
# No inference_logprobs in item, use zeros
|
||||
inference_logprobs_padded.append(np.zeros(token_setup_len - 1, dtype=np.float32))
|
||||
|
||||
# Extract temperature (priority: override > generation_params > group_overrides > 1.0)
|
||||
t = 1.0
|
||||
|
|
@ -139,6 +170,7 @@ def pad_data_to_good_offset(
|
|||
label_batches = []
|
||||
advantage_batches = []
|
||||
temperature_batches = []
|
||||
inference_logprob_batches = []
|
||||
|
||||
for i in range(len(input_ids) // batch_size):
|
||||
start = i * batch_size
|
||||
|
|
@ -158,11 +190,17 @@ def pad_data_to_good_offset(
|
|||
np.array(temperatures[start:end], dtype=np.float32)
|
||||
).view(-1, 1, 1)
|
||||
)
|
||||
|
||||
# Batch inference logprobs (same shape as labels)
|
||||
if extract_inference_logprobs and inference_logprobs_padded:
|
||||
inference_logprob_batches.append(
|
||||
torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0))
|
||||
)
|
||||
|
||||
# Return inference logprobs if available
|
||||
inference_logprobs = inference_logprobs_list if inference_logprobs_list else None
|
||||
# Return inference logprob batches if we have any real logprobs
|
||||
final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None
|
||||
|
||||
return token_batches, label_batches, advantage_batches, temperature_batches, inference_logprobs
|
||||
return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches
|
||||
|
||||
|
||||
def get_data(
|
||||
|
|
@ -171,8 +209,14 @@ def get_data(
|
|||
atropos_url: str = "http://localhost:8000",
|
||||
extract_inference_logprobs: bool = True,
|
||||
) -> Tuple[
|
||||
List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]],
|
||||
Optional[List[np.ndarray]],
|
||||
List[Tuple[
|
||||
List[torch.Tensor], # token_batches
|
||||
List[torch.Tensor], # label_batches
|
||||
List[torch.Tensor], # advantage_batches
|
||||
List[torch.Tensor], # temperature_batches
|
||||
Optional[List[torch.Tensor]], # inference_logprob_batches
|
||||
]],
|
||||
None, # Legacy return (no longer used)
|
||||
]:
|
||||
"""
|
||||
Fetch and process training data from the Atropos API.
|
||||
|
|
@ -184,15 +228,15 @@ def get_data(
|
|||
batch_size: Size of each training batch
|
||||
seq_len: Maximum sequence length (for reference, not used directly)
|
||||
atropos_url: URL of the Atropos API server
|
||||
extract_inference_logprobs: Whether to extract inference logprobs for alignment
|
||||
extract_inference_logprobs: Whether to extract inference logprobs for GRPO loss
|
||||
|
||||
Returns:
|
||||
Tuple of (batches, all_inference_logprobs)
|
||||
- batches: List of processed batch tuples
|
||||
- all_inference_logprobs: List of inference logprob arrays for alignment validation
|
||||
Tuple of (batches, None)
|
||||
- batches: List of processed batch tuples, each containing:
|
||||
(token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches)
|
||||
- inference_logprob_batches are aligned with labels for proper GRPO loss computation
|
||||
"""
|
||||
batches = []
|
||||
all_inference_logprobs: List[np.ndarray] = []
|
||||
|
||||
while True:
|
||||
data = get_batch(url=atropos_url)
|
||||
|
|
@ -202,18 +246,16 @@ def get_data(
|
|||
with open("temp.json", "w", encoding="utf-8") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
# Process and accumulate batches
|
||||
token_batches, label_batches, adv_batches, temp_batches, inf_logprobs = \
|
||||
# Process and accumulate batches (now includes batched inference logprobs)
|
||||
token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \
|
||||
pad_data_to_good_offset(data, batch_size, extract_inference_logprobs)
|
||||
|
||||
batches.append((token_batches, label_batches, adv_batches, temp_batches))
|
||||
|
||||
if inf_logprobs:
|
||||
all_inference_logprobs.extend(inf_logprobs)
|
||||
# Include inference logprob batches in the tuple
|
||||
batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches))
|
||||
|
||||
elif len(batches) > 0:
|
||||
# Return accumulated batches when no more data
|
||||
return batches, all_inference_logprobs if all_inference_logprobs else None
|
||||
return batches, None
|
||||
else:
|
||||
# Wait for data
|
||||
time.sleep(1)
|
||||
|
|
|
|||
|
|
@ -1,647 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Multi-model test suite for shared_vllm trainer.
|
||||
|
||||
Tests the trainer against diverse models to verify robustness.
|
||||
Supports both parallel (different GPUs) and sequential execution.
|
||||
|
||||
With --auto-env, each model gets its own isolated stack:
|
||||
- run-api (port 8002 + offset)
|
||||
- gsm8k environment (with model-specific tokenizer)
|
||||
- vLLM server (port 9001 + offset)
|
||||
- trainer
|
||||
|
||||
Usage:
|
||||
# RECOMMENDED: Fully automated parallel test with W&B logging
|
||||
python -m example_trainer.test_multi_model \
|
||||
--models qwen3-4b hermes-8b nemotron-14b devstral-24b \
|
||||
--parallel \
|
||||
--gpus 0 1 2 3 \
|
||||
--auto-env \
|
||||
--use-wandb \
|
||||
--wandb-project multi-model-test
|
||||
|
||||
# Sequential test on one GPU
|
||||
python -m example_trainer.test_multi_model \
|
||||
--models qwen3-4b hermes-8b \
|
||||
--sequential \
|
||||
--gpu 0 \
|
||||
--auto-env \
|
||||
--use-wandb
|
||||
|
||||
# Manual mode (you must start run-api and gsm8k_server yourself)
|
||||
python -m example_trainer.test_multi_model \
|
||||
--models qwen3-4b \
|
||||
--sequential \
|
||||
--gpu 0 \
|
||||
--atropos-url http://localhost:8002
|
||||
|
||||
Port allocation with --auto-env:
|
||||
Model 0: run-api:8002, vLLM:9001, GPU from --gpus[0]
|
||||
Model 1: run-api:8003, vLLM:9002, GPU from --gpus[1]
|
||||
Model 2: run-api:8004, vLLM:9003, GPU from --gpus[2]
|
||||
...
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import threading
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Configuration for a test model."""
|
||||
name: str
|
||||
model_id: str
|
||||
gpu_memory_utilization: float = 0.5
|
||||
max_model_len: int = 4096
|
||||
dtype: str = "bfloat16"
|
||||
training_steps: int = 10
|
||||
notes: str = ""
|
||||
|
||||
|
||||
# Define test models
|
||||
# Memory estimates for B200 (183GB):
|
||||
# - Model weights (bf16): 2 bytes/param
|
||||
# - Gradients: ~same as weights
|
||||
# - 8-bit optimizer: ~1 byte/param
|
||||
# - KV cache: depends on max_model_len
|
||||
TEST_MODELS: Dict[str, ModelConfig] = {
|
||||
"qwen3-4b": ModelConfig(
|
||||
name="qwen3-4b",
|
||||
model_id="Qwen/Qwen3-4B-Instruct-2507",
|
||||
gpu_memory_utilization=0.4, # ~73GB for vLLM
|
||||
max_model_len=8192, # Plenty of room on B200
|
||||
notes="Small 4B model, good baseline test (~8GB weights)",
|
||||
),
|
||||
"hermes-8b": ModelConfig(
|
||||
name="hermes-8b",
|
||||
model_id="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
gpu_memory_utilization=0.45, # ~82GB for vLLM
|
||||
max_model_len=8192, # 8K context fits well
|
||||
notes="Llama 8B architecture (~16GB weights)",
|
||||
),
|
||||
"nemotron-14b": ModelConfig(
|
||||
name="nemotron-14b",
|
||||
model_id="nvidia/Nemotron-Cascade-14B-Thinking",
|
||||
gpu_memory_utilization=0.5, # ~91GB for vLLM
|
||||
max_model_len=32768, # 32K context for thinking
|
||||
notes="14B thinking model (~28GB weights), needs room for long CoT",
|
||||
),
|
||||
"devstral-24b": ModelConfig(
|
||||
name="devstral-24b",
|
||||
model_id="mistralai/Devstral-Small-2-24B-Instruct-2512",
|
||||
gpu_memory_utilization=0.55, # ~100GB for vLLM
|
||||
max_model_len=16384, # 16K context (conservative for 24B)
|
||||
notes="Large 24B Mistral (~48GB weights), largest model",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_test_dir(base_dir: str, model_name: str, timestamp: str) -> Path:
|
||||
"""Get unique test directory for a model run."""
|
||||
return Path(base_dir) / f"{model_name}_{timestamp}"
|
||||
|
||||
|
||||
def start_run_api(
|
||||
port: int,
|
||||
log_path: Path,
|
||||
) -> subprocess.Popen:
|
||||
"""Start a run-api instance on a specific port."""
|
||||
cmd = [sys.executable, "-m", "atroposlib.cli.run_api", "--port", str(port)]
|
||||
|
||||
log_file = open(log_path, "w")
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
# Don't buffer output
|
||||
bufsize=1,
|
||||
)
|
||||
return process
|
||||
|
||||
|
||||
def wait_for_run_api(port: int, timeout: int = 60) -> bool:
|
||||
"""Wait for run-api to be ready."""
|
||||
import requests
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
# run-api uses /status or / endpoint, not /health
|
||||
resp = requests.get(f"http://localhost:{port}/status", timeout=5)
|
||||
if resp.status_code == 200:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
# Fallback to root endpoint
|
||||
resp = requests.get(f"http://localhost:{port}/", timeout=5)
|
||||
if resp.status_code == 200:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
def start_gsm8k_env(
|
||||
model_id: str,
|
||||
vllm_port: int,
|
||||
run_api_port: int,
|
||||
log_path: Path,
|
||||
atropos_root: Path,
|
||||
) -> subprocess.Popen:
|
||||
"""Start a gsm8k environment process for a specific model."""
|
||||
gsm8k_script = atropos_root / "environments" / "gsm8k_server.py"
|
||||
cmd = [
|
||||
sys.executable, "-u", str(gsm8k_script), "serve",
|
||||
"--env.rollout_server_url", f"http://localhost:{run_api_port}",
|
||||
"--env.tokenizer_name", model_id,
|
||||
"--env.use_wandb", "false",
|
||||
"--env.total_steps", "10000",
|
||||
"--env.batch_size", "64",
|
||||
"--env.group_size", "8",
|
||||
"--openai.model_name", model_id,
|
||||
"--openai.base_url", f"http://localhost:{vllm_port}/v1",
|
||||
"--openai.api_key", "x",
|
||||
"--openai.server_type", "openai",
|
||||
]
|
||||
|
||||
log_file = open(log_path, "w")
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=str(atropos_root), # Run from atropos root
|
||||
)
|
||||
return process
|
||||
|
||||
|
||||
def run_model_test(
|
||||
model_config: ModelConfig,
|
||||
gpu_id: int,
|
||||
atropos_url: str,
|
||||
atropos_port: int,
|
||||
base_dir: str,
|
||||
timestamp: str,
|
||||
training_steps: int,
|
||||
vllm_port_offset: int = 0,
|
||||
auto_env: bool = False,
|
||||
use_wandb: bool = False,
|
||||
wandb_project: str = "multi-model-test",
|
||||
) -> Dict:
|
||||
"""
|
||||
Run a complete training test for a single model.
|
||||
|
||||
Returns dict with test results.
|
||||
"""
|
||||
model_name = model_config.name
|
||||
test_dir = get_test_dir(base_dir, model_name, timestamp).resolve() # Make absolute
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Unique paths for this model (all absolute)
|
||||
vllm_port = 9001 + vllm_port_offset
|
||||
bridge_config_path = test_dir / "vllm_bridge_config.json"
|
||||
checkpoint_dir = test_dir / "checkpoints"
|
||||
log_dir = test_dir / "logs"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
|
||||
vllm_log = log_dir / "vllm.log"
|
||||
trainer_log = log_dir / "trainer.log"
|
||||
|
||||
# Each model gets unique ports
|
||||
run_api_port = 8002 + vllm_port_offset
|
||||
|
||||
result = {
|
||||
"model": model_config.model_id,
|
||||
"model_name": model_name,
|
||||
"gpu": gpu_id,
|
||||
"vllm_port": vllm_port,
|
||||
"run_api_port": run_api_port,
|
||||
"test_dir": str(test_dir),
|
||||
"status": "pending",
|
||||
"error": None,
|
||||
"start_time": None,
|
||||
"end_time": None,
|
||||
"duration_seconds": None,
|
||||
"real_time_alignment": None,
|
||||
"final_gpu_memory": None,
|
||||
}
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[{model_name}] Starting test on GPU {gpu_id}")
|
||||
print(f"[{model_name}] Model: {model_config.model_id}")
|
||||
print(f"[{model_name}] vLLM port: {vllm_port}")
|
||||
print(f"[{model_name}] Test dir: {test_dir}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
result["start_time"] = datetime.now().isoformat()
|
||||
start_time = time.time()
|
||||
|
||||
env_process = None
|
||||
run_api_process = None
|
||||
trainer_process = None
|
||||
|
||||
# Get atropos root directory (used for vLLM and gsm8k scripts)
|
||||
script_dir = Path(__file__).parent
|
||||
atropos_root = script_dir.parent.resolve()
|
||||
|
||||
try:
|
||||
# === Start run-api (if auto_env) ===
|
||||
if auto_env:
|
||||
run_api_log = log_dir / "run_api.log"
|
||||
print(f"[{model_name}] Starting run-api on port {run_api_port}...")
|
||||
run_api_process = start_run_api(run_api_port, run_api_log)
|
||||
|
||||
if not wait_for_run_api(run_api_port, timeout=60):
|
||||
# Check if process died
|
||||
if run_api_process.poll() is not None:
|
||||
print(f"[{model_name}] run-api process exited with code {run_api_process.returncode}")
|
||||
# Print log contents for debugging
|
||||
if run_api_log.exists():
|
||||
print(f"[{model_name}] run-api log contents:")
|
||||
print(run_api_log.read_text()[-2000:]) # Last 2000 chars
|
||||
raise RuntimeError(f"run-api failed to start on port {run_api_port}")
|
||||
print(f"[{model_name}] ✓ run-api ready on port {run_api_port}")
|
||||
|
||||
# Update atropos_url to use this model's run-api
|
||||
atropos_url = f"http://localhost:{run_api_port}"
|
||||
|
||||
# === Start gsm8k Environment (if auto_env) ===
|
||||
if auto_env:
|
||||
env_log = log_dir / "env.log"
|
||||
print(f"[{model_name}] Starting gsm8k environment (tokenizer: {model_config.model_id})...")
|
||||
env_process = start_gsm8k_env(
|
||||
model_config.model_id, vllm_port, run_api_port, env_log, atropos_root
|
||||
)
|
||||
time.sleep(10) # Give it time to initialize and connect
|
||||
print(f"[{model_name}] ✓ gsm8k environment started")
|
||||
|
||||
# === Start Unified vLLM + Trainer (run.py) ===
|
||||
# Using run.py ensures vLLM is a CHILD of the trainer process,
|
||||
# which is required for CUDA IPC with ptrace_scope=1
|
||||
run_script = script_dir / "run.py"
|
||||
|
||||
# Don't use CUDA_VISIBLE_DEVICES - use --device instead
|
||||
# run.py sets CUDA_VISIBLE_DEVICES internally based on --device
|
||||
run_env = os.environ.copy()
|
||||
run_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
run_cmd = [
|
||||
sys.executable, "-u", str(run_script),
|
||||
"--model", model_config.model_id,
|
||||
"--device", f"cuda:{gpu_id}", # This controls GPU selection
|
||||
"--vllm-port", str(vllm_port),
|
||||
"--gpu-memory-utilization", str(model_config.gpu_memory_utilization),
|
||||
"--max-model-len", str(model_config.max_model_len),
|
||||
"--dtype", model_config.dtype,
|
||||
"--atropos-url", atropos_url,
|
||||
"--training-steps", str(training_steps),
|
||||
"--optimizer", "adamw_8bit",
|
||||
"--save-path", str(checkpoint_dir),
|
||||
"--checkpoint-interval", "5",
|
||||
"--log-dir", str(log_dir),
|
||||
]
|
||||
|
||||
# Add wandb flags if enabled
|
||||
if use_wandb:
|
||||
run_cmd.extend(["--use-wandb", "--wandb-project", wandb_project])
|
||||
|
||||
print(f"[{model_name}] Starting unified trainer (vLLM + GRPO) for {training_steps} steps...")
|
||||
with open(trainer_log, "w") as tlog:
|
||||
trainer_process = subprocess.Popen(
|
||||
run_cmd,
|
||||
env=run_env,
|
||||
stdout=tlog,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=str(atropos_root), # Run from atropos root
|
||||
)
|
||||
trainer_process.wait()
|
||||
|
||||
if trainer_process.returncode != 0:
|
||||
raise RuntimeError(f"Unified trainer exited with code {trainer_process.returncode}")
|
||||
|
||||
result["status"] = "success"
|
||||
print(f"[{model_name}] ✓ Training completed successfully!")
|
||||
|
||||
# Parse trainer log for metrics
|
||||
try:
|
||||
with open(trainer_log, "r") as f:
|
||||
log_content = f.read()
|
||||
|
||||
# Extract real-time alignment
|
||||
if "Mean diff:" in log_content:
|
||||
import re
|
||||
match = re.search(r"Mean diff: ([\d.]+)", log_content)
|
||||
if match:
|
||||
result["real_time_alignment"] = float(match.group(1))
|
||||
|
||||
# Extract final GPU memory
|
||||
if "GPU mem:" in log_content:
|
||||
matches = re.findall(r"GPU mem: ([\d.]+)GB", log_content)
|
||||
if matches:
|
||||
result["final_gpu_memory"] = float(matches[-1])
|
||||
except Exception as e:
|
||||
print(f"[{model_name}] Warning: Could not parse log: {e}")
|
||||
|
||||
except Exception as e:
|
||||
result["status"] = "failed"
|
||||
result["error"] = str(e)
|
||||
print(f"[{model_name}] ✗ Test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
# Note: vLLM is managed by run.py and cleaned up automatically
|
||||
|
||||
# Cleanup gsm8k environment
|
||||
if env_process and env_process.poll() is None:
|
||||
print(f"[{model_name}] Terminating gsm8k environment...")
|
||||
env_process.terminate()
|
||||
try:
|
||||
env_process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
env_process.kill()
|
||||
|
||||
# Cleanup run-api
|
||||
if run_api_process and run_api_process.poll() is None:
|
||||
print(f"[{model_name}] Terminating run-api...")
|
||||
run_api_process.terminate()
|
||||
try:
|
||||
run_api_process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
run_api_process.kill()
|
||||
|
||||
result["end_time"] = datetime.now().isoformat()
|
||||
result["duration_seconds"] = time.time() - start_time
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def run_parallel_tests(
|
||||
models: List[ModelConfig],
|
||||
gpu_ids: List[int],
|
||||
atropos_url: str,
|
||||
atropos_port: int,
|
||||
base_dir: str,
|
||||
training_steps: int,
|
||||
auto_env: bool = False,
|
||||
use_wandb: bool = False,
|
||||
wandb_project: str = "multi-model-test",
|
||||
) -> List[Dict]:
|
||||
"""Run tests for multiple models in parallel."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
results = []
|
||||
threads = []
|
||||
result_lock = threading.Lock()
|
||||
|
||||
def run_and_store(model, gpu, port_offset):
|
||||
result = run_model_test(
|
||||
model, gpu, atropos_url, atropos_port, base_dir, timestamp,
|
||||
training_steps, port_offset, auto_env, use_wandb, wandb_project
|
||||
)
|
||||
with result_lock:
|
||||
results.append(result)
|
||||
|
||||
# Start threads
|
||||
for i, (model, gpu) in enumerate(zip(models, gpu_ids)):
|
||||
t = threading.Thread(target=run_and_store, args=(model, gpu, i))
|
||||
t.start()
|
||||
threads.append(t)
|
||||
time.sleep(5) # Stagger starts slightly
|
||||
|
||||
# Wait for all to complete
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_sequential_tests(
|
||||
models: List[ModelConfig],
|
||||
gpu_id: int,
|
||||
atropos_url: str,
|
||||
atropos_port: int,
|
||||
base_dir: str,
|
||||
training_steps: int,
|
||||
auto_env: bool = False,
|
||||
use_wandb: bool = False,
|
||||
wandb_project: str = "multi-model-test",
|
||||
) -> List[Dict]:
|
||||
"""Run tests for multiple models sequentially on one GPU."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
results = []
|
||||
|
||||
for i, model in enumerate(models):
|
||||
result = run_model_test(
|
||||
model, gpu_id, atropos_url, atropos_port, base_dir, timestamp,
|
||||
training_steps, port_offset=0, auto_env=auto_env,
|
||||
use_wandb=use_wandb, wandb_project=wandb_project
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Give GPU time to fully release memory
|
||||
time.sleep(10)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_summary(results: List[Dict]):
|
||||
"""Print summary of test results."""
|
||||
print("\n" + "="*80)
|
||||
print("TEST SUMMARY")
|
||||
print("="*80)
|
||||
|
||||
for r in results:
|
||||
status_icon = "✓" if r["status"] == "success" else "✗"
|
||||
duration = f"{r['duration_seconds']:.1f}s" if r['duration_seconds'] else "N/A"
|
||||
alignment = f"{r['real_time_alignment']:.4f}" if r['real_time_alignment'] else "N/A"
|
||||
memory = f"{r['final_gpu_memory']:.1f}GB" if r['final_gpu_memory'] else "N/A"
|
||||
|
||||
print(f"\n{status_icon} {r['model_name']}")
|
||||
print(f" Model: {r['model']}")
|
||||
print(f" GPU: {r['gpu']}, vLLM port: {r['vllm_port']}, run-api port: {r.get('run_api_port', 'N/A')}")
|
||||
print(f" Status: {r['status']}")
|
||||
print(f" Duration: {duration}")
|
||||
print(f" Real-time alignment: {alignment}")
|
||||
print(f" GPU memory: {memory}")
|
||||
if r["error"]:
|
||||
print(f" Error: {r['error']}")
|
||||
print(f" Logs: {r['test_dir']}/logs/")
|
||||
|
||||
# Summary stats
|
||||
successes = sum(1 for r in results if r["status"] == "success")
|
||||
failures = len(results) - successes
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"TOTAL: {successes} passed, {failures} failed")
|
||||
print("="*80)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Multi-model test suite for shared_vllm trainer",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Run all models in parallel (one per GPU)
|
||||
python -m example_trainer.test_multi_model --parallel
|
||||
|
||||
# Run specific models
|
||||
python -m example_trainer.test_multi_model --models hermes-8b qwen3-4b --parallel
|
||||
|
||||
# Run sequentially on GPU 0
|
||||
python -m example_trainer.test_multi_model --sequential --gpu 0
|
||||
|
||||
Available models: """ + ", ".join(TEST_MODELS.keys())
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
choices=list(TEST_MODELS.keys()),
|
||||
default=["qwen3-4b", "hermes-8b"],
|
||||
help="Models to test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parallel",
|
||||
action="store_true",
|
||||
help="Run models in parallel on different GPUs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sequential",
|
||||
action="store_true",
|
||||
help="Run models sequentially on one GPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpus",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="GPU IDs to use (for parallel mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="GPU ID (for sequential mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atropos-url",
|
||||
type=str,
|
||||
default="http://localhost:8002",
|
||||
help="Atropos API URL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atropos-port",
|
||||
type=int,
|
||||
default=8002,
|
||||
help="Atropos API port (for spawning multiple if needed)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training-steps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of training steps per model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./multi_model_tests",
|
||||
help="Base directory for test outputs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auto-env",
|
||||
action="store_true",
|
||||
help="Automatically start run-api and gsm8k environment for each model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-wandb",
|
||||
action="store_true",
|
||||
help="Enable Weights & Biases logging for training runs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wandb-project",
|
||||
type=str,
|
||||
default="multi-model-test",
|
||||
help="W&B project name for logging",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.parallel and not args.sequential:
|
||||
args.sequential = True # Default to sequential
|
||||
|
||||
# Get model configs
|
||||
models = [TEST_MODELS[name] for name in args.models]
|
||||
|
||||
print(f"\n{'#'*60}")
|
||||
print("# MULTI-MODEL SHARED_VLLM TRAINER TEST SUITE")
|
||||
print(f"{'#'*60}")
|
||||
print(f"\nModels to test: {[m.name for m in models]}")
|
||||
print(f"Mode: {'Parallel' if args.parallel else 'Sequential'}")
|
||||
print(f"Training steps per model: {args.training_steps}")
|
||||
print(f"Output directory: {args.output_dir}")
|
||||
print(f"Atropos URL: {args.atropos_url}")
|
||||
|
||||
# Run tests
|
||||
if args.auto_env:
|
||||
print(f"Auto-env: Will start gsm8k environment per model")
|
||||
|
||||
if args.parallel:
|
||||
gpus = args.gpus or list(range(len(models)))
|
||||
if len(gpus) < len(models):
|
||||
print(f"\nWarning: Not enough GPUs ({len(gpus)}) for models ({len(models)})")
|
||||
print("Some models will share GPUs")
|
||||
gpus = gpus * (len(models) // len(gpus) + 1)
|
||||
|
||||
print(f"Using GPUs: {gpus[:len(models)]}")
|
||||
if args.use_wandb:
|
||||
print(f"W&B logging enabled (project: {args.wandb_project})")
|
||||
results = run_parallel_tests(
|
||||
models, gpus[:len(models)],
|
||||
args.atropos_url, args.atropos_port,
|
||||
args.output_dir, args.training_steps,
|
||||
auto_env=args.auto_env,
|
||||
use_wandb=args.use_wandb,
|
||||
wandb_project=args.wandb_project,
|
||||
)
|
||||
else:
|
||||
print(f"Using GPU: {args.gpu}")
|
||||
if args.use_wandb:
|
||||
print(f"W&B logging enabled (project: {args.wandb_project})")
|
||||
results = run_sequential_tests(
|
||||
models, args.gpu,
|
||||
args.atropos_url, args.atropos_port,
|
||||
args.output_dir, args.training_steps,
|
||||
auto_env=args.auto_env,
|
||||
use_wandb=args.use_wandb,
|
||||
wandb_project=args.wandb_project,
|
||||
)
|
||||
|
||||
# Print summary
|
||||
print_summary(results)
|
||||
|
||||
# Save results to JSON
|
||||
results_file = Path(args.output_dir) / f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
results_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nResults saved to: {results_file}")
|
||||
|
||||
# Exit with error code if any failed
|
||||
if any(r["status"] != "success" for r in results):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -153,13 +153,22 @@ def compute_grpo_loss(
|
|||
temperatures: torch.Tensor,
|
||||
gradient_accumulation_steps: int,
|
||||
inference_logprobs: Optional[torch.Tensor] = None,
|
||||
kl_coef: float = 0.1,
|
||||
clip_eps: float = 0.2,
|
||||
use_reference_logprobs: bool = True,
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch.
|
||||
|
||||
The GRPO loss encourages the model to:
|
||||
This implements proper GRPO/PPO with:
|
||||
- Importance sampling ratio: π(a|s) / π_old(a|s)
|
||||
- PPO-style clipping to prevent large updates
|
||||
- KL penalty to prevent reward hacking/policy collapse
|
||||
|
||||
The loss encourages the model to:
|
||||
- Increase probability for tokens with positive advantages
|
||||
- Decrease probability for tokens with negative advantages
|
||||
- Stay close to the reference policy (inference-time policy)
|
||||
|
||||
Args:
|
||||
model: The model to compute loss for
|
||||
|
|
@ -168,7 +177,10 @@ def compute_grpo_loss(
|
|||
advantages: Advantage values [batch, 1]
|
||||
temperatures: Temperature values [batch, 1, 1]
|
||||
gradient_accumulation_steps: Number of accumulation steps (for scaling)
|
||||
inference_logprobs: Optional logprobs from inference for alignment check
|
||||
inference_logprobs: Logprobs from inference (π_old), aligned with labels [batch, seq_len]
|
||||
kl_coef: KL penalty coefficient (beta). Higher = more conservative updates
|
||||
clip_eps: PPO clipping epsilon. Clips ratio to [1-eps, 1+eps]
|
||||
use_reference_logprobs: If True, use inference_logprobs as reference policy
|
||||
|
||||
Returns:
|
||||
Tuple of (loss tensor, metrics dict)
|
||||
|
|
@ -177,14 +189,14 @@ def compute_grpo_loss(
|
|||
outputs = model(tokens)
|
||||
logits = outputs.logits
|
||||
|
||||
# Temperature scaling
|
||||
# Temperature scaling for training
|
||||
t = temperatures.to(logits.device, logits.dtype)
|
||||
t = torch.where(t <= 0, torch.ones_like(t), t)
|
||||
logits = logits / t
|
||||
scaled_logits = logits / t
|
||||
|
||||
# Log probabilities per token
|
||||
# Log probabilities per token (current policy π)
|
||||
logp_per_token = -F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)),
|
||||
scaled_logits.view(-1, scaled_logits.size(-1)),
|
||||
labels.view(-1),
|
||||
reduction="none",
|
||||
ignore_index=-100,
|
||||
|
|
@ -192,39 +204,103 @@ def compute_grpo_loss(
|
|||
|
||||
# Masking based on labels != -100
|
||||
mask = (labels != -100).float()
|
||||
mask_sum = mask.sum(dim=-1).clamp_min(1e-8)
|
||||
|
||||
# Compute metrics (no grad needed)
|
||||
# Expand advantages to match token shape [batch, 1] -> [batch, seq_len]
|
||||
adv_expanded = advantages.expand_as(logp_per_token).to(logp_per_token.device)
|
||||
|
||||
# === GRPO/PPO Loss Computation ===
|
||||
if use_reference_logprobs and inference_logprobs is not None:
|
||||
# Move inference logprobs to correct device/dtype
|
||||
ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype)
|
||||
|
||||
# Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old)
|
||||
log_ratio = logp_per_token - ref_logprobs
|
||||
ratio = torch.exp(log_ratio)
|
||||
|
||||
# PPO-style clipping
|
||||
clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
|
||||
|
||||
# Surrogate objectives
|
||||
surr1 = ratio * adv_expanded
|
||||
surr2 = clipped_ratio * adv_expanded
|
||||
|
||||
# Pessimistic bound: min for positive advantages, max for negative
|
||||
# This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0
|
||||
# -max(ratio * A, clipped_ratio * A) when A < 0
|
||||
policy_loss_per_token = -torch.where(
|
||||
adv_expanded >= 0,
|
||||
torch.min(surr1, surr2),
|
||||
torch.max(surr1, surr2),
|
||||
)
|
||||
|
||||
# Average over tokens, then over batch
|
||||
policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
|
||||
# KL penalty: encourage staying close to reference policy
|
||||
# KL(π || π_ref) ≈ log(π/π_ref) = log_ratio (when π_ref is the reference)
|
||||
# We use the approximation: KL ≈ (ratio - 1) - log(ratio)
|
||||
# But simpler: just penalize squared log-ratio which is symmetric
|
||||
if kl_coef > 0:
|
||||
# Approximate KL using (log_ratio)^2 / 2 (Taylor expansion)
|
||||
# Or just use log_ratio directly as a penalty
|
||||
kl_per_token = log_ratio.pow(2) # Squared for symmetric penalty
|
||||
kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean()
|
||||
total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps
|
||||
else:
|
||||
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
|
||||
total_loss = policy_loss / gradient_accumulation_steps
|
||||
|
||||
# Compute metrics for logging
|
||||
with torch.no_grad():
|
||||
# Fraction of tokens where ratio was clipped
|
||||
clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float()
|
||||
clipped_fraction = (clipped_fraction * mask).sum() / mask.sum()
|
||||
|
||||
# Mean ratio and KL for monitoring
|
||||
mean_ratio = (ratio * mask).sum() / mask.sum()
|
||||
mean_kl = (log_ratio.pow(2) * mask).sum() / mask.sum()
|
||||
|
||||
# For backward compatibility: collect training logprobs
|
||||
raw_logp_per_token = -F.cross_entropy(
|
||||
outputs.logits.view(-1, outputs.logits.size(-1)),
|
||||
labels.view(-1),
|
||||
reduction="none",
|
||||
ignore_index=-100,
|
||||
).view(labels.shape)
|
||||
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
|
||||
else:
|
||||
# Fallback: REINFORCE-style (no reference policy)
|
||||
# This is what the original code did - NOT recommended!
|
||||
print(" [WARNING] No reference logprobs - using REINFORCE (may cause reward hacking!)")
|
||||
|
||||
# Simple policy gradient: -log(π) * A
|
||||
policy_loss = ((-logp_per_token * mask * adv_expanded).sum(dim=-1) / mask_sum).mean()
|
||||
total_loss = policy_loss / gradient_accumulation_steps
|
||||
kl_penalty = torch.tensor(0.0, device=logp_per_token.device)
|
||||
|
||||
with torch.no_grad():
|
||||
clipped_fraction = torch.tensor(0.0)
|
||||
mean_ratio = torch.tensor(1.0)
|
||||
mean_kl = torch.tensor(0.0)
|
||||
raw_logp_per_token = -F.cross_entropy(
|
||||
outputs.logits.view(-1, outputs.logits.size(-1)),
|
||||
labels.view(-1),
|
||||
reduction="none",
|
||||
ignore_index=-100,
|
||||
).view(labels.shape)
|
||||
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
|
||||
|
||||
# === Compute Additional Metrics ===
|
||||
with torch.no_grad():
|
||||
pos = (advantages > 0).float()
|
||||
neg = (advantages <= 0).float()
|
||||
mask_float = mask.to(logp_per_token.dtype)
|
||||
mask_sum = mask_float.sum(dim=-1).clamp_min(1e-8)
|
||||
avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum
|
||||
pos_logp = (logp_per_token * pos).mean().item()
|
||||
neg_logp = (logp_per_token * neg).mean().item()
|
||||
|
||||
# For alignment check: compute logprobs WITHOUT temperature scaling
|
||||
# This allows fair comparison with inference logprobs (which are at temp=1.0)
|
||||
raw_logp_per_token = -F.cross_entropy(
|
||||
outputs.logits.view(-1, outputs.logits.size(-1)), # Use original logits, not temp-scaled
|
||||
labels.view(-1),
|
||||
reduction="none",
|
||||
ignore_index=-100,
|
||||
).view(labels.shape)
|
||||
|
||||
# Collect raw training logprobs for masked positions (generated tokens only)
|
||||
# Keep as PyTorch tensor (supports bfloat16 natively)
|
||||
training_logprobs_flat = raw_logp_per_token[mask.bool()].detach()
|
||||
|
||||
# GRPO loss: weighted log probabilities by advantages
|
||||
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
|
||||
grpo_loss = (
|
||||
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
|
||||
* advantages.to(logp_per_token.device)
|
||||
).mean() / gradient_accumulation_steps
|
||||
|
||||
# Compute a more interpretable loss metric (advantage-weighted logprobs)
|
||||
with torch.no_grad():
|
||||
# Interpretable metric: advantage-weighted average logprob
|
||||
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
|
||||
|
||||
metrics = {
|
||||
|
|
@ -233,11 +309,16 @@ def compute_grpo_loss(
|
|||
"avg_logp": avg_logp,
|
||||
"pos_count": pos.sum().item(),
|
||||
"neg_count": neg.sum().item(),
|
||||
"training_logprobs": training_logprobs_flat, # For alignment check
|
||||
"interpretable_loss": interpretable_loss, # More meaningful metric
|
||||
"training_logprobs": training_logprobs_flat,
|
||||
"interpretable_loss": interpretable_loss,
|
||||
# GRPO-specific metrics
|
||||
"kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty,
|
||||
"mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio,
|
||||
"mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl,
|
||||
"clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction,
|
||||
}
|
||||
|
||||
return grpo_loss, metrics
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
def compute_logprob_alignment(
|
||||
|
|
@ -309,17 +390,16 @@ def run_training_step(
|
|||
advantage_batches: List[torch.Tensor],
|
||||
temperature_batches: List[torch.Tensor],
|
||||
config: TrainingConfig,
|
||||
inference_logprobs: Optional[List[np.ndarray]] = None,
|
||||
inference_logprob_batches: Optional[List[torch.Tensor]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Run a single training step with gradient accumulation.
|
||||
|
||||
Performs:
|
||||
1. Forward pass through all micro-batches
|
||||
1. Forward pass through all micro-batches with proper GRPO loss
|
||||
2. Backward pass with gradient accumulation
|
||||
3. Gradient clipping
|
||||
4. Optimizer step
|
||||
5. (Optional) Logprob alignment check
|
||||
|
||||
Args:
|
||||
model: The model to train
|
||||
|
|
@ -328,8 +408,8 @@ def run_training_step(
|
|||
label_batches: List of label tensors
|
||||
advantage_batches: List of advantage tensors
|
||||
temperature_batches: List of temperature tensors
|
||||
config: Training configuration
|
||||
inference_logprobs: Optional logprobs from inference for alignment check
|
||||
config: Training configuration (includes kl_coef, clip_eps, use_reference_logprobs)
|
||||
inference_logprob_batches: Batched logprobs from inference (π_old), aligned with labels
|
||||
|
||||
Returns:
|
||||
Dict of training metrics for this step
|
||||
|
|
@ -341,16 +421,32 @@ def run_training_step(
|
|||
total_neg_logp = 0.0
|
||||
total_pos = 0.0
|
||||
total_neg = 0.0
|
||||
total_kl_penalty = 0.0
|
||||
total_mean_ratio = 0.0
|
||||
total_mean_kl = 0.0
|
||||
total_clipped_fraction = 0.0
|
||||
grad_norm = 0.0
|
||||
all_training_logprobs: List[torch.Tensor] = []
|
||||
|
||||
# Get GRPO hyperparameters from config
|
||||
kl_coef = getattr(config, 'kl_coef', 0.1)
|
||||
clip_eps = getattr(config, 'clip_eps', 0.2)
|
||||
use_reference_logprobs = getattr(config, 'use_reference_logprobs', True)
|
||||
|
||||
# Accumulate gradients over micro-batches
|
||||
for tokens, labels, advantages, temperatures in zip(
|
||||
num_batches = len(token_batches) if token_batches else 1
|
||||
|
||||
for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip(
|
||||
token_batches, label_batches, advantage_batches, temperature_batches
|
||||
):
|
||||
)):
|
||||
tokens = tokens.to(config.device)
|
||||
labels = labels.to(config.device)
|
||||
advantages = advantages.to(config.device)
|
||||
|
||||
# Get corresponding inference logprobs batch if available
|
||||
inf_logprobs = None
|
||||
if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches):
|
||||
inf_logprobs = inference_logprob_batches[batch_idx]
|
||||
|
||||
loss, metrics = compute_grpo_loss(
|
||||
model,
|
||||
|
|
@ -359,6 +455,10 @@ def run_training_step(
|
|||
advantages,
|
||||
temperatures,
|
||||
config.gradient_accumulation_steps,
|
||||
inference_logprobs=inf_logprobs,
|
||||
kl_coef=kl_coef,
|
||||
clip_eps=clip_eps,
|
||||
use_reference_logprobs=use_reference_logprobs,
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
|
@ -368,7 +468,13 @@ def run_training_step(
|
|||
total_pos += metrics["pos_count"]
|
||||
total_neg += metrics["neg_count"]
|
||||
|
||||
# Collect training logprobs for alignment check
|
||||
# Accumulate GRPO-specific metrics
|
||||
total_kl_penalty += metrics.get("kl_penalty", 0.0)
|
||||
total_mean_ratio += metrics.get("mean_ratio", 1.0)
|
||||
total_mean_kl += metrics.get("mean_kl", 0.0)
|
||||
total_clipped_fraction += metrics.get("clipped_fraction", 0.0)
|
||||
|
||||
# Collect training logprobs for alignment monitoring
|
||||
if "training_logprobs" in metrics:
|
||||
all_training_logprobs.append(metrics["training_logprobs"])
|
||||
|
||||
|
|
@ -380,8 +486,7 @@ def run_training_step(
|
|||
# Help prevent memory fragmentation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Normalize metrics by count
|
||||
num_batches = len(token_batches) if token_batches else 1
|
||||
# Normalize metrics by batch count
|
||||
if total_pos > 0:
|
||||
total_pos_logp /= num_batches
|
||||
if total_neg > 0:
|
||||
|
|
@ -394,18 +499,22 @@ def run_training_step(
|
|||
"neg_logp": total_neg_logp,
|
||||
"pos_count": total_pos,
|
||||
"neg_count": total_neg,
|
||||
# GRPO-specific metrics (averaged over batches)
|
||||
"kl_penalty": total_kl_penalty / num_batches,
|
||||
"mean_ratio": total_mean_ratio / num_batches,
|
||||
"mean_kl": total_mean_kl / num_batches,
|
||||
"clipped_fraction": total_clipped_fraction / num_batches,
|
||||
}
|
||||
|
||||
# Compute logprob alignment stats
|
||||
# NOTE: This comparison is approximate - inference and training logprobs
|
||||
# come from different batching, so token-by-token alignment isn't possible.
|
||||
# The real-time test at startup is the reliable alignment check.
|
||||
if inference_logprobs is not None and all_training_logprobs:
|
||||
alignment_stats = compute_logprob_alignment(
|
||||
inference_logprobs, all_training_logprobs, debug=False
|
||||
)
|
||||
_logprob_alignment_stats.update(alignment_stats)
|
||||
result["logprob_alignment"] = alignment_stats
|
||||
# Compute logprob alignment stats for monitoring
|
||||
# NOTE: Now that we use proper GRPO, this is less critical
|
||||
# but still useful for debugging weight sharing issues
|
||||
if all_training_logprobs:
|
||||
# Store training logprob stats
|
||||
train_flat = torch.cat(all_training_logprobs)
|
||||
if train_flat.numel() > 0:
|
||||
_logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item()
|
||||
_logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item()
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -441,19 +550,27 @@ def log_metrics(
|
|||
if "gpu_memory_gb" in metrics:
|
||||
timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB"
|
||||
|
||||
# Show interpretable loss (advantage-weighted logprobs) if available
|
||||
interp_loss = metrics.get("interpretable_loss")
|
||||
if interp_loss is not None:
|
||||
print(f" AdvWeightedLogP: {interp_loss:.4f}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
|
||||
else:
|
||||
loss_str = (
|
||||
f"{metrics['loss']:.6f}"
|
||||
if abs(metrics["loss"]) < 0.01
|
||||
else f"{metrics['loss']:.4f}"
|
||||
)
|
||||
print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
|
||||
# Primary metrics line: Loss and grad norm
|
||||
loss_str = (
|
||||
f"{metrics['loss']:.6f}"
|
||||
if abs(metrics["loss"]) < 0.01
|
||||
else f"{metrics['loss']:.4f}"
|
||||
)
|
||||
print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
|
||||
|
||||
# Show GRPO-specific metrics if available
|
||||
# GRPO metrics line: KL, ratio, clipping
|
||||
kl_penalty = metrics.get("kl_penalty", 0)
|
||||
mean_ratio = metrics.get("mean_ratio", 1.0)
|
||||
mean_kl = metrics.get("mean_kl", 0)
|
||||
clipped_frac = metrics.get("clipped_fraction", 0)
|
||||
|
||||
if kl_penalty > 0 or mean_kl > 0:
|
||||
print(
|
||||
f" GRPO: KL={mean_kl:.4f}, ratio={mean_ratio:.3f}, "
|
||||
f"clipped={clipped_frac*100:.1f}%"
|
||||
)
|
||||
|
||||
# Advantage distribution
|
||||
if "pos_count" in metrics or "neg_count" in metrics:
|
||||
pos_count = metrics.get("pos_count", 0)
|
||||
neg_count = metrics.get("neg_count", 0)
|
||||
|
|
@ -463,24 +580,6 @@ def log_metrics(
|
|||
f" Advantages: +{int(pos_count)} / -{int(neg_count)}, "
|
||||
f"LogP: pos={pos_logp:.3f}, neg={neg_logp:.3f}"
|
||||
)
|
||||
|
||||
# Show logprob alignment stats (important for shared_vllm validation!)
|
||||
if "logprob_alignment" in metrics:
|
||||
alignment = metrics["logprob_alignment"]
|
||||
if "logprobs/diff" in alignment:
|
||||
diff = alignment["logprobs/diff"]
|
||||
inf_mean = alignment.get("logprobs/inference_mean", 0)
|
||||
train_mean = alignment.get("logprobs/training_mean", 0)
|
||||
|
||||
# NOTE: This comparison has a fundamental timing issue!
|
||||
# - inference_logprobs: from vLLM at generation time (possibly stale)
|
||||
# - training_logprobs: from trainer's current forward pass
|
||||
# After training starts, weights change, making comparison invalid.
|
||||
#
|
||||
# NOTE: This diff is just for monitoring, not validation!
|
||||
# The real-time test at startup is the reliable alignment check.
|
||||
# This diff will naturally drift as training progresses (expected).
|
||||
print(f" LogProb Stats: inf_mean={inf_mean:.4f}, train_mean={train_mean:.4f}")
|
||||
|
||||
if use_wandb:
|
||||
log_dict = {
|
||||
|
|
@ -488,6 +587,11 @@ def log_metrics(
|
|||
"train/grad_norm": metrics["grad_norm"],
|
||||
"train/pos_logp": metrics.get("pos_logp", 0),
|
||||
"train/neg_logp": metrics.get("neg_logp", 0),
|
||||
# GRPO-specific metrics
|
||||
"grpo/kl_penalty": kl_penalty,
|
||||
"grpo/mean_ratio": mean_ratio,
|
||||
"grpo/mean_kl": mean_kl,
|
||||
"grpo/clipped_fraction": clipped_frac,
|
||||
}
|
||||
# Add timing metrics if present
|
||||
for key in ["step_time", "sync_time", "data_fetch_time",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue