atropos/example_trainer/grpo.py
2025-04-29 12:10:10 -07:00

548 lines
21 KiB
Python

import atexit
import json
import math
import os
import random
import shutil
import string
import subprocess
import time
from typing import List, Optional, Tuple
import numpy as np
import requests
import torch
import torch.nn.functional as F
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb # Added for logging
# Global variable to keep track of the vLLM process
vllm_process = None
def cleanup_vllm():
global vllm_process
if vllm_process:
print("\nTerminating vLLM process...")
vllm_process.terminate()
try:
vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown
print("vLLM process terminated.")
except subprocess.TimeoutExpired:
print("vLLM process did not terminate gracefully, killing.")
vllm_process.kill()
vllm_process.wait()
print("vLLM process killed.")
vllm_process = None
# Register the cleanup function to be called on script exit
atexit.register(cleanup_vllm)
class TrainingConfig(BaseModel):
"""
Training details, model, etc
"""
model_name: str = Field(..., description="Name of the base model to train")
lr: float = Field(1e-5, description="Learning rate for the optimizer")
training_steps: int = Field(
10, description="Number of training steps"
) # Renamed from epochs
batch_size: int = Field(
2, description="Batch size for training (will be handled by get_data)"
)
seq_len: int = Field(2048, description="Sequence length for training")
gradient_accumulation_steps: int = Field(
32, description="Number of gradient accumulation steps"
)
device: str = Field(
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
)
save_path: str = Field(
"trained_model_checkpoints", description="Base path to save model checkpoints"
)
vllm_restart_interval: int = Field(
3, description="Restart vLLM every N training steps"
)
vllm_port: int = Field(9001, description="Port for the vLLM server")
# Wandb configuration
use_wandb: bool = Field(
False, description="Whether to use Weights & Biases for logging"
)
wandb_project: Optional[str] = Field(None, description="Wandb project name")
wandb_group: Optional[str] = Field(None, description="Wandb group name")
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
def register_trainer(config: TrainingConfig):
"""
Register the trainer with the Atropos API
"""
requests.post(
"http://localhost:8000/register",
json={
"wandb_group": config.wandb_group,
"wandb_project": config.wandb_project,
"batch_size": config.batch_size * config.gradient_accumulation_steps,
"max_token_len": config.seq_len,
"starting_step": 0,
"checkpoint_dir": config.save_path,
"save_checkpoint_interval": config.training_steps,
"num_steps": config.training_steps,
},
timeout=10,
)
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
def get_batch():
data = requests.get("http://localhost:8000/batch", timeout=10).json()
return data
def pad_data_to_good_offset(data, batch_size: int):
max_token_len = max(
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
)
# usually 64 is a good choice to ensure nonweird scaling behavior on GPUS
# so we pad to the nearest multiple of 64
good_multiple = 64
if (max_token_len - 1) % (good_multiple) != 0:
max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple
token_setup_len = (
max_token_len + 1
) # add 1 so we can make it causal at the proper length
else:
token_setup_len = max_token_len
max_token_len = (
max_token_len - 1
) # since it's causal we need to remove the last bit...
# pad all tokens to max_token_len and add to lists
input_ids = list()
labels = list()
advantages = list()
lengths = list()
for item in data["batch"]:
scores = item["scores"]
scores = np.array(scores)
# check if we have more than 1 score...
if len(scores) > 1:
scores = scores - scores.mean()
scores = scores / max(scores.std(), 1e-8)
item["scores"] = scores
if item["overrides"] is not None:
for i in range(len(item["overrides"])):
if item["overrides"][i].get("set_advantage_to_zero", False):
item["scores"][i] = 0
for i in range(len(item["tokens"])):
lengths.append(
math.ceil((len(item["tokens"][i]) - 1) / (good_multiple))
* good_multiple
)
label_item = np.concatenate(
[
np.array(item["masks"][i]),
np.full(
max(0, token_setup_len - len(item["tokens"][i])),
-100,
dtype=np.int32,
),
]
)
item["tokens"][i] = np.concatenate(
[
np.array(item["tokens"][i]),
np.zeros(
max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32
),
]
)
input_ids.append(item["tokens"][i][:-1])
labels.append(label_item[1:])
advantages.append(item["scores"][i])
# combine all lists into tensors
token_batches = []
label_batches = []
advantage_batches = []
for i in range(len(input_ids) // batch_size):
token_batches.append(
torch.tensor(
np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0)
)
)
label_batches.append(
torch.tensor(
np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0)
)
)
advantage_batches.append(
torch.tensor(
np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0)
).view(-1, 1)
)
return token_batches, label_batches, advantage_batches
def get_data(
batch_size: int, seq_len: int
) -> List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
"""
getting data from the api
"""
batches = []
while True:
data = get_batch()
if data["batch"] is not None:
# Save the batch
with open("temp.json", "w", encoding="utf-8") as f:
json.dump(data, f)
# In case the inference runs ahead of the training, we loop until we don't have any more data
batches.append(pad_data_to_good_offset(data, batch_size))
elif len(batches) > 0:
# Return the batches
return batches
else:
time.sleep(1)
def train(config: TrainingConfig):
"""
Setups and runs GRPO training, restarting vLLM periodically, with wandb logging.
"""
global vllm_process # Declare intention to modify the global variable
# --- Wandb Setup ---
if config.use_wandb:
if not config.wandb_project:
print("Warning: wandb_project not set, disabling wandb.")
config.use_wandb = False
else:
if not config.wandb_group:
# Set group to random 8 character string
config.wandb_group = "".join(
random.choices(string.ascii_letters + string.digits, k=8)
)
try:
wandb.init(
project=config.wandb_project,
group=config.wandb_group,
config=config.dict(), # Log config parameters
)
print(
f"Wandb logging enabled. Run: {wandb.run.name} (Project: {config.wandb_project}) "
)
except Exception as e:
print(f"Error initializing wandb: {e}. Disabling wandb.")
config.use_wandb = False
# --- End Wandb Setup ---
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForCausalLM.from_pretrained(
config.model_name, torch_dtype=torch.bfloat16
)
model.to(config.device)
model.gradient_checkpointing_enable()
model.train()
# Setup optimizer
optimizer = AdamW(model.parameters(), lr=config.lr)
print(
f"Starting training for {config.training_steps} steps on device: {config.device}"
)
print(
f"vLLM will be restarted every {config.vllm_restart_interval} steps on port {config.vllm_port}"
)
os.makedirs(config.save_path, exist_ok=True) # Ensure base save directory exists
register_trainer(config)
# Init vllm
vllm_command = [
"python",
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
config.model_name,
"--port",
str(config.vllm_port),
"--dtype",
"auto",
"--gpu-memory-utilization",
"0.45",
"--disable-log-requests",
]
print(f" Launching vLLM server: {' '.join(vllm_command)}")
try:
vllm_process = subprocess.Popen(vllm_command)
print(f" vLLM server launched with PID: {vllm_process.pid}")
# Check immediate errors
try:
stdout, stderr = vllm_process.communicate(timeout=2)
if vllm_process.returncode is not None and vllm_process.returncode != 0:
print(f" Error starting vLLM: {stderr.decode()}")
vllm_process = None
# Maybe raise error or just warn?
print(" WARNING: Failed to start vLLM server after checkpoint.")
except subprocess.TimeoutExpired:
print(" vLLM process started (check logs for details).")
except FileNotFoundError:
print(
"\n *** ERROR: 'python -m vllm...' command not found. Make sure vLLM is installed and accessible. ***\n"
)
# Potentially stop training or just disable further vLLM restarts
print(" Disabling further vLLM restarts.")
config.vllm_restart_interval = (
config.training_steps + 1
) # Prevent further restarts
except Exception as e:
print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n")
print(" Disabling further vLLM restarts.")
config.vllm_restart_interval = (
config.training_steps + 1
) # Prevent further restarts
batches = list()
for step in range(config.training_steps):
total_loss = 0
print(f"Step {step+1}/{config.training_steps}")
total_pos_logp = 0
total_neg_logp = 0
total_logp = 0
total_pos = 0
total_neg = 0
if len(batches) == 0:
batches = get_data(config.batch_size, config.seq_len)
token_batches, label_batches, advantage_batches = batches.pop(0)
# Terminate existing vLLM process if running
if (
step + 1
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step
# Terminate existing vLLM process if running
if vllm_process:
print(" Terminating existing vLLM process...")
vllm_process.terminate()
try:
vllm_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(
" Existing vLLM process did not terminate gracefully, killing."
)
vllm_process.kill()
vllm_process.wait()
vllm_process = None
for tokens, labels, advantages in zip(
token_batches, label_batches, advantage_batches
):
tokens, labels, advantages = (
tokens.to(config.device),
labels.to(config.device),
advantages.to(config.device),
)
# Forward pass
# User specified that tokens/labels are already prepared by get_data
outputs = model(tokens) # Assuming model just needs tokens
logits = outputs.logits # Assuming this is the structure
# Calculate GRPO loss (reverting to user's previous logic)
# User stated ignore_index is -100 and tokens/labels are aligned by get_data
# Assuming logits correspond directly to labels indices (no shift needed here)
logp_per_token = -F.cross_entropy(
logits.view(-1, logits.size(-1)), # Flatten logits
labels.view(-1), # Flatten labels
reduction="none",
ignore_index=-100, # User specified ignore index
).view(
labels.shape
) # Reshape back to (batch, seq_len)
# Masking based on labels != -100
mask = (labels != -100).float()
with torch.no_grad():
pos = (advantages > 0).float()
neg = (advantages <= 0).float()
avg_logp = (logp_per_token * mask).sum(-1) / mask.sum(-1)
pos_logp = (logp_per_token * pos).mean().item()
neg_logp = (logp_per_token * neg).mean().item()
total_pos_logp += pos_logp
total_neg_logp += neg_logp
total_logp += avg_logp
total_pos += pos.sum().item()
total_neg += neg.sum().item()
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
grpo_loss = (
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
* advantages.to(logp_per_token.device)
).mean() / config.gradient_accumulation_steps
grpo_loss.backward()
total_loss += grpo_loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
if total_pos > 0:
total_pos_logp /= total_pos
if total_neg > 0:
total_neg_logp /= total_neg
# --- Wandb Logging ---
if config.use_wandb:
wandb.log(
{
"train/loss": total_loss,
"train/learning_rate": optimizer.param_groups[0]["lr"],
"train/grad_norm": grad_norm.item(),
"train/pos_logp": total_pos_logp,
"train/neg_logp": total_neg_logp,
"train/logp": total_logp,
},
step=step + 1,
)
# --- End Wandb Logging ---
print(f" Step Loss: {grpo_loss.item():.4f}")
# --- vLLM Restart Logic (Moved AFTER optimizer step) ---
# Note: There are much better ways of updating the policy, this is just a very simple example
if (
step + 1
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step
checkpoint_path = os.path.join(
config.save_path, f"step_{step+1}"
) # Save as step+1 since it's after step completion
print(f" Saving checkpoint to {checkpoint_path}...")
# Ensure fresh directory for saving
if os.path.exists(checkpoint_path):
shutil.rmtree(checkpoint_path) # Remove old checkpoint if it exists
os.makedirs(checkpoint_path, exist_ok=True)
model.save_pretrained(checkpoint_path)
tokenizer.save_pretrained(checkpoint_path)
print(" Checkpoint saved.")
# Terminate existing vLLM process if running
if vllm_process:
print(" Terminating existing vLLM process...")
vllm_process.terminate()
try:
vllm_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(
" Existing vLLM process did not terminate gracefully, killing."
)
vllm_process.kill()
vllm_process.wait()
vllm_process = None
# Launch new vLLM process (only if not the very last step, maybe? depends on use case)
# Let's still launch it on the last step for consistency, cleanup will handle it.
vllm_command = [
"python",
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
os.path.join(config.save_path, f"step_{step+1}"),
"--port",
str(config.vllm_port),
"--dtype",
"auto",
"--gpu-memory-utilization",
"0.45",
"--disable-log-requests",
"--served-model-name",
config.model_name,
]
print(f" Launching vLLM server: {' '.join(vllm_command)}")
torch.cuda.empty_cache()
try:
vllm_process = subprocess.Popen(vllm_command)
print(f" vLLM server launched with PID: {vllm_process.pid}")
# Check immediate errors
try:
stdout, stderr = vllm_process.communicate(timeout=2)
if (
vllm_process.returncode is not None
and vllm_process.returncode != 0
):
print(f" Error starting vLLM: {stderr.decode()}")
vllm_process = None
# Maybe raise error or just warn?
print(
" WARNING: Failed to start vLLM server after checkpoint."
)
except subprocess.TimeoutExpired:
print(" vLLM process started (check logs for details).")
except FileNotFoundError:
print(
"\n *** ERROR: 'python -m vllm...' command not found. ",
"Make sure vLLM is installed and accessible. ***\n",
)
# Potentially stop training or just disable further vLLM restarts
print(" Disabling further vLLM restarts.")
config.vllm_restart_interval = (
config.training_steps + 1
) # Prevent further restarts
except Exception as e:
print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n")
print(" Disabling further vLLM restarts.")
config.vllm_restart_interval = (
config.training_steps + 1
) # Prevent further restarts
# --- End vLLM Restart Logic ---
# Basic check if vLLM process terminated unexpectedly (outside interval check)
if vllm_process and vllm_process.poll() is not None:
print(
f"\n *** WARNING: vLLM process terminated unexpectedly (return code: {vllm_process.returncode}). ",
"Check vLLM logs. ***\n",
)
stderr_output = (
vllm_process.stderr.read().decode()
if vllm_process.stderr
else "No stderr"
)
print(f"vLLM stderr: {stderr_output}")
vllm_process = None # Reset so it relaunches next interval
print("Training finished.")
# --- Wandb Finish ---
if config.use_wandb:
wandb.finish()
# --- End Wandb Finish ---
# Final cleanup (vLLM termination) is handled by atexit
# --- Placeholder for final model save ---
final_save_path = os.path.join(config.save_path, "final_model")
print(f"Saving final model to {final_save_path}")
if os.path.exists(final_save_path):
shutil.rmtree(final_save_path)
os.makedirs(final_save_path, exist_ok=True)
model.save_pretrained(final_save_path)
tokenizer.save_pretrained(final_save_path)
print("Final model saved.")
# Example usage (optional, can be run from another script)
if __name__ == "__main__":
# Example: Create a config and run training
# Replace "gpt2" with your desired model
training_config = TrainingConfig(
model_name="Qwen/Qwen2.5-1.5B-Instruct",
training_steps=20, # Use steps
vllm_restart_interval=3, # Example interval
use_wandb=True, # Set to True to enable logging
wandb_project="grpo-trainer-example", # Replace with your project name
)
# --- End Mock ---
train(training_config)