mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
462 lines
22 KiB
Python
462 lines
22 KiB
Python
import asyncio
|
|
import logging
|
|
import random
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from pydantic import Field
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.type_definitions import Item
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer_dpo
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
class ConversationalStyleDPOEnvConfig(BaseEnvConfig):
|
|
"""Config for ConversationalStyleDPOEnv."""
|
|
|
|
dataset_name: str = Field(
|
|
"synthetic_conversational_style", description="Name of the dataset to use."
|
|
)
|
|
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
|
|
# Add any other environment-specific configurations here if needed
|
|
|
|
|
|
class ConversationalStyleDPOEnv(BaseEnv):
|
|
name = "conversational_style_dpo"
|
|
name_config_cls = ConversationalStyleDPOEnvConfig
|
|
|
|
def __init__(
|
|
self,
|
|
config: ConversationalStyleDPOEnvConfig,
|
|
server_configs: Optional[
|
|
List[APIServerConfig]
|
|
] = None, # server_configs might not be needed if we don't query a model
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
# If you're not calling an external model server for generation in this specific DPO setup
|
|
# (because chosen/rejected are pre-defined), you might not need server_configs.
|
|
# For simplicity, we'll keep it but not use it actively in this example.
|
|
# If server_configs is None and BaseEnv requires it, initialize it as an empty list.
|
|
resolved_server_configs = server_configs if server_configs is not None else []
|
|
super().__init__(config, resolved_server_configs, slurm, testing)
|
|
self.config: ConversationalStyleDPOEnvConfig = (
|
|
config # Ensure type for self.config
|
|
)
|
|
self.dataset: List[Dict[str, str]] = []
|
|
self.iter: int = 0
|
|
|
|
async def setup(self):
|
|
"""Load and prepare the synthetic dataset."""
|
|
# Synthetic dataset: (prompt, chosen_response, rejected_response)
|
|
# Chosen responses are more engaging, empathetic, or clear.
|
|
# Rejected responses are blunt, generic, or less helpful.
|
|
self.synthetic_data = [
|
|
{
|
|
"prompt": "I'm feeling a bit down today.",
|
|
"chosen": "I'm sorry to hear that. Sometimes a little self-care can help. "
|
|
"What's one small thing you could do for yourself right now?",
|
|
"rejected": "Okay.",
|
|
},
|
|
{
|
|
"prompt": "Can you explain how photosynthesis works?",
|
|
"chosen": "Certainly! Photosynthesis is a fascinating process where plants use "
|
|
"sunlight, water, and carbon dioxide to create their own food (glucose) and "
|
|
"release oxygen. Think of it like a plant's kitchen!",
|
|
"rejected": "Plants make food from light.",
|
|
},
|
|
{
|
|
"prompt": "I'm excited about my new project!",
|
|
"chosen": "That's fantastic news! Tell me more about it - "
|
|
"what are you most looking forward to?",
|
|
"rejected": "Good for you.",
|
|
},
|
|
{
|
|
"prompt": "What's the weather like?",
|
|
"chosen": "I can't check the real-time weather, but I hope it's pleasant where "
|
|
"you are! If you'd like, you can tell me your location, and I could try to "
|
|
"give you a general idea based on typical patterns, or help you find a "
|
|
"weather service.",
|
|
"rejected": "I don't know.",
|
|
},
|
|
{
|
|
"prompt": "I'm having trouble understanding this concept.",
|
|
"chosen": "I can understand that some concepts can be tricky! Could you tell me "
|
|
"which part is confusing you? Maybe we can break it down together.",
|
|
"rejected": "Read the manual.",
|
|
},
|
|
]
|
|
self.dataset = self.synthetic_data
|
|
if self.config.shuffle_dataset:
|
|
random.shuffle(self.dataset)
|
|
self.iter = 0
|
|
logger.info(f"Loaded synthetic dataset with {len(self.dataset)} examples.")
|
|
|
|
async def get_next_item(self) -> Item:
|
|
"""
|
|
Returns the next item from the dataset.
|
|
For DPO, an "item" will be a tuple of (prompt, chosen_response, rejected_response).
|
|
The BaseEnv expects (prompt_tuple, gold_answer, optional_extra_data),
|
|
so we'll adapt our DPO item to this structure.
|
|
The prompt_tuple will be the actual prompt.
|
|
The gold_answer will be the chosen response.
|
|
The optional_extra_data will be the rejected response.
|
|
"""
|
|
if not self.dataset or self.iter >= len(self.dataset):
|
|
await self.setup() # Re-setup if dataset is exhausted or not loaded
|
|
|
|
if not self.dataset: # Still no dataset after setup
|
|
logger.error("Dataset is empty even after setup.")
|
|
# Return a fallback item or raise an error
|
|
# For now, let's create a dummy item to avoid crashing, but this should be handled
|
|
fallback_prompt = tuple(
|
|
[frozenset({"role": "user", "content": "Fallback prompt"}.items())]
|
|
)
|
|
return (fallback_prompt, "Fallback chosen", "Fallback rejected")
|
|
|
|
entry = self.dataset[self.iter % len(self.dataset)]
|
|
self.iter += 1
|
|
|
|
# For DPO, the direct strings are often more useful for tokenization with tokenize_for_trainer_dpo
|
|
|
|
# We will pass the raw strings to collect_trajectories and handle tokenization there.
|
|
# For Item, we can simplify or adjust based on how BaseEnv uses it.
|
|
# Let's pass the prompt string directly as the first element of the "prompt_tuple".
|
|
# The "gold_answer" will be the chosen response, and "extra_data" the rejected one.
|
|
|
|
# Adapting to Item: Tuple[prompt_tuple, gold_answer_str, any_extra_data]
|
|
# prompt_tuple: Tuple[FrozenSet[Tuple[str, str]], ...]
|
|
# For DPO, the core elements are prompt, chosen, rejected. We'll pass them directly.
|
|
return (entry["prompt"], entry["chosen"], entry["rejected"])
|
|
|
|
async def collect_trajectories(
|
|
self, item: Item
|
|
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
|
"""
|
|
Processes a single DPO item (prompt, chosen, rejected) and prepares it for the trainer.
|
|
Since we are not querying a model to generate responses (we have them),
|
|
this method will directly use the provided chosen and rejected responses.
|
|
"""
|
|
prompt_str, chosen_response_str, rejected_response_str = item
|
|
|
|
# For DPO, we typically need to tokenize:
|
|
# 1. Prompt + Chosen
|
|
# 2. Prompt + Rejected
|
|
# The `tokenize_for_trainer_dpo` function should handle this.
|
|
# It usually expects (tokenizer, prompt, chosen, rejected)
|
|
|
|
# We don't need to call self.server.chat_completion here as we have the data.
|
|
# We directly prepare the ScoredDataGroup.
|
|
|
|
# Create a dummy GameHistory if parts of the system expect it.
|
|
# For DPO, the history might be just the prompt and the respective response.
|
|
# This part might need adjustment based on how `tokenize_for_trainer_dpo`
|
|
# and the DPO trainer expect the input.
|
|
|
|
# Let's assume `tokenize_for_trainer_dpo` takes the raw strings.
|
|
# The output of `tokenize_for_trainer_dpo` is expected to be a dictionary like:
|
|
# {
|
|
# "chosen_tokens": [...], "chosen_masks": [...],
|
|
# "rejected_tokens": [...], "rejected_masks": [...]
|
|
# }
|
|
# or perhaps including prompt tokens as well.
|
|
# For simplicity, we'll construct a ScoredDataGroup directly.
|
|
|
|
# We need to ensure the tokenizer is available. It's set in BaseEnv.
|
|
if not self.tokenizer:
|
|
logger.error("Tokenizer not available. Cannot process DPO pair.")
|
|
return None, []
|
|
|
|
try:
|
|
# Note: The actual structure of what `tokenize_for_trainer_dpo` returns
|
|
# and how it's used in `ScoredDataGroup` is crucial.
|
|
# This is a common pattern for DPO data preparation.
|
|
# The function would typically create sequences like:
|
|
# <prompt_tokens><chosen_response_tokens>
|
|
# <prompt_tokens><rejected_response_tokens>
|
|
# And corresponding attention masks. The loss is usually calculated only on response tokens.
|
|
|
|
# Constructing the input for tokenize_for_trainer_dpo
|
|
# It usually takes a list of dictionaries or a specific structure.
|
|
# Let's assume it takes a dictionary with prompt, chosen, rejected.
|
|
dpo_pair_data = {
|
|
"prompt": prompt_str,
|
|
"chosen": chosen_response_str,
|
|
"rejected": rejected_response_str,
|
|
}
|
|
|
|
# This function needs to be defined or imported correctly.
|
|
# It should handle the tokenization for DPO, creating chosen and rejected sequences.
|
|
tokenized_output = tokenize_for_trainer_dpo(
|
|
self.tokenizer,
|
|
dpo_pair_data, # Or (self.tokenizer, prompt_str, chosen_response_str, rejected_response_str)
|
|
# depending on its signature.
|
|
max_length=self.config.max_token_length, # Ensure this config is available
|
|
# Add other necessary args for tokenize_for_trainer_dpo
|
|
)
|
|
|
|
scores = ScoredDataGroup()
|
|
# These keys depend on what your DPO trainer expects.
|
|
# Common keys for DPO batches include:
|
|
# - prompt_input_ids, prompt_attention_mask
|
|
# - chosen_input_ids, chosen_attention_mask, chosen_labels
|
|
# - rejected_input_ids, rejected_attention_mask, rejected_labels
|
|
# `tokenize_for_trainer_dpo` should produce these.
|
|
|
|
# Let's assume tokenize_for_trainer_dpo returns a dict with at least:
|
|
# 'chosen_input_ids', 'chosen_attention_mask',
|
|
# 'rejected_input_ids', 'rejected_attention_mask'
|
|
# 'prompt_input_ids' (optional, might be part of chosen/rejected)
|
|
|
|
# Adapting to a simpler ScoredDataGroup structure for this example:
|
|
# We'll store the direct outputs of a hypothetical tokenize_for_trainer_dpo
|
|
scores["chosen_tokens"] = [tokenized_output["chosen_input_ids"]]
|
|
scores["chosen_masks"] = [tokenized_output["chosen_attention_mask"]]
|
|
scores["rejected_tokens"] = [tokenized_output["rejected_input_ids"]]
|
|
scores["rejected_masks"] = [tokenized_output["rejected_attention_mask"]]
|
|
# Optionally, if prompts are tokenized separately:
|
|
if "prompt_input_ids" in tokenized_output:
|
|
scores["prompt_tokens"] = [tokenized_output["prompt_input_ids"]]
|
|
scores["prompt_masks"] = [
|
|
tokenized_output.get("prompt_attention_mask")
|
|
] # Handle if mask isn't there
|
|
|
|
# DPO doesn't use a single "score" like reward models. The "reward" is implicit
|
|
# in the preference (chosen > rejected). So, the "scores" field in ScoredDataGroup
|
|
# might not be directly used or could be set to a placeholder if required by the trainer.
|
|
scores["scores"] = [
|
|
1.0
|
|
] # Placeholder, DPO loss handles preference directly
|
|
|
|
# Images are not used in this environment
|
|
scores["images"] = [None]
|
|
|
|
# Ensure group size logic if processing multiple items for a group
|
|
# For this example, we process one item at a time.
|
|
# If self.config.group_size > 1, you'd aggregate multiple tokenized_outputs
|
|
# before returning the ScoredDataGroup.
|
|
|
|
return scores, [] # No items to backlog
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in collect_trajectories during DPO processing: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
return None, []
|
|
|
|
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
|
"""
|
|
This method is typically for scoring model generations against a gold answer or reward model.
|
|
For DPO with a static dataset, the "scoring" is implicit in the chosen/rejected pair.
|
|
The main processing happens in `collect_trajectories` which prepares tokenized pairs.
|
|
So, this method might not be directly used if `collect_trajectories` already
|
|
returns a `ScoredDataGroup`. If the main loop calls `score` after `collect_trajectories`,
|
|
it might just pass through the data or perform some final aggregation.
|
|
|
|
If `collect_trajectories` returns the raw data (prompt, chosen, rejected)
|
|
and `score` is responsible for tokenization, then the logic from
|
|
`collect_trajectories` related to tokenization would move here.
|
|
|
|
Assuming `collect_trajectories` prepares the `ScoredDataGroup` with tokenized DPO pairs:
|
|
"""
|
|
if rollout_group_data and isinstance(rollout_group_data, ScoredDataGroup):
|
|
# If `collect_trajectories` already produced the ScoredDataGroup
|
|
return rollout_group_data
|
|
elif rollout_group_data and isinstance(rollout_group_data, list):
|
|
# If `collect_trajectories` returned a list of items to be scored/tokenized here.
|
|
# This would mean moving the tokenization logic from `collect_trajectories` to here.
|
|
# For now, let's assume the former.
|
|
logger.warning(
|
|
"`score` method received a list; expecting ScoredDataGroup for pre-processed DPO."
|
|
)
|
|
# Fallback: if you need to process a list of (prompt, chosen, rejected) tuples here:
|
|
# all_chosen_tokens = []
|
|
# ... and so on, then call tokenize_for_trainer_dpo for each item.
|
|
# This depends on the design of BaseEnv's main loop.
|
|
return None
|
|
|
|
logger.info("No data to score or data is already in ScoredDataGroup format.")
|
|
return None
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""
|
|
Evaluation for DPO might involve comparing the DPO-trained model's preferences
|
|
against a held-out set of preferred/rejected pairs or other metrics.
|
|
For this basic environment, we'll skip custom evaluation.
|
|
"""
|
|
logger.info(
|
|
"Evaluation step called. No custom DPO evaluation implemented in this basic environment."
|
|
)
|
|
# You could load a separate test set of (prompt, chosen, rejected)
|
|
# and see if the model assigns higher logprobs to chosen than rejected.
|
|
return None
|
|
|
|
@classmethod
|
|
def config_init(
|
|
cls,
|
|
) -> Tuple[ConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
|
|
"""
|
|
Provides a default configuration for this environment.
|
|
"""
|
|
env_config = ConversationalStyleDPOEnvConfig(
|
|
wandb_name="conversational_style_dpo", # For logging if wandb is used
|
|
tokenizer_name="gpt2", # Choose an appropriate tokenizer
|
|
group_size=4, # Number of DPO pairs to process in a "group" or "batch"
|
|
use_wandb=False, # Enable or disable wandb
|
|
max_num_workers=1, # Number of parallel workers for data collection (if applicable)
|
|
rollout_server_url="http://localhost:8000", # Corrected URL
|
|
total_steps=100, # Total DPO training steps (or epochs over the dataset)
|
|
batch_size=2, # DPO training batch size (distinct from group_size for data collection)
|
|
steps_per_eval=50,
|
|
max_token_length=512, # Max length for tokenized sequences
|
|
dataset_name="synthetic_conversational_style",
|
|
shuffle_dataset=True,
|
|
)
|
|
|
|
server_configs = [] # Simplified as discussed
|
|
|
|
return env_config, server_configs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# This allows running the environment directly, e.g., to test data loading and processing.
|
|
# The `BaseEnv.cli()` method usually sets up and runs the environment's main loop.
|
|
# For DPO, the "main loop" might involve iterating through the dataset,
|
|
# tokenizing pairs, and perhaps logging them or yielding them to a trainer.
|
|
|
|
# To make this runnable and test the data processing:
|
|
async def main_test():
|
|
config, server_configs_list = ConversationalStyleDPOEnv.config_init()
|
|
|
|
# Manually override tokenizer for local testing if needed and not using a server
|
|
# that provides it, or if the default in BaseEnvConfig isn't what you want for DPO.
|
|
# config.tokenizer_name = "EleutherAI/pythia-70m" # Example
|
|
config.tokenizer_name = "distilgpt2" # A small, fast tokenizer for testing
|
|
config.group_size = (
|
|
1 # Process one DPO item at a time for simplicity in this test
|
|
)
|
|
config.use_wandb = False
|
|
|
|
env = ConversationalStyleDPOEnv(
|
|
config=config, server_configs=server_configs_list, slurm=False, testing=True
|
|
)
|
|
|
|
# Initialize tokenizer (BaseEnv usually does this)
|
|
from transformers import AutoTokenizer
|
|
|
|
env.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
|
|
if env.tokenizer.pad_token is None:
|
|
env.tokenizer.pad_token = env.tokenizer.eos_token
|
|
|
|
print("Setting up environment...")
|
|
await env.setup()
|
|
print(f"Dataset size: {len(env.dataset)}")
|
|
|
|
if not env.dataset:
|
|
print("No data loaded. Exiting.")
|
|
return
|
|
|
|
print("Simulating DPO data processing for a few items...")
|
|
for i in range(min(len(env.dataset), 3)): # Test with a few items
|
|
print(f"--- Item {i+1} ---")
|
|
item = await env.get_next_item()
|
|
if item:
|
|
prompt, chosen, rejected = item
|
|
print(f"Prompt: {prompt}")
|
|
print(f"Chosen: {chosen}")
|
|
print(f"Rejected: {rejected}")
|
|
|
|
# Simulate calling collect_trajectories which should do the DPO tokenization
|
|
# In a real run, this would be part of the BaseEnv's loop
|
|
# For testing, we call it directly.
|
|
# The `tokenize_for_trainer_dpo` function needs to exist and be importable.
|
|
# Let's create a placeholder for it here for the test to run.
|
|
global tokenize_for_trainer_dpo
|
|
|
|
def placeholder_tokenize_for_dpo(tokenizer, data, max_length, **kwargs):
|
|
# This is a simplified placeholder. A real one would handle complex tokenization,
|
|
# padding, truncation, and creating labels for DPO loss.
|
|
prompt = data["prompt"]
|
|
chosen = data["chosen"]
|
|
rejected = " ".join(
|
|
data["rejected"].split()[: max_length // 3]
|
|
) # Basic truncation
|
|
|
|
chosen_text = f"{prompt} {chosen}"
|
|
rejected_text = f"{prompt} {rejected}"
|
|
|
|
chosen_tok = tokenizer(
|
|
chosen_text,
|
|
truncation=True,
|
|
padding="max_length",
|
|
max_length=max_length,
|
|
return_tensors="pt",
|
|
)
|
|
rejected_tok = tokenizer(
|
|
rejected_text,
|
|
truncation=True,
|
|
padding="max_length",
|
|
max_length=max_length,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
# A real DPO tokenizer would also prepare labels, where only response tokens are unmasked.
|
|
# For simplicity, we're just returning input_ids and attention_mask.
|
|
return {
|
|
"chosen_input_ids": chosen_tok["input_ids"].squeeze().tolist(),
|
|
"chosen_attention_mask": chosen_tok["attention_mask"]
|
|
.squeeze()
|
|
.tolist(),
|
|
"rejected_input_ids": rejected_tok["input_ids"]
|
|
.squeeze()
|
|
.tolist(),
|
|
"rejected_attention_mask": rejected_tok["attention_mask"]
|
|
.squeeze()
|
|
.tolist(),
|
|
# "prompt_input_ids": ..., # Optionally
|
|
}
|
|
|
|
tokenize_for_trainer_dpo = placeholder_tokenize_for_dpo
|
|
|
|
scored_data_group, _ = await env.collect_trajectories(item)
|
|
|
|
if scored_data_group:
|
|
print("Tokenized DPO Pair (first item in group):")
|
|
print(
|
|
f" Chosen Tokens (IDs): {scored_data_group['chosen_tokens'][0][:20]}..."
|
|
) # Print first 20
|
|
# print(f" Chosen Masks: {scored_data_group['chosen_masks'][0][:20]}...")
|
|
print(
|
|
f" Rejected Tokens (IDs): {scored_data_group['rejected_tokens'][0][:20]}..."
|
|
)
|
|
# print(f" Rejected Masks: {scored_data_group['rejected_masks'][0][:20]}...")
|
|
# print(f" Scores: {scored_data_group['scores']}")
|
|
else:
|
|
print(" Failed to process DPO pair.")
|
|
else:
|
|
print("Failed to get item.")
|
|
|
|
print("--- End of Test ---")
|
|
|
|
# To run the CLI, you would typically not need the main_test async function,
|
|
# but BaseEnv.cli() would handle it.
|
|
# For this example, we provide a way to test the data processing logic.
|
|
# If you want to run with the actual CLI:
|
|
# ConversationalStyleDPOEnv.cli()
|
|
# But ensure `tokenize_for_trainer_dpo` is correctly implemented and importable in that context.
|
|
|
|
# Running the test:
|
|
if __name__ == "__main__":
|
|
# Define a placeholder tokenize_for_trainer_dpo if it's not available globally
|
|
# This is often part of a utils library.
|
|
# Use the imported tokenize_for_trainer_dpo function
|
|
|
|
asyncio.run(main_test())
|