atropos/environments/hack0/conversational_style_dpo/train_dpo_conversational.py
2025-05-18 17:50:15 -07:00

185 lines
No EOL
7.1 KiB
Python

import asyncio
import os
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
TrainingArguments,
)
from trl import DPOTrainer
# Import your custom environment
# The import below assumes this script is in the same directory as conversational_style_dpo_env.py
from .conversational_style_dpo_env import (
ConversationalStyleDPOEnv,
ConversationalStyleDPOEnvConfig,
)
logger = logging.getLogger(__name__)
@dataclass
class ScriptArguments:
"""
Arguments for the DPO training script.
"""
model_name_or_path: str = field(
default="distilgpt2",
metadata={"help": "The model name or path to load from."},
)
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={"help": "The tokenizer name or path. Defaults to model_name_or_path."},
)
beta: float = field(
default=0.1, metadata={"help": "The beta factor in DPO loss."}
)
max_prompt_length: int = field(
default=256, metadata={"help": "Max length for prompts."}
)
max_length: int = field(
default=512,
metadata={"help": "Max length for chosen/rejected responses including prompt."},
)
# Add any other TRL DPOTrainer arguments or TrainingArguments here if needed
# For example, learning_rate, per_device_train_batch_size, num_train_epochs etc.
# will be part of TrainingArguments.
async def get_dataset_from_env(
env_config: ConversationalStyleDPOEnvConfig,
) -> Dataset:
"""
Initializes the environment and extracts the synthetic dataset
in the format required by DPOTrainer (list of dicts with prompt, chosen, rejected).
"""
# We don't need server_configs if the env doesn't use them for static dataset loading
env = ConversationalStyleDPOEnv(config=env_config, server_configs=[], testing=True)
await env.setup() # This loads env.synthetic_data
# env.synthetic_data is already a List[Dict[str, str]] with "prompt", "chosen", "rejected"
# Convert it to Hugging Face Dataset
# Check if data is loaded
if not env.dataset:
raise ValueError(
"Dataset is empty after environment setup. Check ConversationalStyleDPOEnv."
)
# The DPOTrainer expects columns named "prompt", "chosen", "rejected"
# The synthetic_data in your environment is already in this format.
# Example: [{"prompt": "...", "chosen": "...", "rejected": "..."}]
hf_dataset = Dataset.from_list(list(env.dataset)) # Ensure it's a fresh list copy
# Log a sample to verify
if len(hf_dataset) > 0:
logger.info(f"Sample from dataset: {hf_dataset[0]}")
else:
logger.warning("Dataset created from environment is empty!")
return hf_dataset
def main():
parser = HfArgumentParser((ScriptArguments, TrainingArguments))
script_args, training_args = parser.parse_args_into_dataclasses()
if script_args.tokenizer_name_or_path is None:
script_args.tokenizer_name_or_path = script_args.model_name_or_path
# --- 1. Initialize Environment and Get Dataset ---
logger.info("Initializing environment to get dataset...")
# Use the default config from your environment, but ensure tokenizer matches
env_dpo_config, _ = ConversationalStyleDPOEnv.config_init()
env_dpo_config.tokenizer_name = script_args.tokenizer_name_or_path # Align tokenizer
# You might want to adjust other env_dpo_config parameters if needed
# Run the async function to get the dataset
dataset = asyncio.run(get_dataset_from_env(env_dpo_config))
logger.info(f"Loaded dataset with {len(dataset)} examples.")
if len(dataset) == 0:
logger.error("No data loaded. Exiting training.")
return
# --- 2. Load Tokenizer and Models ---
logger.info(f"Loading tokenizer: {script_args.tokenizer_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path)
if tokenizer.pad_token is None:
logger.warning("Tokenizer does not have a pad token. Setting to eos_token.")
tokenizer.pad_token = tokenizer.eos_token
# For some models, you might also need to set tokenizer.pad_token_id
logger.info(f"Loading policy model: {script_args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
# low_cpu_mem_usage=True, # Can be helpful for large models
# torch_dtype=torch.float16, # For mixed precision if GPU supports
)
# Reference model for DPO. If not provided, DPOTrainer will create a copy of the model.
# For simplicity, we'll let DPOTrainer handle creating the reference model by not passing one.
# If you wanted to load a different SFT model as reference, you would do:
# model_ref = AutoModelForCausalLM.from_pretrained(...)
model_ref = None
logger.info("Reference model will be a copy of the policy model (handled by DPOTrainer).")
# --- 3. Set up Training Arguments ---
# Default DPO training arguments. You might want to customize these.
if training_args.output_dir == "output_dir": # Default value from TrainingArguments
# Output directory will be relative to the script's current working directory when run.
# If run from environments/hack0/conversational_style_dpo/, it will be ./dpo_conversational_trainer_results
training_args.output_dir = "./dpo_conversational_trainer_results"
# training_args.per_device_train_batch_size = 2 # Adjust as needed
# training_args.num_train_epochs = 1 # Keep low for a quick test
# training_args.gradient_accumulation_steps = 1
# training_args.learning_rate = 5e-5
# training_args.logging_steps = 10
# training_args.save_steps = 50
# training_args.report_to = "none" # "wandb" or "tensorboard" if you want to log
logger.info(f"Training Arguments: {training_args}")
# --- 4. Initialize DPOTrainer ---
logger.info("Initializing DPOTrainer...")
dpo_trainer = DPOTrainer(
model=model,
ref_model=model_ref, # If None, a copy of model is made
args=training_args,
beta=script_args.beta,
train_dataset=dataset,
tokenizer=tokenizer,
max_prompt_length=script_args.max_prompt_length,
max_length=script_args.max_length, # Max length of prompt + response
# peft_config=peft_config, # If using PEFT/LoRA
)
logger.info("DPOTrainer initialized.")
# --- 5. Start Training ---
logger.info("Starting DPO training...")
dpo_trainer.train()
logger.info("DPO training completed.")
# --- 6. Save the Model (Optional) ---
if training_args.should_save: # Checks if any save_strategy is enabled
output_save_dir = training_args.output_dir
logger.info(f"Saving model to {output_save_dir}")
dpo_trainer.save_model(output_save_dir)
logger.info("Model saved.")
# Also save the tokenizer
tokenizer.save_pretrained(output_save_dir)
logger.info("Tokenizer saved.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()