diff --git a/configs/rubiks_training.yaml b/configs/rubiks_training.yaml new file mode 100644 index 00000000..760f5319 --- /dev/null +++ b/configs/rubiks_training.yaml @@ -0,0 +1,28 @@ +# Rubik's Cube Solver Training Configuration + +# Flattened configuration for TrainerConfig +model_name: "NousResearch/DeepHermes-3-Llama-3-8B-Preview" # Target model for training +learning_rate: 5.0e-6 +batch_size: 8 +gradient_accumulation_steps: 4 +sequence_length: 2048 +warmup_steps: 100 + +total_steps: 2000 +eval_every: 50 +save_every: 250 +checkpoint_dir: "./rubiks_checkpoints" +use_wandb: true +wandb_project: "atropos-rubiks-cube" +wandb_run_name: "rubiks-solver-training" + +train_file: "/Users/joshuajerin/Desktop/jarvis/atropos/environments/rubiks_process_results_22.jsonl" +validation_size: 0.1 # 10% for validation +prefer_higher_scores: true +max_samples: -1 # Use all samples + +method: "GRPO" # Group Relative Policy Optimization +temperature: 0.7 +top_p: 0.9 +beta: 0.1 # KL penalty coefficient +reference_model: "NousResearch/DeepHermes-3-Llama-3-3B-Preview" # Smaller reference model \ No newline at end of file diff --git a/environments/train_rubiks_model.py b/environments/train_rubiks_model.py new file mode 100755 index 00000000..7b633087 --- /dev/null +++ b/environments/train_rubiks_model.py @@ -0,0 +1,464 @@ +#!/usr/bin/env python3 +""" +Train a model to solve Rubik's cube using reinforcement learning on collected data. +Based on the example GRPO trainer with modifications for pre-collected data. +""" + +import argparse +import json +import logging +import os +import time +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import transformers +import wandb +import yaml +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +@dataclass +class TrainerConfig: + """Configuration for the trainer.""" + # Model configuration + model_name: str + learning_rate: float + batch_size: int + gradient_accumulation_steps: int + sequence_length: int + warmup_steps: int + + # Training configuration + total_steps: int + eval_every: int + save_every: int + checkpoint_dir: str + use_wandb: bool + wandb_project: str + wandb_run_name: str + + # Data configuration + train_file: str + validation_size: float + prefer_higher_scores: bool + max_samples: int + + # RL configuration + method: str + temperature: float + top_p: float + beta: float + reference_model: Optional[str] = None + + +def load_config(config_path: str) -> TrainerConfig: + """Load configuration from YAML file.""" + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + + # The config is already flat, so we use it directly + return TrainerConfig(**config_dict) + + +def load_jsonl_data(file_path: str, max_samples: int = -1) -> List[Dict]: + """Load data from JSONL file.""" + data = [] + with open(file_path, "r") as f: + for line in f: + data.append(json.loads(line)) + if max_samples > 0 and len(data) >= max_samples: + break + return data + + +def split_train_val(data: List[Dict], val_size: float) -> Tuple[List[Dict], List[Dict]]: + """Split data into training and validation sets.""" + val_count = int(len(data) * val_size) + return data[val_count:], data[:val_count] + + +def prepare_training_batch( + data_batch: List[Dict], + tokenizer, + prefer_higher_scores: bool = True, + device: str = "cuda" +) -> Dict[str, torch.Tensor]: + """ + Prepare a batch for training. + + Args: + data_batch: List of data points from JSONL + tokenizer: Tokenizer for the model + prefer_higher_scores: If True, higher scores are better + device: Device to put tensors on + + Returns: + Dict with input_ids, attention_mask, and scores + """ + batch_tokens = [] + batch_masks = [] + batch_scores = [] + + for item in data_batch: + # For each group, select best and worst sequences based on scores + scores = item["scores"] + tokens = item["tokens"] + masks = item["masks"] + + if prefer_higher_scores: + best_idx = max(range(len(scores)), key=lambda i: scores[i]) + worst_idx = min(range(len(scores)), key=lambda i: scores[i]) + else: + best_idx = min(range(len(scores)), key=lambda i: scores[i]) + worst_idx = max(range(len(scores)), key=lambda i: scores[i]) + + batch_tokens.extend([tokens[best_idx], tokens[worst_idx]]) + batch_masks.extend([masks[best_idx], masks[worst_idx]]) + batch_scores.extend([scores[best_idx], scores[worst_idx]]) + + # Convert to tensors + input_ids = torch.tensor(batch_tokens, dtype=torch.long).to(device) + attention_mask = torch.tensor(batch_masks, dtype=torch.long).to(device) + scores = torch.tensor(batch_scores, dtype=torch.float).to(device) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "scores": scores, + } + + +def compute_grpo_loss( + logprobs: torch.Tensor, + ref_logprobs: torch.Tensor, + scores: torch.Tensor, + beta: float +) -> torch.Tensor: + """ + Compute the Group Relative Policy Optimization loss. + + Args: + logprobs: Log probabilities from the model (batch_size, seq_len) + ref_logprobs: Log probabilities from the reference model (batch_size, seq_len) + scores: Scores for each sequence (batch_size,) + beta: KL penalty coefficient + + Returns: + Loss tensor + """ + batch_size = logprobs.shape[0] + assert batch_size % 2 == 0, "Batch size must be even" + + # Reshape to (batch_size/2, 2, seq_len) + logprobs = logprobs.view(batch_size // 2, 2, -1) + ref_logprobs = ref_logprobs.view(batch_size // 2, 2, -1) + scores = scores.view(batch_size // 2, 2) + + # Calculate policy gradient loss + pg_loss = 0 + for i in range(batch_size // 2): + # Policy gradient - weight by the score difference + score_diff = scores[i, 0] - scores[i, 1] + log_ratio_chosen = logprobs[i, 0].sum() - ref_logprobs[i, 0].sum() + log_ratio_rejected = logprobs[i, 1].sum() - ref_logprobs[i, 1].sum() + + # KL penalty + kl_chosen = (ref_logprobs[i, 0] - logprobs[i, 0]).sum() + kl_rejected = (ref_logprobs[i, 1] - logprobs[i, 1]).sum() + + # Final loss - maximize score difference, minimize KL divergence + pg_loss += -score_diff * (log_ratio_chosen - log_ratio_rejected) + pg_loss += beta * (kl_chosen + kl_rejected) + + return pg_loss / (batch_size // 2) + + +def main(): + parser = argparse.ArgumentParser(description="Train a model on Rubik's cube data") + parser.add_argument("--config", type=str, required=True, help="Path to config YAML") + args = parser.parse_args() + + # Load configuration + config = load_config(args.config) + logger.info(f"Loaded configuration from {args.config}") + + # Set device + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {device}") + + # Initialize wandb if specified + if config.use_wandb: + wandb.init( + project=config.wandb_project, + name=config.wandb_run_name, + config=vars(config) + ) + + # Create checkpoint directory + os.makedirs(config.checkpoint_dir, exist_ok=True) + + # Load tokenizer and model + logger.info(f"Loading model {config.model_name}") + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + model = AutoModelForCausalLM.from_pretrained( + config.model_name, + torch_dtype=torch.float16 if device == "cuda" else torch.float32 + ).to(device) + model.train() + + # Load reference model if specified + ref_model = None + if config.reference_model: + logger.info(f"Loading reference model {config.reference_model}") + ref_model = AutoModelForCausalLM.from_pretrained( + config.reference_model, + torch_dtype=torch.float16 if device == "cuda" else torch.float32 + ).to(device) + ref_model.eval() + + # Set up optimizer and lr scheduler + optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) + scheduler = transformers.get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=config.warmup_steps, + num_training_steps=config.total_steps + ) + + # Load and split data + logger.info(f"Loading data from {config.train_file}") + all_data = load_jsonl_data(config.train_file, config.max_samples) + train_data, val_data = split_train_val(all_data, config.validation_size) + logger.info(f"Loaded {len(train_data)} training and {len(val_data)} validation samples") + + # Training loop + global_step = 0 + best_val_loss = float('inf') + + logger.info("Starting training") + try: + for epoch in range(100): # Large number, will break when steps reached + # Shuffle training data + import random + random.shuffle(train_data) + + for i in range(0, len(train_data), config.batch_size // 2): + batch_data = train_data[i:i + config.batch_size // 2] + if len(batch_data) < config.batch_size // 2: + continue # Skip incomplete batches + + # Prepare batch + batch = prepare_training_batch( + batch_data, + tokenizer, + prefer_higher_scores=config.prefer_higher_scores, + device=device + ) + + # Forward pass + with torch.cuda.amp.autocast(enabled=device == "cuda"): + outputs = model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + return_dict=True + ) + + # Compute log probabilities + logits = outputs.logits[:, :-1] + logprobs = F.log_softmax(logits, dim=-1) + target_ids = batch["input_ids"][:, 1:] + masks = batch["attention_mask"][:, 1:] + + # Get log probs for the chosen tokens + chosen_logprobs = torch.gather( + logprobs, + dim=2, + index=target_ids.unsqueeze(-1) + ).squeeze(-1) + + # Apply mask + chosen_logprobs = chosen_logprobs * masks + + # Get reference log probs if using a reference model + if ref_model: + with torch.no_grad(): + ref_outputs = ref_model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + return_dict=True + ) + ref_logits = ref_outputs.logits[:, :-1] + ref_logprobs = F.log_softmax(ref_logits, dim=-1) + + ref_chosen_logprobs = torch.gather( + ref_logprobs, + dim=2, + index=target_ids.unsqueeze(-1) + ).squeeze(-1) + + # Apply mask + ref_chosen_logprobs = ref_chosen_logprobs * masks + else: + # If no reference model, use the current model's initial state + ref_chosen_logprobs = chosen_logprobs.detach() + + # Compute loss + loss = compute_grpo_loss( + chosen_logprobs, + ref_chosen_logprobs, + batch["scores"], + config.beta + ) + + # Backward pass + loss = loss / config.gradient_accumulation_steps + loss.backward() + + # Update weights if gradient accumulation steps reached + if (global_step + 1) % config.gradient_accumulation_steps == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + # Log progress + if global_step % 10 == 0: + logger.info(f"Step {global_step}: loss = {loss.item() * config.gradient_accumulation_steps:.4f}") + if config.use_wandb: + wandb.log({ + "train/loss": loss.item() * config.gradient_accumulation_steps, + "train/learning_rate": scheduler.get_last_lr()[0], + "train/step": global_step, + }) + + # Evaluate on validation set + if global_step % config.eval_every == 0: + model.eval() + val_losses = [] + + with torch.no_grad(): + for j in range(0, min(len(val_data), 100), config.batch_size // 2): + val_batch_data = val_data[j:j + config.batch_size // 2] + if len(val_batch_data) < config.batch_size // 2: + continue + + val_batch = prepare_training_batch( + val_batch_data, + tokenizer, + prefer_higher_scores=config.prefer_higher_scores, + device=device + ) + + # Forward pass + val_outputs = model( + input_ids=val_batch["input_ids"], + attention_mask=val_batch["attention_mask"], + return_dict=True + ) + + # Compute log probabilities + val_logits = val_outputs.logits[:, :-1] + val_logprobs = F.log_softmax(val_logits, dim=-1) + val_target_ids = val_batch["input_ids"][:, 1:] + val_masks = val_batch["attention_mask"][:, 1:] + + # Get log probs for the chosen tokens + val_chosen_logprobs = torch.gather( + val_logprobs, + dim=2, + index=val_target_ids.unsqueeze(-1) + ).squeeze(-1) + + # Apply mask + val_chosen_logprobs = val_chosen_logprobs * val_masks + + # Get reference log probs + if ref_model: + ref_val_outputs = ref_model( + input_ids=val_batch["input_ids"], + attention_mask=val_batch["attention_mask"], + return_dict=True + ) + ref_val_logits = ref_val_outputs.logits[:, :-1] + ref_val_logprobs = F.log_softmax(ref_val_logits, dim=-1) + + ref_val_chosen_logprobs = torch.gather( + ref_val_logprobs, + dim=2, + index=val_target_ids.unsqueeze(-1) + ).squeeze(-1) + + # Apply mask + ref_val_chosen_logprobs = ref_val_chosen_logprobs * val_masks + else: + # If no reference model, use detached current model outputs + ref_val_chosen_logprobs = val_chosen_logprobs.detach() + + # Compute loss + val_loss = compute_grpo_loss( + val_chosen_logprobs, + ref_val_chosen_logprobs, + val_batch["scores"], + config.beta + ) + val_losses.append(val_loss.item()) + + avg_val_loss = sum(val_losses) / len(val_losses) + logger.info(f"Validation loss: {avg_val_loss:.4f}") + + if config.use_wandb: + wandb.log({ + "val/loss": avg_val_loss, + "val/step": global_step, + }) + + # Save best model + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + logger.info(f"New best validation loss: {best_val_loss:.4f}") + # Save model + output_dir = os.path.join(config.checkpoint_dir, f"best_model") + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + model.train() + + # Save checkpoint + if global_step % config.save_every == 0 and global_step > 0: + output_dir = os.path.join(config.checkpoint_dir, f"checkpoint-{global_step}") + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + # Increment step + global_step += 1 + + # Exit if reached total steps + if global_step >= config.total_steps: + break + + # Exit if reached total steps + if global_step >= config.total_steps: + break + + except KeyboardInterrupt: + logger.info("Training interrupted by user") + + # Save final model + logger.info("Saving final model") + output_dir = os.path.join(config.checkpoint_dir, "final_model") + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + if config.use_wandb: + wandb.finish() + + logger.info("Training complete") + + +if __name__ == "__main__": + main() \ No newline at end of file