atropos/environments/community/conversational_style_dpo/train_dpo_conversational.py

185 lines
7.1 KiB
Python

import asyncio
import logging
from dataclasses import dataclass, field
from typing import Optional
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
TrainingArguments,
)
from trl import DPOTrainer
# Import your custom environment
# The import below assumes this script is in the same directory as conversational_style_dpo_env.py
from .conversational_style_dpo_env import (
ConversationalStyleDPOEnv,
ConversationalStyleDPOEnvConfig,
)
logger = logging.getLogger(__name__)
@dataclass
class ScriptArguments:
"""
Arguments for the DPO training script.
"""
model_name_or_path: str = field(
default="distilgpt2",
metadata={"help": "The model name or path to load from."},
)
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The tokenizer name or path. Defaults to model_name_or_path."
},
)
beta: float = field(default=0.1, metadata={"help": "The beta factor in DPO loss."})
max_prompt_length: int = field(
default=256, metadata={"help": "Max length for prompts."}
)
max_length: int = field(
default=512,
metadata={"help": "Max length for chosen/rejected responses including prompt."},
)
# Add any other TRL DPOTrainer arguments or TrainingArguments here if needed
# For example, learning_rate, per_device_train_batch_size, num_train_epochs etc.
# will be part of TrainingArguments.
async def get_dataset_from_env(
env_config: ConversationalStyleDPOEnvConfig,
) -> Dataset:
"""
Initializes the environment and extracts the synthetic dataset
in the format required by DPOTrainer (list of dicts with prompt, chosen, rejected).
"""
# We don't need server_configs if the env doesn't use them for static dataset loading
env = ConversationalStyleDPOEnv(config=env_config, server_configs=[], testing=True)
await env.setup() # This loads env.synthetic_data
# env.synthetic_data is already a List[Dict[str, str]] with "prompt", "chosen", "rejected"
# Convert it to Hugging Face Dataset
# Check if data is loaded
if not env.dataset:
raise ValueError(
"Dataset is empty after environment setup. Check ConversationalStyleDPOEnv."
)
# The DPOTrainer expects columns named "prompt", "chosen", "rejected"
# The synthetic_data in your environment is already in this format.
# Example: [{"prompt": "...", "chosen": "...", "rejected": "..."}]
hf_dataset = Dataset.from_list(list(env.dataset)) # Ensure it's a fresh list copy
# Log a sample to verify
if len(hf_dataset) > 0:
logger.info(f"Sample from dataset: {hf_dataset[0]}")
else:
logger.warning("Dataset created from environment is empty!")
return hf_dataset
def main():
parser = HfArgumentParser((ScriptArguments, TrainingArguments))
script_args, training_args = parser.parse_args_into_dataclasses()
if script_args.tokenizer_name_or_path is None:
script_args.tokenizer_name_or_path = script_args.model_name_or_path
# --- 1. Initialize Environment and Get Dataset ---
logger.info("Initializing environment to get dataset...")
# Use the default config from your environment, but ensure tokenizer matches
env_dpo_config, _ = ConversationalStyleDPOEnv.config_init()
env_dpo_config.tokenizer_name = (
script_args.tokenizer_name_or_path
) # Align tokenizer
# You might want to adjust other env_dpo_config parameters if needed
# Run the async function to get the dataset
dataset = asyncio.run(get_dataset_from_env(env_dpo_config))
logger.info(f"Loaded dataset with {len(dataset)} examples.")
if len(dataset) == 0:
logger.error("No data loaded. Exiting training.")
return
# --- 2. Load Tokenizer and Models ---
logger.info(f"Loading tokenizer: {script_args.tokenizer_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path)
if tokenizer.pad_token is None:
logger.warning("Tokenizer does not have a pad token. Setting to eos_token.")
tokenizer.pad_token = tokenizer.eos_token
# For some models, you might also need to set tokenizer.pad_token_id
logger.info(f"Loading policy model: {script_args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
# low_cpu_mem_usage=True, # Can be helpful for large models
# torch_dtype=torch.float16, # For mixed precision if GPU supports
)
# Reference model for DPO. If not provided, DPOTrainer will create a copy of the model.
# For simplicity, we'll let DPOTrainer handle creating the reference model by not passing one.
# If you wanted to load a different SFT model as reference, you would do:
# model_ref = AutoModelForCausalLM.from_pretrained(...)
model_ref = None
logger.info(
"Reference model will be a copy of the policy model (handled by DPOTrainer)."
)
# --- 3. Set up Training Arguments ---
# Default DPO training arguments. You might want to customize these.
if training_args.output_dir == "output_dir": # Default value from TrainingArguments
# Output directory will be relative to the script's current working directory when run.
# If run from environments/hack0/conversational_style_dpo/, it will be ./dpo_conversational_trainer_results
training_args.output_dir = "./dpo_conversational_trainer_results"
# training_args.per_device_train_batch_size = 2 # Adjust as needed
# training_args.num_train_epochs = 1 # Keep low for a quick test
# training_args.gradient_accumulation_steps = 1
# training_args.learning_rate = 5e-5
# training_args.logging_steps = 10
# training_args.save_steps = 50
# training_args.report_to = "none" # "wandb" or "tensorboard" if you want to log
logger.info(f"Training Arguments: {training_args}")
# --- 4. Initialize DPOTrainer ---
logger.info("Initializing DPOTrainer...")
dpo_trainer = DPOTrainer(
model=model,
ref_model=model_ref, # If None, a copy of model is made
args=training_args,
beta=script_args.beta,
train_dataset=dataset,
tokenizer=tokenizer,
max_prompt_length=script_args.max_prompt_length,
max_length=script_args.max_length, # Max length of prompt + response
# peft_config=peft_config, # If using PEFT/LoRA
)
logger.info("DPOTrainer initialized.")
# --- 5. Start Training ---
logger.info("Starting DPO training...")
dpo_trainer.train()
logger.info("DPO training completed.")
# --- 6. Save the Model (Optional) ---
if training_args.should_save: # Checks if any save_strategy is enabled
output_save_dir = training_args.output_dir
logger.info(f"Saving model to {output_save_dir}")
dpo_trainer.save_model(output_save_dir)
logger.info("Model saved.")
# Also save the tokenizer
tokenizer.save_pretrained(output_save_dir)
logger.info("Tokenizer saved.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()