mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-26 17:13:09 +00:00
train
This commit is contained in:
parent
7e1de80695
commit
b298a0eeb6
2 changed files with 492 additions and 0 deletions
464
environments/train_rubiks_model.py
Executable file
464
environments/train_rubiks_model.py
Executable file
|
|
@ -0,0 +1,464 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Train a model to solve Rubik's cube using reinforcement learning on collected data.
|
||||
Based on the example GRPO trainer with modifications for pre-collected data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
import wandb
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TrainerConfig:
|
||||
"""Configuration for the trainer."""
|
||||
# Model configuration
|
||||
model_name: str
|
||||
learning_rate: float
|
||||
batch_size: int
|
||||
gradient_accumulation_steps: int
|
||||
sequence_length: int
|
||||
warmup_steps: int
|
||||
|
||||
# Training configuration
|
||||
total_steps: int
|
||||
eval_every: int
|
||||
save_every: int
|
||||
checkpoint_dir: str
|
||||
use_wandb: bool
|
||||
wandb_project: str
|
||||
wandb_run_name: str
|
||||
|
||||
# Data configuration
|
||||
train_file: str
|
||||
validation_size: float
|
||||
prefer_higher_scores: bool
|
||||
max_samples: int
|
||||
|
||||
# RL configuration
|
||||
method: str
|
||||
temperature: float
|
||||
top_p: float
|
||||
beta: float
|
||||
reference_model: Optional[str] = None
|
||||
|
||||
|
||||
def load_config(config_path: str) -> TrainerConfig:
|
||||
"""Load configuration from YAML file."""
|
||||
with open(config_path, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
# The config is already flat, so we use it directly
|
||||
return TrainerConfig(**config_dict)
|
||||
|
||||
|
||||
def load_jsonl_data(file_path: str, max_samples: int = -1) -> List[Dict]:
|
||||
"""Load data from JSONL file."""
|
||||
data = []
|
||||
with open(file_path, "r") as f:
|
||||
for line in f:
|
||||
data.append(json.loads(line))
|
||||
if max_samples > 0 and len(data) >= max_samples:
|
||||
break
|
||||
return data
|
||||
|
||||
|
||||
def split_train_val(data: List[Dict], val_size: float) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""Split data into training and validation sets."""
|
||||
val_count = int(len(data) * val_size)
|
||||
return data[val_count:], data[:val_count]
|
||||
|
||||
|
||||
def prepare_training_batch(
|
||||
data_batch: List[Dict],
|
||||
tokenizer,
|
||||
prefer_higher_scores: bool = True,
|
||||
device: str = "cuda"
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Prepare a batch for training.
|
||||
|
||||
Args:
|
||||
data_batch: List of data points from JSONL
|
||||
tokenizer: Tokenizer for the model
|
||||
prefer_higher_scores: If True, higher scores are better
|
||||
device: Device to put tensors on
|
||||
|
||||
Returns:
|
||||
Dict with input_ids, attention_mask, and scores
|
||||
"""
|
||||
batch_tokens = []
|
||||
batch_masks = []
|
||||
batch_scores = []
|
||||
|
||||
for item in data_batch:
|
||||
# For each group, select best and worst sequences based on scores
|
||||
scores = item["scores"]
|
||||
tokens = item["tokens"]
|
||||
masks = item["masks"]
|
||||
|
||||
if prefer_higher_scores:
|
||||
best_idx = max(range(len(scores)), key=lambda i: scores[i])
|
||||
worst_idx = min(range(len(scores)), key=lambda i: scores[i])
|
||||
else:
|
||||
best_idx = min(range(len(scores)), key=lambda i: scores[i])
|
||||
worst_idx = max(range(len(scores)), key=lambda i: scores[i])
|
||||
|
||||
batch_tokens.extend([tokens[best_idx], tokens[worst_idx]])
|
||||
batch_masks.extend([masks[best_idx], masks[worst_idx]])
|
||||
batch_scores.extend([scores[best_idx], scores[worst_idx]])
|
||||
|
||||
# Convert to tensors
|
||||
input_ids = torch.tensor(batch_tokens, dtype=torch.long).to(device)
|
||||
attention_mask = torch.tensor(batch_masks, dtype=torch.long).to(device)
|
||||
scores = torch.tensor(batch_scores, dtype=torch.float).to(device)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"scores": scores,
|
||||
}
|
||||
|
||||
|
||||
def compute_grpo_loss(
|
||||
logprobs: torch.Tensor,
|
||||
ref_logprobs: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
beta: float
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the Group Relative Policy Optimization loss.
|
||||
|
||||
Args:
|
||||
logprobs: Log probabilities from the model (batch_size, seq_len)
|
||||
ref_logprobs: Log probabilities from the reference model (batch_size, seq_len)
|
||||
scores: Scores for each sequence (batch_size,)
|
||||
beta: KL penalty coefficient
|
||||
|
||||
Returns:
|
||||
Loss tensor
|
||||
"""
|
||||
batch_size = logprobs.shape[0]
|
||||
assert batch_size % 2 == 0, "Batch size must be even"
|
||||
|
||||
# Reshape to (batch_size/2, 2, seq_len)
|
||||
logprobs = logprobs.view(batch_size // 2, 2, -1)
|
||||
ref_logprobs = ref_logprobs.view(batch_size // 2, 2, -1)
|
||||
scores = scores.view(batch_size // 2, 2)
|
||||
|
||||
# Calculate policy gradient loss
|
||||
pg_loss = 0
|
||||
for i in range(batch_size // 2):
|
||||
# Policy gradient - weight by the score difference
|
||||
score_diff = scores[i, 0] - scores[i, 1]
|
||||
log_ratio_chosen = logprobs[i, 0].sum() - ref_logprobs[i, 0].sum()
|
||||
log_ratio_rejected = logprobs[i, 1].sum() - ref_logprobs[i, 1].sum()
|
||||
|
||||
# KL penalty
|
||||
kl_chosen = (ref_logprobs[i, 0] - logprobs[i, 0]).sum()
|
||||
kl_rejected = (ref_logprobs[i, 1] - logprobs[i, 1]).sum()
|
||||
|
||||
# Final loss - maximize score difference, minimize KL divergence
|
||||
pg_loss += -score_diff * (log_ratio_chosen - log_ratio_rejected)
|
||||
pg_loss += beta * (kl_chosen + kl_rejected)
|
||||
|
||||
return pg_loss / (batch_size // 2)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train a model on Rubik's cube data")
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load configuration
|
||||
config = load_config(args.config)
|
||||
logger.info(f"Loaded configuration from {args.config}")
|
||||
|
||||
# Set device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Initialize wandb if specified
|
||||
if config.use_wandb:
|
||||
wandb.init(
|
||||
project=config.wandb_project,
|
||||
name=config.wandb_run_name,
|
||||
config=vars(config)
|
||||
)
|
||||
|
||||
# Create checkpoint directory
|
||||
os.makedirs(config.checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Load tokenizer and model
|
||||
logger.info(f"Loading model {config.model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config.model_name,
|
||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
||||
).to(device)
|
||||
model.train()
|
||||
|
||||
# Load reference model if specified
|
||||
ref_model = None
|
||||
if config.reference_model:
|
||||
logger.info(f"Loading reference model {config.reference_model}")
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
config.reference_model,
|
||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
||||
).to(device)
|
||||
ref_model.eval()
|
||||
|
||||
# Set up optimizer and lr scheduler
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||
scheduler = transformers.get_cosine_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config.warmup_steps,
|
||||
num_training_steps=config.total_steps
|
||||
)
|
||||
|
||||
# Load and split data
|
||||
logger.info(f"Loading data from {config.train_file}")
|
||||
all_data = load_jsonl_data(config.train_file, config.max_samples)
|
||||
train_data, val_data = split_train_val(all_data, config.validation_size)
|
||||
logger.info(f"Loaded {len(train_data)} training and {len(val_data)} validation samples")
|
||||
|
||||
# Training loop
|
||||
global_step = 0
|
||||
best_val_loss = float('inf')
|
||||
|
||||
logger.info("Starting training")
|
||||
try:
|
||||
for epoch in range(100): # Large number, will break when steps reached
|
||||
# Shuffle training data
|
||||
import random
|
||||
random.shuffle(train_data)
|
||||
|
||||
for i in range(0, len(train_data), config.batch_size // 2):
|
||||
batch_data = train_data[i:i + config.batch_size // 2]
|
||||
if len(batch_data) < config.batch_size // 2:
|
||||
continue # Skip incomplete batches
|
||||
|
||||
# Prepare batch
|
||||
batch = prepare_training_batch(
|
||||
batch_data,
|
||||
tokenizer,
|
||||
prefer_higher_scores=config.prefer_higher_scores,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
with torch.cuda.amp.autocast(enabled=device == "cuda"):
|
||||
outputs = model(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Compute log probabilities
|
||||
logits = outputs.logits[:, :-1]
|
||||
logprobs = F.log_softmax(logits, dim=-1)
|
||||
target_ids = batch["input_ids"][:, 1:]
|
||||
masks = batch["attention_mask"][:, 1:]
|
||||
|
||||
# Get log probs for the chosen tokens
|
||||
chosen_logprobs = torch.gather(
|
||||
logprobs,
|
||||
dim=2,
|
||||
index=target_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
# Apply mask
|
||||
chosen_logprobs = chosen_logprobs * masks
|
||||
|
||||
# Get reference log probs if using a reference model
|
||||
if ref_model:
|
||||
with torch.no_grad():
|
||||
ref_outputs = ref_model(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True
|
||||
)
|
||||
ref_logits = ref_outputs.logits[:, :-1]
|
||||
ref_logprobs = F.log_softmax(ref_logits, dim=-1)
|
||||
|
||||
ref_chosen_logprobs = torch.gather(
|
||||
ref_logprobs,
|
||||
dim=2,
|
||||
index=target_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
# Apply mask
|
||||
ref_chosen_logprobs = ref_chosen_logprobs * masks
|
||||
else:
|
||||
# If no reference model, use the current model's initial state
|
||||
ref_chosen_logprobs = chosen_logprobs.detach()
|
||||
|
||||
# Compute loss
|
||||
loss = compute_grpo_loss(
|
||||
chosen_logprobs,
|
||||
ref_chosen_logprobs,
|
||||
batch["scores"],
|
||||
config.beta
|
||||
)
|
||||
|
||||
# Backward pass
|
||||
loss = loss / config.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
# Update weights if gradient accumulation steps reached
|
||||
if (global_step + 1) % config.gradient_accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Log progress
|
||||
if global_step % 10 == 0:
|
||||
logger.info(f"Step {global_step}: loss = {loss.item() * config.gradient_accumulation_steps:.4f}")
|
||||
if config.use_wandb:
|
||||
wandb.log({
|
||||
"train/loss": loss.item() * config.gradient_accumulation_steps,
|
||||
"train/learning_rate": scheduler.get_last_lr()[0],
|
||||
"train/step": global_step,
|
||||
})
|
||||
|
||||
# Evaluate on validation set
|
||||
if global_step % config.eval_every == 0:
|
||||
model.eval()
|
||||
val_losses = []
|
||||
|
||||
with torch.no_grad():
|
||||
for j in range(0, min(len(val_data), 100), config.batch_size // 2):
|
||||
val_batch_data = val_data[j:j + config.batch_size // 2]
|
||||
if len(val_batch_data) < config.batch_size // 2:
|
||||
continue
|
||||
|
||||
val_batch = prepare_training_batch(
|
||||
val_batch_data,
|
||||
tokenizer,
|
||||
prefer_higher_scores=config.prefer_higher_scores,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
val_outputs = model(
|
||||
input_ids=val_batch["input_ids"],
|
||||
attention_mask=val_batch["attention_mask"],
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Compute log probabilities
|
||||
val_logits = val_outputs.logits[:, :-1]
|
||||
val_logprobs = F.log_softmax(val_logits, dim=-1)
|
||||
val_target_ids = val_batch["input_ids"][:, 1:]
|
||||
val_masks = val_batch["attention_mask"][:, 1:]
|
||||
|
||||
# Get log probs for the chosen tokens
|
||||
val_chosen_logprobs = torch.gather(
|
||||
val_logprobs,
|
||||
dim=2,
|
||||
index=val_target_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
# Apply mask
|
||||
val_chosen_logprobs = val_chosen_logprobs * val_masks
|
||||
|
||||
# Get reference log probs
|
||||
if ref_model:
|
||||
ref_val_outputs = ref_model(
|
||||
input_ids=val_batch["input_ids"],
|
||||
attention_mask=val_batch["attention_mask"],
|
||||
return_dict=True
|
||||
)
|
||||
ref_val_logits = ref_val_outputs.logits[:, :-1]
|
||||
ref_val_logprobs = F.log_softmax(ref_val_logits, dim=-1)
|
||||
|
||||
ref_val_chosen_logprobs = torch.gather(
|
||||
ref_val_logprobs,
|
||||
dim=2,
|
||||
index=val_target_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
# Apply mask
|
||||
ref_val_chosen_logprobs = ref_val_chosen_logprobs * val_masks
|
||||
else:
|
||||
# If no reference model, use detached current model outputs
|
||||
ref_val_chosen_logprobs = val_chosen_logprobs.detach()
|
||||
|
||||
# Compute loss
|
||||
val_loss = compute_grpo_loss(
|
||||
val_chosen_logprobs,
|
||||
ref_val_chosen_logprobs,
|
||||
val_batch["scores"],
|
||||
config.beta
|
||||
)
|
||||
val_losses.append(val_loss.item())
|
||||
|
||||
avg_val_loss = sum(val_losses) / len(val_losses)
|
||||
logger.info(f"Validation loss: {avg_val_loss:.4f}")
|
||||
|
||||
if config.use_wandb:
|
||||
wandb.log({
|
||||
"val/loss": avg_val_loss,
|
||||
"val/step": global_step,
|
||||
})
|
||||
|
||||
# Save best model
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
logger.info(f"New best validation loss: {best_val_loss:.4f}")
|
||||
# Save model
|
||||
output_dir = os.path.join(config.checkpoint_dir, f"best_model")
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
model.train()
|
||||
|
||||
# Save checkpoint
|
||||
if global_step % config.save_every == 0 and global_step > 0:
|
||||
output_dir = os.path.join(config.checkpoint_dir, f"checkpoint-{global_step}")
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Increment step
|
||||
global_step += 1
|
||||
|
||||
# Exit if reached total steps
|
||||
if global_step >= config.total_steps:
|
||||
break
|
||||
|
||||
# Exit if reached total steps
|
||||
if global_step >= config.total_steps:
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
|
||||
# Save final model
|
||||
logger.info("Saving final model")
|
||||
output_dir = os.path.join(config.checkpoint_dir, "final_model")
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
if config.use_wandb:
|
||||
wandb.finish()
|
||||
|
||||
logger.info("Training complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue