improve dataloading, ctx len

This commit is contained in:
teknium 2025-07-26 21:06:45 +00:00
parent 6604a2255b
commit 1039c3d360

View file

@ -1,4 +1,5 @@
import asyncio
import json
import math
import random
import re
@ -109,6 +110,27 @@ class PairwiseJudgementConfig(BaseEnvConfig):
description="Minimum response length to consider valid (filters out EOS-only responses).",
)
# Dataset configuration
train_dataset: str = Field(
default="dummy/dataset",
description="Training dataset name (HuggingFace) or path to local JSONL file.",
)
eval_dataset: str = Field(
default="allenai/reward-bench-2",
description="Evaluation dataset name (HuggingFace) or path to local JSONL file.",
)
train_split: str = Field(
default="train",
description="Split to use for training dataset (only for HuggingFace datasets).",
)
eval_split: str = Field(
default="test",
description="Split to use for evaluation dataset (only for HuggingFace datasets).",
)
# Debug configuration
full_debug: bool = Field(
default=False,
@ -338,7 +360,8 @@ class PairwiseJudgementEnv(BaseEnv):
total_steps=2000,
batch_size=1024,
steps_per_eval=25,
max_token_length=1024 * 8,
train_max_tokens=1024 * 16,
eval_max_tokens=1024 * 16,
inference_weight=1.0,
wandb_name="pairwise_judgment",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
@ -373,54 +396,157 @@ class PairwiseJudgementEnv(BaseEnv):
async def setup(self) -> None:
"""Set up the environment by loading datasets."""
# Load placeholder train dataset (not actually used since we generate synthetic examples)
# Load training dataset
try:
self.train = load_dataset("example/train", split="train")
print(f"Loaded placeholder train dataset with {len(self.train)} examples")
self.train = self._load_dataset(self.config.train_dataset, self.config.train_split)
# Shuffle training dataset for reproducibility
if hasattr(self.train, 'shuffle'):
self.train = self.train.shuffle(seed=42)
else:
# For list-like objects, convert to list and shuffle
import random
train_list = list(self.train)
random.seed(42)
random.shuffle(train_list)
self.train = train_list
except Exception as e:
# Create minimal placeholder data if dataset doesn't exist
# Note: This isn't actually used since get_next_item() generates synthetic examples
self.train = [{"question": "What is 2+2?", "answer": "4"}] * 100
print(f"Using synthetic placeholder training data due to error: {e}")
print(f"Error loading training dataset '{self.config.train_dataset}': {e}")
# Create minimal fallback data in expected format
self.train = [
{"uid": "train_1", "category": "general", "prompt": "What is the capital of France?"},
{"uid": "train_2", "category": "math", "prompt": "Solve for x: 2x + 5 = 15"},
{"uid": "train_3", "category": "coding", "prompt": "Write a Python function to calculate factorial"},
] * 34 # 102 examples
print(f"Using fallback training data with {len(self.train)} examples")
# Load evaluation dataset - reward-bench-2 (MUST WORK OR CRASH)
self.test = load_dataset(
"allenai/reward-bench-2", split="test", trust_remote_code=True
)
print(f"Loaded reward-bench-2 eval dataset with {len(self.test)} examples")
# Load evaluation dataset
try:
self.test = self._load_dataset(self.config.eval_dataset, self.config.eval_split)
except Exception as e:
print(f"Error loading evaluation dataset '{self.config.eval_dataset}': {e}")
raise # Evaluation dataset must work
# Analyze dataset composition
category_counts = {}
for item in self.test:
category = item.get("subset", "Unknown")
category_counts[category] = category_counts.get(category, 0) + 1
# Analyze training dataset composition
if hasattr(self.train, '__iter__'):
train_category_counts = {}
total_train_items = 0
for item in self.train:
total_train_items += 1
category = item.get("category", "Unknown")
train_category_counts[category] = train_category_counts.get(category, 0) + 1
print("Dataset categories found:")
for category, count in sorted(category_counts.items()):
print(f" - {category}: {count} samples")
print(f"\nTraining dataset analysis ({total_train_items} total items):")
for category, count in sorted(train_category_counts.items()):
print(f" - {category}: {count} samples")
# Count ties vs choice samples
ties_samples = sum(1 for item in self.test if self._is_ties_sample(item))
choice_samples = len(self.test) - ties_samples
print(
f"\nEvaluation modes: {choice_samples} choice samples, {ties_samples} ties samples"
)
# Analyze evaluation dataset composition
if hasattr(self.test, '__iter__'):
eval_category_counts = {}
total_eval_items = 0
for item in self.test:
total_eval_items += 1
category = item.get("subset", "Unknown")
eval_category_counts[category] = eval_category_counts.get(category, 0) + 1
# Show category filtering info
if self.config.eval_categories is not None:
selected_categories = [cat.value for cat in self.config.eval_categories]
print(f"\nEvaluation dataset analysis ({total_eval_items} total items):")
for category, count in sorted(eval_category_counts.items()):
print(f" - {category}: {count} samples")
# Count ties vs choice samples
ties_samples = sum(1 for item in self.test if self._is_ties_sample(item))
choice_samples = len(self.test) - ties_samples
print(
f"\nCategory filtering enabled. Selected categories: {selected_categories}"
)
filtered_count = sum(
1 for item in self.test if self._should_evaluate_category(item)
)
print(f"Will evaluate {filtered_count} out of {len(self.test)} samples")
else:
print(
f"\nNo category filtering. Will evaluate all {len(self.test)} samples"
f"\nEvaluation modes: {choice_samples} choice samples, {ties_samples} ties samples"
)
# Show category filtering info
if self.config.eval_categories is not None:
selected_categories = []
for cat in self.config.eval_categories:
selected_categories.append(cat.value if hasattr(cat, 'value') else str(cat))
print(
f"\nCategory filtering enabled. Selected categories: {selected_categories}"
)
filtered_count = sum(
1 for item in self.test if self._should_evaluate_category(item)
)
print(f"Will evaluate {filtered_count} out of {len(self.test)} samples")
else:
print(
f"\nNo category filtering. Will evaluate all {len(self.test)} samples"
)
# Show configuration info
print(f"\nPairwise Judgement Configuration:")
print(f" - Training dataset: {self.config.train_dataset} (split: {self.config.train_split})")
print(f" - Evaluation dataset: {self.config.eval_dataset} (split: {self.config.eval_split})")
print(f" - Thinking mode: {self.config.thinking_mode}")
print(f" - Eval temperature: {self.config.eval_temperature}")
print(f" - Number of choices: {self.config.num_choices}")
# Show sample training item structure
if len(self.train) > 0:
try:
sample_train_item = self.train[0]
print(f"\nSample training item structure:")
print(f"- Available keys: {list(sample_train_item.keys())}")
if "uid" in sample_train_item:
print(f"- UID: {sample_train_item['uid']}")
if "category" in sample_train_item:
print(f"- Category: {sample_train_item['category']}")
if "subcategory" in sample_train_item:
print(f"- Subcategory: {sample_train_item['subcategory']}")
if "prompt" in sample_train_item:
print(f"- Prompt: {sample_train_item['prompt'][:100]}...")
except Exception as e:
print(f"Warning: Could not display sample training item structure: {e}")
# Show sample evaluation item structure
if len(self.test) > 0:
try:
sample_eval_item = self.test[0]
print(f"\nSample evaluation item structure:")
print(f"- Available keys: {list(sample_eval_item.keys())}")
# Handle different dataset structures
if "prompt" in sample_eval_item:
print(f"- Prompt: {sample_eval_item['prompt'][:100]}...")
elif "chosen" in sample_eval_item and isinstance(sample_eval_item["chosen"], str):
print(f"- Chosen (string): {sample_eval_item['chosen'][:100]}...")
elif "rejected" in sample_eval_item and isinstance(
sample_eval_item["rejected"], str
):
print(f"- Rejected (string): {sample_eval_item['rejected'][:100]}...")
if "chosen" in sample_eval_item:
if isinstance(sample_eval_item["chosen"], list):
print(f"- Chosen responses: {len(sample_eval_item['chosen'])}")
if sample_eval_item["chosen"]:
print(
f"- First chosen (truncated): {sample_eval_item['chosen'][0][:200]}..."
)
else:
print(f"- Chosen (string): {sample_eval_item['chosen'][:200]}...")
if "rejected" in sample_eval_item:
if isinstance(sample_eval_item["rejected"], list):
print(f"- Rejected responses: {len(sample_eval_item['rejected'])}")
if sample_eval_item["rejected"]:
print(
f"- First rejected (truncated): {sample_eval_item['rejected'][0][:200]}..."
)
else:
print(
f"- Rejected (string): {sample_eval_item['rejected'][:200]}..."
)
except Exception as e:
print(f"Warning: Could not display sample evaluation item structure: {e}")
# Show debug mode status
if self.config.full_debug:
print(
@ -438,50 +564,62 @@ class PairwiseJudgementEnv(BaseEnv):
"\n🔍 Full debug mode disabled - Use full_debug=True to enable detailed logging"
)
# Debug: Show sample evaluation item structure
if len(self.test) > 0:
try:
sample_item = self.test[0]
print("\nSample eval item structure:")
print(f"- Available keys: {list(sample_item.keys())}")
# Handle different dataset structures
if "prompt" in sample_item:
print(f"- Prompt: {sample_item['prompt'][:100]}...")
elif "chosen" in sample_item and isinstance(sample_item["chosen"], str):
print(f"- Chosen (string): {sample_item['chosen'][:100]}...")
elif "rejected" in sample_item and isinstance(
sample_item["rejected"], str
):
print(f"- Rejected (string): {sample_item['rejected'][:100]}...")
if "chosen" in sample_item:
if isinstance(sample_item["chosen"], list):
print(f"- Chosen responses: {len(sample_item['chosen'])}")
if sample_item["chosen"]:
print(
f"- First chosen (truncated): {sample_item['chosen'][0][:200]}..."
)
else:
print(f"- Chosen (string): {sample_item['chosen'][:200]}...")
if "rejected" in sample_item:
if isinstance(sample_item["rejected"], list):
print(f"- Rejected responses: {len(sample_item['rejected'])}")
if sample_item["rejected"]:
print(
f"- First rejected (truncated): {sample_item['rejected'][0][:200]}..."
)
else:
print(
f"- Rejected (string): {sample_item['rejected'][:200]}..."
)
except Exception as e:
print(f"Warning: Could not display sample item structure: {e}")
self.iter = 0
def _load_dataset(self, dataset_path: str, split: str = None) -> List[Dict]:
"""
Load dataset using HuggingFace load_dataset (supports both HF datasets and local files).
Args:
dataset_path: Either HuggingFace dataset name or path to local file
split: Split to use
Returns:
List of dataset items
"""
import os
try:
# Check if it's a local file
if os.path.exists(dataset_path):
# Local file - use appropriate loader based on extension
if dataset_path.endswith('.jsonl') or dataset_path.endswith('.json'):
dataset = load_dataset("json", data_files=dataset_path, split=split or "train", trust_remote_code=True)
elif dataset_path.endswith('.csv'):
dataset = load_dataset("csv", data_files=dataset_path, split=split or "train", trust_remote_code=True)
elif dataset_path.endswith('.parquet'):
dataset = load_dataset("parquet", data_files=dataset_path, split=split or "train", trust_remote_code=True)
else:
# Try JSON as default
dataset = load_dataset("json", data_files=dataset_path, split=split or "train", trust_remote_code=True)
print(f"Loaded local dataset from {dataset_path} with {len(dataset)} examples")
else:
# HuggingFace dataset
if split:
dataset = load_dataset(dataset_path, split=split, trust_remote_code=True)
else:
dataset_dict = load_dataset(dataset_path, trust_remote_code=True)
# If no split specified, try to get the first available split
if hasattr(dataset_dict, 'keys'):
available_splits = list(dataset_dict.keys())
if available_splits:
dataset = dataset_dict[available_splits[0]]
print(f"No split specified, using '{available_splits[0]}' split")
else:
dataset = dataset_dict
else:
dataset = dataset_dict
print(f"Loaded HuggingFace dataset {dataset_path} with {len(dataset)} examples")
return dataset
except Exception as e:
print(f"Error loading dataset {dataset_path}: {e}")
raise
def save_checkpoint(self, step: int, data: Optional[Dict] = None) -> None:
"""Save checkpoint including iteration state."""
if data is None:
@ -559,81 +697,110 @@ class PairwiseJudgementEnv(BaseEnv):
return prompt.strip()
async def get_next_item(self) -> Item:
"""Generate next training item with synthetic data."""
"""Generate next training item with prompt-based data."""
self.iter += 1
# Create system message
system_content = self._create_system_content()
# Get next training example sequentially
example = self.train[self.iter % len(self.train)]
# Create varied placeholder judgment tasks
examples = [
{
"question": "What is the capital of France?",
"correct": "The capital of France is Paris, which has been the capital since 987 AD and serves as the political, economic, and cultural center of the country.", # noqa
"incorrect": [
"The capital of France is London.",
"France's capital is Berlin, located in central Europe.",
"I don't know the answer to this question.",
"France doesn't have a capital city.",
"The capital changes every year in France.",
"Paris is just a city, not a capital.",
],
},
{
"question": "How do you fix a memory leak in Python?",
"correct": "To fix memory leaks in Python: 1) Use memory profilers like tracemalloc or memory_profiler to identify leaks, 2) Ensure proper cleanup of resources with context managers, 3) Break circular references, 4) Close files and database connections explicitly, and 5) Use weak references when appropriate.", # noqa
"incorrect": [
"Just restart your computer and the memory leak will be fixed.",
"Python automatically handles all memory management, so memory leaks are impossible.",
"You need to reinstall Python to fix memory leaks.",
"Memory leaks don't exist in Python because it's interpreted.",
"Use more RAM to solve memory leaks.",
"Delete the Python installation and use a different language.",
],
},
{
"question": "Explain the difference between machine learning and artificial intelligence.",
"correct": "Artificial Intelligence (AI) is the broader field focused on creating systems that can perform tasks typically requiring human intelligence. Machine Learning (ML) is a subset of AI that uses algorithms to learn patterns from data without being explicitly programmed for each task. So ML is one approach to achieving AI.", # noqa
"incorrect": [
"Machine learning and artificial intelligence are exactly the same thing with different names.",
"Machine learning is much broader than AI and includes all computer science.",
"AI is only about robots, while machine learning is only about statistics.",
"Machine learning came before AI historically.",
"AI is a subset of machine learning, not the other way around.",
"There is no difference; they are marketing terms for the same technology.",
],
},
]
# Extract prompt from training data
prompt_text = example.get("prompt", "")
if not prompt_text:
# Fallback if prompt field is missing
prompt_text = "Please provide a helpful response to this question."
# Select random example
example = random.choice(examples)
# Generate different quality responses to compare
answer_generation_prompt = self._create_system_content()
answer_prompt = tuple([
frozenset({"role": "system", "content": answer_generation_prompt}.items()),
frozenset({"role": "user", "content": prompt_text}.items()),
])
# Create list with correct and incorrect answers, ensuring we have enough
incorrect_answers = example["incorrect"][
: self.config.num_choices - 1
] # Take enough incorrect answers
all_answers = [example["correct"]] + incorrect_answers
# If we don't have enough incorrect answers, pad with generic ones
while len(all_answers) < self.config.num_choices:
all_answers.append(
"I don't have enough information to answer this question."
# Generate multiple responses with different quality levels for comparison
try:
# Generate responses with different parameters to get varied quality
high_temp_messages = self._prepare_completion_input(answer_prompt)
# High temperature response (more creative/varied, potentially lower quality)
high_temp_completion = await self.server.chat_completion(
messages=high_temp_messages,
n=1,
max_tokens=self.config.train_max_tokens // 2,
temperature=1.2,
)
random.shuffle(all_answers)
# Find where correct answer ended up
correct_index = all_answers.index(example["correct"])
correct_answer = self.choice_letters[correct_index]
user_content = self.create_judgment_prompt(example["question"], all_answers)
prompt = tuple(
[
frozenset({"role": "system", "content": system_content}.items()),
frozenset({"role": "user", "content": user_content}.items()),
# Low temperature response (more conservative, potentially higher quality)
low_temp_completion = await self.server.chat_completion(
messages=high_temp_messages,
n=1,
max_tokens=self.config.train_max_tokens // 2,
temperature=0.3,
)
if (high_temp_completion.choices and low_temp_completion.choices and
high_temp_completion.choices[0].message.content and
low_temp_completion.choices[0].message.content):
high_temp_answer = high_temp_completion.choices[0].message.content
low_temp_answer = low_temp_completion.choices[0].message.content
# Create list of answers for comparison
answers = [low_temp_answer, high_temp_answer]
# Pad with generic answers if we need more choices
while len(answers) < self.config.num_choices:
answers.append("I don't have enough information to answer this question thoroughly.")
# Take only the number of choices we need
answers = answers[:self.config.num_choices]
# Randomly shuffle positions
random.shuffle(answers)
# Find where the low temp (better) answer ended up
correct_index = answers.index(low_temp_answer)
correct_answer = self.choice_letters[correct_index]
else:
# Fallback if generation fails
answers = [
"This is a comprehensive and well-structured response that addresses the question thoroughly with detailed examples and clear explanations.",
"Brief response.",
]
# Pad to required number of choices
while len(answers) < self.config.num_choices:
answers.append("I don't have sufficient information to provide a complete answer.")
answers = answers[:self.config.num_choices]
random.shuffle(answers)
correct_index = answers.index("This is a comprehensive and well-structured response that addresses the question thoroughly with detailed examples and clear explanations.")
correct_answer = self.choice_letters[correct_index]
except Exception as e:
# Fallback if generation fails
print(f"Warning: Failed to generate training responses: {e}")
answers = [
"This is a comprehensive and detailed response that properly addresses all aspects of the question.",
"Short answer.",
]
)
# Pad to required number of choices
while len(answers) < self.config.num_choices:
answers.append("Insufficient information provided.")
answers = answers[:self.config.num_choices]
random.shuffle(answers)
correct_index = answers.index("This is a comprehensive and detailed response that properly addresses all aspects of the question.")
correct_answer = self.choice_letters[correct_index]
# Create judgment prompt
system_content = self._create_system_content()
user_content = self.create_judgment_prompt(prompt_text, answers)
prompt = tuple([
frozenset({"role": "system", "content": system_content}.items()),
frozenset({"role": "user", "content": user_content}.items()),
])
return (prompt, correct_answer)
@ -1958,7 +2125,8 @@ Notes:
),
"train/total_judgments": self.total_judgments,
"config/group_size": self.config.group_size,
"config/max_token_length": self.config.max_token_length,
"config/train_max_tokens": self.config.train_max_tokens,
"config/eval_max_tokens": self.config.eval_max_tokens,
"config/num_choices": self.config.num_choices,
}
)