mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
initialized with grpo
This commit is contained in:
parent
f052f14484
commit
0fbb112eec
4 changed files with 636 additions and 1 deletions
547
environments/hack0/grpo.py
Normal file
547
environments/hack0/grpo.py
Normal file
|
|
@ -0,0 +1,547 @@
|
|||
import atexit
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import string
|
||||
import subprocess
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import wandb # Added for logging
|
||||
from pydantic import BaseModel, Field
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# Global variable to keep track of the vLLM process
|
||||
vllm_process = None
|
||||
|
||||
|
||||
def cleanup_vllm():
|
||||
global vllm_process
|
||||
if vllm_process:
|
||||
print("\nTerminating vLLM process...")
|
||||
vllm_process.terminate()
|
||||
try:
|
||||
vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown
|
||||
print("vLLM process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("vLLM process did not terminate gracefully, killing.")
|
||||
vllm_process.kill()
|
||||
vllm_process.wait()
|
||||
print("vLLM process killed.")
|
||||
vllm_process = None
|
||||
|
||||
|
||||
# Register the cleanup function to be called on script exit
|
||||
atexit.register(cleanup_vllm)
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
"""
|
||||
Training details, model, etc
|
||||
"""
|
||||
|
||||
model_name: str = Field(..., description="Name of the base model to train")
|
||||
lr: float = Field(1e-5, description="Learning rate for the optimizer")
|
||||
training_steps: int = Field(
|
||||
10, description="Number of training steps"
|
||||
) # Renamed from epochs
|
||||
batch_size: int = Field(
|
||||
2, description="Batch size for training (will be handled by get_data)"
|
||||
)
|
||||
seq_len: int = Field(2048, description="Sequence length for training")
|
||||
gradient_accumulation_steps: int = Field(
|
||||
32, description="Number of gradient accumulation steps"
|
||||
)
|
||||
device: str = Field(
|
||||
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
|
||||
)
|
||||
save_path: str = Field(
|
||||
"trained_model_checkpoints", description="Base path to save model checkpoints"
|
||||
)
|
||||
vllm_restart_interval: int = Field(
|
||||
3, description="Restart vLLM every N training steps"
|
||||
)
|
||||
vllm_port: int = Field(9001, description="Port for the vLLM server")
|
||||
|
||||
# Wandb configuration
|
||||
use_wandb: bool = Field(
|
||||
False, description="Whether to use Weights & Biases for logging"
|
||||
)
|
||||
wandb_project: Optional[str] = Field(None, description="Wandb project name")
|
||||
wandb_group: Optional[str] = Field(None, description="Wandb group name")
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
|
||||
def register_trainer(config: TrainingConfig):
|
||||
"""
|
||||
Register the trainer with the Atropos API
|
||||
"""
|
||||
requests.post(
|
||||
"http://localhost:8000/register",
|
||||
json={
|
||||
"wandb_group": config.wandb_group,
|
||||
"wandb_project": config.wandb_project,
|
||||
"batch_size": config.batch_size * config.gradient_accumulation_steps,
|
||||
"max_token_len": config.seq_len,
|
||||
"starting_step": 0,
|
||||
"checkpoint_dir": config.save_path,
|
||||
"save_checkpoint_interval": config.training_steps,
|
||||
"num_steps": config.training_steps,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
|
||||
def get_batch():
|
||||
data = requests.get("http://localhost:8000/batch", timeout=10).json()
|
||||
return data
|
||||
|
||||
|
||||
def pad_data_to_good_offset(data, batch_size: int):
|
||||
max_token_len = max(
|
||||
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
|
||||
)
|
||||
# usually 64 is a good choice to ensure nonweird scaling behavior on GPUS
|
||||
# so we pad to the nearest multiple of 64
|
||||
good_multiple = 64
|
||||
if (max_token_len - 1) % (good_multiple) != 0:
|
||||
max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple
|
||||
token_setup_len = (
|
||||
max_token_len + 1
|
||||
) # add 1 so we can make it causal at the proper length
|
||||
else:
|
||||
token_setup_len = max_token_len
|
||||
max_token_len = (
|
||||
max_token_len - 1
|
||||
) # since it's causal we need to remove the last bit...
|
||||
# pad all tokens to max_token_len and add to lists
|
||||
input_ids = list()
|
||||
labels = list()
|
||||
advantages = list()
|
||||
lengths = list()
|
||||
for item in data["batch"]:
|
||||
scores = item["scores"]
|
||||
scores = np.array(scores)
|
||||
# check if we have more than 1 score...
|
||||
if len(scores) > 1:
|
||||
scores = scores - scores.mean()
|
||||
scores = scores / max(scores.std(), 1e-8)
|
||||
item["scores"] = scores
|
||||
if item["overrides"] is not None:
|
||||
for i in range(len(item["overrides"])):
|
||||
if item["overrides"][i].get("set_advantage_to_zero", False):
|
||||
item["scores"][i] = 0
|
||||
for i in range(len(item["tokens"])):
|
||||
lengths.append(
|
||||
math.ceil((len(item["tokens"][i]) - 1) / (good_multiple))
|
||||
* good_multiple
|
||||
)
|
||||
label_item = np.concatenate(
|
||||
[
|
||||
np.array(item["masks"][i]),
|
||||
np.full(
|
||||
max(0, token_setup_len - len(item["tokens"][i])),
|
||||
-100,
|
||||
dtype=np.int32,
|
||||
),
|
||||
]
|
||||
)
|
||||
item["tokens"][i] = np.concatenate(
|
||||
[
|
||||
np.array(item["tokens"][i]),
|
||||
np.zeros(
|
||||
max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32
|
||||
),
|
||||
]
|
||||
)
|
||||
input_ids.append(item["tokens"][i][:-1])
|
||||
labels.append(label_item[1:])
|
||||
advantages.append(item["scores"][i])
|
||||
# combine all lists into tensors
|
||||
token_batches = []
|
||||
label_batches = []
|
||||
advantage_batches = []
|
||||
for i in range(len(input_ids) // batch_size):
|
||||
token_batches.append(
|
||||
torch.tensor(
|
||||
np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0)
|
||||
)
|
||||
)
|
||||
label_batches.append(
|
||||
torch.tensor(
|
||||
np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0)
|
||||
)
|
||||
)
|
||||
advantage_batches.append(
|
||||
torch.tensor(
|
||||
np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0)
|
||||
).view(-1, 1)
|
||||
)
|
||||
return token_batches, label_batches, advantage_batches
|
||||
|
||||
|
||||
def get_data(
|
||||
batch_size: int, seq_len: int
|
||||
) -> List[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
|
||||
"""
|
||||
getting data from the api
|
||||
"""
|
||||
batches = []
|
||||
while True:
|
||||
data = get_batch()
|
||||
if data["batch"] is not None:
|
||||
# Save the batch
|
||||
with open("temp.json", "w", encoding="utf-8") as f:
|
||||
json.dump(data, f)
|
||||
# In case the inference runs ahead of the training, we loop until we don't have any more data
|
||||
batches.append(pad_data_to_good_offset(data, batch_size))
|
||||
elif len(batches) > 0:
|
||||
# Return the batches
|
||||
return batches
|
||||
else:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def train(config: TrainingConfig):
|
||||
"""
|
||||
Setups and runs GRPO training, restarting vLLM periodically, with wandb logging.
|
||||
"""
|
||||
global vllm_process # Declare intention to modify the global variable
|
||||
|
||||
# --- Wandb Setup ---
|
||||
if config.use_wandb:
|
||||
if not config.wandb_project:
|
||||
print("Warning: wandb_project not set, disabling wandb.")
|
||||
config.use_wandb = False
|
||||
else:
|
||||
if not config.wandb_group:
|
||||
# Set group to random 8 character string
|
||||
config.wandb_group = "".join(
|
||||
random.choices(string.ascii_letters + string.digits, k=8)
|
||||
)
|
||||
try:
|
||||
wandb.init(
|
||||
project=config.wandb_project,
|
||||
group=config.wandb_group,
|
||||
config=config.dict(), # Log config parameters
|
||||
)
|
||||
print(
|
||||
f"Wandb logging enabled. Run: {wandb.run.name} (Project: {config.wandb_project}) "
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error initializing wandb: {e}. Disabling wandb.")
|
||||
config.use_wandb = False
|
||||
# --- End Wandb Setup ---
|
||||
|
||||
# Initialize model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
model.to(config.device)
|
||||
model.gradient_checkpointing_enable()
|
||||
model.train()
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr)
|
||||
|
||||
print(
|
||||
f"Starting training for {config.training_steps} steps on device: {config.device}"
|
||||
)
|
||||
print(
|
||||
f"vLLM will be restarted every {config.vllm_restart_interval} steps on port {config.vllm_port}"
|
||||
)
|
||||
|
||||
os.makedirs(config.save_path, exist_ok=True) # Ensure base save directory exists
|
||||
register_trainer(config)
|
||||
|
||||
# Init vllm
|
||||
vllm_command = [
|
||||
"python",
|
||||
"-m",
|
||||
"vllm.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
config.model_name,
|
||||
"--port",
|
||||
str(config.vllm_port),
|
||||
"--dtype",
|
||||
"auto",
|
||||
"--gpu-memory-utilization",
|
||||
"0.45",
|
||||
"--disable-log-requests",
|
||||
]
|
||||
print(f" Launching vLLM server: {' '.join(vllm_command)}")
|
||||
try:
|
||||
vllm_process = subprocess.Popen(vllm_command)
|
||||
print(f" vLLM server launched with PID: {vllm_process.pid}")
|
||||
# Check immediate errors
|
||||
try:
|
||||
stdout, stderr = vllm_process.communicate(timeout=2)
|
||||
if vllm_process.returncode is not None and vllm_process.returncode != 0:
|
||||
print(f" Error starting vLLM: {stderr.decode()}")
|
||||
vllm_process = None
|
||||
# Maybe raise error or just warn?
|
||||
print(" WARNING: Failed to start vLLM server after checkpoint.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(" vLLM process started (check logs for details).")
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
"\n *** ERROR: 'python -m vllm...' command not found. Make sure vLLM is installed and accessible. ***\n"
|
||||
)
|
||||
# Potentially stop training or just disable further vLLM restarts
|
||||
print(" Disabling further vLLM restarts.")
|
||||
config.vllm_restart_interval = (
|
||||
config.training_steps + 1
|
||||
) # Prevent further restarts
|
||||
except Exception as e:
|
||||
print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n")
|
||||
print(" Disabling further vLLM restarts.")
|
||||
config.vllm_restart_interval = (
|
||||
config.training_steps + 1
|
||||
) # Prevent further restarts
|
||||
|
||||
batches = list()
|
||||
for step in range(config.training_steps):
|
||||
total_loss = 0
|
||||
print(f"Step {step+1}/{config.training_steps}")
|
||||
total_pos_logp = 0
|
||||
total_neg_logp = 0
|
||||
total_logp = 0
|
||||
total_pos = 0
|
||||
total_neg = 0
|
||||
if len(batches) == 0:
|
||||
batches = get_data(config.batch_size, config.seq_len)
|
||||
token_batches, label_batches, advantage_batches = batches.pop(0)
|
||||
# Terminate existing vLLM process if running
|
||||
if (
|
||||
step + 1
|
||||
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step
|
||||
# Terminate existing vLLM process if running
|
||||
if vllm_process:
|
||||
print(" Terminating existing vLLM process...")
|
||||
vllm_process.terminate()
|
||||
try:
|
||||
vllm_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
" Existing vLLM process did not terminate gracefully, killing."
|
||||
)
|
||||
vllm_process.kill()
|
||||
vllm_process.wait()
|
||||
vllm_process = None
|
||||
for tokens, labels, advantages in zip(
|
||||
token_batches, label_batches, advantage_batches
|
||||
):
|
||||
|
||||
tokens, labels, advantages = (
|
||||
tokens.to(config.device),
|
||||
labels.to(config.device),
|
||||
advantages.to(config.device),
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
# User specified that tokens/labels are already prepared by get_data
|
||||
outputs = model(tokens) # Assuming model just needs tokens
|
||||
logits = outputs.logits # Assuming this is the structure
|
||||
|
||||
# Calculate GRPO loss (reverting to user's previous logic)
|
||||
# User stated ignore_index is -100 and tokens/labels are aligned by get_data
|
||||
# Assuming logits correspond directly to labels indices (no shift needed here)
|
||||
logp_per_token = -F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)), # Flatten logits
|
||||
labels.view(-1), # Flatten labels
|
||||
reduction="none",
|
||||
ignore_index=-100, # User specified ignore index
|
||||
).view(
|
||||
labels.shape
|
||||
) # Reshape back to (batch, seq_len)
|
||||
|
||||
# Masking based on labels != -100
|
||||
mask = (labels != -100).float()
|
||||
with torch.no_grad():
|
||||
pos = (advantages > 0).float()
|
||||
neg = (advantages <= 0).float()
|
||||
avg_logp = (logp_per_token * mask).sum(-1) / mask.sum(-1)
|
||||
pos_logp = (logp_per_token * pos).mean().item()
|
||||
neg_logp = (logp_per_token * neg).mean().item()
|
||||
total_pos_logp += pos_logp
|
||||
total_neg_logp += neg_logp
|
||||
total_logp += avg_logp
|
||||
total_pos += pos.sum().item()
|
||||
total_neg += neg.sum().item()
|
||||
|
||||
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
|
||||
grpo_loss = (
|
||||
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
|
||||
* advantages.to(logp_per_token.device)
|
||||
).mean() / config.gradient_accumulation_steps
|
||||
grpo_loss.backward()
|
||||
total_loss += grpo_loss.item()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if total_pos > 0:
|
||||
total_pos_logp /= total_pos
|
||||
if total_neg > 0:
|
||||
total_neg_logp /= total_neg
|
||||
# --- Wandb Logging ---
|
||||
if config.use_wandb:
|
||||
wandb.log(
|
||||
{
|
||||
"train/loss": total_loss,
|
||||
"train/learning_rate": optimizer.param_groups[0]["lr"],
|
||||
"train/grad_norm": grad_norm.item(),
|
||||
"train/pos_logp": total_pos_logp,
|
||||
"train/neg_logp": total_neg_logp,
|
||||
"train/logp": total_logp,
|
||||
},
|
||||
step=step + 1,
|
||||
)
|
||||
# --- End Wandb Logging ---
|
||||
|
||||
print(f" Step Loss: {grpo_loss.item():.4f}")
|
||||
|
||||
# --- vLLM Restart Logic (Moved AFTER optimizer step) ---
|
||||
# Note: There are much better ways of updating the policy, this is just a very simple example
|
||||
if (
|
||||
step + 1
|
||||
) % config.vllm_restart_interval == 0 or step == config.training_steps - 1: # Also restart/save on last step
|
||||
checkpoint_path = os.path.join(
|
||||
config.save_path, f"step_{step+1}"
|
||||
) # Save as step+1 since it's after step completion
|
||||
print(f" Saving checkpoint to {checkpoint_path}...")
|
||||
# Ensure fresh directory for saving
|
||||
if os.path.exists(checkpoint_path):
|
||||
shutil.rmtree(checkpoint_path) # Remove old checkpoint if it exists
|
||||
os.makedirs(checkpoint_path, exist_ok=True)
|
||||
model.save_pretrained(checkpoint_path)
|
||||
tokenizer.save_pretrained(checkpoint_path)
|
||||
print(" Checkpoint saved.")
|
||||
|
||||
# Terminate existing vLLM process if running
|
||||
if vllm_process:
|
||||
print(" Terminating existing vLLM process...")
|
||||
vllm_process.terminate()
|
||||
try:
|
||||
vllm_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
" Existing vLLM process did not terminate gracefully, killing."
|
||||
)
|
||||
vllm_process.kill()
|
||||
vllm_process.wait()
|
||||
vllm_process = None
|
||||
|
||||
# Launch new vLLM process (only if not the very last step, maybe? depends on use case)
|
||||
# Let's still launch it on the last step for consistency, cleanup will handle it.
|
||||
vllm_command = [
|
||||
"python",
|
||||
"-m",
|
||||
"vllm.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
os.path.join(config.save_path, f"step_{step+1}"),
|
||||
"--port",
|
||||
str(config.vllm_port),
|
||||
"--dtype",
|
||||
"auto",
|
||||
"--gpu-memory-utilization",
|
||||
"0.45",
|
||||
"--disable-log-requests",
|
||||
"--served-model-name",
|
||||
config.model_name,
|
||||
]
|
||||
print(f" Launching vLLM server: {' '.join(vllm_command)}")
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
vllm_process = subprocess.Popen(vllm_command)
|
||||
print(f" vLLM server launched with PID: {vllm_process.pid}")
|
||||
# Check immediate errors
|
||||
try:
|
||||
stdout, stderr = vllm_process.communicate(timeout=2)
|
||||
if (
|
||||
vllm_process.returncode is not None
|
||||
and vllm_process.returncode != 0
|
||||
):
|
||||
print(f" Error starting vLLM: {stderr.decode()}")
|
||||
vllm_process = None
|
||||
# Maybe raise error or just warn?
|
||||
print(
|
||||
" WARNING: Failed to start vLLM server after checkpoint."
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(" vLLM process started (check logs for details).")
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
"\n *** ERROR: 'python -m vllm...' command not found. ",
|
||||
"Make sure vLLM is installed and accessible. ***\n",
|
||||
)
|
||||
# Potentially stop training or just disable further vLLM restarts
|
||||
print(" Disabling further vLLM restarts.")
|
||||
config.vllm_restart_interval = (
|
||||
config.training_steps + 1
|
||||
) # Prevent further restarts
|
||||
except Exception as e:
|
||||
print(f"\n *** ERROR: Failed to launch vLLM: {e} ***\n")
|
||||
print(" Disabling further vLLM restarts.")
|
||||
config.vllm_restart_interval = (
|
||||
config.training_steps + 1
|
||||
) # Prevent further restarts
|
||||
# --- End vLLM Restart Logic ---
|
||||
|
||||
# Basic check if vLLM process terminated unexpectedly (outside interval check)
|
||||
if vllm_process and vllm_process.poll() is not None:
|
||||
print(
|
||||
f"\n *** WARNING: vLLM process terminated unexpectedly (return code: {vllm_process.returncode}). ",
|
||||
"Check vLLM logs. ***\n",
|
||||
)
|
||||
stderr_output = (
|
||||
vllm_process.stderr.read().decode()
|
||||
if vllm_process.stderr
|
||||
else "No stderr"
|
||||
)
|
||||
print(f"vLLM stderr: {stderr_output}")
|
||||
vllm_process = None # Reset so it relaunches next interval
|
||||
|
||||
print("Training finished.")
|
||||
# --- Wandb Finish ---
|
||||
if config.use_wandb:
|
||||
wandb.finish()
|
||||
# --- End Wandb Finish ---
|
||||
# Final cleanup (vLLM termination) is handled by atexit
|
||||
|
||||
# --- Placeholder for final model save ---
|
||||
final_save_path = os.path.join(config.save_path, "final_model")
|
||||
print(f"Saving final model to {final_save_path}")
|
||||
if os.path.exists(final_save_path):
|
||||
shutil.rmtree(final_save_path)
|
||||
os.makedirs(final_save_path, exist_ok=True)
|
||||
model.save_pretrained(final_save_path)
|
||||
tokenizer.save_pretrained(final_save_path)
|
||||
print("Final model saved.")
|
||||
|
||||
|
||||
# Example usage (optional, can be run from another script)
|
||||
if __name__ == "__main__":
|
||||
# Example: Create a config and run training
|
||||
# Replace "gpt2" with your desired model
|
||||
training_config = TrainingConfig(
|
||||
model_name="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
training_steps=20, # Use steps
|
||||
vllm_restart_interval=3, # Example interval
|
||||
use_wandb=True, # Set to True to enable logging
|
||||
wandb_project="grpo-trainer-example", # Replace with your project name
|
||||
)
|
||||
|
||||
# --- End Mock ---
|
||||
|
||||
train(training_config)
|
||||
Loading…
Add table
Add a link
Reference in a new issue