mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
185 lines
No EOL
7.1 KiB
Python
185 lines
No EOL
7.1 KiB
Python
import asyncio
|
|
import os
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Optional
|
|
|
|
import torch
|
|
from datasets import Dataset
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
HfArgumentParser,
|
|
TrainingArguments,
|
|
)
|
|
from trl import DPOTrainer
|
|
|
|
# Import your custom environment
|
|
# The import below assumes this script is in the same directory as conversational_style_dpo_env.py
|
|
from .conversational_style_dpo_env import (
|
|
ConversationalStyleDPOEnv,
|
|
ConversationalStyleDPOEnvConfig,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ScriptArguments:
|
|
"""
|
|
Arguments for the DPO training script.
|
|
"""
|
|
|
|
model_name_or_path: str = field(
|
|
default="distilgpt2",
|
|
metadata={"help": "The model name or path to load from."},
|
|
)
|
|
tokenizer_name_or_path: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "The tokenizer name or path. Defaults to model_name_or_path."},
|
|
)
|
|
beta: float = field(
|
|
default=0.1, metadata={"help": "The beta factor in DPO loss."}
|
|
)
|
|
max_prompt_length: int = field(
|
|
default=256, metadata={"help": "Max length for prompts."}
|
|
)
|
|
max_length: int = field(
|
|
default=512,
|
|
metadata={"help": "Max length for chosen/rejected responses including prompt."},
|
|
)
|
|
# Add any other TRL DPOTrainer arguments or TrainingArguments here if needed
|
|
# For example, learning_rate, per_device_train_batch_size, num_train_epochs etc.
|
|
# will be part of TrainingArguments.
|
|
|
|
|
|
async def get_dataset_from_env(
|
|
env_config: ConversationalStyleDPOEnvConfig,
|
|
) -> Dataset:
|
|
"""
|
|
Initializes the environment and extracts the synthetic dataset
|
|
in the format required by DPOTrainer (list of dicts with prompt, chosen, rejected).
|
|
"""
|
|
# We don't need server_configs if the env doesn't use them for static dataset loading
|
|
env = ConversationalStyleDPOEnv(config=env_config, server_configs=[], testing=True)
|
|
await env.setup() # This loads env.synthetic_data
|
|
|
|
# env.synthetic_data is already a List[Dict[str, str]] with "prompt", "chosen", "rejected"
|
|
# Convert it to Hugging Face Dataset
|
|
# Check if data is loaded
|
|
if not env.dataset:
|
|
raise ValueError(
|
|
"Dataset is empty after environment setup. Check ConversationalStyleDPOEnv."
|
|
)
|
|
|
|
# The DPOTrainer expects columns named "prompt", "chosen", "rejected"
|
|
# The synthetic_data in your environment is already in this format.
|
|
# Example: [{"prompt": "...", "chosen": "...", "rejected": "..."}]
|
|
hf_dataset = Dataset.from_list(list(env.dataset)) # Ensure it's a fresh list copy
|
|
|
|
# Log a sample to verify
|
|
if len(hf_dataset) > 0:
|
|
logger.info(f"Sample from dataset: {hf_dataset[0]}")
|
|
else:
|
|
logger.warning("Dataset created from environment is empty!")
|
|
|
|
return hf_dataset
|
|
|
|
|
|
def main():
|
|
parser = HfArgumentParser((ScriptArguments, TrainingArguments))
|
|
script_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
|
if script_args.tokenizer_name_or_path is None:
|
|
script_args.tokenizer_name_or_path = script_args.model_name_or_path
|
|
|
|
# --- 1. Initialize Environment and Get Dataset ---
|
|
logger.info("Initializing environment to get dataset...")
|
|
# Use the default config from your environment, but ensure tokenizer matches
|
|
env_dpo_config, _ = ConversationalStyleDPOEnv.config_init()
|
|
env_dpo_config.tokenizer_name = script_args.tokenizer_name_or_path # Align tokenizer
|
|
# You might want to adjust other env_dpo_config parameters if needed
|
|
|
|
# Run the async function to get the dataset
|
|
dataset = asyncio.run(get_dataset_from_env(env_dpo_config))
|
|
logger.info(f"Loaded dataset with {len(dataset)} examples.")
|
|
|
|
if len(dataset) == 0:
|
|
logger.error("No data loaded. Exiting training.")
|
|
return
|
|
|
|
# --- 2. Load Tokenizer and Models ---
|
|
logger.info(f"Loading tokenizer: {script_args.tokenizer_name_or_path}")
|
|
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path)
|
|
if tokenizer.pad_token is None:
|
|
logger.warning("Tokenizer does not have a pad token. Setting to eos_token.")
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
# For some models, you might also need to set tokenizer.pad_token_id
|
|
|
|
logger.info(f"Loading policy model: {script_args.model_name_or_path}")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
script_args.model_name_or_path,
|
|
# low_cpu_mem_usage=True, # Can be helpful for large models
|
|
# torch_dtype=torch.float16, # For mixed precision if GPU supports
|
|
)
|
|
|
|
# Reference model for DPO. If not provided, DPOTrainer will create a copy of the model.
|
|
# For simplicity, we'll let DPOTrainer handle creating the reference model by not passing one.
|
|
# If you wanted to load a different SFT model as reference, you would do:
|
|
# model_ref = AutoModelForCausalLM.from_pretrained(...)
|
|
model_ref = None
|
|
logger.info("Reference model will be a copy of the policy model (handled by DPOTrainer).")
|
|
|
|
|
|
# --- 3. Set up Training Arguments ---
|
|
# Default DPO training arguments. You might want to customize these.
|
|
if training_args.output_dir == "output_dir": # Default value from TrainingArguments
|
|
# Output directory will be relative to the script's current working directory when run.
|
|
# If run from environments/hack0/conversational_style_dpo/, it will be ./dpo_conversational_trainer_results
|
|
training_args.output_dir = "./dpo_conversational_trainer_results"
|
|
|
|
# training_args.per_device_train_batch_size = 2 # Adjust as needed
|
|
# training_args.num_train_epochs = 1 # Keep low for a quick test
|
|
# training_args.gradient_accumulation_steps = 1
|
|
# training_args.learning_rate = 5e-5
|
|
# training_args.logging_steps = 10
|
|
# training_args.save_steps = 50
|
|
# training_args.report_to = "none" # "wandb" or "tensorboard" if you want to log
|
|
|
|
logger.info(f"Training Arguments: {training_args}")
|
|
|
|
|
|
# --- 4. Initialize DPOTrainer ---
|
|
logger.info("Initializing DPOTrainer...")
|
|
dpo_trainer = DPOTrainer(
|
|
model=model,
|
|
ref_model=model_ref, # If None, a copy of model is made
|
|
args=training_args,
|
|
beta=script_args.beta,
|
|
train_dataset=dataset,
|
|
tokenizer=tokenizer,
|
|
max_prompt_length=script_args.max_prompt_length,
|
|
max_length=script_args.max_length, # Max length of prompt + response
|
|
# peft_config=peft_config, # If using PEFT/LoRA
|
|
)
|
|
logger.info("DPOTrainer initialized.")
|
|
|
|
# --- 5. Start Training ---
|
|
logger.info("Starting DPO training...")
|
|
dpo_trainer.train()
|
|
logger.info("DPO training completed.")
|
|
|
|
# --- 6. Save the Model (Optional) ---
|
|
if training_args.should_save: # Checks if any save_strategy is enabled
|
|
output_save_dir = training_args.output_dir
|
|
logger.info(f"Saving model to {output_save_dir}")
|
|
dpo_trainer.save_model(output_save_dir)
|
|
logger.info("Model saved.")
|
|
# Also save the tokenizer
|
|
tokenizer.save_pretrained(output_save_dir)
|
|
logger.info("Tokenizer saved.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
main() |