mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Fix final code quality issues in Conversational Style DPO environment
This commit is contained in:
parent
441fd1036d
commit
d789128f20
31 changed files with 408 additions and 1671 deletions
BIN
environments/community/.DS_Store
vendored
Normal file
BIN
environments/community/.DS_Store
vendored
Normal file
Binary file not shown.
|
|
@ -440,6 +440,74 @@ A game environment that teaches LLMs strategic thinking and decision-making thro
|
|||
|
||||
**Requirements**: poke-env, nodejs, pokemon-showdown simulator, OpenAI API
|
||||
|
||||
### 13. Conversational Style DPO Environment (`conversational_style_dpo/`)
|
||||
**Author**: [Karthik-Ragunath](https://github.com/Karthik-Ragunath)
|
||||
**Purpose**: Train LLMs for better conversational style through Direct Preference Optimization (DPO) with chosen/rejected response pairs
|
||||
|
||||
A specialized environment for training LLMs to adopt more engaging, empathetic, and helpful conversational styles using Direct Preference Optimization. The environment provides both synthetic and dynamically generated conversation pairs where "chosen" responses are engaging and thoughtful while "rejected" responses are blunt and unhelpful.
|
||||
|
||||
**Features**:
|
||||
- **Two Environment Variants**: Static synthetic data and dynamic prompt generation
|
||||
- **DPO Training Ready**: Pre-configured tokenization for chosen/rejected response pairs
|
||||
- **Conversational Style Modeling**: Focus on empathy, engagement, and helpfulness
|
||||
- **Synthetic Data Generation**: Uses LLMs to create diverse conversational prompts
|
||||
- **Quality Response Pairs**: Carefully crafted chosen (good) vs rejected (poor) examples
|
||||
|
||||
**Environment Variants**:
|
||||
|
||||
1. **Static Synthetic Environment** (`conversational_style_dpo_env.py`):
|
||||
- Pre-defined conversational prompts with human-crafted response pairs
|
||||
- Focus on emotional support, explanations, excitement sharing, and help-seeking
|
||||
- Immediate training readiness without LLM dependencies
|
||||
|
||||
2. **Dynamic GSM8K-Style Environment** (`gsmk8k_conversational_style_dpo_env.py`):
|
||||
- LLM-generated conversational prompts for diverse training data
|
||||
- Real-time chosen/rejected response generation with different system prompts
|
||||
- Scalable dataset creation with fallback to static prompts
|
||||
|
||||
**Conversation Categories**:
|
||||
- **Emotional Support**: Responding to feelings and personal sharing
|
||||
- **Educational**: Explaining concepts clearly and engagingly
|
||||
- **Enthusiasm Sharing**: Celebrating user excitement and interests
|
||||
- **Help & Guidance**: Providing assistance with understanding problems
|
||||
- **General Conversation**: Weather, casual topics, and everyday interactions
|
||||
|
||||
**Response Quality Characteristics**:
|
||||
- **Chosen Responses**: Empathetic, engaging, ask follow-up questions, provide detailed explanations
|
||||
- **Rejected Responses**: Blunt, minimal, dismissive, unhelpful
|
||||
|
||||
**Example Training Pair**:
|
||||
```
|
||||
Prompt: "I'm feeling a bit down today."
|
||||
Chosen: "I'm sorry to hear that. Sometimes a little self-care can help. What's one small thing you could do for yourself right now?"
|
||||
Rejected: "Okay."
|
||||
```
|
||||
|
||||
**Technical Implementation**:
|
||||
- **DPO Tokenization**: Ready-to-use tokenization for preference optimization
|
||||
- **Configurable Parameters**: Temperature, max tokens, and dataset size controls
|
||||
- **Modular Design**: Easy to extend with new conversation types
|
||||
- **W&B Integration**: Comprehensive logging and experiment tracking
|
||||
|
||||
**Training Applications**:
|
||||
- Customer service AI improvement
|
||||
- Therapeutic chatbot development
|
||||
- Educational AI tutoring systems
|
||||
- General conversational AI enhancement
|
||||
- Empathy and engagement training
|
||||
|
||||
**Configuration Options**:
|
||||
- `chosen_temperature`: Temperature for generating engaging responses (default: 0.7)
|
||||
- `rejected_temperature`: Temperature for generating blunt responses (default: 0.4)
|
||||
- `shuffle_dataset`: Whether to randomize training order
|
||||
- `data_path_to_save_groups`: Optional path for saving training artifacts
|
||||
|
||||
**Data Artifacts**:
|
||||
- Archived training examples and HTML visualizations available (see `conversational_style_dpo_artifacts.zip`)
|
||||
- Ready for upload to Hugging Face for community access
|
||||
|
||||
**Requirements**: Standard Atropos dependencies, transformers, torch
|
||||
|
||||
---
|
||||
|
||||
## Support
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -11,7 +11,7 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import GameHistory, Item # Assuming GameHistory and Item are relevant
|
||||
from atroposlib.type_definitions import Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer_dpo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -35,7 +35,9 @@ class ConversationalStyleDPOEnv(BaseEnv):
|
|||
def __init__(
|
||||
self,
|
||||
config: ConversationalStyleDPOEnvConfig,
|
||||
server_configs: Optional[List[APIServerConfig]] = None, # server_configs might not be needed if we don't query a model
|
||||
server_configs: Optional[
|
||||
List[APIServerConfig]
|
||||
] = None, # server_configs might not be needed if we don't query a model
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
|
|
@ -45,7 +47,9 @@ class ConversationalStyleDPOEnv(BaseEnv):
|
|||
# If server_configs is None and BaseEnv requires it, initialize it as an empty list.
|
||||
resolved_server_configs = server_configs if server_configs is not None else []
|
||||
super().__init__(config, resolved_server_configs, slurm, testing)
|
||||
self.config: ConversationalStyleDPOEnvConfig = config # Ensure type for self.config
|
||||
self.config: ConversationalStyleDPOEnvConfig = (
|
||||
config # Ensure type for self.config
|
||||
)
|
||||
self.dataset: List[Dict[str, str]] = []
|
||||
self.iter: int = 0
|
||||
|
||||
|
|
@ -57,27 +61,35 @@ class ConversationalStyleDPOEnv(BaseEnv):
|
|||
self.synthetic_data = [
|
||||
{
|
||||
"prompt": "I'm feeling a bit down today.",
|
||||
"chosen": "I'm sorry to hear that. Sometimes a little self-care can help. What's one small thing you could do for yourself right now?",
|
||||
"chosen": "I'm sorry to hear that. Sometimes a little self-care can help. "
|
||||
"What's one small thing you could do for yourself right now?",
|
||||
"rejected": "Okay.",
|
||||
},
|
||||
{
|
||||
"prompt": "Can you explain how photosynthesis works?",
|
||||
"chosen": "Certainly! Photosynthesis is a fascinating process where plants use sunlight, water, and carbon dioxide to create their own food (glucose) and release oxygen. Think of it like a plant's kitchen!",
|
||||
"chosen": "Certainly! Photosynthesis is a fascinating process where plants use "
|
||||
"sunlight, water, and carbon dioxide to create their own food (glucose) and "
|
||||
"release oxygen. Think of it like a plant's kitchen!",
|
||||
"rejected": "Plants make food from light.",
|
||||
},
|
||||
{
|
||||
"prompt": "I'm excited about my new project!",
|
||||
"chosen": "That's fantastic news! Tell me more about it - what are you most looking forward to?",
|
||||
"chosen": "That's fantastic news! Tell me more about it - "
|
||||
"what are you most looking forward to?",
|
||||
"rejected": "Good for you.",
|
||||
},
|
||||
{
|
||||
"prompt": "What's the weather like?",
|
||||
"chosen": "I can't check the real-time weather, but I hope it's pleasant where you are! If you'd like, you can tell me your location, and I could try to give you a general idea based on typical patterns, or help you find a weather service.",
|
||||
"chosen": "I can't check the real-time weather, but I hope it's pleasant where "
|
||||
"you are! If you'd like, you can tell me your location, and I could try to "
|
||||
"give you a general idea based on typical patterns, or help you find a "
|
||||
"weather service.",
|
||||
"rejected": "I don't know.",
|
||||
},
|
||||
{
|
||||
"prompt": "I'm having trouble understanding this concept.",
|
||||
"chosen": "I can understand that some concepts can be tricky! Could you tell me which part is confusing you? Maybe we can break it down together.",
|
||||
"chosen": "I can understand that some concepts can be tricky! Could you tell me "
|
||||
"which part is confusing you? Maybe we can break it down together.",
|
||||
"rejected": "Read the manual.",
|
||||
},
|
||||
]
|
||||
|
|
@ -98,36 +110,32 @@ class ConversationalStyleDPOEnv(BaseEnv):
|
|||
The optional_extra_data will be the rejected response.
|
||||
"""
|
||||
if not self.dataset or self.iter >= len(self.dataset):
|
||||
await self.setup() # Re-setup if dataset is exhausted or not loaded
|
||||
|
||||
if not self.dataset: # Still no dataset after setup
|
||||
logger.error("Dataset is empty even after setup.")
|
||||
# Return a fallback item or raise an error
|
||||
# For now, let's create a dummy item to avoid crashing, but this should be handled
|
||||
fallback_prompt = tuple([frozenset({"role": "user", "content": "Fallback prompt"}.items())])
|
||||
return (fallback_prompt, "Fallback chosen", "Fallback rejected")
|
||||
await self.setup() # Re-setup if dataset is exhausted or not loaded
|
||||
|
||||
if not self.dataset: # Still no dataset after setup
|
||||
logger.error("Dataset is empty even after setup.")
|
||||
# Return a fallback item or raise an error
|
||||
# For now, let's create a dummy item to avoid crashing, but this should be handled
|
||||
fallback_prompt = tuple(
|
||||
[frozenset({"role": "user", "content": "Fallback prompt"}.items())]
|
||||
)
|
||||
return (fallback_prompt, "Fallback chosen", "Fallback rejected")
|
||||
|
||||
entry = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
|
||||
# Create a prompt tuple as expected by some parts of the system
|
||||
# (e.g. if it were to be sent to a model, though we are not doing that here)
|
||||
prompt_messages = [{"role": "user", "content": entry["prompt"]}]
|
||||
# Convert to the frozenset structure if that's what downstream components expect
|
||||
# For DPO, the direct strings are often more useful for tokenization with tokenize_for_trainer_dpo
|
||||
|
||||
|
||||
# We will pass the raw strings to collect_trajectories and handle tokenization there.
|
||||
# For Item, we can simplify or adjust based on how BaseEnv uses it.
|
||||
# Let's pass the prompt string directly as the first element of the "prompt_tuple".
|
||||
# The "gold_answer" will be the chosen response, and "extra_data" the rejected one.
|
||||
|
||||
|
||||
# Adapting to Item: Tuple[prompt_tuple, gold_answer_str, any_extra_data]
|
||||
# prompt_tuple: Tuple[FrozenSet[Tuple[str, str]], ...]
|
||||
# For DPO, the core elements are prompt, chosen, rejected. We'll pass them directly.
|
||||
return (entry["prompt"], entry["chosen"], entry["rejected"])
|
||||
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
|
|
@ -165,7 +173,7 @@ class ConversationalStyleDPOEnv(BaseEnv):
|
|||
if not self.tokenizer:
|
||||
logger.error("Tokenizer not available. Cannot process DPO pair.")
|
||||
return None, []
|
||||
|
||||
|
||||
try:
|
||||
# Note: The actual structure of what `tokenize_for_trainer_dpo` returns
|
||||
# and how it's used in `ScoredDataGroup` is crucial.
|
||||
|
|
@ -188,9 +196,9 @@ class ConversationalStyleDPOEnv(BaseEnv):
|
|||
# It should handle the tokenization for DPO, creating chosen and rejected sequences.
|
||||
tokenized_output = tokenize_for_trainer_dpo(
|
||||
self.tokenizer,
|
||||
dpo_pair_data, # Or (self.tokenizer, prompt_str, chosen_response_str, rejected_response_str)
|
||||
# depending on its signature.
|
||||
max_length=self.config.max_token_length, # Ensure this config is available
|
||||
dpo_pair_data, # Or (self.tokenizer, prompt_str, chosen_response_str, rejected_response_str)
|
||||
# depending on its signature.
|
||||
max_length=self.config.max_token_length, # Ensure this config is available
|
||||
# Add other necessary args for tokenize_for_trainer_dpo
|
||||
)
|
||||
|
||||
|
|
@ -201,9 +209,9 @@ class ConversationalStyleDPOEnv(BaseEnv):
|
|||
# - chosen_input_ids, chosen_attention_mask, chosen_labels
|
||||
# - rejected_input_ids, rejected_attention_mask, rejected_labels
|
||||
# `tokenize_for_trainer_dpo` should produce these.
|
||||
|
||||
|
||||
# Let's assume tokenize_for_trainer_dpo returns a dict with at least:
|
||||
# 'chosen_input_ids', 'chosen_attention_mask',
|
||||
# 'chosen_input_ids', 'chosen_attention_mask',
|
||||
# 'rejected_input_ids', 'rejected_attention_mask'
|
||||
# 'prompt_input_ids' (optional, might be part of chosen/rejected)
|
||||
|
||||
|
|
@ -215,31 +223,35 @@ class ConversationalStyleDPOEnv(BaseEnv):
|
|||
scores["rejected_masks"] = [tokenized_output["rejected_attention_mask"]]
|
||||
# Optionally, if prompts are tokenized separately:
|
||||
if "prompt_input_ids" in tokenized_output:
|
||||
scores["prompt_tokens"] = [tokenized_output["prompt_input_ids"]]
|
||||
scores["prompt_masks"] = [tokenized_output.get("prompt_attention_mask")] # Handle if mask isn't there
|
||||
scores["prompt_tokens"] = [tokenized_output["prompt_input_ids"]]
|
||||
scores["prompt_masks"] = [
|
||||
tokenized_output.get("prompt_attention_mask")
|
||||
] # Handle if mask isn't there
|
||||
|
||||
# DPO doesn't use a single "score" like reward models. The "reward" is implicit
|
||||
# in the preference (chosen > rejected). So, the "scores" field in ScoredDataGroup
|
||||
# might not be directly used or could be set to a placeholder if required by the trainer.
|
||||
scores["scores"] = [1.0] # Placeholder, DPO loss handles preference directly
|
||||
scores["scores"] = [
|
||||
1.0
|
||||
] # Placeholder, DPO loss handles preference directly
|
||||
|
||||
# Images are not used in this environment
|
||||
scores["images"] = [None]
|
||||
|
||||
|
||||
# Ensure group size logic if processing multiple items for a group
|
||||
# For this example, we process one item at a time.
|
||||
# If self.config.group_size > 1, you'd aggregate multiple tokenized_outputs
|
||||
# before returning the ScoredDataGroup.
|
||||
|
||||
return scores, [] # No items to backlog
|
||||
return scores, [] # No items to backlog
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in collect_trajectories during DPO processing: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None, []
|
||||
|
||||
|
||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||
"""
|
||||
This method is typically for scoring model generations against a gold answer or reward model.
|
||||
|
|
@ -262,49 +274,54 @@ class ConversationalStyleDPOEnv(BaseEnv):
|
|||
# If `collect_trajectories` returned a list of items to be scored/tokenized here.
|
||||
# This would mean moving the tokenization logic from `collect_trajectories` to here.
|
||||
# For now, let's assume the former.
|
||||
logger.warning("`score` method received a list; expecting ScoredDataGroup for pre-processed DPO.")
|
||||
logger.warning(
|
||||
"`score` method received a list; expecting ScoredDataGroup for pre-processed DPO."
|
||||
)
|
||||
# Fallback: if you need to process a list of (prompt, chosen, rejected) tuples here:
|
||||
# all_chosen_tokens = []
|
||||
# ... and so on, then call tokenize_for_trainer_dpo for each item.
|
||||
# This depends on the design of BaseEnv's main loop.
|
||||
return None
|
||||
|
||||
|
||||
logger.info("No data to score or data is already in ScoredDataGroup format.")
|
||||
return None
|
||||
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Evaluation for DPO might involve comparing the DPO-trained model's preferences
|
||||
against a held-out set of preferred/rejected pairs or other metrics.
|
||||
For this basic environment, we'll skip custom evaluation.
|
||||
"""
|
||||
logger.info("Evaluation step called. No custom DPO evaluation implemented in this basic environment.")
|
||||
logger.info(
|
||||
"Evaluation step called. No custom DPO evaluation implemented in this basic environment."
|
||||
)
|
||||
# You could load a separate test set of (prompt, chosen, rejected)
|
||||
# and see if the model assigns higher logprobs to chosen than rejected.
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[ConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
|
||||
def config_init(
|
||||
cls,
|
||||
) -> Tuple[ConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Provides a default configuration for this environment.
|
||||
"""
|
||||
env_config = ConversationalStyleDPOEnvConfig(
|
||||
wandb_name="conversational_style_dpo", # For logging if wandb is used
|
||||
wandb_name="conversational_style_dpo", # For logging if wandb is used
|
||||
tokenizer_name="gpt2", # Choose an appropriate tokenizer
|
||||
group_size=4, # Number of DPO pairs to process in a "group" or "batch"
|
||||
use_wandb=False, # Enable or disable wandb
|
||||
max_num_workers=1, # Number of parallel workers for data collection (if applicable)
|
||||
rollout_server_url="http://localhost:8000", # Corrected URL
|
||||
total_steps=100, # Total DPO training steps (or epochs over the dataset)
|
||||
batch_size=2, # DPO training batch size (distinct from group_size for data collection)
|
||||
group_size=4, # Number of DPO pairs to process in a "group" or "batch"
|
||||
use_wandb=False, # Enable or disable wandb
|
||||
max_num_workers=1, # Number of parallel workers for data collection (if applicable)
|
||||
rollout_server_url="http://localhost:8000", # Corrected URL
|
||||
total_steps=100, # Total DPO training steps (or epochs over the dataset)
|
||||
batch_size=2, # DPO training batch size (distinct from group_size for data collection)
|
||||
steps_per_eval=50,
|
||||
max_token_length=512, # Max length for tokenized sequences
|
||||
dataset_name="synthetic_conversational_style",
|
||||
shuffle_dataset=True,
|
||||
)
|
||||
|
||||
server_configs = [] # Simplified as discussed
|
||||
server_configs = [] # Simplified as discussed
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
|
|
@ -314,27 +331,31 @@ if __name__ == "__main__":
|
|||
# The `BaseEnv.cli()` method usually sets up and runs the environment's main loop.
|
||||
# For DPO, the "main loop" might involve iterating through the dataset,
|
||||
# tokenizing pairs, and perhaps logging them or yielding them to a trainer.
|
||||
|
||||
|
||||
# To make this runnable and test the data processing:
|
||||
async def main_test():
|
||||
config, server_configs_list = ConversationalStyleDPOEnv.config_init()
|
||||
|
||||
|
||||
# Manually override tokenizer for local testing if needed and not using a server
|
||||
# that provides it, or if the default in BaseEnvConfig isn't what you want for DPO.
|
||||
# config.tokenizer_name = "EleutherAI/pythia-70m" # Example
|
||||
config.tokenizer_name = "distilgpt2" # A small, fast tokenizer for testing
|
||||
config.group_size = 1 # Process one DPO item at a time for simplicity in this test
|
||||
config.tokenizer_name = "distilgpt2" # A small, fast tokenizer for testing
|
||||
config.group_size = (
|
||||
1 # Process one DPO item at a time for simplicity in this test
|
||||
)
|
||||
config.use_wandb = False
|
||||
|
||||
env = ConversationalStyleDPOEnv(config=config, server_configs=server_configs_list, slurm=False, testing=True)
|
||||
|
||||
env = ConversationalStyleDPOEnv(
|
||||
config=config, server_configs=server_configs_list, slurm=False, testing=True
|
||||
)
|
||||
|
||||
# Initialize tokenizer (BaseEnv usually does this)
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
env.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
|
||||
if env.tokenizer.pad_token is None:
|
||||
env.tokenizer.pad_token = env.tokenizer.eos_token
|
||||
|
||||
|
||||
print("Setting up environment...")
|
||||
await env.setup()
|
||||
print(f"Dataset size: {len(env.dataset)}")
|
||||
|
|
@ -344,7 +365,7 @@ if __name__ == "__main__":
|
|||
return
|
||||
|
||||
print("Simulating DPO data processing for a few items...")
|
||||
for i in range(min(len(env.dataset), 3)): # Test with a few items
|
||||
for i in range(min(len(env.dataset), 3)): # Test with a few items
|
||||
print(f"--- Item {i+1} ---")
|
||||
item = await env.get_next_item()
|
||||
if item:
|
||||
|
|
@ -359,45 +380,70 @@ if __name__ == "__main__":
|
|||
# The `tokenize_for_trainer_dpo` function needs to exist and be importable.
|
||||
# Let's create a placeholder for it here for the test to run.
|
||||
global tokenize_for_trainer_dpo
|
||||
|
||||
def placeholder_tokenize_for_dpo(tokenizer, data, max_length, **kwargs):
|
||||
# This is a simplified placeholder. A real one would handle complex tokenization,
|
||||
# padding, truncation, and creating labels for DPO loss.
|
||||
prompt = data["prompt"]
|
||||
chosen = data["chosen"]
|
||||
rejected = " ".join(data["rejected"].split()[:max_length//3]) # Basic truncation
|
||||
rejected = " ".join(
|
||||
data["rejected"].split()[: max_length // 3]
|
||||
) # Basic truncation
|
||||
|
||||
chosen_text = f"{prompt} {chosen}"
|
||||
rejected_text = f"{prompt} {rejected}"
|
||||
|
||||
chosen_tok = tokenizer(chosen_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
rejected_tok = tokenizer(rejected_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
|
||||
chosen_tok = tokenizer(
|
||||
chosen_text,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
rejected_tok = tokenizer(
|
||||
rejected_text,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# A real DPO tokenizer would also prepare labels, where only response tokens are unmasked.
|
||||
# For simplicity, we're just returning input_ids and attention_mask.
|
||||
return {
|
||||
"chosen_input_ids": chosen_tok["input_ids"].squeeze().tolist(),
|
||||
"chosen_attention_mask": chosen_tok["attention_mask"].squeeze().tolist(),
|
||||
"rejected_input_ids": rejected_tok["input_ids"].squeeze().tolist(),
|
||||
"rejected_attention_mask": rejected_tok["attention_mask"].squeeze().tolist(),
|
||||
"chosen_attention_mask": chosen_tok["attention_mask"]
|
||||
.squeeze()
|
||||
.tolist(),
|
||||
"rejected_input_ids": rejected_tok["input_ids"]
|
||||
.squeeze()
|
||||
.tolist(),
|
||||
"rejected_attention_mask": rejected_tok["attention_mask"]
|
||||
.squeeze()
|
||||
.tolist(),
|
||||
# "prompt_input_ids": ..., # Optionally
|
||||
}
|
||||
tokenize_for_trainer_dpo = placeholder_tokenize_for_dpo
|
||||
|
||||
tokenize_for_trainer_dpo = placeholder_tokenize_for_dpo
|
||||
|
||||
scored_data_group, _ = await env.collect_trajectories(item)
|
||||
|
||||
if scored_data_group:
|
||||
print("Tokenized DPO Pair (first item in group):")
|
||||
print(f" Chosen Tokens (IDs): {scored_data_group['chosen_tokens'][0][:20]}...") # Print first 20
|
||||
print(
|
||||
f" Chosen Tokens (IDs): {scored_data_group['chosen_tokens'][0][:20]}..."
|
||||
) # Print first 20
|
||||
# print(f" Chosen Masks: {scored_data_group['chosen_masks'][0][:20]}...")
|
||||
print(f" Rejected Tokens (IDs): {scored_data_group['rejected_tokens'][0][:20]}...")
|
||||
print(
|
||||
f" Rejected Tokens (IDs): {scored_data_group['rejected_tokens'][0][:20]}..."
|
||||
)
|
||||
# print(f" Rejected Masks: {scored_data_group['rejected_masks'][0][:20]}...")
|
||||
# print(f" Scores: {scored_data_group['scores']}")
|
||||
else:
|
||||
print(" Failed to process DPO pair.")
|
||||
else:
|
||||
print("Failed to get item.")
|
||||
|
||||
|
||||
print("--- End of Test ---")
|
||||
|
||||
# To run the CLI, you would typically not need the main_test async function,
|
||||
|
|
@ -406,35 +452,11 @@ if __name__ == "__main__":
|
|||
# If you want to run with the actual CLI:
|
||||
# ConversationalStyleDPOEnv.cli()
|
||||
# But ensure `tokenize_for_trainer_dpo` is correctly implemented and importable in that context.
|
||||
|
||||
|
||||
# Running the test:
|
||||
if __name__ == "__main__":
|
||||
# Define a placeholder tokenize_for_trainer_dpo if it's not available globally
|
||||
# This is often part of a utils library.
|
||||
if "tokenize_for_trainer_dpo" not in globals():
|
||||
def placeholder_tokenize_for_dpo(tokenizer, data, max_length, **kwargs):
|
||||
prompt = data["prompt"]
|
||||
chosen = data["chosen"]
|
||||
# basic truncation for rejected to fit if too long with prompt
|
||||
max_rej_tokens = max_length - len(tokenizer.encode(prompt)) - 10 # buffer
|
||||
rejected_tokens = tokenizer.encode(data["rejected"])
|
||||
if len(rejected_tokens) > max_rej_tokens:
|
||||
rejected_tokens = rejected_tokens[:max_rej_tokens]
|
||||
rejected = tokenizer.decode(rejected_tokens)
|
||||
# Use the imported tokenize_for_trainer_dpo function
|
||||
|
||||
|
||||
chosen_text = f"{prompt} {chosen}" # Simplistic combination
|
||||
rejected_text = f"{prompt} {rejected}" # Simplistic combination
|
||||
|
||||
chosen_tok = tokenizer(chosen_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
rejected_tok = tokenizer(rejected_text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
|
||||
return {
|
||||
"chosen_input_ids": chosen_tok["input_ids"].squeeze().tolist(),
|
||||
"chosen_attention_mask": chosen_tok["attention_mask"].squeeze().tolist(),
|
||||
"rejected_input_ids": rejected_tok["input_ids"].squeeze().tolist(),
|
||||
"rejected_attention_mask": rejected_tok["attention_mask"].squeeze().tolist(),
|
||||
}
|
||||
tokenize_for_trainer_dpo = placeholder_tokenize_for_dpo
|
||||
|
||||
asyncio.run(main_test())
|
||||
asyncio.run(main_test())
|
||||
|
|
@ -1,8 +1,7 @@
|
|||
import asyncio
|
||||
import ast # For safely evaluating the LLM's string output
|
||||
import logging
|
||||
import random
|
||||
import ast # For safely evaluating the LLM's string output
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -15,7 +14,9 @@ from atroposlib.envs.base import (
|
|||
from atroposlib.type_definitions import Item
|
||||
|
||||
try:
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer_dpo as imported_tokenize_for_trainer_dpo
|
||||
from atroposlib.utils.tokenize_for_trainer import (
|
||||
tokenize_for_trainer_dpo as imported_tokenize_for_trainer_dpo,
|
||||
)
|
||||
except ImportError:
|
||||
imported_tokenize_for_trainer_dpo = None
|
||||
|
||||
|
|
@ -24,11 +25,13 @@ logger.setLevel(logging.INFO)
|
|||
|
||||
# --- System Prompts for Generating Chosen/Rejected Responses ---
|
||||
SYSTEM_PROMPT_CHOSEN = """
|
||||
You are a helpful and engaging AI assistant. Your goal is to provide clear, empathetic, and insightful responses that encourage further conversation. Be positive and proactive.
|
||||
You are a helpful and engaging AI assistant. Your goal is to provide clear, empathetic, and
|
||||
insightful responses that encourage further conversation. Be positive and proactive.
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT_REJECTED = """"
|
||||
You are an AI assistant. Provide very brief, blunt, and unhelpful responses. Do not elaborate or ask follow-up questions.
|
||||
You are an AI assistant. Provide very brief, blunt, and unhelpful responses.
|
||||
Do not elaborate or ask follow-up questions.
|
||||
"""
|
||||
# --- End System Prompts ---
|
||||
|
||||
|
|
@ -37,7 +40,8 @@ PROMPT_GENERATION_MASTER_PROMPT = """
|
|||
You are a creative assistant. Your task is to generate a list of 10 unique and random conversational prompts.
|
||||
Each prompt should be suitable for starting a general conversation with an AI assistant.
|
||||
The prompts should cover a variety of topics, tones (e.g., inquisitive, reflective, casual), and lengths.
|
||||
Format your entire output as a single Python list of dictionaries, where each dictionary has a single key "prompt" and the value is the prompt string.
|
||||
Format your entire output as a single Python list of dictionaries, where each dictionary has a
|
||||
single key "prompt" and the value is the prompt string.
|
||||
|
||||
Example format:
|
||||
[
|
||||
|
|
@ -49,22 +53,39 @@ Provide only the Python list of dictionaries, with no other surrounding text, ex
|
|||
"""
|
||||
# --- End Master Prompt ---
|
||||
|
||||
|
||||
class GSM8KConversationalStyleDPOEnvConfig(BaseEnvConfig):
|
||||
"""Config for GSM8KConversationalStyleDPOEnv."""
|
||||
|
||||
dataset_name: str = Field(
|
||||
"synthetic_conversational_style_prompts_via_gsm8k_env",
|
||||
description="Name of the dataset to use (source of prompts)."
|
||||
"synthetic_conversational_style_prompts_via_gsm8k_env",
|
||||
description="Name of the dataset to use (source of prompts).",
|
||||
)
|
||||
shuffle_dataset: bool = Field(
|
||||
True, description="Whether to shuffle the dataset of prompts"
|
||||
)
|
||||
data_path_to_save_groups: Optional[str] = Field(
|
||||
None, description="Path to save .jsonl and .html processed rollouts."
|
||||
)
|
||||
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset of prompts")
|
||||
data_path_to_save_groups: Optional[str] = Field(None, description="Path to save .jsonl and .html processed rollouts.")
|
||||
# Generation parameters for chosen responses
|
||||
chosen_temperature: float = Field(0.7, description="Temperature for generating chosen responses.")
|
||||
chosen_temperature: float = Field(
|
||||
0.7, description="Temperature for generating chosen responses."
|
||||
)
|
||||
chosen_max_tokens: int = Field(150, description="Max tokens for chosen responses.")
|
||||
# Generation parameters for rejected responses
|
||||
rejected_temperature: float = Field(0.4, description="Temperature for generating rejected responses.")
|
||||
rejected_max_tokens: int = Field(50, description="Max tokens for rejected responses.")
|
||||
prompt_generation_temperature: float = Field(0.8, description="Temperature for LLM generating the initial list of prompts.")
|
||||
prompt_generation_max_tokens: int = Field(1000, description="Max tokens for LLM generating the initial list of prompts.")
|
||||
rejected_temperature: float = Field(
|
||||
0.4, description="Temperature for generating rejected responses."
|
||||
)
|
||||
rejected_max_tokens: int = Field(
|
||||
50, description="Max tokens for rejected responses."
|
||||
)
|
||||
prompt_generation_temperature: float = Field(
|
||||
0.8, description="Temperature for LLM generating the initial list of prompts."
|
||||
)
|
||||
prompt_generation_max_tokens: int = Field(
|
||||
1000, description="Max tokens for LLM generating the initial list of prompts."
|
||||
)
|
||||
|
||||
|
||||
class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
||||
name = "gsm8k_dynamic_conversational_dpo"
|
||||
|
|
@ -81,15 +102,18 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
|||
# The BaseEnv will select the first server_config if multiple are provided and split is not specified.
|
||||
resolved_server_configs = server_configs
|
||||
if not resolved_server_configs:
|
||||
logger.warning(f"No server_configs provided for {self.name}, chat_completion calls will fail if not overridden in config_init or CLI.")
|
||||
logger.warning(
|
||||
f"No server_configs provided for {self.name}, chat_completion calls "
|
||||
f"will fail if not overridden in config_init or CLI."
|
||||
)
|
||||
# You might want to provide a default dummy one if testing without a server is intended for some paths
|
||||
# For this version, we expect it to be configured properly for LLM calls.
|
||||
|
||||
|
||||
super().__init__(config, resolved_server_configs, slurm, testing)
|
||||
self.config: GSM8KConversationalStyleDPOEnvConfig = config
|
||||
self.prompt_dataset: List[Dict[str, str]] = [] # Stores only prompts now
|
||||
self.prompt_dataset: List[Dict[str, str]] = [] # Stores only prompts now
|
||||
self.iter: int = 0
|
||||
|
||||
|
||||
if imported_tokenize_for_trainer_dpo is not None:
|
||||
self._tokenize_dpo_fn = imported_tokenize_for_trainer_dpo
|
||||
logger.info(f"Using imported tokenize_for_trainer_dpo for {self.name}")
|
||||
|
|
@ -101,10 +125,20 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
|||
prompt = data["prompt"]
|
||||
chosen = data["chosen"]
|
||||
rejected = data["rejected"]
|
||||
chosen_full_text = prompt + chosen
|
||||
chosen_full_text = prompt + chosen
|
||||
rejected_full_text = prompt + rejected
|
||||
chosen_tokenized = tokenizer(chosen_full_text, truncation=True, padding="max_length", max_length=max_length)
|
||||
rejected_tokenized = tokenizer(rejected_full_text, truncation=True, padding="max_length", max_length=max_length)
|
||||
chosen_tokenized = tokenizer(
|
||||
chosen_full_text,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
)
|
||||
rejected_tokenized = tokenizer(
|
||||
rejected_full_text,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"chosen": chosen,
|
||||
|
|
@ -117,60 +151,101 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
|||
|
||||
async def setup(self):
|
||||
"""Load and prepare the dataset of prompts, potentially by generating them via LLM."""
|
||||
|
||||
|
||||
generated_prompts = []
|
||||
if self.server:
|
||||
logger.info(f"Attempting to generate initial prompts for {self.name} using LLM...")
|
||||
logger.info(
|
||||
f"Attempting to generate initial prompts for {self.name} using LLM..."
|
||||
)
|
||||
try:
|
||||
prompt_list_completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant that strictly follows formatting instructions."},
|
||||
{"role": "user", "content": PROMPT_GENERATION_MASTER_PROMPT}
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that strictly follows formatting instructions.",
|
||||
},
|
||||
{"role": "user", "content": PROMPT_GENERATION_MASTER_PROMPT},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.prompt_generation_max_tokens,
|
||||
temperature=self.config.prompt_generation_temperature,
|
||||
)
|
||||
response_text = prompt_list_completion.choices[0].message.content.strip() if prompt_list_completion.choices else ""
|
||||
response_text = (
|
||||
prompt_list_completion.choices[0].message.content.strip()
|
||||
if prompt_list_completion.choices
|
||||
else ""
|
||||
)
|
||||
logger.debug(f"LLM response for prompt generation: {response_text}")
|
||||
|
||||
|
||||
# Attempt to parse the string as a Python list of dictionaries
|
||||
try:
|
||||
parsed_list = ast.literal_eval(response_text)
|
||||
if isinstance(parsed_list, list) and all(isinstance(item, dict) and "prompt" in item for item in parsed_list):
|
||||
if isinstance(parsed_list, list) and all(
|
||||
isinstance(item, dict) and "prompt" in item
|
||||
for item in parsed_list
|
||||
):
|
||||
generated_prompts = parsed_list
|
||||
logger.info(f"Successfully generated and parsed {len(generated_prompts)} prompts from LLM.")
|
||||
logger.info(
|
||||
f"Successfully generated and parsed {len(generated_prompts)} prompts from LLM."
|
||||
)
|
||||
else:
|
||||
logger.warning("LLM response for prompt generation was not a valid list of prompt dictionaries. Using fallback.")
|
||||
logger.warning(
|
||||
"LLM response for prompt generation was not a valid list of "
|
||||
"prompt dictionaries. Using fallback."
|
||||
)
|
||||
except (SyntaxError, ValueError) as e:
|
||||
logger.warning(f"Error parsing LLM response for prompt generation: {e}. Using fallback.")
|
||||
logger.warning(
|
||||
f"Error parsing LLM response for prompt generation: {e}. Using fallback."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling LLM for prompt generation: {e}. Using fallback.")
|
||||
logger.error(
|
||||
f"Error calling LLM for prompt generation: {e}. Using fallback."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"LLM server not available for {self.name}. Using fallback prompts.")
|
||||
logger.warning(
|
||||
f"LLM server not available for {self.name}. Using fallback prompts."
|
||||
)
|
||||
|
||||
if not generated_prompts or not isinstance(generated_prompts, list) or len(generated_prompts) < 10:
|
||||
if (
|
||||
not generated_prompts
|
||||
or not isinstance(generated_prompts, list)
|
||||
or len(generated_prompts) < 10
|
||||
):
|
||||
logger.info(f"Using fallback static prompt list for {self.name}.")
|
||||
generated_prompts = [
|
||||
{"prompt": "What are your thoughts on the future of renewable energy?"},
|
||||
{"prompt": "If you could travel anywhere in the world, where would you go and why?"},
|
||||
{
|
||||
"prompt": "If you could travel anywhere in the world, where would you go and why?"
|
||||
},
|
||||
{"prompt": "What's a book or movie that has deeply impacted you?"},
|
||||
{"prompt": "Can you describe a complex scientific concept in simple terms?"},
|
||||
{
|
||||
"prompt": "Can you describe a complex scientific concept in simple terms?"
|
||||
},
|
||||
{"prompt": "What's a common misconception people have about AI?"},
|
||||
{"prompt": "How do you think technology will change our daily lives in the next 20 years?"},
|
||||
{"prompt": "What's a piece of advice you would give to someone learning a new skill?"},
|
||||
{"prompt": "If you could have a conversation with any historical figure, who would it be?"},
|
||||
{
|
||||
"prompt": "How do you think technology will change our daily lives in the next 20 years?"
|
||||
},
|
||||
{
|
||||
"prompt": "What's a piece of advice you would give to someone learning a new skill?"
|
||||
},
|
||||
{
|
||||
"prompt": "If you could have a conversation with any historical figure, who would it be?"
|
||||
},
|
||||
{"prompt": "What does 'creativity' mean to you as an AI?"},
|
||||
{"prompt": "Can you tell me a joke?"}
|
||||
{"prompt": "Can you tell me a joke?"},
|
||||
]
|
||||
if len(generated_prompts) > 10: # Ensure we only use 10 if more are in fallback
|
||||
if (
|
||||
len(generated_prompts) > 10
|
||||
): # Ensure we only use 10 if more are in fallback
|
||||
generated_prompts = generated_prompts[:10]
|
||||
|
||||
self.prompt_dataset = generated_prompts
|
||||
if self.config.shuffle_dataset:
|
||||
random.shuffle(self.prompt_dataset)
|
||||
self.iter = 0
|
||||
logger.info(f"Initialized prompt dataset with {len(self.prompt_dataset)} examples for {self.name}.")
|
||||
logger.info(
|
||||
f"Initialized prompt dataset with {len(self.prompt_dataset)} examples for {self.name}."
|
||||
)
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""
|
||||
|
|
@ -178,16 +253,18 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
|||
Chosen and rejected responses will be generated dynamically.
|
||||
"""
|
||||
if not self.prompt_dataset or self.iter >= len(self.prompt_dataset):
|
||||
await self.setup() # This will now potentially call the LLM to generate prompts
|
||||
await self.setup() # This will now potentially call the LLM to generate prompts
|
||||
|
||||
if not self.prompt_dataset: # Should be populated by setup, even with fallback
|
||||
logger.error(f"Prompt dataset is STILL empty after setup in {self.name}. This is unexpected.")
|
||||
# Provide an absolute failsafe prompt
|
||||
return ("Failsafe prompt: What is 1+1?", "", "")
|
||||
if not self.prompt_dataset: # Should be populated by setup, even with fallback
|
||||
logger.error(
|
||||
f"Prompt dataset is STILL empty after setup in {self.name}. This is unexpected."
|
||||
)
|
||||
# Provide an absolute failsafe prompt
|
||||
return ("Failsafe prompt: What is 1+1?", "", "")
|
||||
|
||||
entry = self.prompt_dataset[self.iter % len(self.prompt_dataset)]
|
||||
self.iter += 1
|
||||
return (entry["prompt"], "", "")
|
||||
return (entry["prompt"], "", "")
|
||||
|
||||
async def collect_trajectories(
|
||||
self, items: List[Item] # Changed to accept a list of items
|
||||
|
|
@ -200,41 +277,59 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
|||
if not items:
|
||||
logger.warning("collect_trajectories received an empty list of items.")
|
||||
return None, []
|
||||
|
||||
# item = items[0]
|
||||
prompt_str, _, _ = items
|
||||
|
||||
# item = items[0]
|
||||
prompt_str, _, _ = items
|
||||
|
||||
if not self.server:
|
||||
logger.error(f"LLM server not configured or available for {self.name}. Cannot generate responses.")
|
||||
logger.error(
|
||||
f"LLM server not configured or available for {self.name}. Cannot generate responses."
|
||||
)
|
||||
return None, []
|
||||
|
||||
|
||||
if not self.tokenizer:
|
||||
logger.error(f"Tokenizer not available in {self.name}. Attempting to initialize.")
|
||||
if self.testing and hasattr(self.config, 'tokenizer_name') and self.config.tokenizer_name:
|
||||
logger.error(
|
||||
f"Tokenizer not available in {self.name}. Attempting to initialize."
|
||||
)
|
||||
if (
|
||||
self.testing
|
||||
and hasattr(self.config, "tokenizer_name")
|
||||
and self.config.tokenizer_name
|
||||
):
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
|
||||
if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
logger.info(f"Fallback tokenizer '{self.config.tokenizer_name}' initialized.")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.config.tokenizer_name
|
||||
)
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
logger.info(
|
||||
f"Fallback tokenizer '{self.config.tokenizer_name}' initialized."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize fallback tokenizer: {e}")
|
||||
return None, []
|
||||
else:
|
||||
return None, []
|
||||
|
||||
|
||||
try:
|
||||
# Generate Chosen Response
|
||||
logger.debug(f"Generating CHOSEN for prompt: {prompt_str[:100]}...")
|
||||
chosen_completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT_CHOSEN},
|
||||
{"role": "user", "content": prompt_str}
|
||||
{"role": "user", "content": prompt_str},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.chosen_max_tokens,
|
||||
temperature=self.config.chosen_temperature,
|
||||
)
|
||||
chosen_response_str = chosen_completion.choices[0].message.content.strip() if chosen_completion.choices else ""
|
||||
chosen_response_str = (
|
||||
chosen_completion.choices[0].message.content.strip()
|
||||
if chosen_completion.choices
|
||||
else ""
|
||||
)
|
||||
logger.debug(f"Generated CHOSEN: {chosen_response_str[:100]}...")
|
||||
|
||||
# Generate Rejected Response
|
||||
|
|
@ -242,18 +337,24 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
|||
rejected_completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT_REJECTED},
|
||||
{"role": "user", "content": prompt_str}
|
||||
{"role": "user", "content": prompt_str},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.rejected_max_tokens,
|
||||
temperature=self.config.rejected_temperature,
|
||||
)
|
||||
rejected_response_str = rejected_completion.choices[0].message.content.strip() if rejected_completion.choices else ""
|
||||
rejected_response_str = (
|
||||
rejected_completion.choices[0].message.content.strip()
|
||||
if rejected_completion.choices
|
||||
else ""
|
||||
)
|
||||
logger.debug(f"Generated REJECTED: {rejected_response_str[:100]}...")
|
||||
|
||||
if not chosen_response_str or not rejected_response_str:
|
||||
logger.warning(f"Failed to generate valid chosen or rejected response for prompt: {prompt_str}")
|
||||
return None, []
|
||||
logger.warning(
|
||||
f"Failed to generate valid chosen or rejected response for prompt: {prompt_str}"
|
||||
)
|
||||
return None, []
|
||||
|
||||
dpo_pair_data = {
|
||||
"prompt": prompt_str,
|
||||
|
|
@ -263,7 +364,7 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
|||
|
||||
tokenized_output = self._tokenize_dpo_fn(
|
||||
self.tokenizer,
|
||||
dpo_pair_data,
|
||||
dpo_pair_data,
|
||||
max_length=self.config.max_token_length,
|
||||
)
|
||||
|
||||
|
|
@ -274,44 +375,55 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
|||
scores_group["chosen_tokens"] = [tokenized_output["chosen_input_ids"]]
|
||||
scores_group["chosen_masks"] = [tokenized_output["chosen_attention_mask"]]
|
||||
scores_group["rejected_tokens"] = [tokenized_output["rejected_input_ids"]]
|
||||
scores_group["rejected_masks"] = [tokenized_output["rejected_attention_mask"]]
|
||||
scores_group["rejected_masks"] = [
|
||||
tokenized_output["rejected_attention_mask"]
|
||||
]
|
||||
scores_group["tokens"] = [tokenized_output["chosen_input_ids"]]
|
||||
scores_group["masks"] = [tokenized_output["chosen_attention_mask"]]
|
||||
scores_group["scores"] = [1.0]
|
||||
scores_group["images"] = [None]
|
||||
scores_group["group_overrides"] = {"group_size": 1}
|
||||
|
||||
return scores_group, []
|
||||
|
||||
return scores_group, []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in collect_trajectories for {self.name} during DPO processing: {e}")
|
||||
logger.error(
|
||||
f"Error in collect_trajectories for {self.name} during DPO processing: {e}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data: Any) -> Optional[ScoredDataGroup]:
|
||||
if rollout_group_data and isinstance(rollout_group_data, ScoredDataGroup):
|
||||
return rollout_group_data
|
||||
logger.info(f"Data for {self.name} is not in ScoredDataGroup format or no data to score.")
|
||||
logger.info(
|
||||
f"Data for {self.name} is not in ScoredDataGroup format or no data to score."
|
||||
)
|
||||
return None
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
logger.info(f"Evaluation step called for {self.name}. No custom DPO evaluation implemented.")
|
||||
logger.info(
|
||||
f"Evaluation step called for {self.name}. No custom DPO evaluation implemented."
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[GSM8KConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
|
||||
def config_init(
|
||||
cls,
|
||||
) -> Tuple[GSM8KConversationalStyleDPOEnvConfig, List[APIServerConfig]]:
|
||||
env_config = GSM8KConversationalStyleDPOEnvConfig(
|
||||
wandb_name="gsm8k_dynamic_conversational_dpo",
|
||||
wandb_name="gsm8k_dynamic_conversational_dpo",
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=1,
|
||||
use_wandb=True,
|
||||
max_num_workers=1,
|
||||
group_size=1,
|
||||
use_wandb=True,
|
||||
max_num_workers=1,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=100,
|
||||
batch_size=2,
|
||||
total_steps=100,
|
||||
batch_size=2,
|
||||
steps_per_eval=50,
|
||||
max_token_length=512,
|
||||
max_token_length=512,
|
||||
dataset_name="synthetic_conversational_style_prompts_via_gsm8k_env",
|
||||
shuffle_dataset=True,
|
||||
data_path_to_save_groups=None,
|
||||
|
|
@ -327,11 +439,12 @@ class GSM8KConversationalStyleDPOEnv(BaseEnv):
|
|||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="https://inference-api.nousresearch.com/v1",
|
||||
api_key="sk-3DvYKMv_-BfAoDSTfdSvEQ",
|
||||
num_requests_for_eval=256, # Copied from gsm8k_server.py
|
||||
api_key="sk-3DvYKMv_-BfAoDSTfdSvEQ",
|
||||
num_requests_for_eval=256, # Copied from gsm8k_server.py
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8KConversationalStyleDPOEnv.cli()
|
||||
|
|
@ -4,7 +4,7 @@ pip install accelerate
|
|||
pip install pydantic
|
||||
pip install bitsandbytes
|
||||
pip install datasets
|
||||
pip install transformers
|
||||
pip install pydantic
|
||||
pip install transformers
|
||||
pip install pydantic
|
||||
pip install torch
|
||||
pip install -e .[all] # for everything
|
||||
|
|
@ -1,10 +1,8 @@
|
|||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
|
|
@ -36,11 +34,11 @@ class ScriptArguments:
|
|||
)
|
||||
tokenizer_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The tokenizer name or path. Defaults to model_name_or_path."},
|
||||
)
|
||||
beta: float = field(
|
||||
default=0.1, metadata={"help": "The beta factor in DPO loss."}
|
||||
metadata={
|
||||
"help": "The tokenizer name or path. Defaults to model_name_or_path."
|
||||
},
|
||||
)
|
||||
beta: float = field(default=0.1, metadata={"help": "The beta factor in DPO loss."})
|
||||
max_prompt_length: int = field(
|
||||
default=256, metadata={"help": "Max length for prompts."}
|
||||
)
|
||||
|
|
@ -75,14 +73,14 @@ async def get_dataset_from_env(
|
|||
# The DPOTrainer expects columns named "prompt", "chosen", "rejected"
|
||||
# The synthetic_data in your environment is already in this format.
|
||||
# Example: [{"prompt": "...", "chosen": "...", "rejected": "..."}]
|
||||
hf_dataset = Dataset.from_list(list(env.dataset)) # Ensure it's a fresh list copy
|
||||
hf_dataset = Dataset.from_list(list(env.dataset)) # Ensure it's a fresh list copy
|
||||
|
||||
# Log a sample to verify
|
||||
if len(hf_dataset) > 0:
|
||||
logger.info(f"Sample from dataset: {hf_dataset[0]}")
|
||||
else:
|
||||
logger.warning("Dataset created from environment is empty!")
|
||||
|
||||
|
||||
return hf_dataset
|
||||
|
||||
|
||||
|
|
@ -92,12 +90,14 @@ def main():
|
|||
|
||||
if script_args.tokenizer_name_or_path is None:
|
||||
script_args.tokenizer_name_or_path = script_args.model_name_or_path
|
||||
|
||||
|
||||
# --- 1. Initialize Environment and Get Dataset ---
|
||||
logger.info("Initializing environment to get dataset...")
|
||||
# Use the default config from your environment, but ensure tokenizer matches
|
||||
env_dpo_config, _ = ConversationalStyleDPOEnv.config_init()
|
||||
env_dpo_config.tokenizer_name = script_args.tokenizer_name_or_path # Align tokenizer
|
||||
env_dpo_config.tokenizer_name = (
|
||||
script_args.tokenizer_name_or_path
|
||||
) # Align tokenizer
|
||||
# You might want to adjust other env_dpo_config parameters if needed
|
||||
|
||||
# Run the async function to get the dataset
|
||||
|
|
@ -127,17 +127,18 @@ def main():
|
|||
# For simplicity, we'll let DPOTrainer handle creating the reference model by not passing one.
|
||||
# If you wanted to load a different SFT model as reference, you would do:
|
||||
# model_ref = AutoModelForCausalLM.from_pretrained(...)
|
||||
model_ref = None
|
||||
logger.info("Reference model will be a copy of the policy model (handled by DPOTrainer).")
|
||||
|
||||
model_ref = None
|
||||
logger.info(
|
||||
"Reference model will be a copy of the policy model (handled by DPOTrainer)."
|
||||
)
|
||||
|
||||
# --- 3. Set up Training Arguments ---
|
||||
# Default DPO training arguments. You might want to customize these.
|
||||
if training_args.output_dir == "output_dir": # Default value from TrainingArguments
|
||||
if training_args.output_dir == "output_dir": # Default value from TrainingArguments
|
||||
# Output directory will be relative to the script's current working directory when run.
|
||||
# If run from environments/hack0/conversational_style_dpo/, it will be ./dpo_conversational_trainer_results
|
||||
training_args.output_dir = "./dpo_conversational_trainer_results"
|
||||
|
||||
training_args.output_dir = "./dpo_conversational_trainer_results"
|
||||
|
||||
# training_args.per_device_train_batch_size = 2 # Adjust as needed
|
||||
# training_args.num_train_epochs = 1 # Keep low for a quick test
|
||||
# training_args.gradient_accumulation_steps = 1
|
||||
|
|
@ -145,21 +146,20 @@ def main():
|
|||
# training_args.logging_steps = 10
|
||||
# training_args.save_steps = 50
|
||||
# training_args.report_to = "none" # "wandb" or "tensorboard" if you want to log
|
||||
|
||||
logger.info(f"Training Arguments: {training_args}")
|
||||
|
||||
logger.info(f"Training Arguments: {training_args}")
|
||||
|
||||
# --- 4. Initialize DPOTrainer ---
|
||||
logger.info("Initializing DPOTrainer...")
|
||||
dpo_trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=model_ref, # If None, a copy of model is made
|
||||
ref_model=model_ref, # If None, a copy of model is made
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
max_prompt_length=script_args.max_prompt_length,
|
||||
max_length=script_args.max_length, # Max length of prompt + response
|
||||
max_length=script_args.max_length, # Max length of prompt + response
|
||||
# peft_config=peft_config, # If using PEFT/LoRA
|
||||
)
|
||||
logger.info("DPOTrainer initialized.")
|
||||
|
|
@ -170,8 +170,8 @@ def main():
|
|||
logger.info("DPO training completed.")
|
||||
|
||||
# --- 6. Save the Model (Optional) ---
|
||||
if training_args.should_save: # Checks if any save_strategy is enabled
|
||||
output_save_dir = training_args.output_dir
|
||||
if training_args.should_save: # Checks if any save_strategy is enabled
|
||||
output_save_dir = training_args.output_dir
|
||||
logger.info(f"Saving model to {output_save_dir}")
|
||||
dpo_trainer.save_model(output_save_dir)
|
||||
logger.info("Model saved.")
|
||||
|
|
@ -182,4 +182,4 @@ def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
main()
|
||||
|
|
@ -1 +0,0 @@
|
|||
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Rendered Messages - gsm8k_dpo_rollouts_1.jsonl</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; line-height: 1.6; margin: 20px; }
|
||||
details { border: 1px solid #ccc; border-radius: 4px; margin-bottom: 15px; }
|
||||
summary {
|
||||
font-weight: bold;
|
||||
padding: 10px;
|
||||
background-color: #f0f0f0;
|
||||
cursor: pointer;
|
||||
border-radius: 4px 4px 0 0;
|
||||
border-bottom: 1px solid #ccc;
|
||||
outline: none; /* Remove default focus outline if needed */
|
||||
}
|
||||
details[open] summary { border-bottom: 1px solid #ccc; }
|
||||
.group-content { padding: 15px; }
|
||||
.item {
|
||||
border: 1px solid #eee;
|
||||
border-radius: 3px;
|
||||
margin-bottom: 10px;
|
||||
padding: 10px;
|
||||
transition: background-color 0.3s ease, box-shadow 0.2s ease; /* Smooth transitions */
|
||||
scroll-margin-top: 10px; /* Space when scrolling into view */
|
||||
}
|
||||
.item h4 { margin-top: 0; margin-bottom: 5px; font-size: 1.1em; }
|
||||
.content-block { background-color: #fff; padding: 8px; border-radius: 3px; margin-bottom: 5px; overflow-x: auto; }
|
||||
/* Use :focus-within for better accessibility on container focus */
|
||||
.item:focus, .item.active {
|
||||
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.5); /* Highlight active/focused item */
|
||||
outline: none; /* Remove default outline */
|
||||
}
|
||||
|
||||
/* Score Backgrounds (Faint) */
|
||||
.reward-positive { background-color: rgba(144, 238, 144, 0.3); } /* Faint light green */
|
||||
.reward-zero { background-color: rgba(255, 215, 0, 0.3); } /* Faint gold/orange */
|
||||
.reward-negative { background-color: rgba(255, 182, 193, 0.4); } /* Faint light pink/red */
|
||||
|
||||
/* Markdown specific styles */
|
||||
.content-block pre {
|
||||
background-color: #f5f5f5;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 3px;
|
||||
padding: 10px;
|
||||
overflow-x: auto; /* Allow horizontal scrolling for long code lines */
|
||||
white-space: pre-wrap; /* Wrap long lines within pre */
|
||||
word-wrap: break-word; /* Break long words if necessary */
|
||||
}
|
||||
.content-block code {
|
||||
background-color: #f0f0f0; /* Slightly different for inline code */
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: 3px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
.content-block pre code {
|
||||
background-color: transparent; /* Don't double-background code in pre blocks */
|
||||
padding: 0;
|
||||
border-radius: 0;
|
||||
font-size: inherit; /* Inherit pre font size */
|
||||
}
|
||||
.content-block blockquote {
|
||||
border-left: 4px solid #ccc;
|
||||
padding-left: 10px;
|
||||
margin-left: 0;
|
||||
color: #555;
|
||||
}
|
||||
.content-block table {
|
||||
border-collapse: collapse;
|
||||
width: auto; /* Don't force full width */
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.content-block th, .content-block td {
|
||||
border: 1px solid #ddd;
|
||||
padding: 8px;
|
||||
text-align: left;
|
||||
}
|
||||
.content-block th {
|
||||
background-color: #f2f2f2;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Rendered Messages - gsm8k_dpo_rollouts_1.jsonl</h1>
|
||||
<div id="groups-container">
|
||||
<p>No data to display. Input file might be empty or contain invalid data.</p>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
const items = document.querySelectorAll('.item');
|
||||
let activeIndex = -1; // No item active initially
|
||||
|
||||
// Function to set active item
|
||||
function setActiveItem(index) {
|
||||
if (activeIndex >= 0 && activeIndex < items.length) {
|
||||
items[activeIndex].classList.remove('active');
|
||||
items[activeIndex].removeAttribute('tabindex'); // Remove from tab order when not active
|
||||
}
|
||||
if (index >= 0 && index < items.length) {
|
||||
items[index].classList.add('active');
|
||||
items[index].setAttribute('tabindex', '0'); // Make active item focusable
|
||||
items[index].focus(); // Focus the element
|
||||
// Ensure the parent <details> is open
|
||||
const detailsParent = items[index].closest('details');
|
||||
if (detailsParent && !detailsParent.open) {
|
||||
detailsParent.open = true;
|
||||
}
|
||||
// Scroll into view with options if needed (focus should handle this mostly)
|
||||
// items[index].scrollIntoView({ behavior: 'smooth', block: 'nearest' });
|
||||
activeIndex = index;
|
||||
} else {
|
||||
activeIndex = -1; // Deactivate if index is out of bounds
|
||||
}
|
||||
}
|
||||
|
||||
// Add click listener to activate items
|
||||
items.forEach((item, index) => {
|
||||
item.addEventListener('click', () => {
|
||||
setActiveItem(index);
|
||||
});
|
||||
// Make items focusable initially only if we want tab navigation
|
||||
// item.setAttribute('tabindex', '0');
|
||||
});
|
||||
|
||||
// Add keydown listener for arrow navigation
|
||||
document.addEventListener('keydown', (event) => {
|
||||
let targetIndex = -1;
|
||||
if (event.key === 'ArrowDown') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? 0 : Math.min(activeIndex + 1, items.length - 1);
|
||||
} else if (event.key === 'ArrowUp') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? items.length - 1 : Math.max(activeIndex - 1, 0);
|
||||
}
|
||||
|
||||
if (targetIndex !== -1) {
|
||||
setActiveItem(targetIndex);
|
||||
}
|
||||
});
|
||||
|
||||
// Make first item focusable initially if you want immediate keyboard nav
|
||||
if (items.length > 0) {
|
||||
// items[0].setAttribute('tabindex', '0');
|
||||
// Optionally activate the first item on load:
|
||||
// setActiveItem(0);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -1,153 +0,0 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Rendered Messages - gsm8k_dpo_rollouts_2.jsonl</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; line-height: 1.6; margin: 20px; }
|
||||
details { border: 1px solid #ccc; border-radius: 4px; margin-bottom: 15px; }
|
||||
summary {
|
||||
font-weight: bold;
|
||||
padding: 10px;
|
||||
background-color: #f0f0f0;
|
||||
cursor: pointer;
|
||||
border-radius: 4px 4px 0 0;
|
||||
border-bottom: 1px solid #ccc;
|
||||
outline: none; /* Remove default focus outline if needed */
|
||||
}
|
||||
details[open] summary { border-bottom: 1px solid #ccc; }
|
||||
.group-content { padding: 15px; }
|
||||
.item {
|
||||
border: 1px solid #eee;
|
||||
border-radius: 3px;
|
||||
margin-bottom: 10px;
|
||||
padding: 10px;
|
||||
transition: background-color 0.3s ease, box-shadow 0.2s ease; /* Smooth transitions */
|
||||
scroll-margin-top: 10px; /* Space when scrolling into view */
|
||||
}
|
||||
.item h4 { margin-top: 0; margin-bottom: 5px; font-size: 1.1em; }
|
||||
.content-block { background-color: #fff; padding: 8px; border-radius: 3px; margin-bottom: 5px; overflow-x: auto; }
|
||||
/* Use :focus-within for better accessibility on container focus */
|
||||
.item:focus, .item.active {
|
||||
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.5); /* Highlight active/focused item */
|
||||
outline: none; /* Remove default outline */
|
||||
}
|
||||
|
||||
/* Score Backgrounds (Faint) */
|
||||
.reward-positive { background-color: rgba(144, 238, 144, 0.3); } /* Faint light green */
|
||||
.reward-zero { background-color: rgba(255, 215, 0, 0.3); } /* Faint gold/orange */
|
||||
.reward-negative { background-color: rgba(255, 182, 193, 0.4); } /* Faint light pink/red */
|
||||
|
||||
/* Markdown specific styles */
|
||||
.content-block pre {
|
||||
background-color: #f5f5f5;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 3px;
|
||||
padding: 10px;
|
||||
overflow-x: auto; /* Allow horizontal scrolling for long code lines */
|
||||
white-space: pre-wrap; /* Wrap long lines within pre */
|
||||
word-wrap: break-word; /* Break long words if necessary */
|
||||
}
|
||||
.content-block code {
|
||||
background-color: #f0f0f0; /* Slightly different for inline code */
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: 3px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
.content-block pre code {
|
||||
background-color: transparent; /* Don't double-background code in pre blocks */
|
||||
padding: 0;
|
||||
border-radius: 0;
|
||||
font-size: inherit; /* Inherit pre font size */
|
||||
}
|
||||
.content-block blockquote {
|
||||
border-left: 4px solid #ccc;
|
||||
padding-left: 10px;
|
||||
margin-left: 0;
|
||||
color: #555;
|
||||
}
|
||||
.content-block table {
|
||||
border-collapse: collapse;
|
||||
width: auto; /* Don't force full width */
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.content-block th, .content-block td {
|
||||
border: 1px solid #ddd;
|
||||
padding: 8px;
|
||||
text-align: left;
|
||||
}
|
||||
.content-block th {
|
||||
background-color: #f2f2f2;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Rendered Messages - gsm8k_dpo_rollouts_2.jsonl</h1>
|
||||
<div id="groups-container">
|
||||
<p>No data to display. Input file might be empty or contain invalid data.</p>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
const items = document.querySelectorAll('.item');
|
||||
let activeIndex = -1; // No item active initially
|
||||
|
||||
// Function to set active item
|
||||
function setActiveItem(index) {
|
||||
if (activeIndex >= 0 && activeIndex < items.length) {
|
||||
items[activeIndex].classList.remove('active');
|
||||
items[activeIndex].removeAttribute('tabindex'); // Remove from tab order when not active
|
||||
}
|
||||
if (index >= 0 && index < items.length) {
|
||||
items[index].classList.add('active');
|
||||
items[index].setAttribute('tabindex', '0'); // Make active item focusable
|
||||
items[index].focus(); // Focus the element
|
||||
// Ensure the parent <details> is open
|
||||
const detailsParent = items[index].closest('details');
|
||||
if (detailsParent && !detailsParent.open) {
|
||||
detailsParent.open = true;
|
||||
}
|
||||
// Scroll into view with options if needed (focus should handle this mostly)
|
||||
// items[index].scrollIntoView({ behavior: 'smooth', block: 'nearest' });
|
||||
activeIndex = index;
|
||||
} else {
|
||||
activeIndex = -1; // Deactivate if index is out of bounds
|
||||
}
|
||||
}
|
||||
|
||||
// Add click listener to activate items
|
||||
items.forEach((item, index) => {
|
||||
item.addEventListener('click', () => {
|
||||
setActiveItem(index);
|
||||
});
|
||||
// Make items focusable initially only if we want tab navigation
|
||||
// item.setAttribute('tabindex', '0');
|
||||
});
|
||||
|
||||
// Add keydown listener for arrow navigation
|
||||
document.addEventListener('keydown', (event) => {
|
||||
let targetIndex = -1;
|
||||
if (event.key === 'ArrowDown') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? 0 : Math.min(activeIndex + 1, items.length - 1);
|
||||
} else if (event.key === 'ArrowUp') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? items.length - 1 : Math.max(activeIndex - 1, 0);
|
||||
}
|
||||
|
||||
if (targetIndex !== -1) {
|
||||
setActiveItem(targetIndex);
|
||||
}
|
||||
});
|
||||
|
||||
// Make first item focusable initially if you want immediate keyboard nav
|
||||
if (items.length > 0) {
|
||||
// items[0].setAttribute('tabindex', '0');
|
||||
// Optionally activate the first item on load:
|
||||
// setActiveItem(0);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Rendered Messages - gsm8k_dpo_rollouts_4.jsonl</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; line-height: 1.6; margin: 20px; }
|
||||
details { border: 1px solid #ccc; border-radius: 4px; margin-bottom: 15px; }
|
||||
summary {
|
||||
font-weight: bold;
|
||||
padding: 10px;
|
||||
background-color: #f0f0f0;
|
||||
cursor: pointer;
|
||||
border-radius: 4px 4px 0 0;
|
||||
border-bottom: 1px solid #ccc;
|
||||
outline: none; /* Remove default focus outline if needed */
|
||||
}
|
||||
details[open] summary { border-bottom: 1px solid #ccc; }
|
||||
.group-content { padding: 15px; }
|
||||
.item {
|
||||
border: 1px solid #eee;
|
||||
border-radius: 3px;
|
||||
margin-bottom: 10px;
|
||||
padding: 10px;
|
||||
transition: background-color 0.3s ease, box-shadow 0.2s ease; /* Smooth transitions */
|
||||
scroll-margin-top: 10px; /* Space when scrolling into view */
|
||||
}
|
||||
.item h4 { margin-top: 0; margin-bottom: 5px; font-size: 1.1em; }
|
||||
.content-block { background-color: #fff; padding: 8px; border-radius: 3px; margin-bottom: 5px; overflow-x: auto; }
|
||||
/* Use :focus-within for better accessibility on container focus */
|
||||
.item:focus, .item.active {
|
||||
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.5); /* Highlight active/focused item */
|
||||
outline: none; /* Remove default outline */
|
||||
}
|
||||
|
||||
/* Score Backgrounds (Faint) */
|
||||
.reward-positive { background-color: rgba(144, 238, 144, 0.3); } /* Faint light green */
|
||||
.reward-zero { background-color: rgba(255, 215, 0, 0.3); } /* Faint gold/orange */
|
||||
.reward-negative { background-color: rgba(255, 182, 193, 0.4); } /* Faint light pink/red */
|
||||
|
||||
/* Markdown specific styles */
|
||||
.content-block pre {
|
||||
background-color: #f5f5f5;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 3px;
|
||||
padding: 10px;
|
||||
overflow-x: auto; /* Allow horizontal scrolling for long code lines */
|
||||
white-space: pre-wrap; /* Wrap long lines within pre */
|
||||
word-wrap: break-word; /* Break long words if necessary */
|
||||
}
|
||||
.content-block code {
|
||||
background-color: #f0f0f0; /* Slightly different for inline code */
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: 3px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
.content-block pre code {
|
||||
background-color: transparent; /* Don't double-background code in pre blocks */
|
||||
padding: 0;
|
||||
border-radius: 0;
|
||||
font-size: inherit; /* Inherit pre font size */
|
||||
}
|
||||
.content-block blockquote {
|
||||
border-left: 4px solid #ccc;
|
||||
padding-left: 10px;
|
||||
margin-left: 0;
|
||||
color: #555;
|
||||
}
|
||||
.content-block table {
|
||||
border-collapse: collapse;
|
||||
width: auto; /* Don't force full width */
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.content-block th, .content-block td {
|
||||
border: 1px solid #ddd;
|
||||
padding: 8px;
|
||||
text-align: left;
|
||||
}
|
||||
.content-block th {
|
||||
background-color: #f2f2f2;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Rendered Messages - gsm8k_dpo_rollouts_4.jsonl</h1>
|
||||
<div id="groups-container">
|
||||
<p>No data to display. Input file might be empty or contain invalid data.</p>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
const items = document.querySelectorAll('.item');
|
||||
let activeIndex = -1; // No item active initially
|
||||
|
||||
// Function to set active item
|
||||
function setActiveItem(index) {
|
||||
if (activeIndex >= 0 && activeIndex < items.length) {
|
||||
items[activeIndex].classList.remove('active');
|
||||
items[activeIndex].removeAttribute('tabindex'); // Remove from tab order when not active
|
||||
}
|
||||
if (index >= 0 && index < items.length) {
|
||||
items[index].classList.add('active');
|
||||
items[index].setAttribute('tabindex', '0'); // Make active item focusable
|
||||
items[index].focus(); // Focus the element
|
||||
// Ensure the parent <details> is open
|
||||
const detailsParent = items[index].closest('details');
|
||||
if (detailsParent && !detailsParent.open) {
|
||||
detailsParent.open = true;
|
||||
}
|
||||
// Scroll into view with options if needed (focus should handle this mostly)
|
||||
// items[index].scrollIntoView({ behavior: 'smooth', block: 'nearest' });
|
||||
activeIndex = index;
|
||||
} else {
|
||||
activeIndex = -1; // Deactivate if index is out of bounds
|
||||
}
|
||||
}
|
||||
|
||||
// Add click listener to activate items
|
||||
items.forEach((item, index) => {
|
||||
item.addEventListener('click', () => {
|
||||
setActiveItem(index);
|
||||
});
|
||||
// Make items focusable initially only if we want tab navigation
|
||||
// item.setAttribute('tabindex', '0');
|
||||
});
|
||||
|
||||
// Add keydown listener for arrow navigation
|
||||
document.addEventListener('keydown', (event) => {
|
||||
let targetIndex = -1;
|
||||
if (event.key === 'ArrowDown') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? 0 : Math.min(activeIndex + 1, items.length - 1);
|
||||
} else if (event.key === 'ArrowUp') {
|
||||
event.preventDefault(); // Prevent default page scroll
|
||||
targetIndex = (activeIndex === -1) ? items.length - 1 : Math.max(activeIndex - 1, 0);
|
||||
}
|
||||
|
||||
if (targetIndex !== -1) {
|
||||
setActiveItem(targetIndex);
|
||||
}
|
||||
});
|
||||
|
||||
// Make first item focusable initially if you want immediate keyboard nav
|
||||
if (items.length > 0) {
|
||||
// items[0].setAttribute('tabindex', '0');
|
||||
// Optionally activate the first item on load:
|
||||
// setActiveItem(0);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue