Fix final code quality issues in Conversational Style DPO environment

This commit is contained in:
Shannon Sands 2025-05-26 10:48:11 +10:00
parent 441fd1036d
commit d789128f20
31 changed files with 408 additions and 1671 deletions

BIN
environments/community/.DS_Store vendored Normal file

Binary file not shown.

View file

@ -440,6 +440,74 @@ A game environment that teaches LLMs strategic thinking and decision-making thro
**Requirements**: poke-env, nodejs, pokemon-showdown simulator, OpenAI API
### 13. Conversational Style DPO Environment (`conversational_style_dpo/`)
**Author**: [Karthik-Ragunath](https://github.com/Karthik-Ragunath)
**Purpose**: Train LLMs for better conversational style through Direct Preference Optimization (DPO) with chosen/rejected response pairs
A specialized environment for training LLMs to adopt more engaging, empathetic, and helpful conversational styles using Direct Preference Optimization. The environment provides both synthetic and dynamically generated conversation pairs where "chosen" responses are engaging and thoughtful while "rejected" responses are blunt and unhelpful.
**Features**:
- **Two Environment Variants**: Static synthetic data and dynamic prompt generation
- **DPO Training Ready**: Pre-configured tokenization for chosen/rejected response pairs
- **Conversational Style Modeling**: Focus on empathy, engagement, and helpfulness
- **Synthetic Data Generation**: Uses LLMs to create diverse conversational prompts
- **Quality Response Pairs**: Carefully crafted chosen (good) vs rejected (poor) examples
**Environment Variants**:
1. **Static Synthetic Environment** (`conversational_style_dpo_env.py`):
- Pre-defined conversational prompts with human-crafted response pairs
- Focus on emotional support, explanations, excitement sharing, and help-seeking
- Immediate training readiness without LLM dependencies
2. **Dynamic GSM8K-Style Environment** (`gsmk8k_conversational_style_dpo_env.py`):
- LLM-generated conversational prompts for diverse training data
- Real-time chosen/rejected response generation with different system prompts
- Scalable dataset creation with fallback to static prompts
**Conversation Categories**:
- **Emotional Support**: Responding to feelings and personal sharing
- **Educational**: Explaining concepts clearly and engagingly
- **Enthusiasm Sharing**: Celebrating user excitement and interests
- **Help & Guidance**: Providing assistance with understanding problems
- **General Conversation**: Weather, casual topics, and everyday interactions
**Response Quality Characteristics**:
- **Chosen Responses**: Empathetic, engaging, ask follow-up questions, provide detailed explanations
- **Rejected Responses**: Blunt, minimal, dismissive, unhelpful
**Example Training Pair**:
```
Prompt: "I'm feeling a bit down today."
Chosen: "I'm sorry to hear that. Sometimes a little self-care can help. What's one small thing you could do for yourself right now?"
Rejected: "Okay."
```
**Technical Implementation**:
- **DPO Tokenization**: Ready-to-use tokenization for preference optimization
- **Configurable Parameters**: Temperature, max tokens, and dataset size controls
- **Modular Design**: Easy to extend with new conversation types
- **W&B Integration**: Comprehensive logging and experiment tracking
**Training Applications**:
- Customer service AI improvement
- Therapeutic chatbot development
- Educational AI tutoring systems
- General conversational AI enhancement
- Empathy and engagement training
**Configuration Options**:
- `chosen_temperature`: Temperature for generating engaging responses (default: 0.7)
- `rejected_temperature`: Temperature for generating blunt responses (default: 0.4)
- `shuffle_dataset`: Whether to randomize training order
- `data_path_to_save_groups`: Optional path for saving training artifacts
**Data Artifacts**:
- Archived training examples and HTML visualizations available (see `conversational_style_dpo_artifacts.zip`)
- Ready for upload to Hugging Face for community access
**Requirements**: Standard Atropos dependencies, transformers, torch
---
## Support

View file

@ -1,7 +1,7 @@
import asyncio
import logging
import random
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple
from pydantic import Field
@ -11,7 +11,7 @@ from atroposlib.envs.base import (
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import GameHistory, Item # Assuming GameHistory and Item are relevant
from atroposlib.type_definitions import Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer_dpo
logger = logging.getLogger(__name__)
@ -35,7 +35,9 @@ class ConversationalStyleDPOEnv(BaseEnv):
def __init__(
self,
config: ConversationalStyleDPOEnvConfig,
server_configs: Optional[List[APIServerConfig]] = None, # server_configs might not be needed if we don't query a model
server_configs: Optional[
List[APIServerConfig]
] = None, # server_configs might not be needed if we don't query a model
slurm=True,
testing=False,
):
@ -45,7 +47,9 @@ class ConversationalStyleDPOEnv(BaseEnv):
# If server_configs is None and BaseEnv requires it, initialize it as an empty list.
resolved_server_configs = server_configs if server_configs is not None else []
super().__init__(config, resolved_server_configs, slurm, testing)
self.config: ConversationalStyleDPOEnvConfig = config # Ensure type for self.config
self.config: ConversationalStyleDPOEnvConfig = (
config # Ensure type for self.config
)
self.dataset: List[Dict[str, str]] = []
self.iter: int = 0
@ -57,27 +61,35 @@ class ConversationalStyleDPOEnv(BaseEnv):
self.synthetic_data = [
{
"prompt": "I'm feeling a bit down today.",
"chosen": "I'm sorry to hear that. Sometimes a little self-care can help. What's one small thing you could do for yourself right now?",
"chosen": "I'm sorry to hear that. Sometimes a little self-care can help. "
"What's one small thing you could do for yourself right now?",
"rejected": "Okay.",
},
{
"prompt": "Can you explain how photosynthesis works?",
"chosen": "Certainly! Photosynthesis is a fascinating process where plants use sunlight, water, and carbon dioxide to create their own food (glucose) and release oxygen. Think of it like a plant's kitchen!",
"chosen": "Certainly! Photosynthesis is a fascinating process where plants use "
"sunlight, water, and carbon dioxide to create their own food (glucose) and "
"release oxygen. Think of it like a plant's kitchen!",
"rejected": "Plants make food from light.",
},
{
"prompt": "I'm excited about my new project!",
"chosen": "That's fantastic news! Tell me more about it - what are you most looking forward to?",
"chosen": "That's fantastic news! Tell me more about it - "
"what are you most looking forward to?",
"rejected": "Good for you.",
},
{
"prompt": "What's the weather like?",
"chosen": "I can't check the real-time weather, but I hope it's pleasant where you are! If you'd like, you can tell me your location, and I could try to give you a general idea based on typical patterns, or help you find a weather service.",
"chosen": "I can't check the real-time weather, but I hope it's pleasant where "
"you are! If you'd like, you can tell me your location, and I could try to "
"give you a general idea based on typical patterns, or help you find a "
"weather service.",
"rejected": "I don't know.",
},
{
"prompt": "I'm having trouble understanding this concept.",
"chosen": "I can understand that some concepts can be tricky! Could you tell me which part is confusing you? Maybe we can break it down together.",
"chosen": "I can understand that some concepts can be tricky! Could you tell me "
"which part is confusing you? Maybe we can break it down together.",
"rejected": "Read the manual.",
},
]
@ -98,36 +110,32 @@ class ConversationalStyleDPOEnv(BaseEnv):
The optional_extra_data will be the rejected response.
"""
if not self.dataset or self.iter >= len(self.dataset):
await self.setup() # Re-setup if dataset is exhausted or not loaded
if not self.dataset: # Still no dataset after setup
logger.error("Dataset is empty even after setup.")
# Return a fallback item or raise an error
# For now, let's create a dummy item to avoid crashing, but this should be handled
fallback_prompt = tuple([frozenset({"role": "user", "content": "Fallback prompt"}.items())])
return (fallback_prompt, "Fallback chosen", "Fallback rejected")
await self.setup() # Re-setup if dataset is exhausted or not loaded
if not self.dataset: # Still no dataset after setup
logger.error("Dataset is empty even after setup.")
# Return a fallback item or raise an error
# For now, let's create a dummy item to avoid crashing, but this should be handled
fallback_prompt = tuple(
[frozenset({"role": "user", "content": "Fallback prompt"}.items())]
)
return (fallback_prompt, "Fallback chosen", "Fallback rejected")
entry = self.dataset[self.iter % len(self.dataset)]
self.iter += 1
# Create a prompt tuple as expected by some parts of the system
# (e.g. if it were to be sent to a model, though we are not doing that here)
prompt_messages = [{"role": "user", "content": entry["prompt"]}]
# Convert to the frozenset structure if that's what downstream components expect
# For DPO, the direct strings are often more useful for tokenization with tokenize_for_trainer_dpo
# We will pass the raw strings to collect_trajectories and handle tokenization there.
# For Item, we can simplify or adjust based on how BaseEnv uses it.
# Let's pass the prompt string directly as the first element of the "prompt_tuple".
# The "gold_answer" will be the chosen response, and "extra_data" the rejected one.
# Adapting to Item: Tuple[prompt_tuple, gold_answer_str, any_extra_data]
# prompt_tuple: Tuple[FrozenSet[Tuple[str, str]], ...]
# For DPO, the core elements are prompt, chosen, rejected. We'll pass them directly.
return (entry["prompt"], entry["chosen"], entry["rejected"])
async def collect_trajectories(
self, item: Item
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
@ -165,7 +173,7 @@ class ConversationalStyleDPOEnv(BaseEnv):
if not self.tokenizer:
logger.error("Tokenizer not available. Cannot process DPO pair.")
return None, []
try:
# Note: The actual structure of what `tokenize_for_trainer_dpo` returns
# and how it's used in `ScoredDataGroup` is crucial.
@ -188,9 +196,9 @@ class ConversationalStyleDPOEnv(BaseEnv):
# It should handle the tokenization for DPO, creating chosen and rejected sequences.
tokenized_output = tokenize_for_trainer_dpo(
self.tokenizer,
dpo_pair_data, # Or (self.tokenizer, prompt_str, chosen_response_str, rejected_response_str)
# depending on its signature.
max_length=self.config.max_token_length, # Ensure this config is available
dpo_pair_data, # Or (self.tokenizer, prompt_str, chosen_response_str, rejected_response_str)
# depending on its signature.
max_length=self.config.max_token_length, # Ensure this config is available
# Add other necessary args for tokenize_for_trainer_dpo
)
@ -201,9 +209,9 @@ class ConversationalStyleDPOEnv(BaseEnv):
# - chosen_input_ids, chosen_attention_mask, chosen_labels
# - rejected_input_ids, rejected_attention_mask, rejected_labels
# `tokenize_for_trainer_dpo` should produce these.
# Let's assume tokenize_for_trainer_dpo returns a dict with at least:
# 'chosen_input_ids', 'chosen_attention_mask',
# 'chosen_input_ids', 'chosen_attention_mask',
# 'rejected_input_ids', 'rejected_attention_mask'
# 'prompt_input_ids' (optional, might be part of chosen/rejected)
@ -215,31 +223,35 @@ class ConversationalStyleDPOEnv(BaseEnv):
scores["rejected_masks"] = [tokenized_output["rejected_attention_mask"]]
# Optionally, if prompts are tokenized separately:
if "prompt_input_ids" in tokenized_output:
scores["prompt_tokens"] = [tokenized_output["prompt_input_ids"]]
scores["prompt_masks"] = [tokenized_output.get("prompt_attention_mask")] # Handle if mask isn't there
scores["prompt_tokens"] = [tokenized_output["prompt_input_ids"]]
scores["prompt_masks"] = [
tokenized_output.get("prompt_attention_mask")
] # Handle if mask isn't there
# DPO doesn't use a single "score" like reward models. The "reward" is implicit
# in the preference (chosen > rejected). So, the "scores" field in ScoredDataGroup
# might not be directly used or could be set to a placeholder if required by the trainer.
scores["scores"] = [1.0] # Placeholder, DPO loss handles preference directly
scores["scores"] = [
1.0
] # Placeholder, DPO loss handles preference directly
# Images are not used in this environment
scores["images"] = [None]
# Ensure group size logic if processing multiple items for a group
# For this example, we process one item at a time.
# If self.config.group_size > 1, you'd aggregate multiple tokenized_outputs
# before returning the ScoredDataGroup.
return scores, [] # No items to backlog
return scores, [] # No items to backlog
except Exception as e:
logger.error(f"Error in collect_trajectories during DPO processing: {e}")
import traceback
traceback.print_exc()
return None, []
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
"""
This method is typically for scoring model generations against a gold answer or reward model.
@ -262,49 +274,54 @@ class ConversationalStyleDPOEnv(BaseEnv):
# If `collect_trajectories` returned a list of items to be scored/tokenized here.
# This would mean moving the tokenization logic from `collect_trajectories` to here.
# For now, let's assume the former.
logger.warning("`score` method received a list; expecting ScoredDataGroup for pre-processed DPO.")
logger.warning(
"`score` method received a list; expecting ScoredDataGroup for pre-processed DPO."
)
# Fallback: if you need to process a list of (prompt, chosen, rejected) tuples here:
# all_chosen_tokens = []
# ... and so on, then call tokenize_for_trainer_dpo for each item.
# This depends on the design of BaseEnv's main loop.
return None
logger.info("No data to score or data is already in ScoredDataGroup format.")
return None
async def evaluate(self, *args, **kwargs):
"""
Evaluation for DPO might involve comparing the DPO-trained model's preferences
against a held-out set of preferred/rejected pairs or other metrics.
For this basic environment, we'll skip custom evaluation.
"""
logger.info("Evaluation step called. No custom DPO evaluation implemented in this basic environment.")
logger.info(
"Evaluation step called. No custom DPO evaluation implemented in this basic environment."
)
# You could load a separate test set of (prompt, chosen, rejected)
# and see if the model assigns higher logprobs to chosen than rejected.
return None
@classmethod
def config_init(cls) -> Tuple[ConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
def config_init(
cls,
) -> Tuple[ConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
"""
Provides a default configuration for this environment.
"""
env_config = ConversationalStyleDPOEnvConfig(
wandb_name="conversational_style_dpo", # For logging if wandb is used
wandb_name="conversational_style_dpo", # For logging if wandb is used
tokenizer_name="gpt2", # Choose an appropriate tokenizer
group_size=4, # Number of DPO pairs to process in a "group" or "batch"
use_wandb=False, # Enable or disable wandb
max_num_workers=1, # Number of parallel workers for data collection (if applicable)
rollout_server_url="http://localhost:8000", # Corrected URL
total_steps=100, # Total DPO training steps (or epochs over the dataset)
batch_size=2, # DPO training batch size (distinct from group_size for data collection)
group_size=4, # Number of DPO pairs to process in a "group" or "batch"
use_wandb=False, # Enable or disable wandb
max_num_workers=1, # Number of parallel workers for data collection (if applicable)
rollout_server_url="http://localhost:8000", # Corrected URL
total_steps=100, # Total DPO training steps (or epochs over the dataset)
batch_size=2, # DPO training batch size (distinct from group_size for data collection)
steps_per_eval=50,
max_token_length=512, # Max length for tokenized sequences
dataset_name="synthetic_conversational_style",
shuffle_dataset=True,
)
server_configs = [] # Simplified as discussed
server_configs = [] # Simplified as discussed
return env_config, server_configs
@ -314,27 +331,31 @@ if __name__ == "__main__":
# The `BaseEnv.cli()` method usually sets up and runs the environment's main loop.
# For DPO, the "main loop" might involve iterating through the dataset,
# tokenizing pairs, and perhaps logging them or yielding them to a trainer.
# To make this runnable and test the data processing:
async def main_test():
config, server_configs_list = ConversationalStyleDPOEnv.config_init()
# Manually override tokenizer for local testing if needed and not using a server
# that provides it, or if the default in BaseEnvConfig isn't what you want for DPO.
# config.tokenizer_name = "EleutherAI/pythia-70m" # Example
config.tokenizer_name = "distilgpt2" # A small, fast tokenizer for testing
config.group_size = 1 # Process one DPO item at a time for simplicity in this test
config.tokenizer_name = "distilgpt2" # A small, fast tokenizer for testing
config.group_size = (
1 # Process one DPO item at a time for simplicity in this test
)
config.use_wandb = False
env = ConversationalStyleDPOEnv(config=config, server_configs=server_configs_list, slurm=False, testing=True)
env = ConversationalStyleDPOEnv(
config=config, server_configs=server_configs_list, slurm=False, testing=True
)
# Initialize tokenizer (BaseEnv usually does this)
from transformers import AutoTokenizer
env.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
if env.tokenizer.pad_token is None:
env.tokenizer.pad_token = env.tokenizer.eos_token
print("Setting up environment...")
await env.setup()
print(f"Dataset size: {len(env.dataset)}")
@ -344,7 +365,7 @@ if __name__ == "__main__":
return
print("Simulating DPO data processing for a few items...")
for i in range(min(len(env.dataset), 3)): # Test with a few items
for i in range(min(len(env.dataset), 3)): # Test with a few items
print(f"--- Item {i+1} ---")
item = await env.get_next_item()
if item:
@ -359,45 +380,70 @@ if __name__ == "__main__":
# The `tokenize_for_trainer_dpo` function needs to exist and be importable.
# Let's create a placeholder for it here for the test to run.
global tokenize_for_trainer_dpo
def placeholder_tokenize_for_dpo(tokenizer, data, max_length, **kwargs):
# This is a simplified placeholder. A real one would handle complex tokenization,
# padding, truncation, and creating labels for DPO loss.
prompt = data["prompt"]
chosen = data["chosen"]
rejected = " ".join(data["rejected"].split()[:max_length//3]) # Basic truncation
rejected = " ".join(
data["rejected"].split()[: max_length // 3]
) # Basic truncation
chosen_text = f"{prompt} {chosen}"
rejected_text = f"{prompt} {rejected}"
chosen_tok = tokenizer(chosen_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
rejected_tok = tokenizer(rejected_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
chosen_tok = tokenizer(
chosen_text,
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
rejected_tok = tokenizer(
rejected_text,
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
# A real DPO tokenizer would also prepare labels, where only response tokens are unmasked.
# For simplicity, we're just returning input_ids and attention_mask.
return {
"chosen_input_ids": chosen_tok["input_ids"].squeeze().tolist(),
"chosen_attention_mask": chosen_tok["attention_mask"].squeeze().tolist(),
"rejected_input_ids": rejected_tok["input_ids"].squeeze().tolist(),
"rejected_attention_mask": rejected_tok["attention_mask"].squeeze().tolist(),
"chosen_attention_mask": chosen_tok["attention_mask"]
.squeeze()
.tolist(),
"rejected_input_ids": rejected_tok["input_ids"]
.squeeze()
.tolist(),
"rejected_attention_mask": rejected_tok["attention_mask"]
.squeeze()
.tolist(),
# "prompt_input_ids": ..., # Optionally
}
tokenize_for_trainer_dpo = placeholder_tokenize_for_dpo
tokenize_for_trainer_dpo = placeholder_tokenize_for_dpo
scored_data_group, _ = await env.collect_trajectories(item)
if scored_data_group:
print("Tokenized DPO Pair (first item in group):")
print(f" Chosen Tokens (IDs): {scored_data_group['chosen_tokens'][0][:20]}...") # Print first 20
print(
f" Chosen Tokens (IDs): {scored_data_group['chosen_tokens'][0][:20]}..."
) # Print first 20
# print(f" Chosen Masks: {scored_data_group['chosen_masks'][0][:20]}...")
print(f" Rejected Tokens (IDs): {scored_data_group['rejected_tokens'][0][:20]}...")
print(
f" Rejected Tokens (IDs): {scored_data_group['rejected_tokens'][0][:20]}..."
)
# print(f" Rejected Masks: {scored_data_group['rejected_masks'][0][:20]}...")
# print(f" Scores: {scored_data_group['scores']}")
else:
print(" Failed to process DPO pair.")
else:
print("Failed to get item.")
print("--- End of Test ---")
# To run the CLI, you would typically not need the main_test async function,
@ -406,35 +452,11 @@ if __name__ == "__main__":
# If you want to run with the actual CLI:
# ConversationalStyleDPOEnv.cli()
# But ensure `tokenize_for_trainer_dpo` is correctly implemented and importable in that context.
# Running the test:
if __name__ == "__main__":
# Define a placeholder tokenize_for_trainer_dpo if it's not available globally
# This is often part of a utils library.
if "tokenize_for_trainer_dpo" not in globals():
def placeholder_tokenize_for_dpo(tokenizer, data, max_length, **kwargs):
prompt = data["prompt"]
chosen = data["chosen"]
# basic truncation for rejected to fit if too long with prompt
max_rej_tokens = max_length - len(tokenizer.encode(prompt)) - 10 # buffer
rejected_tokens = tokenizer.encode(data["rejected"])
if len(rejected_tokens) > max_rej_tokens:
rejected_tokens = rejected_tokens[:max_rej_tokens]
rejected = tokenizer.decode(rejected_tokens)
# Use the imported tokenize_for_trainer_dpo function
chosen_text = f"{prompt} {chosen}" # Simplistic combination
rejected_text = f"{prompt} {rejected}" # Simplistic combination
chosen_tok = tokenizer(chosen_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
rejected_tok = tokenizer(rejected_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
return {
"chosen_input_ids": chosen_tok["input_ids"].squeeze().tolist(),
"chosen_attention_mask": chosen_tok["attention_mask"].squeeze().tolist(),
"rejected_input_ids": rejected_tok["input_ids"].squeeze().tolist(),
"rejected_attention_mask": rejected_tok["attention_mask"].squeeze().tolist(),
}
tokenize_for_trainer_dpo = placeholder_tokenize_for_dpo
asyncio.run(main_test())
asyncio.run(main_test())

View file

@ -1,8 +1,7 @@
import asyncio
import ast # For safely evaluating the LLM's string output
import logging
import random
import ast # For safely evaluating the LLM's string output
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple
from pydantic import Field
@ -15,7 +14,9 @@ from atroposlib.envs.base import (
from atroposlib.type_definitions import Item
try:
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer_dpo as imported_tokenize_for_trainer_dpo
from atroposlib.utils.tokenize_for_trainer import (
tokenize_for_trainer_dpo as imported_tokenize_for_trainer_dpo,
)
except ImportError:
imported_tokenize_for_trainer_dpo = None
@ -24,11 +25,13 @@ logger.setLevel(logging.INFO)
# --- System Prompts for Generating Chosen/Rejected Responses ---
SYSTEM_PROMPT_CHOSEN = """
You are a helpful and engaging AI assistant. Your goal is to provide clear, empathetic, and insightful responses that encourage further conversation. Be positive and proactive.
You are a helpful and engaging AI assistant. Your goal is to provide clear, empathetic, and
insightful responses that encourage further conversation. Be positive and proactive.
"""
SYSTEM_PROMPT_REJECTED = """"
You are an AI assistant. Provide very brief, blunt, and unhelpful responses. Do not elaborate or ask follow-up questions.
You are an AI assistant. Provide very brief, blunt, and unhelpful responses.
Do not elaborate or ask follow-up questions.
"""
# --- End System Prompts ---
@ -37,7 +40,8 @@ PROMPT_GENERATION_MASTER_PROMPT = """
You are a creative assistant. Your task is to generate a list of 10 unique and random conversational prompts.
Each prompt should be suitable for starting a general conversation with an AI assistant.
The prompts should cover a variety of topics, tones (e.g., inquisitive, reflective, casual), and lengths.
Format your entire output as a single Python list of dictionaries, where each dictionary has a single key "prompt" and the value is the prompt string.
Format your entire output as a single Python list of dictionaries, where each dictionary has a
single key "prompt" and the value is the prompt string.
Example format:
[
@ -49,22 +53,39 @@ Provide only the Python list of dictionaries, with no other surrounding text, ex
"""
# --- End Master Prompt ---
class GSM8KConversationalStyleDPOEnvConfig(BaseEnvConfig):
"""Config for GSM8KConversationalStyleDPOEnv."""
dataset_name: str = Field(
"synthetic_conversational_style_prompts_via_gsm8k_env",
description="Name of the dataset to use (source of prompts)."
"synthetic_conversational_style_prompts_via_gsm8k_env",
description="Name of the dataset to use (source of prompts).",
)
shuffle_dataset: bool = Field(
True, description="Whether to shuffle the dataset of prompts"
)
data_path_to_save_groups: Optional[str] = Field(
None, description="Path to save .jsonl and .html processed rollouts."
)
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset of prompts")
data_path_to_save_groups: Optional[str] = Field(None, description="Path to save .jsonl and .html processed rollouts.")
# Generation parameters for chosen responses
chosen_temperature: float = Field(0.7, description="Temperature for generating chosen responses.")
chosen_temperature: float = Field(
0.7, description="Temperature for generating chosen responses."
)
chosen_max_tokens: int = Field(150, description="Max tokens for chosen responses.")
# Generation parameters for rejected responses
rejected_temperature: float = Field(0.4, description="Temperature for generating rejected responses.")
rejected_max_tokens: int = Field(50, description="Max tokens for rejected responses.")
prompt_generation_temperature: float = Field(0.8, description="Temperature for LLM generating the initial list of prompts.")
prompt_generation_max_tokens: int = Field(1000, description="Max tokens for LLM generating the initial list of prompts.")
rejected_temperature: float = Field(
0.4, description="Temperature for generating rejected responses."
)
rejected_max_tokens: int = Field(
50, description="Max tokens for rejected responses."
)
prompt_generation_temperature: float = Field(
0.8, description="Temperature for LLM generating the initial list of prompts."
)
prompt_generation_max_tokens: int = Field(
1000, description="Max tokens for LLM generating the initial list of prompts."
)
class GSM8KConversationalStyleDPOEnv(BaseEnv):
name = "gsm8k_dynamic_conversational_dpo"
@ -81,15 +102,18 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
# The BaseEnv will select the first server_config if multiple are provided and split is not specified.
resolved_server_configs = server_configs
if not resolved_server_configs:
logger.warning(f"No server_configs provided for {self.name}, chat_completion calls will fail if not overridden in config_init or CLI.")
logger.warning(
f"No server_configs provided for {self.name}, chat_completion calls "
f"will fail if not overridden in config_init or CLI."
)
# You might want to provide a default dummy one if testing without a server is intended for some paths
# For this version, we expect it to be configured properly for LLM calls.
super().__init__(config, resolved_server_configs, slurm, testing)
self.config: GSM8KConversationalStyleDPOEnvConfig = config
self.prompt_dataset: List[Dict[str, str]] = [] # Stores only prompts now
self.prompt_dataset: List[Dict[str, str]] = [] # Stores only prompts now
self.iter: int = 0
if imported_tokenize_for_trainer_dpo is not None:
self._tokenize_dpo_fn = imported_tokenize_for_trainer_dpo
logger.info(f"Using imported tokenize_for_trainer_dpo for {self.name}")
@ -101,10 +125,20 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
prompt = data["prompt"]
chosen = data["chosen"]
rejected = data["rejected"]
chosen_full_text = prompt + chosen
chosen_full_text = prompt + chosen
rejected_full_text = prompt + rejected
chosen_tokenized = tokenizer(chosen_full_text, truncation=True, padding="max_length", max_length=max_length)
rejected_tokenized = tokenizer(rejected_full_text, truncation=True, padding="max_length", max_length=max_length)
chosen_tokenized = tokenizer(
chosen_full_text,
truncation=True,
padding="max_length",
max_length=max_length,
)
rejected_tokenized = tokenizer(
rejected_full_text,
truncation=True,
padding="max_length",
max_length=max_length,
)
return {
"prompt": prompt,
"chosen": chosen,
@ -117,60 +151,101 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
async def setup(self):
"""Load and prepare the dataset of prompts, potentially by generating them via LLM."""
generated_prompts = []
if self.server:
logger.info(f"Attempting to generate initial prompts for {self.name} using LLM...")
logger.info(
f"Attempting to generate initial prompts for {self.name} using LLM..."
)
try:
prompt_list_completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": "You are a helpful assistant that strictly follows formatting instructions."},
{"role": "user", "content": PROMPT_GENERATION_MASTER_PROMPT}
{
"role": "system",
"content": "You are a helpful assistant that strictly follows formatting instructions.",
},
{"role": "user", "content": PROMPT_GENERATION_MASTER_PROMPT},
],
n=1,
max_tokens=self.config.prompt_generation_max_tokens,
temperature=self.config.prompt_generation_temperature,
)
response_text = prompt_list_completion.choices[0].message.content.strip() if prompt_list_completion.choices else ""
response_text = (
prompt_list_completion.choices[0].message.content.strip()
if prompt_list_completion.choices
else ""
)
logger.debug(f"LLM response for prompt generation: {response_text}")
# Attempt to parse the string as a Python list of dictionaries
try:
parsed_list = ast.literal_eval(response_text)
if isinstance(parsed_list, list) and all(isinstance(item, dict) and "prompt" in item for item in parsed_list):
if isinstance(parsed_list, list) and all(
isinstance(item, dict) and "prompt" in item
for item in parsed_list
):
generated_prompts = parsed_list
logger.info(f"Successfully generated and parsed {len(generated_prompts)} prompts from LLM.")
logger.info(
f"Successfully generated and parsed {len(generated_prompts)} prompts from LLM."
)
else:
logger.warning("LLM response for prompt generation was not a valid list of prompt dictionaries. Using fallback.")
logger.warning(
"LLM response for prompt generation was not a valid list of "
"prompt dictionaries. Using fallback."
)
except (SyntaxError, ValueError) as e:
logger.warning(f"Error parsing LLM response for prompt generation: {e}. Using fallback.")
logger.warning(
f"Error parsing LLM response for prompt generation: {e}. Using fallback."
)
except Exception as e:
logger.error(f"Error calling LLM for prompt generation: {e}. Using fallback.")
logger.error(
f"Error calling LLM for prompt generation: {e}. Using fallback."
)
else:
logger.warning(f"LLM server not available for {self.name}. Using fallback prompts.")
logger.warning(
f"LLM server not available for {self.name}. Using fallback prompts."
)
if not generated_prompts or not isinstance(generated_prompts, list) or len(generated_prompts) < 10:
if (
not generated_prompts
or not isinstance(generated_prompts, list)
or len(generated_prompts) < 10
):
logger.info(f"Using fallback static prompt list for {self.name}.")
generated_prompts = [
{"prompt": "What are your thoughts on the future of renewable energy?"},
{"prompt": "If you could travel anywhere in the world, where would you go and why?"},
{
"prompt": "If you could travel anywhere in the world, where would you go and why?"
},
{"prompt": "What's a book or movie that has deeply impacted you?"},
{"prompt": "Can you describe a complex scientific concept in simple terms?"},
{
"prompt": "Can you describe a complex scientific concept in simple terms?"
},
{"prompt": "What's a common misconception people have about AI?"},
{"prompt": "How do you think technology will change our daily lives in the next 20 years?"},
{"prompt": "What's a piece of advice you would give to someone learning a new skill?"},
{"prompt": "If you could have a conversation with any historical figure, who would it be?"},
{
"prompt": "How do you think technology will change our daily lives in the next 20 years?"
},
{
"prompt": "What's a piece of advice you would give to someone learning a new skill?"
},
{
"prompt": "If you could have a conversation with any historical figure, who would it be?"
},
{"prompt": "What does 'creativity' mean to you as an AI?"},
{"prompt": "Can you tell me a joke?"}
{"prompt": "Can you tell me a joke?"},
]
if len(generated_prompts) > 10: # Ensure we only use 10 if more are in fallback
if (
len(generated_prompts) > 10
): # Ensure we only use 10 if more are in fallback
generated_prompts = generated_prompts[:10]
self.prompt_dataset = generated_prompts
if self.config.shuffle_dataset:
random.shuffle(self.prompt_dataset)
self.iter = 0
logger.info(f"Initialized prompt dataset with {len(self.prompt_dataset)} examples for {self.name}.")
logger.info(
f"Initialized prompt dataset with {len(self.prompt_dataset)} examples for {self.name}."
)
async def get_next_item(self) -> Item:
"""
@ -178,16 +253,18 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
Chosen and rejected responses will be generated dynamically.
"""
if not self.prompt_dataset or self.iter >= len(self.prompt_dataset):
await self.setup() # This will now potentially call the LLM to generate prompts
await self.setup() # This will now potentially call the LLM to generate prompts
if not self.prompt_dataset: # Should be populated by setup, even with fallback
logger.error(f"Prompt dataset is STILL empty after setup in {self.name}. This is unexpected.")
# Provide an absolute failsafe prompt
return ("Failsafe prompt: What is 1+1?", "", "")
if not self.prompt_dataset: # Should be populated by setup, even with fallback
logger.error(
f"Prompt dataset is STILL empty after setup in {self.name}. This is unexpected."
)
# Provide an absolute failsafe prompt
return ("Failsafe prompt: What is 1+1?", "", "")
entry = self.prompt_dataset[self.iter % len(self.prompt_dataset)]
self.iter += 1
return (entry["prompt"], "", "")
return (entry["prompt"], "", "")
async def collect_trajectories(
self, items: List[Item] # Changed to accept a list of items
@ -200,41 +277,59 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
if not items:
logger.warning("collect_trajectories received an empty list of items.")
return None, []
# item = items[0]
prompt_str, _, _ = items
# item = items[0]
prompt_str, _, _ = items
if not self.server:
logger.error(f"LLM server not configured or available for {self.name}. Cannot generate responses.")
logger.error(
f"LLM server not configured or available for {self.name}. Cannot generate responses."
)
return None, []
if not self.tokenizer:
logger.error(f"Tokenizer not available in {self.name}. Attempting to initialize.")
if self.testing and hasattr(self.config, 'tokenizer_name') and self.config.tokenizer_name:
logger.error(
f"Tokenizer not available in {self.name}. Attempting to initialize."
)
if (
self.testing
and hasattr(self.config, "tokenizer_name")
and self.config.tokenizer_name
):
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info(f"Fallback tokenizer '{self.config.tokenizer_name}' initialized.")
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.tokenizer_name
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info(
f"Fallback tokenizer '{self.config.tokenizer_name}' initialized."
)
except Exception as e:
logger.error(f"Failed to initialize fallback tokenizer: {e}")
return None, []
else:
return None, []
try:
# Generate Chosen Response
logger.debug(f"Generating CHOSEN for prompt: {prompt_str[:100]}...")
chosen_completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": SYSTEM_PROMPT_CHOSEN},
{"role": "user", "content": prompt_str}
{"role": "user", "content": prompt_str},
],
n=1,
max_tokens=self.config.chosen_max_tokens,
temperature=self.config.chosen_temperature,
)
chosen_response_str = chosen_completion.choices[0].message.content.strip() if chosen_completion.choices else ""
chosen_response_str = (
chosen_completion.choices[0].message.content.strip()
if chosen_completion.choices
else ""
)
logger.debug(f"Generated CHOSEN: {chosen_response_str[:100]}...")
# Generate Rejected Response
@ -242,18 +337,24 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
rejected_completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": SYSTEM_PROMPT_REJECTED},
{"role": "user", "content": prompt_str}
{"role": "user", "content": prompt_str},
],
n=1,
max_tokens=self.config.rejected_max_tokens,
temperature=self.config.rejected_temperature,
)
rejected_response_str = rejected_completion.choices[0].message.content.strip() if rejected_completion.choices else ""
rejected_response_str = (
rejected_completion.choices[0].message.content.strip()
if rejected_completion.choices
else ""
)
logger.debug(f"Generated REJECTED: {rejected_response_str[:100]}...")
if not chosen_response_str or not rejected_response_str:
logger.warning(f"Failed to generate valid chosen or rejected response for prompt: {prompt_str}")
return None, []
logger.warning(
f"Failed to generate valid chosen or rejected response for prompt: {prompt_str}"
)
return None, []
dpo_pair_data = {
"prompt": prompt_str,
@ -263,7 +364,7 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
tokenized_output = self._tokenize_dpo_fn(
self.tokenizer,
dpo_pair_data,
dpo_pair_data,
max_length=self.config.max_token_length,
)
@ -274,44 +375,55 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
scores_group["chosen_tokens"] = [tokenized_output["chosen_input_ids"]]
scores_group["chosen_masks"] = [tokenized_output["chosen_attention_mask"]]
scores_group["rejected_tokens"] = [tokenized_output["rejected_input_ids"]]
scores_group["rejected_masks"] = [tokenized_output["rejected_attention_mask"]]
scores_group["rejected_masks"] = [
tokenized_output["rejected_attention_mask"]
]
scores_group["tokens"] = [tokenized_output["chosen_input_ids"]]
scores_group["masks"] = [tokenized_output["chosen_attention_mask"]]
scores_group["scores"] = [1.0]
scores_group["images"] = [None]
scores_group["group_overrides"] = {"group_size": 1}
return scores_group, []
return scores_group, []
except Exception as e:
logger.error(f"Error in collect_trajectories for {self.name} during DPO processing: {e}")
logger.error(
f"Error in collect_trajectories for {self.name} during DPO processing: {e}"
)
import traceback
traceback.print_exc()
return None, []
async def score(self, rollout_group_data: Any) -> Optional[ScoredDataGroup]:
if rollout_group_data and isinstance(rollout_group_data, ScoredDataGroup):
return rollout_group_data
logger.info(f"Data for {self.name} is not in ScoredDataGroup format or no data to score.")
logger.info(
f"Data for {self.name} is not in ScoredDataGroup format or no data to score."
)
return None
async def evaluate(self, *args, **kwargs):
logger.info(f"Evaluation step called for {self.name}. No custom DPO evaluation implemented.")
logger.info(
f"Evaluation step called for {self.name}. No custom DPO evaluation implemented."
)
return None
@classmethod
def config_init(cls) -> Tuple[GSM8KConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
def config_init(
cls,
) -> Tuple[GSM8KConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
env_config = GSM8KConversationalStyleDPOEnvConfig(
wandb_name="gsm8k_dynamic_conversational_dpo",
wandb_name="gsm8k_dynamic_conversational_dpo",
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
group_size=1,
use_wandb=True,
max_num_workers=1,
group_size=1,
use_wandb=True,
max_num_workers=1,
rollout_server_url="http://localhost:8000",
total_steps=100,
batch_size=2,
total_steps=100,
batch_size=2,
steps_per_eval=50,
max_token_length=512,
max_token_length=512,
dataset_name="synthetic_conversational_style_prompts_via_gsm8k_env",
shuffle_dataset=True,
data_path_to_save_groups=None,
@ -327,11 +439,12 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
base_url="https://inference-api.nousresearch.com/v1",
api_key="sk-3DvYKMv_-BfAoDSTfdSvEQ",
num_requests_for_eval=256, # Copied from gsm8k_server.py
api_key="sk-3DvYKMv_-BfAoDSTfdSvEQ",
num_requests_for_eval=256, # Copied from gsm8k_server.py
)
]
return env_config, server_configs
if __name__ == "__main__":
GSM8KConversationalStyleDPOEnv.cli()

View file

@ -4,7 +4,7 @@ pip install accelerate
pip install pydantic
pip install bitsandbytes
pip install datasets
pip install transformers
pip install pydantic
pip install transformers
pip install pydantic
pip install torch
pip install -e .[all] # for everything

View file

@ -1,10 +1,8 @@
import asyncio
import os
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Optional
import torch
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
@ -36,11 +34,11 @@ class ScriptArguments:
)
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={"help": "The tokenizer name or path. Defaults to model_name_or_path."},
)
beta: float = field(
default=0.1, metadata={"help": "The beta factor in DPO loss."}
metadata={
"help": "The tokenizer name or path. Defaults to model_name_or_path."
},
)
beta: float = field(default=0.1, metadata={"help": "The beta factor in DPO loss."})
max_prompt_length: int = field(
default=256, metadata={"help": "Max length for prompts."}
)
@ -75,14 +73,14 @@ async def get_dataset_from_env(
# The DPOTrainer expects columns named "prompt", "chosen", "rejected"
# The synthetic_data in your environment is already in this format.
# Example: [{"prompt": "...", "chosen": "...", "rejected": "..."}]
hf_dataset = Dataset.from_list(list(env.dataset)) # Ensure it's a fresh list copy
hf_dataset = Dataset.from_list(list(env.dataset)) # Ensure it's a fresh list copy
# Log a sample to verify
if len(hf_dataset) > 0:
logger.info(f"Sample from dataset: {hf_dataset[0]}")
else:
logger.warning("Dataset created from environment is empty!")
return hf_dataset
@ -92,12 +90,14 @@ def main():
if script_args.tokenizer_name_or_path is None:
script_args.tokenizer_name_or_path = script_args.model_name_or_path
# --- 1. Initialize Environment and Get Dataset ---
logger.info("Initializing environment to get dataset...")
# Use the default config from your environment, but ensure tokenizer matches
env_dpo_config, _ = ConversationalStyleDPOEnv.config_init()
env_dpo_config.tokenizer_name = script_args.tokenizer_name_or_path # Align tokenizer
env_dpo_config.tokenizer_name = (
script_args.tokenizer_name_or_path
) # Align tokenizer
# You might want to adjust other env_dpo_config parameters if needed
# Run the async function to get the dataset
@ -127,17 +127,18 @@ def main():
# For simplicity, we'll let DPOTrainer handle creating the reference model by not passing one.
# If you wanted to load a different SFT model as reference, you would do:
# model_ref = AutoModelForCausalLM.from_pretrained(...)
model_ref = None
logger.info("Reference model will be a copy of the policy model (handled by DPOTrainer).")
model_ref = None
logger.info(
"Reference model will be a copy of the policy model (handled by DPOTrainer)."
)
# --- 3. Set up Training Arguments ---
# Default DPO training arguments. You might want to customize these.
if training_args.output_dir == "output_dir": # Default value from TrainingArguments
if training_args.output_dir == "output_dir": # Default value from TrainingArguments
# Output directory will be relative to the script's current working directory when run.
# If run from environments/hack0/conversational_style_dpo/, it will be ./dpo_conversational_trainer_results
training_args.output_dir = "./dpo_conversational_trainer_results"
training_args.output_dir = "./dpo_conversational_trainer_results"
# training_args.per_device_train_batch_size = 2 # Adjust as needed
# training_args.num_train_epochs = 1 # Keep low for a quick test
# training_args.gradient_accumulation_steps = 1
@ -145,21 +146,20 @@ def main():
# training_args.logging_steps = 10
# training_args.save_steps = 50
# training_args.report_to = "none" # "wandb" or "tensorboard" if you want to log
logger.info(f"Training Arguments: {training_args}")
logger.info(f"Training Arguments: {training_args}")
# --- 4. Initialize DPOTrainer ---
logger.info("Initializing DPOTrainer...")
dpo_trainer = DPOTrainer(
model=model,
ref_model=model_ref, # If None, a copy of model is made
ref_model=model_ref, # If None, a copy of model is made
args=training_args,
beta=script_args.beta,
train_dataset=dataset,
tokenizer=tokenizer,
max_prompt_length=script_args.max_prompt_length,
max_length=script_args.max_length, # Max length of prompt + response
max_length=script_args.max_length, # Max length of prompt + response
# peft_config=peft_config, # If using PEFT/LoRA
)
logger.info("DPOTrainer initialized.")
@ -170,8 +170,8 @@ def main():
logger.info("DPO training completed.")
# --- 6. Save the Model (Optional) ---
if training_args.should_save: # Checks if any save_strategy is enabled
output_save_dir = training_args.output_dir
if training_args.should_save: # Checks if any save_strategy is enabled
output_save_dir = training_args.output_dir
logger.info(f"Saving model to {output_save_dir}")
dpo_trainer.save_model(output_save_dir)
logger.info("Model saved.")
@ -182,4 +182,4 @@ def main():
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()
main()

View file

@ -1,153 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Rendered Messages - gsm8k_dpo_rollouts_1.jsonl</title>
<style>
body { font-family: sans-serif; line-height: 1.6; margin: 20px; }
details { border: 1px solid #ccc; border-radius: 4px; margin-bottom: 15px; }
summary {
font-weight: bold;
padding: 10px;
background-color: #f0f0f0;
cursor: pointer;
border-radius: 4px 4px 0 0;
border-bottom: 1px solid #ccc;
outline: none; /* Remove default focus outline if needed */
}
details[open] summary { border-bottom: 1px solid #ccc; }
.group-content { padding: 15px; }
.item {
border: 1px solid #eee;
border-radius: 3px;
margin-bottom: 10px;
padding: 10px;
transition: background-color 0.3s ease, box-shadow 0.2s ease; /* Smooth transitions */
scroll-margin-top: 10px; /* Space when scrolling into view */
}
.item h4 { margin-top: 0; margin-bottom: 5px; font-size: 1.1em; }
.content-block { background-color: #fff; padding: 8px; border-radius: 3px; margin-bottom: 5px; overflow-x: auto; }
/* Use :focus-within for better accessibility on container focus */
.item:focus, .item.active {
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.5); /* Highlight active/focused item */
outline: none; /* Remove default outline */
}
/* Score Backgrounds (Faint) */
.reward-positive { background-color: rgba(144, 238, 144, 0.3); } /* Faint light green */
.reward-zero { background-color: rgba(255, 215, 0, 0.3); } /* Faint gold/orange */
.reward-negative { background-color: rgba(255, 182, 193, 0.4); } /* Faint light pink/red */
/* Markdown specific styles */
.content-block pre {
background-color: #f5f5f5;
border: 1px solid #ddd;
border-radius: 3px;
padding: 10px;
overflow-x: auto; /* Allow horizontal scrolling for long code lines */
white-space: pre-wrap; /* Wrap long lines within pre */
word-wrap: break-word; /* Break long words if necessary */
}
.content-block code {
background-color: #f0f0f0; /* Slightly different for inline code */
padding: 0.2em 0.4em;
border-radius: 3px;
font-size: 0.9em;
}
.content-block pre code {
background-color: transparent; /* Don't double-background code in pre blocks */
padding: 0;
border-radius: 0;
font-size: inherit; /* Inherit pre font size */
}
.content-block blockquote {
border-left: 4px solid #ccc;
padding-left: 10px;
margin-left: 0;
color: #555;
}
.content-block table {
border-collapse: collapse;
width: auto; /* Don't force full width */
margin-bottom: 1em;
}
.content-block th, .content-block td {
border: 1px solid #ddd;
padding: 8px;
text-align: left;
}
.content-block th {
background-color: #f2f2f2;
}
</style>
</head>
<body>
<h1>Rendered Messages - gsm8k_dpo_rollouts_1.jsonl</h1>
<div id="groups-container">
<p>No data to display. Input file might be empty or contain invalid data.</p>
</div>
<script>
document.addEventListener('DOMContentLoaded', () => {
const items = document.querySelectorAll('.item');
let activeIndex = -1; // No item active initially
// Function to set active item
function setActiveItem(index) {
if (activeIndex >= 0 && activeIndex < items.length) {
items[activeIndex].classList.remove('active');
items[activeIndex].removeAttribute('tabindex'); // Remove from tab order when not active
}
if (index >= 0 && index < items.length) {
items[index].classList.add('active');
items[index].setAttribute('tabindex', '0'); // Make active item focusable
items[index].focus(); // Focus the element
// Ensure the parent <details> is open
const detailsParent = items[index].closest('details');
if (detailsParent && !detailsParent.open) {
detailsParent.open = true;
}
// Scroll into view with options if needed (focus should handle this mostly)
// items[index].scrollIntoView({ behavior: 'smooth', block: 'nearest' });
activeIndex = index;
} else {
activeIndex = -1; // Deactivate if index is out of bounds
}
}
// Add click listener to activate items
items.forEach((item, index) => {
item.addEventListener('click', () => {
setActiveItem(index);
});
// Make items focusable initially only if we want tab navigation
// item.setAttribute('tabindex', '0');
});
// Add keydown listener for arrow navigation
document.addEventListener('keydown', (event) => {
let targetIndex = -1;
if (event.key === 'ArrowDown') {
event.preventDefault(); // Prevent default page scroll
targetIndex = (activeIndex === -1) ? 0 : Math.min(activeIndex + 1, items.length - 1);
} else if (event.key === 'ArrowUp') {
event.preventDefault(); // Prevent default page scroll
targetIndex = (activeIndex === -1) ? items.length - 1 : Math.max(activeIndex - 1, 0);
}
if (targetIndex !== -1) {
setActiveItem(targetIndex);
}
});
// Make first item focusable initially if you want immediate keyboard nav
if (items.length > 0) {
// items[0].setAttribute('tabindex', '0');
// Optionally activate the first item on load:
// setActiveItem(0);
}
});
</script>
</body>
</html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,153 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Rendered Messages - gsm8k_dpo_rollouts_2.jsonl</title>
<style>
body { font-family: sans-serif; line-height: 1.6; margin: 20px; }
details { border: 1px solid #ccc; border-radius: 4px; margin-bottom: 15px; }
summary {
font-weight: bold;
padding: 10px;
background-color: #f0f0f0;
cursor: pointer;
border-radius: 4px 4px 0 0;
border-bottom: 1px solid #ccc;
outline: none; /* Remove default focus outline if needed */
}
details[open] summary { border-bottom: 1px solid #ccc; }
.group-content { padding: 15px; }
.item {
border: 1px solid #eee;
border-radius: 3px;
margin-bottom: 10px;
padding: 10px;
transition: background-color 0.3s ease, box-shadow 0.2s ease; /* Smooth transitions */
scroll-margin-top: 10px; /* Space when scrolling into view */
}
.item h4 { margin-top: 0; margin-bottom: 5px; font-size: 1.1em; }
.content-block { background-color: #fff; padding: 8px; border-radius: 3px; margin-bottom: 5px; overflow-x: auto; }
/* Use :focus-within for better accessibility on container focus */
.item:focus, .item.active {
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.5); /* Highlight active/focused item */
outline: none; /* Remove default outline */
}
/* Score Backgrounds (Faint) */
.reward-positive { background-color: rgba(144, 238, 144, 0.3); } /* Faint light green */
.reward-zero { background-color: rgba(255, 215, 0, 0.3); } /* Faint gold/orange */
.reward-negative { background-color: rgba(255, 182, 193, 0.4); } /* Faint light pink/red */
/* Markdown specific styles */
.content-block pre {
background-color: #f5f5f5;
border: 1px solid #ddd;
border-radius: 3px;
padding: 10px;
overflow-x: auto; /* Allow horizontal scrolling for long code lines */
white-space: pre-wrap; /* Wrap long lines within pre */
word-wrap: break-word; /* Break long words if necessary */
}
.content-block code {
background-color: #f0f0f0; /* Slightly different for inline code */
padding: 0.2em 0.4em;
border-radius: 3px;
font-size: 0.9em;
}
.content-block pre code {
background-color: transparent; /* Don't double-background code in pre blocks */
padding: 0;
border-radius: 0;
font-size: inherit; /* Inherit pre font size */
}
.content-block blockquote {
border-left: 4px solid #ccc;
padding-left: 10px;
margin-left: 0;
color: #555;
}
.content-block table {
border-collapse: collapse;
width: auto; /* Don't force full width */
margin-bottom: 1em;
}
.content-block th, .content-block td {
border: 1px solid #ddd;
padding: 8px;
text-align: left;
}
.content-block th {
background-color: #f2f2f2;
}
</style>
</head>
<body>
<h1>Rendered Messages - gsm8k_dpo_rollouts_2.jsonl</h1>
<div id="groups-container">
<p>No data to display. Input file might be empty or contain invalid data.</p>
</div>
<script>
document.addEventListener('DOMContentLoaded', () => {
const items = document.querySelectorAll('.item');
let activeIndex = -1; // No item active initially
// Function to set active item
function setActiveItem(index) {
if (activeIndex >= 0 && activeIndex < items.length) {
items[activeIndex].classList.remove('active');
items[activeIndex].removeAttribute('tabindex'); // Remove from tab order when not active
}
if (index >= 0 && index < items.length) {
items[index].classList.add('active');
items[index].setAttribute('tabindex', '0'); // Make active item focusable
items[index].focus(); // Focus the element
// Ensure the parent <details> is open
const detailsParent = items[index].closest('details');
if (detailsParent && !detailsParent.open) {
detailsParent.open = true;
}
// Scroll into view with options if needed (focus should handle this mostly)
// items[index].scrollIntoView({ behavior: 'smooth', block: 'nearest' });
activeIndex = index;
} else {
activeIndex = -1; // Deactivate if index is out of bounds
}
}
// Add click listener to activate items
items.forEach((item, index) => {
item.addEventListener('click', () => {
setActiveItem(index);
});
// Make items focusable initially only if we want tab navigation
// item.setAttribute('tabindex', '0');
});
// Add keydown listener for arrow navigation
document.addEventListener('keydown', (event) => {
let targetIndex = -1;
if (event.key === 'ArrowDown') {
event.preventDefault(); // Prevent default page scroll
targetIndex = (activeIndex === -1) ? 0 : Math.min(activeIndex + 1, items.length - 1);
} else if (event.key === 'ArrowUp') {
event.preventDefault(); // Prevent default page scroll
targetIndex = (activeIndex === -1) ? items.length - 1 : Math.max(activeIndex - 1, 0);
}
if (targetIndex !== -1) {
setActiveItem(targetIndex);
}
});
// Make first item focusable initially if you want immediate keyboard nav
if (items.length > 0) {
// items[0].setAttribute('tabindex', '0');
// Optionally activate the first item on load:
// setActiveItem(0);
}
});
</script>
</body>
</html>

View file

@ -1,153 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Rendered Messages - gsm8k_dpo_rollouts_4.jsonl</title>
<style>
body { font-family: sans-serif; line-height: 1.6; margin: 20px; }
details { border: 1px solid #ccc; border-radius: 4px; margin-bottom: 15px; }
summary {
font-weight: bold;
padding: 10px;
background-color: #f0f0f0;
cursor: pointer;
border-radius: 4px 4px 0 0;
border-bottom: 1px solid #ccc;
outline: none; /* Remove default focus outline if needed */
}
details[open] summary { border-bottom: 1px solid #ccc; }
.group-content { padding: 15px; }
.item {
border: 1px solid #eee;
border-radius: 3px;
margin-bottom: 10px;
padding: 10px;
transition: background-color 0.3s ease, box-shadow 0.2s ease; /* Smooth transitions */
scroll-margin-top: 10px; /* Space when scrolling into view */
}
.item h4 { margin-top: 0; margin-bottom: 5px; font-size: 1.1em; }
.content-block { background-color: #fff; padding: 8px; border-radius: 3px; margin-bottom: 5px; overflow-x: auto; }
/* Use :focus-within for better accessibility on container focus */
.item:focus, .item.active {
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.5); /* Highlight active/focused item */
outline: none; /* Remove default outline */
}
/* Score Backgrounds (Faint) */
.reward-positive { background-color: rgba(144, 238, 144, 0.3); } /* Faint light green */
.reward-zero { background-color: rgba(255, 215, 0, 0.3); } /* Faint gold/orange */
.reward-negative { background-color: rgba(255, 182, 193, 0.4); } /* Faint light pink/red */
/* Markdown specific styles */
.content-block pre {
background-color: #f5f5f5;
border: 1px solid #ddd;
border-radius: 3px;
padding: 10px;
overflow-x: auto; /* Allow horizontal scrolling for long code lines */
white-space: pre-wrap; /* Wrap long lines within pre */
word-wrap: break-word; /* Break long words if necessary */
}
.content-block code {
background-color: #f0f0f0; /* Slightly different for inline code */
padding: 0.2em 0.4em;
border-radius: 3px;
font-size: 0.9em;
}
.content-block pre code {
background-color: transparent; /* Don't double-background code in pre blocks */
padding: 0;
border-radius: 0;
font-size: inherit; /* Inherit pre font size */
}
.content-block blockquote {
border-left: 4px solid #ccc;
padding-left: 10px;
margin-left: 0;
color: #555;
}
.content-block table {
border-collapse: collapse;
width: auto; /* Don't force full width */
margin-bottom: 1em;
}
.content-block th, .content-block td {
border: 1px solid #ddd;
padding: 8px;
text-align: left;
}
.content-block th {
background-color: #f2f2f2;
}
</style>
</head>
<body>
<h1>Rendered Messages - gsm8k_dpo_rollouts_4.jsonl</h1>
<div id="groups-container">
<p>No data to display. Input file might be empty or contain invalid data.</p>
</div>
<script>
document.addEventListener('DOMContentLoaded', () => {
const items = document.querySelectorAll('.item');
let activeIndex = -1; // No item active initially
// Function to set active item
function setActiveItem(index) {
if (activeIndex >= 0 && activeIndex < items.length) {
items[activeIndex].classList.remove('active');
items[activeIndex].removeAttribute('tabindex'); // Remove from tab order when not active
}
if (index >= 0 && index < items.length) {
items[index].classList.add('active');
items[index].setAttribute('tabindex', '0'); // Make active item focusable
items[index].focus(); // Focus the element
// Ensure the parent <details> is open
const detailsParent = items[index].closest('details');
if (detailsParent && !detailsParent.open) {
detailsParent.open = true;
}
// Scroll into view with options if needed (focus should handle this mostly)
// items[index].scrollIntoView({ behavior: 'smooth', block: 'nearest' });
activeIndex = index;
} else {
activeIndex = -1; // Deactivate if index is out of bounds
}
}
// Add click listener to activate items
items.forEach((item, index) => {
item.addEventListener('click', () => {
setActiveItem(index);
});
// Make items focusable initially only if we want tab navigation
// item.setAttribute('tabindex', '0');
});
// Add keydown listener for arrow navigation
document.addEventListener('keydown', (event) => {
let targetIndex = -1;
if (event.key === 'ArrowDown') {
event.preventDefault(); // Prevent default page scroll
targetIndex = (activeIndex === -1) ? 0 : Math.min(activeIndex + 1, items.length - 1);
} else if (event.key === 'ArrowUp') {
event.preventDefault(); // Prevent default page scroll
targetIndex = (activeIndex === -1) ? items.length - 1 : Math.max(activeIndex - 1, 0);
}
if (targetIndex !== -1) {
setActiveItem(targetIndex);
}
});
// Make first item focusable initially if you want immediate keyboard nav
if (items.length > 0) {
// items[0].setAttribute('tabindex', '0');
// Optionally activate the first item on load:
// setActiveItem(0);
}
});
</script>
</body>
</html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long