mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
547 lines
21 KiB
Python
547 lines
21 KiB
Python
import atexit
|
|
import json
|
|
import math
|
|
import os
|
|
import random
|
|
import shutil
|
|
import string
|
|
import subprocess
|
|
import time
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import wandb # Added for logging
|
|
from pydantic import BaseModel, Field
|
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
from torch.optim import AdamW
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
# Global variable to keep track of the vLLM process
|
|
vllm_process = None
|
|
|
|
|
|
def cleanup_vllm():
|
|
global vllm_process
|
|
if vllm_process:
|
|
print("\nTerminating vLLM process...")
|
|
vllm_process.terminate()
|
|
try:
|
|
vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown
|
|
print("vLLM process terminated.")
|
|
except subprocess.TimeoutExpired:
|
|
print("vLLM process did not terminate gracefully, killing.")
|
|
vllm_process.kill()
|
|
vllm_process.wait()
|
|
print("vLLM process killed.")
|
|
vllm_process = None
|
|
|
|
|
|
# Register the cleanup function to be called on script exit
|
|
atexit.register(cleanup_vllm)
|
|
|
|
|
|
class TrainingConfig(BaseModel):
|
|
"""
|
|
Training details, model, etc
|
|
"""
|
|
|
|
model_name: str = Field(..., description="Name of the base model to train")
|
|
lr: float = Field(1e-5, description="Learning rate for the optimizer")
|
|
training_steps: int = Field(
|
|
10, description="Number of training steps"
|
|
) # Renamed from epochs
|
|
batch_size: int = Field(
|
|
2, description="Batch size for training (will be handled by get_data)"
|
|
)
|
|
seq_len: int = Field(2048, description="Sequence length for training")
|
|
gradient_accumulation_steps: int = Field(
|
|
32, description="Number of gradient accumulation steps"
|
|
)
|
|
device: str = Field(
|
|
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
|
|
)
|
|
save_path: str = Field(
|
|
"trained_model_checkpoints", description="Base path to save model checkpoints"
|
|
)
|
|
vllm_restart_interval: int = Field(
|
|
3, description="Restart vLLM every N training steps"
|
|
)
|
|
vllm_port: int = Field(9001, description="Port for the vLLM server")
|
|
|
|
# Wandb configuration
|
|
use_wandb: bool = Field(
|
|
False, description="Whether to use Weights & Biases for logging"
|
|
)
|
|
wandb_project: Optional[str] = Field(None, description="Wandb project name")
|
|
wandb_group: Optional[str] = Field(None, description="Wandb group name")
|
|
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
|
|
def register_trainer(config: TrainingConfig):
|
|
"""
|
|
Register the trainer with the Atropos API
|
|
"""
|
|
requests.post(
|
|
"http://localhost:8000/register",
|
|
json={
|
|
"wandb_group": config.wandb_group,
|
|
"wandb_project": config.wandb_project,
|
|
"batch_size": config.batch_size * config.gradient_accumulation_steps,
|
|
"max_token_len": config.seq_len,
|
|
"starting_step": 0,
|
|
"checkpoint_dir": config.save_path,
|
|
"save_checkpoint_interval": config.training_steps,
|
|
"num_steps": config.training_steps,
|
|
},
|
|
timeout=10,
|
|
)
|
|
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
|
|
def get_batch():
|
|
data = requests.get("http://localhost:8000/batch", timeout=10).json()
|
|
return data
|
|
|
|
|
|
def pad_data_to_good_offset(data, batch_size: int):
|
|
max_token_len = max(
|
|
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
|
|
)
|
|
# usually 64 is a good choice to ensure nonweird scaling behavior on GPUS
|
|
# so we pad to the nearest multiple of 64
|
|
good_multiple = 64
|
|
if (max_token_len - 1) % (good_multiple) != 0:
|
|
max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple
|
|
token_setup_len = (
|
|
max_token_len + 1
|
|
) # add 1 so we can make it causal at the proper length
|
|
else:
|
|
token_setup_len = max_token_len
|
|
max_token_len = (
|
|
max_token_len - 1
|
|
) # since it's causal we need to remove the last bit...
|
|
# pad all tokens to max_token_len and add to lists
|
|
input_ids = list()
|
|
labels = list()
|
|
advantages = list()
|
|
lengths = list()
|
|
for item in data["batch"]:
|
|
scores = item["scores"]
|
|
scores = np.array(scores)
|
|
# check if we have more than 1 score...
|
|
if len(scores) > 1:
|
|
scores = scores - scores.mean()
|
|
scores = scores / max(scores.std(), 1e-8)
|
|
item["scores"] = scores
|
|
if item["overrides"] is not None:
|
|
for i in range(len(item["overrides"])):
|
|
if item["overrides"][i].get("set_advantage_to_zero", False):
|
|
item["scores"][i] = 0
|
|
for i in range(len(item["tokens"])):
|
|
lengths.append(
|
|
math.ceil((len(item["tokens"][i]) - 1) / (good_multiple))
|
|
* good_multiple
|
|
)
|
|
label_item = np.concatenate(
|
|
[
|
|
np.array(item["masks"][i]),
|
|
np.full(
|
|
max(0, token_setup_len - len(item["tokens"][i])),
|
|
-100,
|
|
dtype=np.int32,
|
|
),
|
|
]
|
|
)
|
|
item["tokens"][i] = np.concatenate(
|
|
[
|
|
np.array(item["tokens"][i]),
|
|
np.zeros(
|
|
max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32
|
|
),
|
|
]
|
|
)
|
|
input_ids.append(item["tokens"][i][:-1])
|
|
labels.append(label_item[1:])
|
|
advantages.append(item["scores"][i])
|
|
# combine all lists into tensors
|
|
token_batches = []
|
|
label_batches = []
|
|
advantage_batches = []
|
|
for i in range(len(input_ids) // batch_size):
|
|
token_batches.append(
|
|
torch.tensor(
|
|
np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0)
|
|
)
|
|
)
|
|
label_batches.append(
|
|
torch.tensor(
|
|
np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0)
|
|
)
|
|
)
|
|
advantage_batches.append(
|
|
torch.tensor(
|
|
np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0)
|
|
).view(-1, 1)
|
|
)
|
|
return token_batches, label_batches, advantage_batches
|
|
|
|
|
|
def get_data(
|
|
batch_size: int, seq_len: int
|
|
) -> List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
|
|
"""
|
|
getting data from the api
|
|
"""
|
|
batches = []
|
|
while True:
|
|
data = get_batch()
|
|
if data["batch"] is not None:
|
|
# Save the batch
|
|
with open("temp.json", "w", encoding="utf-8") as f:
|
|
json.dump(data, f)
|
|
# In case the inference runs ahead of the training, we loop until we don't have any more data
|
|
batches.append(pad_data_to_good_offset(data, batch_size))
|
|
elif len(batches) > 0:
|
|
# Return the batches
|
|
return batches
|
|
else:
|
|
time.sleep(1)
|
|
|
|
|
|
def train(config: TrainingConfig):
|
|
"""
|
|
Setups and runs GRPO training, restarting vLLM periodically, with wandb logging.
|
|
"""
|
|
global vllm_process # Declare intention to modify the global variable
|
|
|
|
# --- Wandb Setup ---
|
|
if config.use_wandb:
|
|
if not config.wandb_project:
|
|
print("Warning: wandb_project not set, disabling wandb.")
|
|
config.use_wandb = False
|
|
else:
|
|
if not config.wandb_group:
|
|
# Set group to random 8 character string
|
|
config.wandb_group = "".join(
|
|
random.choices(string.ascii_letters + string.digits, k=8)
|
|
)
|
|
try:
|
|
wandb.init(
|
|
project=config.wandb_project,
|
|
group=config.wandb_group,
|
|
config=config.dict(), # Log config parameters
|
|
)
|
|
print(
|
|
f"Wandb logging enabled. Run: {wandb.run.name} (Project: {config.wandb_project}) "
|
|
)
|
|
except Exception as e:
|
|
print(f"Error initializing wandb: {e}. Disabling wandb.")
|
|
config.use_wandb = False
|
|
# --- End Wandb Setup ---
|
|
|
|
# Initialize model and tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
config.model_name, torch_dtype=torch.bfloat16
|
|
)
|
|
|
|
model.to(config.device)
|
|
model.gradient_checkpointing_enable()
|
|
model.train()
|
|
|
|
# Setup optimizer
|
|
optimizer = AdamW(model.parameters(), lr=config.lr)
|
|
|
|
print(
|
|
f"Starting training for {config.training_steps} steps on device: {config.device}"
|
|
)
|
|
print(
|
|
f"vLLM will be restarted every {config.vllm_restart_interval} steps on port {config.vllm_port}"
|
|
)
|
|
|
|
os.makedirs(config.save_path, exist_ok=True) # Ensure base save directory exists
|
|
register_trainer(config)
|
|
|
|
# Init vllm
|
|
vllm_command = [
|
|
"python",
|
|
"-m",
|
|
"vllm.entrypoints.openai.api_server",
|
|
"--model",
|
|
config.model_name,
|
|
"--port",
|
|
str(config.vllm_port),
|
|
"--dtype",
|
|
"auto",
|
|
"--gpu-memory-utilization",
|
|
"0.45",
|
|
"--disable-log-requests",
|
|
]
|
|
print(f" Launching vLLM server: {' '.join(vllm_command)}")
|
|
try:
|
|
vllm_process = subprocess.Popen(vllm_command)
|
|
print(f" vLLM server launched with PID: {vllm_process.pid}")
|
|
# Check immediate errors
|
|
try:
|
|
stdout, stderr = vllm_process.communicate(timeout=2)
|
|
if vllm_process.returncode is not None and vllm_process.returncode != 0:
|
|
print(f" Error starting vLLM: {stderr.decode()}")
|
|
vllm_process = None
|
|
# Maybe raise error or just warn?
|
|
print(" WARNING: Failed to start vLLM server after checkpoint.")
|
|
except subprocess.TimeoutExpired:
|
|
print(" vLLM process started (check logs for details).")
|
|
except FileNotFoundError:
|
|
print(
|
|
"\n *** ERROR: 'python -m vllm...' command not found. Make sure vLLM is installed and accessible. ***\n"
|
|
)
|
|
# Potentially stop training or just disable further vLLM restarts
|
|
print(" Disabling further vLLM restarts.")
|
|
config.vllm_restart_interval = (
|
|
config.training_steps + 1
|
|
) # Prevent further restarts
|
|
except Exception as e:
|
|
print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n")
|
|
print(" Disabling further vLLM restarts.")
|
|
config.vllm_restart_interval = (
|
|
config.training_steps + 1
|
|
) # Prevent further restarts
|
|
|
|
batches = list()
|
|
for step in range(config.training_steps):
|
|
total_loss = 0
|
|
print(f"Step {step+1}/{config.training_steps}")
|
|
total_pos_logp = 0
|
|
total_neg_logp = 0
|
|
total_logp = 0
|
|
total_pos = 0
|
|
total_neg = 0
|
|
if len(batches) == 0:
|
|
batches = get_data(config.batch_size, config.seq_len)
|
|
token_batches, label_batches, advantage_batches = batches.pop(0)
|
|
# Terminate existing vLLM process if running
|
|
if (
|
|
step + 1
|
|
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step
|
|
# Terminate existing vLLM process if running
|
|
if vllm_process:
|
|
print(" Terminating existing vLLM process...")
|
|
vllm_process.terminate()
|
|
try:
|
|
vllm_process.wait(timeout=5)
|
|
except subprocess.TimeoutExpired:
|
|
print(
|
|
" Existing vLLM process did not terminate gracefully, killing."
|
|
)
|
|
vllm_process.kill()
|
|
vllm_process.wait()
|
|
vllm_process = None
|
|
for tokens, labels, advantages in zip(
|
|
token_batches, label_batches, advantage_batches
|
|
):
|
|
|
|
tokens, labels, advantages = (
|
|
tokens.to(config.device),
|
|
labels.to(config.device),
|
|
advantages.to(config.device),
|
|
)
|
|
|
|
# Forward pass
|
|
# User specified that tokens/labels are already prepared by get_data
|
|
outputs = model(tokens) # Assuming model just needs tokens
|
|
logits = outputs.logits # Assuming this is the structure
|
|
|
|
# Calculate GRPO loss (reverting to user's previous logic)
|
|
# User stated ignore_index is -100 and tokens/labels are aligned by get_data
|
|
# Assuming logits correspond directly to labels indices (no shift needed here)
|
|
logp_per_token = -F.cross_entropy(
|
|
logits.view(-1, logits.size(-1)), # Flatten logits
|
|
labels.view(-1), # Flatten labels
|
|
reduction="none",
|
|
ignore_index=-100, # User specified ignore index
|
|
).view(
|
|
labels.shape
|
|
) # Reshape back to (batch, seq_len)
|
|
|
|
# Masking based on labels != -100
|
|
mask = (labels != -100).float()
|
|
with torch.no_grad():
|
|
pos = (advantages > 0).float()
|
|
neg = (advantages <= 0).float()
|
|
avg_logp = (logp_per_token * mask).sum(-1) / mask.sum(-1)
|
|
pos_logp = (logp_per_token * pos).mean().item()
|
|
neg_logp = (logp_per_token * neg).mean().item()
|
|
total_pos_logp += pos_logp
|
|
total_neg_logp += neg_logp
|
|
total_logp += avg_logp
|
|
total_pos += pos.sum().item()
|
|
total_neg += neg.sum().item()
|
|
|
|
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
|
|
grpo_loss = (
|
|
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
|
|
* advantages.to(logp_per_token.device)
|
|
).mean() / config.gradient_accumulation_steps
|
|
grpo_loss.backward()
|
|
total_loss += grpo_loss.item()
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
if total_pos > 0:
|
|
total_pos_logp /= total_pos
|
|
if total_neg > 0:
|
|
total_neg_logp /= total_neg
|
|
# --- Wandb Logging ---
|
|
if config.use_wandb:
|
|
wandb.log(
|
|
{
|
|
"train/loss": total_loss,
|
|
"train/learning_rate": optimizer.param_groups[0]["lr"],
|
|
"train/grad_norm": grad_norm.item(),
|
|
"train/pos_logp": total_pos_logp,
|
|
"train/neg_logp": total_neg_logp,
|
|
"train/logp": total_logp,
|
|
},
|
|
step=step + 1,
|
|
)
|
|
# --- End Wandb Logging ---
|
|
|
|
print(f" Step Loss: {grpo_loss.item():.4f}")
|
|
|
|
# --- vLLM Restart Logic (Moved AFTER optimizer step) ---
|
|
# Note: There are much better ways of updating the policy, this is just a very simple example
|
|
if (
|
|
step + 1
|
|
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step
|
|
checkpoint_path = os.path.join(
|
|
config.save_path, f"step_{step+1}"
|
|
) # Save as step+1 since it's after step completion
|
|
print(f" Saving checkpoint to {checkpoint_path}...")
|
|
# Ensure fresh directory for saving
|
|
if os.path.exists(checkpoint_path):
|
|
shutil.rmtree(checkpoint_path) # Remove old checkpoint if it exists
|
|
os.makedirs(checkpoint_path, exist_ok=True)
|
|
model.save_pretrained(checkpoint_path)
|
|
tokenizer.save_pretrained(checkpoint_path)
|
|
print(" Checkpoint saved.")
|
|
|
|
# Terminate existing vLLM process if running
|
|
if vllm_process:
|
|
print(" Terminating existing vLLM process...")
|
|
vllm_process.terminate()
|
|
try:
|
|
vllm_process.wait(timeout=5)
|
|
except subprocess.TimeoutExpired:
|
|
print(
|
|
" Existing vLLM process did not terminate gracefully, killing."
|
|
)
|
|
vllm_process.kill()
|
|
vllm_process.wait()
|
|
vllm_process = None
|
|
|
|
# Launch new vLLM process (only if not the very last step, maybe? depends on use case)
|
|
# Let's still launch it on the last step for consistency, cleanup will handle it.
|
|
vllm_command = [
|
|
"python",
|
|
"-m",
|
|
"vllm.entrypoints.openai.api_server",
|
|
"--model",
|
|
os.path.join(config.save_path, f"step_{step+1}"),
|
|
"--port",
|
|
str(config.vllm_port),
|
|
"--dtype",
|
|
"auto",
|
|
"--gpu-memory-utilization",
|
|
"0.45",
|
|
"--disable-log-requests",
|
|
"--served-model-name",
|
|
config.model_name,
|
|
]
|
|
print(f" Launching vLLM server: {' '.join(vllm_command)}")
|
|
torch.cuda.empty_cache()
|
|
try:
|
|
vllm_process = subprocess.Popen(vllm_command)
|
|
print(f" vLLM server launched with PID: {vllm_process.pid}")
|
|
# Check immediate errors
|
|
try:
|
|
stdout, stderr = vllm_process.communicate(timeout=2)
|
|
if (
|
|
vllm_process.returncode is not None
|
|
and vllm_process.returncode != 0
|
|
):
|
|
print(f" Error starting vLLM: {stderr.decode()}")
|
|
vllm_process = None
|
|
# Maybe raise error or just warn?
|
|
print(
|
|
" WARNING: Failed to start vLLM server after checkpoint."
|
|
)
|
|
except subprocess.TimeoutExpired:
|
|
print(" vLLM process started (check logs for details).")
|
|
except FileNotFoundError:
|
|
print(
|
|
"\n *** ERROR: 'python -m vllm...' command not found. ",
|
|
"Make sure vLLM is installed and accessible. ***\n",
|
|
)
|
|
# Potentially stop training or just disable further vLLM restarts
|
|
print(" Disabling further vLLM restarts.")
|
|
config.vllm_restart_interval = (
|
|
config.training_steps + 1
|
|
) # Prevent further restarts
|
|
except Exception as e:
|
|
print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n")
|
|
print(" Disabling further vLLM restarts.")
|
|
config.vllm_restart_interval = (
|
|
config.training_steps + 1
|
|
) # Prevent further restarts
|
|
# --- End vLLM Restart Logic ---
|
|
|
|
# Basic check if vLLM process terminated unexpectedly (outside interval check)
|
|
if vllm_process and vllm_process.poll() is not None:
|
|
print(
|
|
f"\n *** WARNING: vLLM process terminated unexpectedly (return code: {vllm_process.returncode}). ",
|
|
"Check vLLM logs. ***\n",
|
|
)
|
|
stderr_output = (
|
|
vllm_process.stderr.read().decode()
|
|
if vllm_process.stderr
|
|
else "No stderr"
|
|
)
|
|
print(f"vLLM stderr: {stderr_output}")
|
|
vllm_process = None # Reset so it relaunches next interval
|
|
|
|
print("Training finished.")
|
|
# --- Wandb Finish ---
|
|
if config.use_wandb:
|
|
wandb.finish()
|
|
# --- End Wandb Finish ---
|
|
# Final cleanup (vLLM termination) is handled by atexit
|
|
|
|
# --- Placeholder for final model save ---
|
|
final_save_path = os.path.join(config.save_path, "final_model")
|
|
print(f"Saving final model to {final_save_path}")
|
|
if os.path.exists(final_save_path):
|
|
shutil.rmtree(final_save_path)
|
|
os.makedirs(final_save_path, exist_ok=True)
|
|
model.save_pretrained(final_save_path)
|
|
tokenizer.save_pretrained(final_save_path)
|
|
print("Final model saved.")
|
|
|
|
|
|
# Example usage (optional, can be run from another script)
|
|
if __name__ == "__main__":
|
|
# Example: Create a config and run training
|
|
# Replace "gpt2" with your desired model
|
|
training_config = TrainingConfig(
|
|
model_name="Qwen/Qwen2.5-1.5B-Instruct",
|
|
training_steps=20, # Use steps
|
|
vllm_restart_interval=3, # Example interval
|
|
use_wandb=True, # Set to True to enable logging
|
|
wandb_project="grpo-trainer-example", # Replace with your project name
|
|
)
|
|
|
|
# --- End Mock ---
|
|
|
|
train(training_config)
|