mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
dev - push for submission
This commit is contained in:
parent
c189fc3351
commit
9725761f5b
12 changed files with 2414 additions and 0 deletions
1
environments/hack0/conversational_style_dpo/__init__.py
Normal file
1
environments/hack0/conversational_style_dpo/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,440 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import GameHistory, Item # Assuming GameHistory and Item are relevant
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer_dpo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class ConversationalStyleDPOEnvConfig(BaseEnvConfig):
|
||||
"""Config for ConversationalStyleDPOEnv."""
|
||||
|
||||
dataset_name: str = Field(
|
||||
"synthetic_conversational_style", description="Name of the dataset to use."
|
||||
)
|
||||
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
|
||||
# Add any other environment-specific configurations here if needed
|
||||
|
||||
|
||||
class ConversationalStyleDPOEnv(BaseEnv):
|
||||
name = "conversational_style_dpo"
|
||||
name_config_cls = ConversationalStyleDPOEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ConversationalStyleDPOEnvConfig,
|
||||
server_configs: Optional[List[APIServerConfig]] = None, # server_configs might not be needed if we don't query a model
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
# If you're not calling an external model server for generation in this specific DPO setup
|
||||
# (because chosen/rejected are pre-defined), you might not need server_configs.
|
||||
# For simplicity, we'll keep it but not use it actively in this example.
|
||||
# If server_configs is None and BaseEnv requires it, initialize it as an empty list.
|
||||
resolved_server_configs = server_configs if server_configs is not None else []
|
||||
super().__init__(config, resolved_server_configs, slurm, testing)
|
||||
self.config: ConversationalStyleDPOEnvConfig = config # Ensure type for self.config
|
||||
self.dataset: List[Dict[str, str]] = []
|
||||
self.iter: int = 0
|
||||
|
||||
async def setup(self):
|
||||
"""Load and prepare the synthetic dataset."""
|
||||
# Synthetic dataset: (prompt, chosen_response, rejected_response)
|
||||
# Chosen responses are more engaging, empathetic, or clear.
|
||||
# Rejected responses are blunt, generic, or less helpful.
|
||||
self.synthetic_data = [
|
||||
{
|
||||
"prompt": "I'm feeling a bit down today.",
|
||||
"chosen": "I'm sorry to hear that. Sometimes a little self-care can help. What's one small thing you could do for yourself right now?",
|
||||
"rejected": "Okay.",
|
||||
},
|
||||
{
|
||||
"prompt": "Can you explain how photosynthesis works?",
|
||||
"chosen": "Certainly! Photosynthesis is a fascinating process where plants use sunlight, water, and carbon dioxide to create their own food (glucose) and release oxygen. Think of it like a plant's kitchen!",
|
||||
"rejected": "Plants make food from light.",
|
||||
},
|
||||
{
|
||||
"prompt": "I'm excited about my new project!",
|
||||
"chosen": "That's fantastic news! Tell me more about it - what are you most looking forward to?",
|
||||
"rejected": "Good for you.",
|
||||
},
|
||||
{
|
||||
"prompt": "What's the weather like?",
|
||||
"chosen": "I can't check the real-time weather, but I hope it's pleasant where you are! If you'd like, you can tell me your location, and I could try to give you a general idea based on typical patterns, or help you find a weather service.",
|
||||
"rejected": "I don't know.",
|
||||
},
|
||||
{
|
||||
"prompt": "I'm having trouble understanding this concept.",
|
||||
"chosen": "I can understand that some concepts can be tricky! Could you tell me which part is confusing you? Maybe we can break it down together.",
|
||||
"rejected": "Read the manual.",
|
||||
},
|
||||
]
|
||||
self.dataset = self.synthetic_data
|
||||
if self.config.shuffle_dataset:
|
||||
random.shuffle(self.dataset)
|
||||
self.iter = 0
|
||||
logger.info(f"Loaded synthetic dataset with {len(self.dataset)} examples.")
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""
|
||||
Returns the next item from the dataset.
|
||||
For DPO, an "item" will be a tuple of (prompt, chosen_response, rejected_response).
|
||||
The BaseEnv expects (prompt_tuple, gold_answer, optional_extra_data),
|
||||
so we'll adapt our DPO item to this structure.
|
||||
The prompt_tuple will be the actual prompt.
|
||||
The gold_answer will be the chosen response.
|
||||
The optional_extra_data will be the rejected response.
|
||||
"""
|
||||
if not self.dataset or self.iter >= len(self.dataset):
|
||||
await self.setup() # Re-setup if dataset is exhausted or not loaded
|
||||
|
||||
if not self.dataset: # Still no dataset after setup
|
||||
logger.error("Dataset is empty even after setup.")
|
||||
# Return a fallback item or raise an error
|
||||
# For now, let's create a dummy item to avoid crashing, but this should be handled
|
||||
fallback_prompt = tuple([frozenset({"role": "user", "content": "Fallback prompt"}.items())])
|
||||
return (fallback_prompt, "Fallback chosen", "Fallback rejected")
|
||||
|
||||
|
||||
entry = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
|
||||
# Create a prompt tuple as expected by some parts of the system
|
||||
# (e.g. if it were to be sent to a model, though we are not doing that here)
|
||||
prompt_messages = [{"role": "user", "content": entry["prompt"]}]
|
||||
# Convert to the frozenset structure if that's what downstream components expect
|
||||
# For DPO, the direct strings are often more useful for tokenization with tokenize_for_trainer_dpo
|
||||
|
||||
# We will pass the raw strings to collect_trajectories and handle tokenization there.
|
||||
# For Item, we can simplify or adjust based on how BaseEnv uses it.
|
||||
# Let's pass the prompt string directly as the first element of the "prompt_tuple".
|
||||
# The "gold_answer" will be the chosen response, and "extra_data" the rejected one.
|
||||
|
||||
# Adapting to Item: Tuple[prompt_tuple, gold_answer_str, any_extra_data]
|
||||
# prompt_tuple: Tuple[FrozenSet[Tuple[str, str]], ...]
|
||||
# For DPO, the core elements are prompt, chosen, rejected. We'll pass them directly.
|
||||
return (entry["prompt"], entry["chosen"], entry["rejected"])
|
||||
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
"""
|
||||
Processes a single DPO item (prompt, chosen, rejected) and prepares it for the trainer.
|
||||
Since we are not querying a model to generate responses (we have them),
|
||||
this method will directly use the provided chosen and rejected responses.
|
||||
"""
|
||||
prompt_str, chosen_response_str, rejected_response_str = item
|
||||
|
||||
# For DPO, we typically need to tokenize:
|
||||
# 1. Prompt + Chosen
|
||||
# 2. Prompt + Rejected
|
||||
# The `tokenize_for_trainer_dpo` function should handle this.
|
||||
# It usually expects (tokenizer, prompt, chosen, rejected)
|
||||
|
||||
# We don't need to call self.server.chat_completion here as we have the data.
|
||||
# We directly prepare the ScoredDataGroup.
|
||||
|
||||
# Create a dummy GameHistory if parts of the system expect it.
|
||||
# For DPO, the history might be just the prompt and the respective response.
|
||||
# This part might need adjustment based on how `tokenize_for_trainer_dpo`
|
||||
# and the DPO trainer expect the input.
|
||||
|
||||
# Let's assume `tokenize_for_trainer_dpo` takes the raw strings.
|
||||
# The output of `tokenize_for_trainer_dpo` is expected to be a dictionary like:
|
||||
# {
|
||||
# "chosen_tokens": [...], "chosen_masks": [...],
|
||||
# "rejected_tokens": [...], "rejected_masks": [...]
|
||||
# }
|
||||
# or perhaps including prompt tokens as well.
|
||||
# For simplicity, we'll construct a ScoredDataGroup directly.
|
||||
|
||||
# We need to ensure the tokenizer is available. It's set in BaseEnv.
|
||||
if not self.tokenizer:
|
||||
logger.error("Tokenizer not available. Cannot process DPO pair.")
|
||||
return None, []
|
||||
|
||||
try:
|
||||
# Note: The actual structure of what `tokenize_for_trainer_dpo` returns
|
||||
# and how it's used in `ScoredDataGroup` is crucial.
|
||||
# This is a common pattern for DPO data preparation.
|
||||
# The function would typically create sequences like:
|
||||
# <prompt_tokens><chosen_response_tokens>
|
||||
# <prompt_tokens><rejected_response_tokens>
|
||||
# And corresponding attention masks. The loss is usually calculated only on response tokens.
|
||||
|
||||
# Constructing the input for tokenize_for_trainer_dpo
|
||||
# It usually takes a list of dictionaries or a specific structure.
|
||||
# Let's assume it takes a dictionary with prompt, chosen, rejected.
|
||||
dpo_pair_data = {
|
||||
"prompt": prompt_str,
|
||||
"chosen": chosen_response_str,
|
||||
"rejected": rejected_response_str,
|
||||
}
|
||||
|
||||
# This function needs to be defined or imported correctly.
|
||||
# It should handle the tokenization for DPO, creating chosen and rejected sequences.
|
||||
tokenized_output = tokenize_for_trainer_dpo(
|
||||
self.tokenizer,
|
||||
dpo_pair_data, # Or (self.tokenizer, prompt_str, chosen_response_str, rejected_response_str)
|
||||
# depending on its signature.
|
||||
max_length=self.config.max_token_length, # Ensure this config is available
|
||||
# Add other necessary args for tokenize_for_trainer_dpo
|
||||
)
|
||||
|
||||
scores = ScoredDataGroup()
|
||||
# These keys depend on what your DPO trainer expects.
|
||||
# Common keys for DPO batches include:
|
||||
# - prompt_input_ids, prompt_attention_mask
|
||||
# - chosen_input_ids, chosen_attention_mask, chosen_labels
|
||||
# - rejected_input_ids, rejected_attention_mask, rejected_labels
|
||||
# `tokenize_for_trainer_dpo` should produce these.
|
||||
|
||||
# Let's assume tokenize_for_trainer_dpo returns a dict with at least:
|
||||
# 'chosen_input_ids', 'chosen_attention_mask',
|
||||
# 'rejected_input_ids', 'rejected_attention_mask'
|
||||
# 'prompt_input_ids' (optional, might be part of chosen/rejected)
|
||||
|
||||
# Adapting to a simpler ScoredDataGroup structure for this example:
|
||||
# We'll store the direct outputs of a hypothetical tokenize_for_trainer_dpo
|
||||
scores["chosen_tokens"] = [tokenized_output["chosen_input_ids"]]
|
||||
scores["chosen_masks"] = [tokenized_output["chosen_attention_mask"]]
|
||||
scores["rejected_tokens"] = [tokenized_output["rejected_input_ids"]]
|
||||
scores["rejected_masks"] = [tokenized_output["rejected_attention_mask"]]
|
||||
# Optionally, if prompts are tokenized separately:
|
||||
if "prompt_input_ids" in tokenized_output:
|
||||
scores["prompt_tokens"] = [tokenized_output["prompt_input_ids"]]
|
||||
scores["prompt_masks"] = [tokenized_output.get("prompt_attention_mask")] # Handle if mask isn't there
|
||||
|
||||
# DPO doesn't use a single "score" like reward models. The "reward" is implicit
|
||||
# in the preference (chosen > rejected). So, the "scores" field in ScoredDataGroup
|
||||
# might not be directly used or could be set to a placeholder if required by the trainer.
|
||||
scores["scores"] = [1.0] # Placeholder, DPO loss handles preference directly
|
||||
|
||||
# Images are not used in this environment
|
||||
scores["images"] = [None]
|
||||
|
||||
# Ensure group size logic if processing multiple items for a group
|
||||
# For this example, we process one item at a time.
|
||||
# If self.config.group_size > 1, you'd aggregate multiple tokenized_outputs
|
||||
# before returning the ScoredDataGroup.
|
||||
|
||||
return scores, [] # No items to backlog
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in collect_trajectories during DPO processing: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None, []
|
||||
|
||||
|
||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||
"""
|
||||
This method is typically for scoring model generations against a gold answer or reward model.
|
||||
For DPO with a static dataset, the "scoring" is implicit in the chosen/rejected pair.
|
||||
The main processing happens in `collect_trajectories` which prepares tokenized pairs.
|
||||
So, this method might not be directly used if `collect_trajectories` already
|
||||
returns a `ScoredDataGroup`. If the main loop calls `score` after `collect_trajectories`,
|
||||
it might just pass through the data or perform some final aggregation.
|
||||
|
||||
If `collect_trajectories` returns the raw data (prompt, chosen, rejected)
|
||||
and `score` is responsible for tokenization, then the logic from
|
||||
`collect_trajectories` related to tokenization would move here.
|
||||
|
||||
Assuming `collect_trajectories` prepares the `ScoredDataGroup` with tokenized DPO pairs:
|
||||
"""
|
||||
if rollout_group_data and isinstance(rollout_group_data, ScoredDataGroup):
|
||||
# If `collect_trajectories` already produced the ScoredDataGroup
|
||||
return rollout_group_data
|
||||
elif rollout_group_data and isinstance(rollout_group_data, list):
|
||||
# If `collect_trajectories` returned a list of items to be scored/tokenized here.
|
||||
# This would mean moving the tokenization logic from `collect_trajectories` to here.
|
||||
# For now, let's assume the former.
|
||||
logger.warning("`score` method received a list; expecting ScoredDataGroup for pre-processed DPO.")
|
||||
# Fallback: if you need to process a list of (prompt, chosen, rejected) tuples here:
|
||||
# all_chosen_tokens = []
|
||||
# ... and so on, then call tokenize_for_trainer_dpo for each item.
|
||||
# This depends on the design of BaseEnv's main loop.
|
||||
return None
|
||||
|
||||
logger.info("No data to score or data is already in ScoredDataGroup format.")
|
||||
return None
|
||||
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Evaluation for DPO might involve comparing the DPO-trained model's preferences
|
||||
against a held-out set of preferred/rejected pairs or other metrics.
|
||||
For this basic environment, we'll skip custom evaluation.
|
||||
"""
|
||||
logger.info("Evaluation step called. No custom DPO evaluation implemented in this basic environment.")
|
||||
# You could load a separate test set of (prompt, chosen, rejected)
|
||||
# and see if the model assigns higher logprobs to chosen than rejected.
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[ConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Provides a default configuration for this environment.
|
||||
"""
|
||||
env_config = ConversationalStyleDPOEnvConfig(
|
||||
wandb_name="conversational_style_dpo", # For logging if wandb is used
|
||||
tokenizer_name="gpt2", # Choose an appropriate tokenizer
|
||||
group_size=4, # Number of DPO pairs to process in a "group" or "batch"
|
||||
use_wandb=False, # Enable or disable wandb
|
||||
max_num_workers=1, # Number of parallel workers for data collection (if applicable)
|
||||
rollout_server_url="http://localhost:8000", # Corrected URL
|
||||
total_steps=100, # Total DPO training steps (or epochs over the dataset)
|
||||
batch_size=2, # DPO training batch size (distinct from group_size for data collection)
|
||||
steps_per_eval=50,
|
||||
max_token_length=512, # Max length for tokenized sequences
|
||||
dataset_name="synthetic_conversational_style",
|
||||
shuffle_dataset=True,
|
||||
)
|
||||
|
||||
server_configs = [] # Simplified as discussed
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This allows running the environment directly, e.g., to test data loading and processing.
|
||||
# The `BaseEnv.cli()` method usually sets up and runs the environment's main loop.
|
||||
# For DPO, the "main loop" might involve iterating through the dataset,
|
||||
# tokenizing pairs, and perhaps logging them or yielding them to a trainer.
|
||||
|
||||
# To make this runnable and test the data processing:
|
||||
async def main_test():
|
||||
config, server_configs_list = ConversationalStyleDPOEnv.config_init()
|
||||
|
||||
# Manually override tokenizer for local testing if needed and not using a server
|
||||
# that provides it, or if the default in BaseEnvConfig isn't what you want for DPO.
|
||||
# config.tokenizer_name = "EleutherAI/pythia-70m" # Example
|
||||
config.tokenizer_name = "distilgpt2" # A small, fast tokenizer for testing
|
||||
config.group_size = 1 # Process one DPO item at a time for simplicity in this test
|
||||
config.use_wandb = False
|
||||
|
||||
env = ConversationalStyleDPOEnv(config=config, server_configs=server_configs_list, slurm=False, testing=True)
|
||||
|
||||
# Initialize tokenizer (BaseEnv usually does this)
|
||||
from transformers import AutoTokenizer
|
||||
env.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
|
||||
if env.tokenizer.pad_token is None:
|
||||
env.tokenizer.pad_token = env.tokenizer.eos_token
|
||||
|
||||
|
||||
print("Setting up environment...")
|
||||
await env.setup()
|
||||
print(f"Dataset size: {len(env.dataset)}")
|
||||
|
||||
if not env.dataset:
|
||||
print("No data loaded. Exiting.")
|
||||
return
|
||||
|
||||
print("Simulating DPO data processing for a few items...")
|
||||
for i in range(min(len(env.dataset), 3)): # Test with a few items
|
||||
print(f"--- Item {i+1} ---")
|
||||
item = await env.get_next_item()
|
||||
if item:
|
||||
prompt, chosen, rejected = item
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Chosen: {chosen}")
|
||||
print(f"Rejected: {rejected}")
|
||||
|
||||
# Simulate calling collect_trajectories which should do the DPO tokenization
|
||||
# In a real run, this would be part of the BaseEnv's loop
|
||||
# For testing, we call it directly.
|
||||
# The `tokenize_for_trainer_dpo` function needs to exist and be importable.
|
||||
# Let's create a placeholder for it here for the test to run.
|
||||
global tokenize_for_trainer_dpo
|
||||
def placeholder_tokenize_for_dpo(tokenizer, data, max_length, **kwargs):
|
||||
# This is a simplified placeholder. A real one would handle complex tokenization,
|
||||
# padding, truncation, and creating labels for DPO loss.
|
||||
prompt = data["prompt"]
|
||||
chosen = data["chosen"]
|
||||
rejected = " ".join(data["rejected"].split()[:max_length//3]) # Basic truncation
|
||||
|
||||
chosen_text = f"{prompt} {chosen}"
|
||||
rejected_text = f"{prompt} {rejected}"
|
||||
|
||||
chosen_tok = tokenizer(chosen_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
rejected_tok = tokenizer(rejected_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
|
||||
# A real DPO tokenizer would also prepare labels, where only response tokens are unmasked.
|
||||
# For simplicity, we're just returning input_ids and attention_mask.
|
||||
return {
|
||||
"chosen_input_ids": chosen_tok["input_ids"].squeeze().tolist(),
|
||||
"chosen_attention_mask": chosen_tok["attention_mask"].squeeze().tolist(),
|
||||
"rejected_input_ids": rejected_tok["input_ids"].squeeze().tolist(),
|
||||
"rejected_attention_mask": rejected_tok["attention_mask"].squeeze().tolist(),
|
||||
# "prompt_input_ids": ..., # Optionally
|
||||
}
|
||||
tokenize_for_trainer_dpo = placeholder_tokenize_for_dpo
|
||||
|
||||
|
||||
scored_data_group, _ = await env.collect_trajectories(item)
|
||||
|
||||
if scored_data_group:
|
||||
print("Tokenized DPO Pair (first item in group):")
|
||||
print(f" Chosen Tokens (IDs): {scored_data_group['chosen_tokens'][0][:20]}...") # Print first 20
|
||||
# print(f" Chosen Masks: {scored_data_group['chosen_masks'][0][:20]}...")
|
||||
print(f" Rejected Tokens (IDs): {scored_data_group['rejected_tokens'][0][:20]}...")
|
||||
# print(f" Rejected Masks: {scored_data_group['rejected_masks'][0][:20]}...")
|
||||
# print(f" Scores: {scored_data_group['scores']}")
|
||||
else:
|
||||
print(" Failed to process DPO pair.")
|
||||
else:
|
||||
print("Failed to get item.")
|
||||
|
||||
print("--- End of Test ---")
|
||||
|
||||
# To run the CLI, you would typically not need the main_test async function,
|
||||
# but BaseEnv.cli() would handle it.
|
||||
# For this example, we provide a way to test the data processing logic.
|
||||
# If you want to run with the actual CLI:
|
||||
# ConversationalStyleDPOEnv.cli()
|
||||
# But ensure `tokenize_for_trainer_dpo` is correctly implemented and importable in that context.
|
||||
|
||||
# Running the test:
|
||||
if __name__ == "__main__":
|
||||
# Define a placeholder tokenize_for_trainer_dpo if it's not available globally
|
||||
# This is often part of a utils library.
|
||||
if "tokenize_for_trainer_dpo" not in globals():
|
||||
def placeholder_tokenize_for_dpo(tokenizer, data, max_length, **kwargs):
|
||||
prompt = data["prompt"]
|
||||
chosen = data["chosen"]
|
||||
# basic truncation for rejected to fit if too long with prompt
|
||||
max_rej_tokens = max_length - len(tokenizer.encode(prompt)) - 10 # buffer
|
||||
rejected_tokens = tokenizer.encode(data["rejected"])
|
||||
if len(rejected_tokens) > max_rej_tokens:
|
||||
rejected_tokens = rejected_tokens[:max_rej_tokens]
|
||||
rejected = tokenizer.decode(rejected_tokens)
|
||||
|
||||
|
||||
chosen_text = f"{prompt} {chosen}" # Simplistic combination
|
||||
rejected_text = f"{prompt} {rejected}" # Simplistic combination
|
||||
|
||||
chosen_tok = tokenizer(chosen_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
rejected_tok = tokenizer(rejected_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
|
||||
return {
|
||||
"chosen_input_ids": chosen_tok["input_ids"].squeeze().tolist(),
|
||||
"chosen_attention_mask": chosen_tok["attention_mask"].squeeze().tolist(),
|
||||
"rejected_input_ids": rejected_tok["input_ids"].squeeze().tolist(),
|
||||
"rejected_attention_mask": rejected_tok["attention_mask"].squeeze().tolist(),
|
||||
}
|
||||
tokenize_for_trainer_dpo = placeholder_tokenize_for_dpo
|
||||
|
||||
asyncio.run(main_test())
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Rendered Messages - gsm8k_dpo_rollouts_1.jsonl</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; line-height: 1.6; margin: 20px; }
|
||||
details { border: 1px solid #ccc; border-radius: 4px; margin-bottom: 15px; }
|
||||
summary {
|
||||
font-weight: bold;
|
||||
padding: 10px;
|
||||
background-color: #f0f0f0;
|
||||
cursor: pointer;
|
||||
border-radius: 4px 4px 0 0;
|
||||
border-bottom: 1px solid #ccc;
|
||||
outline: none; /* Remove default focus outline if needed */
|
||||
}
|
||||
details[open] summary { border-bottom: 1px solid #ccc; }
|
||||
.group-content { padding: 15px; }
|
||||
.item {
|
||||
border: 1px solid #eee;
|
||||
border-radius: 3px;
|
||||
margin-bottom: 10px;
|
||||
padding: 10px;
|
||||
transition: background-color 0.3s ease, box-shadow 0.2s ease; /* Smooth transitions */
|
||||
scroll-margin-top: 10px; /* Space when scrolling into view */
|
||||
}
|
||||
.item h4 { margin-top: 0; margin-bottom: 5px; font-size: 1.1em; }
|
||||
.content-block { background-color: #fff; padding: 8px; border-radius: 3px; margin-bottom: 5px; overflow-x: auto; }
|
||||
/* Use :focus-within for better accessibility on container focus */
|
||||
.item:focus, .item.active {
|
||||
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.5); /* Highlight active/focused item */
|
||||
outline: none; /* Remove default outline */
|
||||
}
|
||||
|
||||
/* Score Backgrounds (Faint) */
|
||||
.reward-positive { background-color: rgba(144, 238, 144, 0.3); } /* Faint light green */
|
||||
.reward-zero { background-color: rgba(255, 215, 0, 0.3); } /* Faint gold/orange */
|
||||
.reward-negative { background-color: rgba(255, 182, 193, 0.4); } /* Faint light pink/red */
|
||||
|
||||
/* Markdown specific styles */
|
||||
.content-block pre {
|
||||
background-color: #f5f5f5;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 3px;
|
||||
padding: 10px;
|
||||
overflow-x: auto; /* Allow horizontal scrolling for long code lines */
|
||||
white-space: pre-wrap; /* Wrap long lines within pre */
|
||||
word-wrap: break-word; /* Break long words if necessary */
|
||||
}
|
||||
.content-block code {
|
||||
background-color: #f0f0f0; /* Slightly different for inline code */
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: 3px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
.content-block pre code {
|
||||
background-color: transparent; /* Don't double-background code in pre blocks */
|
||||
padding: 0;
|
||||
border-radius: 0;
|
||||
font-size: inherit; /* Inherit pre font size */
|
||||
}
|
||||
.content-block blockquote {
|
||||
border-left: 4px solid #ccc;
|
||||
padding-left: 10px;
|
||||
margin-left: 0;
|
||||
color: #555;
|
||||
}
|
||||
.content-block table {
|
||||
border-collapse: collapse;
|
||||
width: auto; /* Don't force full width */
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.content-block th, .content-block td {
|
||||
border: 1px solid #ddd;
|
||||
padding: 8px;
|
||||
text-align: left;
|
||||
}
|
||||
.content-block th {
|
||||
background-color: #f2f2f2;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Rendered Messages - gsm8k_dpo_rollouts_1.jsonl</h1>
|
||||
<div id="groups-container">
|
||||
<p>No data to display. Input file might be empty or contain invalid data.</p>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
const items = document.querySelectorAll('.item');
|
||||
let activeIndex = -1; // No item active initially
|
||||
|
||||
// Function to set active item
|
||||
function setActiveItem(index) {
|
||||
if (activeIndex >= 0 && activeIndex < items.length) {
|
||||
items[activeIndex].classList.remove('active');
|
||||
items[activeIndex].removeAttribute('tabindex'); // Remove from tab order when not active
|
||||
}
|
||||
if (index >= 0 && index < items.length) {
|
||||
items[index].classList.add('active');
|
||||
items[index].setAttribute('tabindex', '0'); // Make active item focusable
|
||||
items[index].focus(); // Focus the element
|
||||
// Ensure the parent <details> is open
|
||||
const detailsParent = items[index].closest('details');
|
||||
if (detailsParent && !detailsParent.open) {
|
||||
detailsParent.open = true;
|
||||
}
|
||||
// Scroll into view with options if needed (focus should handle this mostly)
|
||||
// items[index].scrollIntoView({ behavior: 'smooth', block: 'nearest' });
|
||||
activeIndex = index;
|
||||
} else {
|
||||
activeIndex = -1; // Deactivate if index is out of bounds
|
||||
}
|
||||
}
|
||||
|
||||
// Add click listener to activate items
|
||||
items.forEach((item, index) => {
|
||||
item.addEventListener('click', () => {
|
||||
setActiveItem(index);
|
||||
});
|
||||
// Make items focusable initially only if we want tab navigation
|
||||
// item.setAttribute('tabindex', '0');
|
||||
});
|
||||
|
||||
// Add keydown listener for arrow navigation
|
||||
document.addEventListener('keydown', (event) => {
|
||||
let targetIndex = -1;
|
||||
if (event.key === 'ArrowDown') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? 0 : Math.min(activeIndex + 1, items.length - 1);
|
||||
} else if (event.key === 'ArrowUp') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? items.length - 1 : Math.max(activeIndex - 1, 0);
|
||||
}
|
||||
|
||||
if (targetIndex !== -1) {
|
||||
setActiveItem(targetIndex);
|
||||
}
|
||||
});
|
||||
|
||||
// Make first item focusable initially if you want immediate keyboard nav
|
||||
if (items.length > 0) {
|
||||
// items[0].setAttribute('tabindex', '0');
|
||||
// Optionally activate the first item on load:
|
||||
// setActiveItem(0);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,153 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Rendered Messages - gsm8k_dpo_rollouts_2.jsonl</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; line-height: 1.6; margin: 20px; }
|
||||
details { border: 1px solid #ccc; border-radius: 4px; margin-bottom: 15px; }
|
||||
summary {
|
||||
font-weight: bold;
|
||||
padding: 10px;
|
||||
background-color: #f0f0f0;
|
||||
cursor: pointer;
|
||||
border-radius: 4px 4px 0 0;
|
||||
border-bottom: 1px solid #ccc;
|
||||
outline: none; /* Remove default focus outline if needed */
|
||||
}
|
||||
details[open] summary { border-bottom: 1px solid #ccc; }
|
||||
.group-content { padding: 15px; }
|
||||
.item {
|
||||
border: 1px solid #eee;
|
||||
border-radius: 3px;
|
||||
margin-bottom: 10px;
|
||||
padding: 10px;
|
||||
transition: background-color 0.3s ease, box-shadow 0.2s ease; /* Smooth transitions */
|
||||
scroll-margin-top: 10px; /* Space when scrolling into view */
|
||||
}
|
||||
.item h4 { margin-top: 0; margin-bottom: 5px; font-size: 1.1em; }
|
||||
.content-block { background-color: #fff; padding: 8px; border-radius: 3px; margin-bottom: 5px; overflow-x: auto; }
|
||||
/* Use :focus-within for better accessibility on container focus */
|
||||
.item:focus, .item.active {
|
||||
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.5); /* Highlight active/focused item */
|
||||
outline: none; /* Remove default outline */
|
||||
}
|
||||
|
||||
/* Score Backgrounds (Faint) */
|
||||
.reward-positive { background-color: rgba(144, 238, 144, 0.3); } /* Faint light green */
|
||||
.reward-zero { background-color: rgba(255, 215, 0, 0.3); } /* Faint gold/orange */
|
||||
.reward-negative { background-color: rgba(255, 182, 193, 0.4); } /* Faint light pink/red */
|
||||
|
||||
/* Markdown specific styles */
|
||||
.content-block pre {
|
||||
background-color: #f5f5f5;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 3px;
|
||||
padding: 10px;
|
||||
overflow-x: auto; /* Allow horizontal scrolling for long code lines */
|
||||
white-space: pre-wrap; /* Wrap long lines within pre */
|
||||
word-wrap: break-word; /* Break long words if necessary */
|
||||
}
|
||||
.content-block code {
|
||||
background-color: #f0f0f0; /* Slightly different for inline code */
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: 3px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
.content-block pre code {
|
||||
background-color: transparent; /* Don't double-background code in pre blocks */
|
||||
padding: 0;
|
||||
border-radius: 0;
|
||||
font-size: inherit; /* Inherit pre font size */
|
||||
}
|
||||
.content-block blockquote {
|
||||
border-left: 4px solid #ccc;
|
||||
padding-left: 10px;
|
||||
margin-left: 0;
|
||||
color: #555;
|
||||
}
|
||||
.content-block table {
|
||||
border-collapse: collapse;
|
||||
width: auto; /* Don't force full width */
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.content-block th, .content-block td {
|
||||
border: 1px solid #ddd;
|
||||
padding: 8px;
|
||||
text-align: left;
|
||||
}
|
||||
.content-block th {
|
||||
background-color: #f2f2f2;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Rendered Messages - gsm8k_dpo_rollouts_2.jsonl</h1>
|
||||
<div id="groups-container">
|
||||
<p>No data to display. Input file might be empty or contain invalid data.</p>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
const items = document.querySelectorAll('.item');
|
||||
let activeIndex = -1; // No item active initially
|
||||
|
||||
// Function to set active item
|
||||
function setActiveItem(index) {
|
||||
if (activeIndex >= 0 && activeIndex < items.length) {
|
||||
items[activeIndex].classList.remove('active');
|
||||
items[activeIndex].removeAttribute('tabindex'); // Remove from tab order when not active
|
||||
}
|
||||
if (index >= 0 && index < items.length) {
|
||||
items[index].classList.add('active');
|
||||
items[index].setAttribute('tabindex', '0'); // Make active item focusable
|
||||
items[index].focus(); // Focus the element
|
||||
// Ensure the parent <details> is open
|
||||
const detailsParent = items[index].closest('details');
|
||||
if (detailsParent && !detailsParent.open) {
|
||||
detailsParent.open = true;
|
||||
}
|
||||
// Scroll into view with options if needed (focus should handle this mostly)
|
||||
// items[index].scrollIntoView({ behavior: 'smooth', block: 'nearest' });
|
||||
activeIndex = index;
|
||||
} else {
|
||||
activeIndex = -1; // Deactivate if index is out of bounds
|
||||
}
|
||||
}
|
||||
|
||||
// Add click listener to activate items
|
||||
items.forEach((item, index) => {
|
||||
item.addEventListener('click', () => {
|
||||
setActiveItem(index);
|
||||
});
|
||||
// Make items focusable initially only if we want tab navigation
|
||||
// item.setAttribute('tabindex', '0');
|
||||
});
|
||||
|
||||
// Add keydown listener for arrow navigation
|
||||
document.addEventListener('keydown', (event) => {
|
||||
let targetIndex = -1;
|
||||
if (event.key === 'ArrowDown') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? 0 : Math.min(activeIndex + 1, items.length - 1);
|
||||
} else if (event.key === 'ArrowUp') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? items.length - 1 : Math.max(activeIndex - 1, 0);
|
||||
}
|
||||
|
||||
if (targetIndex !== -1) {
|
||||
setActiveItem(targetIndex);
|
||||
}
|
||||
});
|
||||
|
||||
// Make first item focusable initially if you want immediate keyboard nav
|
||||
if (items.length > 0) {
|
||||
// items[0].setAttribute('tabindex', '0');
|
||||
// Optionally activate the first item on load:
|
||||
// setActiveItem(0);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Rendered Messages - gsm8k_dpo_rollouts_4.jsonl</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; line-height: 1.6; margin: 20px; }
|
||||
details { border: 1px solid #ccc; border-radius: 4px; margin-bottom: 15px; }
|
||||
summary {
|
||||
font-weight: bold;
|
||||
padding: 10px;
|
||||
background-color: #f0f0f0;
|
||||
cursor: pointer;
|
||||
border-radius: 4px 4px 0 0;
|
||||
border-bottom: 1px solid #ccc;
|
||||
outline: none; /* Remove default focus outline if needed */
|
||||
}
|
||||
details[open] summary { border-bottom: 1px solid #ccc; }
|
||||
.group-content { padding: 15px; }
|
||||
.item {
|
||||
border: 1px solid #eee;
|
||||
border-radius: 3px;
|
||||
margin-bottom: 10px;
|
||||
padding: 10px;
|
||||
transition: background-color 0.3s ease, box-shadow 0.2s ease; /* Smooth transitions */
|
||||
scroll-margin-top: 10px; /* Space when scrolling into view */
|
||||
}
|
||||
.item h4 { margin-top: 0; margin-bottom: 5px; font-size: 1.1em; }
|
||||
.content-block { background-color: #fff; padding: 8px; border-radius: 3px; margin-bottom: 5px; overflow-x: auto; }
|
||||
/* Use :focus-within for better accessibility on container focus */
|
||||
.item:focus, .item.active {
|
||||
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.5); /* Highlight active/focused item */
|
||||
outline: none; /* Remove default outline */
|
||||
}
|
||||
|
||||
/* Score Backgrounds (Faint) */
|
||||
.reward-positive { background-color: rgba(144, 238, 144, 0.3); } /* Faint light green */
|
||||
.reward-zero { background-color: rgba(255, 215, 0, 0.3); } /* Faint gold/orange */
|
||||
.reward-negative { background-color: rgba(255, 182, 193, 0.4); } /* Faint light pink/red */
|
||||
|
||||
/* Markdown specific styles */
|
||||
.content-block pre {
|
||||
background-color: #f5f5f5;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 3px;
|
||||
padding: 10px;
|
||||
overflow-x: auto; /* Allow horizontal scrolling for long code lines */
|
||||
white-space: pre-wrap; /* Wrap long lines within pre */
|
||||
word-wrap: break-word; /* Break long words if necessary */
|
||||
}
|
||||
.content-block code {
|
||||
background-color: #f0f0f0; /* Slightly different for inline code */
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: 3px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
.content-block pre code {
|
||||
background-color: transparent; /* Don't double-background code in pre blocks */
|
||||
padding: 0;
|
||||
border-radius: 0;
|
||||
font-size: inherit; /* Inherit pre font size */
|
||||
}
|
||||
.content-block blockquote {
|
||||
border-left: 4px solid #ccc;
|
||||
padding-left: 10px;
|
||||
margin-left: 0;
|
||||
color: #555;
|
||||
}
|
||||
.content-block table {
|
||||
border-collapse: collapse;
|
||||
width: auto; /* Don't force full width */
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.content-block th, .content-block td {
|
||||
border: 1px solid #ddd;
|
||||
padding: 8px;
|
||||
text-align: left;
|
||||
}
|
||||
.content-block th {
|
||||
background-color: #f2f2f2;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Rendered Messages - gsm8k_dpo_rollouts_4.jsonl</h1>
|
||||
<div id="groups-container">
|
||||
<p>No data to display. Input file might be empty or contain invalid data.</p>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
const items = document.querySelectorAll('.item');
|
||||
let activeIndex = -1; // No item active initially
|
||||
|
||||
// Function to set active item
|
||||
function setActiveItem(index) {
|
||||
if (activeIndex >= 0 && activeIndex < items.length) {
|
||||
items[activeIndex].classList.remove('active');
|
||||
items[activeIndex].removeAttribute('tabindex'); // Remove from tab order when not active
|
||||
}
|
||||
if (index >= 0 && index < items.length) {
|
||||
items[index].classList.add('active');
|
||||
items[index].setAttribute('tabindex', '0'); // Make active item focusable
|
||||
items[index].focus(); // Focus the element
|
||||
// Ensure the parent <details> is open
|
||||
const detailsParent = items[index].closest('details');
|
||||
if (detailsParent && !detailsParent.open) {
|
||||
detailsParent.open = true;
|
||||
}
|
||||
// Scroll into view with options if needed (focus should handle this mostly)
|
||||
// items[index].scrollIntoView({ behavior: 'smooth', block: 'nearest' });
|
||||
activeIndex = index;
|
||||
} else {
|
||||
activeIndex = -1; // Deactivate if index is out of bounds
|
||||
}
|
||||
}
|
||||
|
||||
// Add click listener to activate items
|
||||
items.forEach((item, index) => {
|
||||
item.addEventListener('click', () => {
|
||||
setActiveItem(index);
|
||||
});
|
||||
// Make items focusable initially only if we want tab navigation
|
||||
// item.setAttribute('tabindex', '0');
|
||||
});
|
||||
|
||||
// Add keydown listener for arrow navigation
|
||||
document.addEventListener('keydown', (event) => {
|
||||
let targetIndex = -1;
|
||||
if (event.key === 'ArrowDown') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? 0 : Math.min(activeIndex + 1, items.length - 1);
|
||||
} else if (event.key === 'ArrowUp') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? items.length - 1 : Math.max(activeIndex - 1, 0);
|
||||
}
|
||||
|
||||
if (targetIndex !== -1) {
|
||||
setActiveItem(targetIndex);
|
||||
}
|
||||
});
|
||||
|
||||
// Make first item focusable initially if you want immediate keyboard nav
|
||||
if (items.length > 0) {
|
||||
// items[0].setAttribute('tabindex', '0');
|
||||
// Optionally activate the first item on load:
|
||||
// setActiveItem(0);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,337 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import ast # For safely evaluating the LLM's string output
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
try:
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer_dpo as imported_tokenize_for_trainer_dpo
|
||||
except ImportError:
|
||||
imported_tokenize_for_trainer_dpo = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# --- System Prompts for Generating Chosen/Rejected Responses ---
|
||||
SYSTEM_PROMPT_CHOSEN = """
|
||||
You are a helpful and engaging AI assistant. Your goal is to provide clear, empathetic, and insightful responses that encourage further conversation. Be positive and proactive.
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT_REJECTED = """"
|
||||
You are an AI assistant. Provide very brief, blunt, and unhelpful responses. Do not elaborate or ask follow-up questions.
|
||||
"""
|
||||
# --- End System Prompts ---
|
||||
|
||||
# --- Master Prompt for Generating Initial Prompts ---
|
||||
PROMPT_GENERATION_MASTER_PROMPT = """
|
||||
You are a creative assistant. Your task is to generate a list of 10 unique and random conversational prompts.
|
||||
Each prompt should be suitable for starting a general conversation with an AI assistant.
|
||||
The prompts should cover a variety of topics, tones (e.g., inquisitive, reflective, casual), and lengths.
|
||||
Format your entire output as a single Python list of dictionaries, where each dictionary has a single key "prompt" and the value is the prompt string.
|
||||
|
||||
Example format:
|
||||
[
|
||||
{"prompt": "What's the most interesting dream you've ever had, if AIs could dream?"},
|
||||
{"prompt": "If you could learn any new skill instantly, what would it be and why?"}
|
||||
]
|
||||
|
||||
Provide only the Python list of dictionaries, with no other surrounding text, explanations, or markdown formatting.
|
||||
"""
|
||||
# --- End Master Prompt ---
|
||||
|
||||
class GSM8KConversationalStyleDPOEnvConfig(BaseEnvConfig):
|
||||
"""Config for GSM8KConversationalStyleDPOEnv."""
|
||||
dataset_name: str = Field(
|
||||
"synthetic_conversational_style_prompts_via_gsm8k_env",
|
||||
description="Name of the dataset to use (source of prompts)."
|
||||
)
|
||||
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset of prompts")
|
||||
data_path_to_save_groups: Optional[str] = Field(None, description="Path to save .jsonl and .html processed rollouts.")
|
||||
# Generation parameters for chosen responses
|
||||
chosen_temperature: float = Field(0.7, description="Temperature for generating chosen responses.")
|
||||
chosen_max_tokens: int = Field(150, description="Max tokens for chosen responses.")
|
||||
# Generation parameters for rejected responses
|
||||
rejected_temperature: float = Field(0.4, description="Temperature for generating rejected responses.")
|
||||
rejected_max_tokens: int = Field(50, description="Max tokens for rejected responses.")
|
||||
prompt_generation_temperature: float = Field(0.8, description="Temperature for LLM generating the initial list of prompts.")
|
||||
prompt_generation_max_tokens: int = Field(1000, description="Max tokens for LLM generating the initial list of prompts.")
|
||||
|
||||
class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
||||
name = "gsm8k_dynamic_conversational_dpo"
|
||||
name_config_cls = GSM8KConversationalStyleDPOEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GSM8KConversationalStyleDPOEnvConfig,
|
||||
server_configs: Optional[List[APIServerConfig]] = None,
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
# Ensure server_configs is not None if we intend to use self.server
|
||||
# The BaseEnv will select the first server_config if multiple are provided and split is not specified.
|
||||
resolved_server_configs = server_configs
|
||||
if not resolved_server_configs:
|
||||
logger.warning(f"No server_configs provided for {self.name}, chat_completion calls will fail if not overridden in config_init or CLI.")
|
||||
# You might want to provide a default dummy one if testing without a server is intended for some paths
|
||||
# For this version, we expect it to be configured properly for LLM calls.
|
||||
|
||||
super().__init__(config, resolved_server_configs, slurm, testing)
|
||||
self.config: GSM8KConversationalStyleDPOEnvConfig = config
|
||||
self.prompt_dataset: List[Dict[str, str]] = [] # Stores only prompts now
|
||||
self.iter: int = 0
|
||||
|
||||
if imported_tokenize_for_trainer_dpo is not None:
|
||||
self._tokenize_dpo_fn = imported_tokenize_for_trainer_dpo
|
||||
logger.info(f"Using imported tokenize_for_trainer_dpo for {self.name}")
|
||||
else:
|
||||
self._tokenize_dpo_fn = self._placeholder_tokenize_for_dpo
|
||||
logger.info(f"Using placeholder tokenize_for_trainer_dpo for {self.name}")
|
||||
|
||||
def _placeholder_tokenize_for_dpo(self, tokenizer, data, max_length, **kwargs):
|
||||
prompt = data["prompt"]
|
||||
chosen = data["chosen"]
|
||||
rejected = data["rejected"]
|
||||
chosen_full_text = prompt + chosen
|
||||
rejected_full_text = prompt + rejected
|
||||
chosen_tokenized = tokenizer(chosen_full_text, truncation=True, padding="max_length", max_length=max_length)
|
||||
rejected_tokenized = tokenizer(rejected_full_text, truncation=True, padding="max_length", max_length=max_length)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"chosen": chosen,
|
||||
"rejected": rejected,
|
||||
"chosen_input_ids": chosen_tokenized["input_ids"],
|
||||
"chosen_attention_mask": chosen_tokenized["attention_mask"],
|
||||
"rejected_input_ids": rejected_tokenized["input_ids"],
|
||||
"rejected_attention_mask": rejected_tokenized["attention_mask"],
|
||||
}
|
||||
|
||||
async def setup(self):
|
||||
"""Load and prepare the dataset of prompts, potentially by generating them via LLM."""
|
||||
|
||||
generated_prompts = []
|
||||
if self.server:
|
||||
logger.info(f"Attempting to generate initial prompts for {self.name} using LLM...")
|
||||
try:
|
||||
prompt_list_completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant that strictly follows formatting instructions."},
|
||||
{"role": "user", "content": PROMPT_GENERATION_MASTER_PROMPT}
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.prompt_generation_max_tokens,
|
||||
temperature=self.config.prompt_generation_temperature,
|
||||
)
|
||||
response_text = prompt_list_completion.choices[0].message.content.strip() if prompt_list_completion.choices else ""
|
||||
logger.debug(f"LLM response for prompt generation: {response_text}")
|
||||
|
||||
# Attempt to parse the string as a Python list of dictionaries
|
||||
try:
|
||||
parsed_list = ast.literal_eval(response_text)
|
||||
if isinstance(parsed_list, list) and all(isinstance(item, dict) and "prompt" in item for item in parsed_list):
|
||||
generated_prompts = parsed_list
|
||||
logger.info(f"Successfully generated and parsed {len(generated_prompts)} prompts from LLM.")
|
||||
else:
|
||||
logger.warning("LLM response for prompt generation was not a valid list of prompt dictionaries. Using fallback.")
|
||||
except (SyntaxError, ValueError) as e:
|
||||
logger.warning(f"Error parsing LLM response for prompt generation: {e}. Using fallback.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling LLM for prompt generation: {e}. Using fallback.")
|
||||
else:
|
||||
logger.warning(f"LLM server not available for {self.name}. Using fallback prompts.")
|
||||
|
||||
if not generated_prompts or not isinstance(generated_prompts, list) or len(generated_prompts) < 10:
|
||||
logger.info(f"Using fallback static prompt list for {self.name}.")
|
||||
generated_prompts = [
|
||||
{"prompt": "What are your thoughts on the future of renewable energy?"},
|
||||
{"prompt": "If you could travel anywhere in the world, where would you go and why?"},
|
||||
{"prompt": "What's a book or movie that has deeply impacted you?"},
|
||||
{"prompt": "Can you describe a complex scientific concept in simple terms?"},
|
||||
{"prompt": "What's a common misconception people have about AI?"},
|
||||
{"prompt": "How do you think technology will change our daily lives in the next 20 years?"},
|
||||
{"prompt": "What's a piece of advice you would give to someone learning a new skill?"},
|
||||
{"prompt": "If you could have a conversation with any historical figure, who would it be?"},
|
||||
{"prompt": "What does 'creativity' mean to you as an AI?"},
|
||||
{"prompt": "Can you tell me a joke?"}
|
||||
]
|
||||
if len(generated_prompts) > 10: # Ensure we only use 10 if more are in fallback
|
||||
generated_prompts = generated_prompts[:10]
|
||||
|
||||
self.prompt_dataset = generated_prompts
|
||||
if self.config.shuffle_dataset:
|
||||
random.shuffle(self.prompt_dataset)
|
||||
self.iter = 0
|
||||
logger.info(f"Initialized prompt dataset with {len(self.prompt_dataset)} examples for {self.name}.")
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""
|
||||
Returns the next prompt from the dataset.
|
||||
Chosen and rejected responses will be generated dynamically.
|
||||
"""
|
||||
if not self.prompt_dataset or self.iter >= len(self.prompt_dataset):
|
||||
await self.setup() # This will now potentially call the LLM to generate prompts
|
||||
|
||||
if not self.prompt_dataset: # Should be populated by setup, even with fallback
|
||||
logger.error(f"Prompt dataset is STILL empty after setup in {self.name}. This is unexpected.")
|
||||
# Provide an absolute failsafe prompt
|
||||
return ("Failsafe prompt: What is 1+1?", "", "")
|
||||
|
||||
entry = self.prompt_dataset[self.iter % len(self.prompt_dataset)]
|
||||
self.iter += 1
|
||||
return (entry["prompt"], "", "")
|
||||
|
||||
async def collect_trajectories(
|
||||
self, items: List[Item] # Changed to accept a list of items
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
"""
|
||||
Receives a prompt, generates chosen and rejected responses using an LLM,
|
||||
then tokenizes for DPO.
|
||||
Assumes group_size is 1 for this DPO setup.
|
||||
"""
|
||||
if not items:
|
||||
logger.warning("collect_trajectories received an empty list of items.")
|
||||
return None, []
|
||||
|
||||
# item = items[0]
|
||||
prompt_str, _, _ = items
|
||||
|
||||
if not self.server:
|
||||
logger.error(f"LLM server not configured or available for {self.name}. Cannot generate responses.")
|
||||
return None, []
|
||||
|
||||
if not self.tokenizer:
|
||||
logger.error(f"Tokenizer not available in {self.name}. Attempting to initialize.")
|
||||
if self.testing and hasattr(self.config, 'tokenizer_name') and self.config.tokenizer_name:
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
|
||||
if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
logger.info(f"Fallback tokenizer '{self.config.tokenizer_name}' initialized.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize fallback tokenizer: {e}")
|
||||
return None, []
|
||||
else:
|
||||
return None, []
|
||||
|
||||
try:
|
||||
# Generate Chosen Response
|
||||
logger.debug(f"Generating CHOSEN for prompt: {prompt_str[:100]}...")
|
||||
chosen_completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT_CHOSEN},
|
||||
{"role": "user", "content": prompt_str}
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.chosen_max_tokens,
|
||||
temperature=self.config.chosen_temperature,
|
||||
)
|
||||
chosen_response_str = chosen_completion.choices[0].message.content.strip() if chosen_completion.choices else ""
|
||||
logger.debug(f"Generated CHOSEN: {chosen_response_str[:100]}...")
|
||||
|
||||
# Generate Rejected Response
|
||||
logger.debug(f"Generating REJECTED for prompt: {prompt_str[:100]}...")
|
||||
rejected_completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT_REJECTED},
|
||||
{"role": "user", "content": prompt_str}
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.rejected_max_tokens,
|
||||
temperature=self.config.rejected_temperature,
|
||||
)
|
||||
rejected_response_str = rejected_completion.choices[0].message.content.strip() if rejected_completion.choices else ""
|
||||
logger.debug(f"Generated REJECTED: {rejected_response_str[:100]}...")
|
||||
|
||||
if not chosen_response_str or not rejected_response_str:
|
||||
logger.warning(f"Failed to generate valid chosen or rejected response for prompt: {prompt_str}")
|
||||
return None, []
|
||||
|
||||
dpo_pair_data = {
|
||||
"prompt": prompt_str,
|
||||
"chosen": chosen_response_str,
|
||||
"rejected": rejected_response_str,
|
||||
}
|
||||
|
||||
tokenized_output = self._tokenize_dpo_fn(
|
||||
self.tokenizer,
|
||||
dpo_pair_data,
|
||||
max_length=self.config.max_token_length,
|
||||
)
|
||||
|
||||
scores_group = ScoredDataGroup()
|
||||
scores_group["prompt"] = [tokenized_output["prompt"]]
|
||||
scores_group["chosen"] = [tokenized_output["chosen"]]
|
||||
scores_group["rejected"] = [tokenized_output["rejected"]]
|
||||
scores_group["chosen_tokens"] = [tokenized_output["chosen_input_ids"]]
|
||||
scores_group["chosen_masks"] = [tokenized_output["chosen_attention_mask"]]
|
||||
scores_group["rejected_tokens"] = [tokenized_output["rejected_input_ids"]]
|
||||
scores_group["rejected_masks"] = [tokenized_output["rejected_attention_mask"]]
|
||||
scores_group["tokens"] = [tokenized_output["chosen_input_ids"]]
|
||||
scores_group["masks"] = [tokenized_output["chosen_attention_mask"]]
|
||||
scores_group["scores"] = [1.0]
|
||||
scores_group["images"] = [None]
|
||||
scores_group["group_overrides"] = {"group_size": 1}
|
||||
|
||||
return scores_group, []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in collect_trajectories for {self.name} during DPO processing: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data: Any) -> Optional[ScoredDataGroup]:
|
||||
if rollout_group_data and isinstance(rollout_group_data, ScoredDataGroup):
|
||||
return rollout_group_data
|
||||
logger.info(f"Data for {self.name} is not in ScoredDataGroup format or no data to score.")
|
||||
return None
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
logger.info(f"Evaluation step called for {self.name}. No custom DPO evaluation implemented.")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[GSM8KConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
|
||||
env_config = GSM8KConversationalStyleDPOEnvConfig(
|
||||
wandb_name="gsm8k_dynamic_conversational_dpo",
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=1,
|
||||
use_wandb=True,
|
||||
max_num_workers=1,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=100,
|
||||
batch_size=2,
|
||||
steps_per_eval=50,
|
||||
max_token_length=512,
|
||||
dataset_name="synthetic_conversational_style_prompts_via_gsm8k_env",
|
||||
shuffle_dataset=True,
|
||||
data_path_to_save_groups=None,
|
||||
chosen_temperature=0.7,
|
||||
chosen_max_tokens=150,
|
||||
rejected_temperature=0.4,
|
||||
rejected_max_tokens=50,
|
||||
prompt_generation_temperature=0.8,
|
||||
prompt_generation_max_tokens=1000,
|
||||
)
|
||||
# IMPORTANT: Configure your LLM inference server details here or via CLI/config file.
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="https://inference-api.nousresearch.com/v1",
|
||||
api_key="sk-3DvYKMv_-BfAoDSTfdSvEQ",
|
||||
num_requests_for_eval=256, # Copied from gsm8k_server.py
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8KConversationalStyleDPOEnv.cli()
|
||||
10
environments/hack0/conversational_style_dpo/requirements.txt
Normal file
10
environments/hack0/conversational_style_dpo/requirements.txt
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
pip install trl
|
||||
pip install datasets
|
||||
pip install accelerate
|
||||
pip install pydantic
|
||||
pip install bitsandbytes
|
||||
pip install datasets
|
||||
pip install transformers
|
||||
pip install pydantic
|
||||
pip install torch
|
||||
pip install -e .[all] # for everything
|
||||
|
|
@ -0,0 +1,185 @@
|
|||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
TrainingArguments,
|
||||
)
|
||||
from trl import DPOTrainer
|
||||
|
||||
# Import your custom environment
|
||||
# The import below assumes this script is in the same directory as conversational_style_dpo_env.py
|
||||
from .conversational_style_dpo_env import (
|
||||
ConversationalStyleDPOEnv,
|
||||
ConversationalStyleDPOEnvConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
Arguments for the DPO training script.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
default="distilgpt2",
|
||||
metadata={"help": "The model name or path to load from."},
|
||||
)
|
||||
tokenizer_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The tokenizer name or path. Defaults to model_name_or_path."},
|
||||
)
|
||||
beta: float = field(
|
||||
default=0.1, metadata={"help": "The beta factor in DPO loss."}
|
||||
)
|
||||
max_prompt_length: int = field(
|
||||
default=256, metadata={"help": "Max length for prompts."}
|
||||
)
|
||||
max_length: int = field(
|
||||
default=512,
|
||||
metadata={"help": "Max length for chosen/rejected responses including prompt."},
|
||||
)
|
||||
# Add any other TRL DPOTrainer arguments or TrainingArguments here if needed
|
||||
# For example, learning_rate, per_device_train_batch_size, num_train_epochs etc.
|
||||
# will be part of TrainingArguments.
|
||||
|
||||
|
||||
async def get_dataset_from_env(
|
||||
env_config: ConversationalStyleDPOEnvConfig,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Initializes the environment and extracts the synthetic dataset
|
||||
in the format required by DPOTrainer (list of dicts with prompt, chosen, rejected).
|
||||
"""
|
||||
# We don't need server_configs if the env doesn't use them for static dataset loading
|
||||
env = ConversationalStyleDPOEnv(config=env_config, server_configs=[], testing=True)
|
||||
await env.setup() # This loads env.synthetic_data
|
||||
|
||||
# env.synthetic_data is already a List[Dict[str, str]] with "prompt", "chosen", "rejected"
|
||||
# Convert it to Hugging Face Dataset
|
||||
# Check if data is loaded
|
||||
if not env.dataset:
|
||||
raise ValueError(
|
||||
"Dataset is empty after environment setup. Check ConversationalStyleDPOEnv."
|
||||
)
|
||||
|
||||
# The DPOTrainer expects columns named "prompt", "chosen", "rejected"
|
||||
# The synthetic_data in your environment is already in this format.
|
||||
# Example: [{"prompt": "...", "chosen": "...", "rejected": "..."}]
|
||||
hf_dataset = Dataset.from_list(list(env.dataset)) # Ensure it's a fresh list copy
|
||||
|
||||
# Log a sample to verify
|
||||
if len(hf_dataset) > 0:
|
||||
logger.info(f"Sample from dataset: {hf_dataset[0]}")
|
||||
else:
|
||||
logger.warning("Dataset created from environment is empty!")
|
||||
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ScriptArguments, TrainingArguments))
|
||||
script_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if script_args.tokenizer_name_or_path is None:
|
||||
script_args.tokenizer_name_or_path = script_args.model_name_or_path
|
||||
|
||||
# --- 1. Initialize Environment and Get Dataset ---
|
||||
logger.info("Initializing environment to get dataset...")
|
||||
# Use the default config from your environment, but ensure tokenizer matches
|
||||
env_dpo_config, _ = ConversationalStyleDPOEnv.config_init()
|
||||
env_dpo_config.tokenizer_name = script_args.tokenizer_name_or_path # Align tokenizer
|
||||
# You might want to adjust other env_dpo_config parameters if needed
|
||||
|
||||
# Run the async function to get the dataset
|
||||
dataset = asyncio.run(get_dataset_from_env(env_dpo_config))
|
||||
logger.info(f"Loaded dataset with {len(dataset)} examples.")
|
||||
|
||||
if len(dataset) == 0:
|
||||
logger.error("No data loaded. Exiting training.")
|
||||
return
|
||||
|
||||
# --- 2. Load Tokenizer and Models ---
|
||||
logger.info(f"Loading tokenizer: {script_args.tokenizer_name_or_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
logger.warning("Tokenizer does not have a pad token. Setting to eos_token.")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
# For some models, you might also need to set tokenizer.pad_token_id
|
||||
|
||||
logger.info(f"Loading policy model: {script_args.model_name_or_path}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
# low_cpu_mem_usage=True, # Can be helpful for large models
|
||||
# torch_dtype=torch.float16, # For mixed precision if GPU supports
|
||||
)
|
||||
|
||||
# Reference model for DPO. If not provided, DPOTrainer will create a copy of the model.
|
||||
# For simplicity, we'll let DPOTrainer handle creating the reference model by not passing one.
|
||||
# If you wanted to load a different SFT model as reference, you would do:
|
||||
# model_ref = AutoModelForCausalLM.from_pretrained(...)
|
||||
model_ref = None
|
||||
logger.info("Reference model will be a copy of the policy model (handled by DPOTrainer).")
|
||||
|
||||
|
||||
# --- 3. Set up Training Arguments ---
|
||||
# Default DPO training arguments. You might want to customize these.
|
||||
if training_args.output_dir == "output_dir": # Default value from TrainingArguments
|
||||
# Output directory will be relative to the script's current working directory when run.
|
||||
# If run from environments/hack0/conversational_style_dpo/, it will be ./dpo_conversational_trainer_results
|
||||
training_args.output_dir = "./dpo_conversational_trainer_results"
|
||||
|
||||
# training_args.per_device_train_batch_size = 2 # Adjust as needed
|
||||
# training_args.num_train_epochs = 1 # Keep low for a quick test
|
||||
# training_args.gradient_accumulation_steps = 1
|
||||
# training_args.learning_rate = 5e-5
|
||||
# training_args.logging_steps = 10
|
||||
# training_args.save_steps = 50
|
||||
# training_args.report_to = "none" # "wandb" or "tensorboard" if you want to log
|
||||
|
||||
logger.info(f"Training Arguments: {training_args}")
|
||||
|
||||
|
||||
# --- 4. Initialize DPOTrainer ---
|
||||
logger.info("Initializing DPOTrainer...")
|
||||
dpo_trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=model_ref, # If None, a copy of model is made
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
max_prompt_length=script_args.max_prompt_length,
|
||||
max_length=script_args.max_length, # Max length of prompt + response
|
||||
# peft_config=peft_config, # If using PEFT/LoRA
|
||||
)
|
||||
logger.info("DPOTrainer initialized.")
|
||||
|
||||
# --- 5. Start Training ---
|
||||
logger.info("Starting DPO training...")
|
||||
dpo_trainer.train()
|
||||
logger.info("DPO training completed.")
|
||||
|
||||
# --- 6. Save the Model (Optional) ---
|
||||
if training_args.should_save: # Checks if any save_strategy is enabled
|
||||
output_save_dir = training_args.output_dir
|
||||
logger.info(f"Saving model to {output_save_dir}")
|
||||
dpo_trainer.save_model(output_save_dir)
|
||||
logger.info("Model saved.")
|
||||
# Also save the tokenizer
|
||||
tokenizer.save_pretrained(output_save_dir)
|
||||
logger.info("Tokenizer saved.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue