mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
1039c3d360
commit
a2e14cf50c
1 changed files with 155 additions and 73 deletions
|
|
@ -398,13 +398,16 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
"""Set up the environment by loading datasets."""
|
||||
# Load training dataset
|
||||
try:
|
||||
self.train = self._load_dataset(self.config.train_dataset, self.config.train_split)
|
||||
self.train = self._load_dataset(
|
||||
self.config.train_dataset, self.config.train_split
|
||||
)
|
||||
# Shuffle training dataset for reproducibility
|
||||
if hasattr(self.train, 'shuffle'):
|
||||
if hasattr(self.train, "shuffle"):
|
||||
self.train = self.train.shuffle(seed=42)
|
||||
else:
|
||||
# For list-like objects, convert to list and shuffle
|
||||
import random
|
||||
|
||||
train_list = list(self.train)
|
||||
random.seed(42)
|
||||
random.shuffle(train_list)
|
||||
|
|
@ -413,42 +416,60 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
print(f"Error loading training dataset '{self.config.train_dataset}': {e}")
|
||||
# Create minimal fallback data in expected format
|
||||
self.train = [
|
||||
{"uid": "train_1", "category": "general", "prompt": "What is the capital of France?"},
|
||||
{"uid": "train_2", "category": "math", "prompt": "Solve for x: 2x + 5 = 15"},
|
||||
{"uid": "train_3", "category": "coding", "prompt": "Write a Python function to calculate factorial"},
|
||||
{
|
||||
"uid": "train_1",
|
||||
"category": "general",
|
||||
"prompt": "What is the capital of France?",
|
||||
},
|
||||
{
|
||||
"uid": "train_2",
|
||||
"category": "math",
|
||||
"prompt": "Solve for x: 2x + 5 = 15",
|
||||
},
|
||||
{
|
||||
"uid": "train_3",
|
||||
"category": "coding",
|
||||
"prompt": "Write a Python function to calculate factorial",
|
||||
},
|
||||
] * 34 # 102 examples
|
||||
print(f"Using fallback training data with {len(self.train)} examples")
|
||||
|
||||
# Load evaluation dataset
|
||||
try:
|
||||
self.test = self._load_dataset(self.config.eval_dataset, self.config.eval_split)
|
||||
self.test = self._load_dataset(
|
||||
self.config.eval_dataset, self.config.eval_split
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading evaluation dataset '{self.config.eval_dataset}': {e}")
|
||||
raise # Evaluation dataset must work
|
||||
|
||||
# Analyze training dataset composition
|
||||
if hasattr(self.train, '__iter__'):
|
||||
if hasattr(self.train, "__iter__"):
|
||||
train_category_counts = {}
|
||||
total_train_items = 0
|
||||
|
||||
|
||||
for item in self.train:
|
||||
total_train_items += 1
|
||||
category = item.get("category", "Unknown")
|
||||
train_category_counts[category] = train_category_counts.get(category, 0) + 1
|
||||
train_category_counts[category] = (
|
||||
train_category_counts.get(category, 0) + 1
|
||||
)
|
||||
|
||||
print(f"\nTraining dataset analysis ({total_train_items} total items):")
|
||||
for category, count in sorted(train_category_counts.items()):
|
||||
print(f" - {category}: {count} samples")
|
||||
|
||||
# Analyze evaluation dataset composition
|
||||
if hasattr(self.test, '__iter__'):
|
||||
if hasattr(self.test, "__iter__"):
|
||||
eval_category_counts = {}
|
||||
total_eval_items = 0
|
||||
|
||||
|
||||
for item in self.test:
|
||||
total_eval_items += 1
|
||||
category = item.get("subset", "Unknown")
|
||||
eval_category_counts[category] = eval_category_counts.get(category, 0) + 1
|
||||
eval_category_counts[category] = (
|
||||
eval_category_counts.get(category, 0) + 1
|
||||
)
|
||||
|
||||
print(f"\nEvaluation dataset analysis ({total_eval_items} total items):")
|
||||
for category, count in sorted(eval_category_counts.items()):
|
||||
|
|
@ -465,7 +486,9 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
if self.config.eval_categories is not None:
|
||||
selected_categories = []
|
||||
for cat in self.config.eval_categories:
|
||||
selected_categories.append(cat.value if hasattr(cat, 'value') else str(cat))
|
||||
selected_categories.append(
|
||||
cat.value if hasattr(cat, "value") else str(cat)
|
||||
)
|
||||
print(
|
||||
f"\nCategory filtering enabled. Selected categories: {selected_categories}"
|
||||
)
|
||||
|
|
@ -480,8 +503,12 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
|
||||
# Show configuration info
|
||||
print(f"\nPairwise Judgement Configuration:")
|
||||
print(f" - Training dataset: {self.config.train_dataset} (split: {self.config.train_split})")
|
||||
print(f" - Evaluation dataset: {self.config.eval_dataset} (split: {self.config.eval_split})")
|
||||
print(
|
||||
f" - Training dataset: {self.config.train_dataset} (split: {self.config.train_split})"
|
||||
)
|
||||
print(
|
||||
f" - Evaluation dataset: {self.config.eval_dataset} (split: {self.config.eval_split})"
|
||||
)
|
||||
print(f" - Thinking mode: {self.config.thinking_mode}")
|
||||
print(f" - Eval temperature: {self.config.eval_temperature}")
|
||||
print(f" - Number of choices: {self.config.num_choices}")
|
||||
|
|
@ -492,7 +519,7 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
sample_train_item = self.train[0]
|
||||
print(f"\nSample training item structure:")
|
||||
print(f"- Available keys: {list(sample_train_item.keys())}")
|
||||
|
||||
|
||||
if "uid" in sample_train_item:
|
||||
print(f"- UID: {sample_train_item['uid']}")
|
||||
if "category" in sample_train_item:
|
||||
|
|
@ -515,12 +542,16 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
# Handle different dataset structures
|
||||
if "prompt" in sample_eval_item:
|
||||
print(f"- Prompt: {sample_eval_item['prompt'][:100]}...")
|
||||
elif "chosen" in sample_eval_item and isinstance(sample_eval_item["chosen"], str):
|
||||
elif "chosen" in sample_eval_item and isinstance(
|
||||
sample_eval_item["chosen"], str
|
||||
):
|
||||
print(f"- Chosen (string): {sample_eval_item['chosen'][:100]}...")
|
||||
elif "rejected" in sample_eval_item and isinstance(
|
||||
sample_eval_item["rejected"], str
|
||||
):
|
||||
print(f"- Rejected (string): {sample_eval_item['rejected'][:100]}...")
|
||||
print(
|
||||
f"- Rejected (string): {sample_eval_item['rejected'][:100]}..."
|
||||
)
|
||||
|
||||
if "chosen" in sample_eval_item:
|
||||
if isinstance(sample_eval_item["chosen"], list):
|
||||
|
|
@ -530,11 +561,15 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
f"- First chosen (truncated): {sample_eval_item['chosen'][0][:200]}..."
|
||||
)
|
||||
else:
|
||||
print(f"- Chosen (string): {sample_eval_item['chosen'][:200]}...")
|
||||
print(
|
||||
f"- Chosen (string): {sample_eval_item['chosen'][:200]}..."
|
||||
)
|
||||
|
||||
if "rejected" in sample_eval_item:
|
||||
if isinstance(sample_eval_item["rejected"], list):
|
||||
print(f"- Rejected responses: {len(sample_eval_item['rejected'])}")
|
||||
print(
|
||||
f"- Rejected responses: {len(sample_eval_item['rejected'])}"
|
||||
)
|
||||
if sample_eval_item["rejected"]:
|
||||
print(
|
||||
f"- First rejected (truncated): {sample_eval_item['rejected'][0][:200]}..."
|
||||
|
|
@ -545,7 +580,9 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not display sample evaluation item structure: {e}")
|
||||
print(
|
||||
f"Warning: Could not display sample evaluation item structure: {e}"
|
||||
)
|
||||
|
||||
# Show debug mode status
|
||||
if self.config.full_debug:
|
||||
|
|
@ -569,53 +606,81 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
def _load_dataset(self, dataset_path: str, split: str = None) -> List[Dict]:
|
||||
"""
|
||||
Load dataset using HuggingFace load_dataset (supports both HF datasets and local files).
|
||||
|
||||
|
||||
Args:
|
||||
dataset_path: Either HuggingFace dataset name or path to local file
|
||||
split: Split to use
|
||||
|
||||
|
||||
Returns:
|
||||
List of dataset items
|
||||
"""
|
||||
import os
|
||||
|
||||
|
||||
try:
|
||||
# Check if it's a local file
|
||||
if os.path.exists(dataset_path):
|
||||
# Local file - use appropriate loader based on extension
|
||||
if dataset_path.endswith('.jsonl') or dataset_path.endswith('.json'):
|
||||
dataset = load_dataset("json", data_files=dataset_path, split=split or "train", trust_remote_code=True)
|
||||
elif dataset_path.endswith('.csv'):
|
||||
dataset = load_dataset("csv", data_files=dataset_path, split=split or "train", trust_remote_code=True)
|
||||
elif dataset_path.endswith('.parquet'):
|
||||
dataset = load_dataset("parquet", data_files=dataset_path, split=split or "train", trust_remote_code=True)
|
||||
if dataset_path.endswith(".jsonl") or dataset_path.endswith(".json"):
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files=dataset_path,
|
||||
split=split or "train",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
elif dataset_path.endswith(".csv"):
|
||||
dataset = load_dataset(
|
||||
"csv",
|
||||
data_files=dataset_path,
|
||||
split=split or "train",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
elif dataset_path.endswith(".parquet"):
|
||||
dataset = load_dataset(
|
||||
"parquet",
|
||||
data_files=dataset_path,
|
||||
split=split or "train",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
# Try JSON as default
|
||||
dataset = load_dataset("json", data_files=dataset_path, split=split or "train", trust_remote_code=True)
|
||||
|
||||
print(f"Loaded local dataset from {dataset_path} with {len(dataset)} examples")
|
||||
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files=dataset_path,
|
||||
split=split or "train",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
print(
|
||||
f"Loaded local dataset from {dataset_path} with {len(dataset)} examples"
|
||||
)
|
||||
|
||||
else:
|
||||
# HuggingFace dataset
|
||||
if split:
|
||||
dataset = load_dataset(dataset_path, split=split, trust_remote_code=True)
|
||||
dataset = load_dataset(
|
||||
dataset_path, split=split, trust_remote_code=True
|
||||
)
|
||||
else:
|
||||
dataset_dict = load_dataset(dataset_path, trust_remote_code=True)
|
||||
# If no split specified, try to get the first available split
|
||||
if hasattr(dataset_dict, 'keys'):
|
||||
if hasattr(dataset_dict, "keys"):
|
||||
available_splits = list(dataset_dict.keys())
|
||||
if available_splits:
|
||||
dataset = dataset_dict[available_splits[0]]
|
||||
print(f"No split specified, using '{available_splits[0]}' split")
|
||||
print(
|
||||
f"No split specified, using '{available_splits[0]}' split"
|
||||
)
|
||||
else:
|
||||
dataset = dataset_dict
|
||||
else:
|
||||
dataset = dataset_dict
|
||||
|
||||
print(f"Loaded HuggingFace dataset {dataset_path} with {len(dataset)} examples")
|
||||
|
||||
|
||||
print(
|
||||
f"Loaded HuggingFace dataset {dataset_path} with {len(dataset)} examples"
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading dataset {dataset_path}: {e}")
|
||||
raise
|
||||
|
|
@ -711,17 +776,21 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
|
||||
# Generate different quality responses to compare
|
||||
answer_generation_prompt = self._create_system_content()
|
||||
|
||||
answer_prompt = tuple([
|
||||
frozenset({"role": "system", "content": answer_generation_prompt}.items()),
|
||||
frozenset({"role": "user", "content": prompt_text}.items()),
|
||||
])
|
||||
|
||||
answer_prompt = tuple(
|
||||
[
|
||||
frozenset(
|
||||
{"role": "system", "content": answer_generation_prompt}.items()
|
||||
),
|
||||
frozenset({"role": "user", "content": prompt_text}.items()),
|
||||
]
|
||||
)
|
||||
|
||||
# Generate multiple responses with different quality levels for comparison
|
||||
try:
|
||||
# Generate responses with different parameters to get varied quality
|
||||
high_temp_messages = self._prepare_completion_input(answer_prompt)
|
||||
|
||||
|
||||
# High temperature response (more creative/varied, potentially lower quality)
|
||||
high_temp_completion = await self.server.chat_completion(
|
||||
messages=high_temp_messages,
|
||||
|
|
@ -729,7 +798,7 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
max_tokens=self.config.train_max_tokens // 2,
|
||||
temperature=1.2,
|
||||
)
|
||||
|
||||
|
||||
# Low temperature response (more conservative, potentially higher quality)
|
||||
low_temp_completion = await self.server.chat_completion(
|
||||
messages=high_temp_messages,
|
||||
|
|
@ -737,31 +806,36 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
max_tokens=self.config.train_max_tokens // 2,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
if (high_temp_completion.choices and low_temp_completion.choices and
|
||||
high_temp_completion.choices[0].message.content and
|
||||
low_temp_completion.choices[0].message.content):
|
||||
|
||||
|
||||
if (
|
||||
high_temp_completion.choices
|
||||
and low_temp_completion.choices
|
||||
and high_temp_completion.choices[0].message.content
|
||||
and low_temp_completion.choices[0].message.content
|
||||
):
|
||||
|
||||
high_temp_answer = high_temp_completion.choices[0].message.content
|
||||
low_temp_answer = low_temp_completion.choices[0].message.content
|
||||
|
||||
|
||||
# Create list of answers for comparison
|
||||
answers = [low_temp_answer, high_temp_answer]
|
||||
|
||||
|
||||
# Pad with generic answers if we need more choices
|
||||
while len(answers) < self.config.num_choices:
|
||||
answers.append("I don't have enough information to answer this question thoroughly.")
|
||||
|
||||
answers.append(
|
||||
"I don't have enough information to answer this question thoroughly."
|
||||
)
|
||||
|
||||
# Take only the number of choices we need
|
||||
answers = answers[:self.config.num_choices]
|
||||
|
||||
answers = answers[: self.config.num_choices]
|
||||
|
||||
# Randomly shuffle positions
|
||||
random.shuffle(answers)
|
||||
|
||||
|
||||
# Find where the low temp (better) answer ended up
|
||||
correct_index = answers.index(low_temp_answer)
|
||||
correct_answer = self.choice_letters[correct_index]
|
||||
|
||||
|
||||
else:
|
||||
# Fallback if generation fails
|
||||
answers = [
|
||||
|
|
@ -770,13 +844,17 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
]
|
||||
# Pad to required number of choices
|
||||
while len(answers) < self.config.num_choices:
|
||||
answers.append("I don't have sufficient information to provide a complete answer.")
|
||||
|
||||
answers = answers[:self.config.num_choices]
|
||||
answers.append(
|
||||
"I don't have sufficient information to provide a complete answer."
|
||||
)
|
||||
|
||||
answers = answers[: self.config.num_choices]
|
||||
random.shuffle(answers)
|
||||
correct_index = answers.index("This is a comprehensive and well-structured response that addresses the question thoroughly with detailed examples and clear explanations.")
|
||||
correct_index = answers.index(
|
||||
"This is a comprehensive and well-structured response that addresses the question thoroughly with detailed examples and clear explanations."
|
||||
)
|
||||
correct_answer = self.choice_letters[correct_index]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Fallback if generation fails
|
||||
print(f"Warning: Failed to generate training responses: {e}")
|
||||
|
|
@ -787,20 +865,24 @@ class PairwiseJudgementEnv(BaseEnv):
|
|||
# Pad to required number of choices
|
||||
while len(answers) < self.config.num_choices:
|
||||
answers.append("Insufficient information provided.")
|
||||
|
||||
answers = answers[:self.config.num_choices]
|
||||
|
||||
answers = answers[: self.config.num_choices]
|
||||
random.shuffle(answers)
|
||||
correct_index = answers.index("This is a comprehensive and detailed response that properly addresses all aspects of the question.")
|
||||
correct_index = answers.index(
|
||||
"This is a comprehensive and detailed response that properly addresses all aspects of the question."
|
||||
)
|
||||
correct_answer = self.choice_letters[correct_index]
|
||||
|
||||
# Create judgment prompt
|
||||
system_content = self._create_system_content()
|
||||
user_content = self.create_judgment_prompt(prompt_text, answers)
|
||||
|
||||
prompt = tuple([
|
||||
frozenset({"role": "system", "content": system_content}.items()),
|
||||
frozenset({"role": "user", "content": user_content}.items()),
|
||||
])
|
||||
prompt = tuple(
|
||||
[
|
||||
frozenset({"role": "system", "content": system_content}.items()),
|
||||
frozenset({"role": "user", "content": user_content}.items()),
|
||||
]
|
||||
)
|
||||
|
||||
return (prompt, correct_answer)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue