mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
dev - push for submission
This commit is contained in:
parent
c189fc3351
commit
9725761f5b
12 changed files with 2414 additions and 0 deletions
|
|
@ -0,0 +1,185 @@
|
|||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
TrainingArguments,
|
||||
)
|
||||
from trl import DPOTrainer
|
||||
|
||||
# Import your custom environment
|
||||
# The import below assumes this script is in the same directory as conversational_style_dpo_env.py
|
||||
from .conversational_style_dpo_env import (
|
||||
ConversationalStyleDPOEnv,
|
||||
ConversationalStyleDPOEnvConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
Arguments for the DPO training script.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
default="distilgpt2",
|
||||
metadata={"help": "The model name or path to load from."},
|
||||
)
|
||||
tokenizer_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The tokenizer name or path. Defaults to model_name_or_path."},
|
||||
)
|
||||
beta: float = field(
|
||||
default=0.1, metadata={"help": "The beta factor in DPO loss."}
|
||||
)
|
||||
max_prompt_length: int = field(
|
||||
default=256, metadata={"help": "Max length for prompts."}
|
||||
)
|
||||
max_length: int = field(
|
||||
default=512,
|
||||
metadata={"help": "Max length for chosen/rejected responses including prompt."},
|
||||
)
|
||||
# Add any other TRL DPOTrainer arguments or TrainingArguments here if needed
|
||||
# For example, learning_rate, per_device_train_batch_size, num_train_epochs etc.
|
||||
# will be part of TrainingArguments.
|
||||
|
||||
|
||||
async def get_dataset_from_env(
|
||||
env_config: ConversationalStyleDPOEnvConfig,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Initializes the environment and extracts the synthetic dataset
|
||||
in the format required by DPOTrainer (list of dicts with prompt, chosen, rejected).
|
||||
"""
|
||||
# We don't need server_configs if the env doesn't use them for static dataset loading
|
||||
env = ConversationalStyleDPOEnv(config=env_config, server_configs=[], testing=True)
|
||||
await env.setup() # This loads env.synthetic_data
|
||||
|
||||
# env.synthetic_data is already a List[Dict[str, str]] with "prompt", "chosen", "rejected"
|
||||
# Convert it to Hugging Face Dataset
|
||||
# Check if data is loaded
|
||||
if not env.dataset:
|
||||
raise ValueError(
|
||||
"Dataset is empty after environment setup. Check ConversationalStyleDPOEnv."
|
||||
)
|
||||
|
||||
# The DPOTrainer expects columns named "prompt", "chosen", "rejected"
|
||||
# The synthetic_data in your environment is already in this format.
|
||||
# Example: [{"prompt": "...", "chosen": "...", "rejected": "..."}]
|
||||
hf_dataset = Dataset.from_list(list(env.dataset)) # Ensure it's a fresh list copy
|
||||
|
||||
# Log a sample to verify
|
||||
if len(hf_dataset) > 0:
|
||||
logger.info(f"Sample from dataset: {hf_dataset[0]}")
|
||||
else:
|
||||
logger.warning("Dataset created from environment is empty!")
|
||||
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ScriptArguments, TrainingArguments))
|
||||
script_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if script_args.tokenizer_name_or_path is None:
|
||||
script_args.tokenizer_name_or_path = script_args.model_name_or_path
|
||||
|
||||
# --- 1. Initialize Environment and Get Dataset ---
|
||||
logger.info("Initializing environment to get dataset...")
|
||||
# Use the default config from your environment, but ensure tokenizer matches
|
||||
env_dpo_config, _ = ConversationalStyleDPOEnv.config_init()
|
||||
env_dpo_config.tokenizer_name = script_args.tokenizer_name_or_path # Align tokenizer
|
||||
# You might want to adjust other env_dpo_config parameters if needed
|
||||
|
||||
# Run the async function to get the dataset
|
||||
dataset = asyncio.run(get_dataset_from_env(env_dpo_config))
|
||||
logger.info(f"Loaded dataset with {len(dataset)} examples.")
|
||||
|
||||
if len(dataset) == 0:
|
||||
logger.error("No data loaded. Exiting training.")
|
||||
return
|
||||
|
||||
# --- 2. Load Tokenizer and Models ---
|
||||
logger.info(f"Loading tokenizer: {script_args.tokenizer_name_or_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
logger.warning("Tokenizer does not have a pad token. Setting to eos_token.")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
# For some models, you might also need to set tokenizer.pad_token_id
|
||||
|
||||
logger.info(f"Loading policy model: {script_args.model_name_or_path}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
# low_cpu_mem_usage=True, # Can be helpful for large models
|
||||
# torch_dtype=torch.float16, # For mixed precision if GPU supports
|
||||
)
|
||||
|
||||
# Reference model for DPO. If not provided, DPOTrainer will create a copy of the model.
|
||||
# For simplicity, we'll let DPOTrainer handle creating the reference model by not passing one.
|
||||
# If you wanted to load a different SFT model as reference, you would do:
|
||||
# model_ref = AutoModelForCausalLM.from_pretrained(...)
|
||||
model_ref = None
|
||||
logger.info("Reference model will be a copy of the policy model (handled by DPOTrainer).")
|
||||
|
||||
|
||||
# --- 3. Set up Training Arguments ---
|
||||
# Default DPO training arguments. You might want to customize these.
|
||||
if training_args.output_dir == "output_dir": # Default value from TrainingArguments
|
||||
# Output directory will be relative to the script's current working directory when run.
|
||||
# If run from environments/hack0/conversational_style_dpo/, it will be ./dpo_conversational_trainer_results
|
||||
training_args.output_dir = "./dpo_conversational_trainer_results"
|
||||
|
||||
# training_args.per_device_train_batch_size = 2 # Adjust as needed
|
||||
# training_args.num_train_epochs = 1 # Keep low for a quick test
|
||||
# training_args.gradient_accumulation_steps = 1
|
||||
# training_args.learning_rate = 5e-5
|
||||
# training_args.logging_steps = 10
|
||||
# training_args.save_steps = 50
|
||||
# training_args.report_to = "none" # "wandb" or "tensorboard" if you want to log
|
||||
|
||||
logger.info(f"Training Arguments: {training_args}")
|
||||
|
||||
|
||||
# --- 4. Initialize DPOTrainer ---
|
||||
logger.info("Initializing DPOTrainer...")
|
||||
dpo_trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=model_ref, # If None, a copy of model is made
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
max_prompt_length=script_args.max_prompt_length,
|
||||
max_length=script_args.max_length, # Max length of prompt + response
|
||||
# peft_config=peft_config, # If using PEFT/LoRA
|
||||
)
|
||||
logger.info("DPOTrainer initialized.")
|
||||
|
||||
# --- 5. Start Training ---
|
||||
logger.info("Starting DPO training...")
|
||||
dpo_trainer.train()
|
||||
logger.info("DPO training completed.")
|
||||
|
||||
# --- 6. Save the Model (Optional) ---
|
||||
if training_args.should_save: # Checks if any save_strategy is enabled
|
||||
output_save_dir = training_args.output_dir
|
||||
logger.info(f"Saving model to {output_save_dir}")
|
||||
dpo_trainer.save_model(output_save_dir)
|
||||
logger.info("Model saved.")
|
||||
# Also save the tokenizer
|
||||
tokenizer.save_pretrained(output_save_dir)
|
||||
logger.info("Tokenizer saved.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue