mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
1043 lines
41 KiB
Python
1043 lines
41 KiB
Python
import json
|
|
import math
|
|
import random
|
|
import re
|
|
import time
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import wandb
|
|
from datasets import load_dataset
|
|
from pydantic import Field
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
EvalHandlingEnum,
|
|
Item,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
|
|
class PairwiseJudgementConfig(BaseEnvConfig):
|
|
"""Configuration for PairwiseJudgementEnv with thinking mode and configurable options."""
|
|
|
|
thinking_mode: bool = Field(
|
|
default=False,
|
|
description="Whether to enable thinking mode with <think></think> tags.",
|
|
)
|
|
|
|
num_choices: int = Field(
|
|
default=4,
|
|
ge=2,
|
|
le=26,
|
|
description="Number of choices for pairwise judgment (2-26, corresponding to A-Z).",
|
|
)
|
|
|
|
custom_thinking_prompt: Optional[str] = Field(
|
|
default=None,
|
|
description="Custom thinking prompt. If None, uses the default thinking prompt.",
|
|
)
|
|
|
|
custom_judgment_prompt: Optional[str] = Field(
|
|
default=None,
|
|
description="Custom judgment prompt. If None, uses the default judgment prompt.",
|
|
)
|
|
|
|
eval_temperature: float = Field(
|
|
default=0.6,
|
|
description="Temperature for evaluation completions.",
|
|
)
|
|
|
|
rollout_temperature: float = Field(
|
|
default=0.8,
|
|
description="Temperature for training rollout completions.",
|
|
)
|
|
|
|
eval_max_tokens: int = Field(
|
|
default=1024 * 16,
|
|
description="Maximum tokens for evaluation completions.",
|
|
)
|
|
|
|
train_max_tokens: int = Field(
|
|
default=1024 * 16,
|
|
description="Maximum tokens for training completions.",
|
|
)
|
|
|
|
|
|
class PairwiseJudgementEnv(BaseEnv):
|
|
name = "pairwise_judgement"
|
|
env_config_cls = PairwiseJudgementConfig
|
|
|
|
def __init__(
|
|
self,
|
|
config: PairwiseJudgementConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.config: PairwiseJudgementConfig = config
|
|
self.percent_correct_buffer = []
|
|
self.eval_metrics = []
|
|
|
|
# Generate choice letters based on num_choices (A, B, C, D... up to Z)
|
|
self.choice_letters = [chr(65 + i) for i in range(self.config.num_choices)]
|
|
|
|
# Initialize detailed metrics tracking for all choice letters
|
|
self.judgment_letter_counts = {letter: 0 for letter in self.choice_letters}
|
|
self.judgment_letter_correct = {letter: 0 for letter in self.choice_letters}
|
|
self.error_count = 0 # Failed to follow format
|
|
self.total_judgments = 0
|
|
self.rollouts_for_wandb = []
|
|
|
|
# Pre-compile regex patterns for performance
|
|
self._think_pattern = re.compile(r"<think>")
|
|
self._think_close_pattern = re.compile(r"</think>")
|
|
self._think_content_pattern = re.compile(r"</think>\s*(.*)", re.DOTALL)
|
|
self._question_pattern = re.compile(
|
|
r"\[User Question\]\s*(.*?)\s*\[The Start of Assistant A", re.DOTALL
|
|
)
|
|
self._thinking_extract_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
|
|
|
# Pre-compile choice patterns for each letter
|
|
self._choice_patterns = {
|
|
letter: re.compile(rf"\[\[{letter}\]\]") for letter in self.choice_letters
|
|
}
|
|
self._answer_choice_patterns = {
|
|
letter: re.compile(
|
|
rf"\[The Start of Assistant {letter}\'s Answer\]\s*(.*?)\s*\[The End of Assistant {letter}\'s Answer\]",
|
|
re.DOTALL,
|
|
)
|
|
for letter in self.choice_letters
|
|
}
|
|
|
|
# System prompts (use custom ones if provided, otherwise defaults)
|
|
self.thinking_system_prompt = self._get_thinking_prompt()
|
|
self.judgment_system_prompt = self._get_judgment_prompt()
|
|
|
|
def _get_thinking_prompt(self) -> str:
|
|
"""Get thinking system prompt."""
|
|
return (
|
|
self.config.custom_thinking_prompt
|
|
if self.config.custom_thinking_prompt
|
|
else "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
|
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
|
|
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
|
|
"</think> tags, and then provide your solution or response to the problem."
|
|
)
|
|
|
|
def _get_judgment_prompt(self) -> str:
|
|
"""Get judgment system prompt."""
|
|
if self.config.custom_judgment_prompt:
|
|
return self.config.custom_judgment_prompt
|
|
|
|
choice_format_examples = ", ".join(
|
|
[
|
|
f'"[[{letter}]]" if assistant {letter} is best'
|
|
for letter in self.choice_letters
|
|
]
|
|
)
|
|
|
|
return (
|
|
f"Please act as an impartial judge and evaluate the quality of the responses provided by {self.config.num_choices} AI assistants to the user question displayed below. "
|
|
f"You should choose the assistant that follows the user's instructions and answers the user's question best. "
|
|
f"Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. "
|
|
f"Begin your evaluation by comparing the {self.config.num_choices} responses and provide a short explanation. "
|
|
f"Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. "
|
|
f"Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. "
|
|
f"Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: "
|
|
f"{choice_format_examples}."
|
|
)
|
|
|
|
def _reset_metrics(self) -> None:
|
|
"""Reset training metrics."""
|
|
self.percent_correct_buffer = []
|
|
self.judgment_letter_counts = {letter: 0 for letter in self.choice_letters}
|
|
self.judgment_letter_correct = {letter: 0 for letter in self.choice_letters}
|
|
self.error_count = 0
|
|
self.total_judgments = 0
|
|
|
|
def _convert_messages_to_list(self, prompt_tuple: Tuple) -> List[Dict]:
|
|
"""Convert frozenset message format to list format."""
|
|
messages = []
|
|
for role_dict in prompt_tuple:
|
|
messages.append(dict(role_dict))
|
|
return messages
|
|
|
|
def _create_system_content(self) -> str:
|
|
"""Create system message content based on thinking mode."""
|
|
if self.config.thinking_mode:
|
|
return f"{self.thinking_system_prompt}\n\n{self.judgment_system_prompt}"
|
|
return self.judgment_system_prompt
|
|
|
|
def _prepare_completion_input(self, prompt_tuple: Tuple) -> Tuple[List[Dict], str]:
|
|
"""Convert prompt tuple to messages and formatted prompt text."""
|
|
messages = self._convert_messages_to_list(prompt_tuple)
|
|
prompt_text = self.tokenizer.apply_chat_template(
|
|
messages, add_generation_prompt=True, tokenize=False
|
|
)
|
|
return messages, prompt_text
|
|
|
|
def _get_train_completion_params(self) -> Dict:
|
|
"""Get completion parameters for training rollouts."""
|
|
return {
|
|
"n": self.config.group_size,
|
|
"max_tokens": self.config.train_max_tokens,
|
|
"temperature": self.config.rollout_temperature,
|
|
}
|
|
|
|
def _get_eval_completion_params(self) -> Dict:
|
|
"""Get completion parameters for evaluation."""
|
|
return {
|
|
"n": 1,
|
|
"max_tokens": self.config.eval_max_tokens,
|
|
"temperature": self.config.eval_temperature,
|
|
"split": "eval",
|
|
}
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[PairwiseJudgementConfig, List[APIServerConfig]]:
|
|
env_config = PairwiseJudgementConfig(
|
|
tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
|
|
group_size=16,
|
|
use_wandb=True,
|
|
max_num_workers_per_node=16,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=2000,
|
|
batch_size=1024,
|
|
steps_per_eval=25,
|
|
max_token_length=1024 * 16,
|
|
inference_weight=1.0,
|
|
wandb_name="pairwise_judgment",
|
|
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
|
eval_limit_ratio=0.1,
|
|
min_batch_allocation=0.1,
|
|
thinking_mode=True,
|
|
)
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
base_url="http://localhost:9004/v1",
|
|
api_key="x",
|
|
num_max_requests_at_once=32,
|
|
num_requests_for_eval=256,
|
|
),
|
|
]
|
|
return env_config, server_configs
|
|
|
|
async def setup(self) -> None:
|
|
"""Set up the environment by loading datasets."""
|
|
# Load placeholder train dataset (not actually used since we generate synthetic examples)
|
|
try:
|
|
self.train = load_dataset("example/train", split="train")
|
|
print(f"Loaded placeholder train dataset with {len(self.train)} examples")
|
|
except Exception as e:
|
|
# Create minimal placeholder data if dataset doesn't exist
|
|
# Note: This isn't actually used since get_next_item() generates synthetic examples
|
|
self.train = [{"question": "What is 2+2?", "answer": "4"}] * 100
|
|
print(f"Using synthetic placeholder training data due to error: {e}")
|
|
|
|
# Load evaluation dataset - reward-bench-2 (MUST WORK OR CRASH)
|
|
self.test = load_dataset(
|
|
"allenai/reward-bench-2", split="test", trust_remote_code=True
|
|
)
|
|
print(f"Loaded reward-bench-2 eval dataset with {len(self.test)} examples")
|
|
|
|
# Debug: Show sample evaluation item structure
|
|
if len(self.test) > 0:
|
|
try:
|
|
sample_item = self.test[0]
|
|
print(f"\nSample eval item structure:")
|
|
print(f"- Available keys: {list(sample_item.keys())}")
|
|
|
|
# Handle different dataset structures
|
|
if "prompt" in sample_item:
|
|
print(f"- Prompt: {sample_item['prompt'][:100]}...")
|
|
elif "chosen" in sample_item and isinstance(sample_item["chosen"], str):
|
|
print(f"- Chosen (string): {sample_item['chosen'][:100]}...")
|
|
elif "rejected" in sample_item and isinstance(
|
|
sample_item["rejected"], str
|
|
):
|
|
print(f"- Rejected (string): {sample_item['rejected'][:100]}...")
|
|
|
|
if "chosen" in sample_item:
|
|
if isinstance(sample_item["chosen"], list):
|
|
print(f"- Chosen responses: {len(sample_item['chosen'])}")
|
|
if sample_item["chosen"]:
|
|
print(
|
|
f"- First chosen (truncated): {sample_item['chosen'][0][:200]}..."
|
|
)
|
|
else:
|
|
print(f"- Chosen (string): {sample_item['chosen'][:200]}...")
|
|
|
|
if "rejected" in sample_item:
|
|
if isinstance(sample_item["rejected"], list):
|
|
print(f"- Rejected responses: {len(sample_item['rejected'])}")
|
|
if sample_item["rejected"]:
|
|
print(
|
|
f"- First rejected (truncated): {sample_item['rejected'][0][:200]}..."
|
|
)
|
|
else:
|
|
print(
|
|
f"- Rejected (string): {sample_item['rejected'][:200]}..."
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Could not display sample item structure: {e}")
|
|
|
|
self.iter = 0
|
|
|
|
def save_checkpoint(self, step: int, data: Optional[Dict] = None) -> None:
|
|
"""Save checkpoint including iteration state."""
|
|
if data is None:
|
|
data = {}
|
|
data["iter"] = self.iter
|
|
super().save_checkpoint(step, data)
|
|
|
|
def process_judgement(self, judgment: str, track_metrics: bool = True) -> str:
|
|
"""Extract judgment from model response."""
|
|
if self.config.thinking_mode:
|
|
# Check for exactly one pair of think tags using pre-compiled patterns
|
|
think_open_count = len(self._think_pattern.findall(judgment))
|
|
think_close_count = len(self._think_close_pattern.findall(judgment))
|
|
|
|
if think_open_count != 1 or think_close_count != 1:
|
|
if track_metrics:
|
|
self.error_count += 1
|
|
self.total_judgments += 1
|
|
return "format_error"
|
|
|
|
# Parse only content after </think> tags
|
|
match = self._think_content_pattern.search(judgment)
|
|
if match:
|
|
judgment = match.group(1)
|
|
else:
|
|
if track_metrics:
|
|
self.error_count += 1
|
|
self.total_judgments += 1
|
|
return "format_error"
|
|
|
|
if track_metrics:
|
|
self.total_judgments += 1
|
|
|
|
# Check for each possible choice letter using pre-compiled patterns
|
|
for letter in self.choice_letters:
|
|
if self._choice_patterns[letter].search(judgment):
|
|
if track_metrics:
|
|
self.judgment_letter_counts[letter] += 1
|
|
return letter
|
|
|
|
# No valid judgment found
|
|
if track_metrics:
|
|
self.error_count += 1
|
|
return "format_error"
|
|
|
|
def create_judgment_prompt(self, question: str, answers: List[str]) -> str:
|
|
"""Create the user prompt for judgment task."""
|
|
if len(answers) != self.config.num_choices:
|
|
raise ValueError(
|
|
f"Need exactly {self.config.num_choices} answers for judgment, got {len(answers)}"
|
|
)
|
|
|
|
prompt = f"[User Question]\n{question}\n\n"
|
|
|
|
for i, answer in enumerate(answers):
|
|
letter = self.choice_letters[i]
|
|
prompt += f"[The Start of Assistant {letter}'s Answer]\n{answer}\n[The End of Assistant {letter}'s Answer]\n\n"
|
|
|
|
return prompt.strip()
|
|
|
|
async def get_next_item(self) -> Item:
|
|
"""Generate next training item with synthetic data."""
|
|
self.iter += 1
|
|
|
|
# Create system message
|
|
system_content = self._create_system_content()
|
|
|
|
# Create varied placeholder judgment tasks
|
|
examples = [
|
|
{
|
|
"question": "What is the capital of France?",
|
|
"correct": "The capital of France is Paris, which has been the capital since 987 AD and serves as the political, economic, and cultural center of the country.",
|
|
"incorrect": [
|
|
"The capital of France is London.",
|
|
"France's capital is Berlin, located in central Europe.",
|
|
"I don't know the answer to this question.",
|
|
"France doesn't have a capital city.",
|
|
"The capital changes every year in France.",
|
|
"Paris is just a city, not a capital.",
|
|
],
|
|
},
|
|
{
|
|
"question": "How do you fix a memory leak in Python?",
|
|
"correct": "To fix memory leaks in Python: 1) Use memory profilers like tracemalloc or memory_profiler to identify leaks, 2) Ensure proper cleanup of resources with context managers, 3) Break circular references, 4) Close files and database connections explicitly, and 5) Use weak references when appropriate.",
|
|
"incorrect": [
|
|
"Just restart your computer and the memory leak will be fixed.",
|
|
"Python automatically handles all memory management, so memory leaks are impossible.",
|
|
"You need to reinstall Python to fix memory leaks.",
|
|
"Memory leaks don't exist in Python because it's interpreted.",
|
|
"Use more RAM to solve memory leaks.",
|
|
"Delete the Python installation and use a different language.",
|
|
],
|
|
},
|
|
{
|
|
"question": "Explain the difference between machine learning and artificial intelligence.",
|
|
"correct": "Artificial Intelligence (AI) is the broader field focused on creating systems that can perform tasks typically requiring human intelligence. Machine Learning (ML) is a subset of AI that uses algorithms to learn patterns from data without being explicitly programmed for each task. So ML is one approach to achieving AI.",
|
|
"incorrect": [
|
|
"Machine learning and artificial intelligence are exactly the same thing with different names.",
|
|
"Machine learning is much broader than AI and includes all computer science.",
|
|
"AI is only about robots, while machine learning is only about statistics.",
|
|
"Machine learning came before AI historically.",
|
|
"AI is a subset of machine learning, not the other way around.",
|
|
"There is no difference; they are marketing terms for the same technology.",
|
|
],
|
|
},
|
|
]
|
|
|
|
# Select random example
|
|
example = random.choice(examples)
|
|
|
|
# Create list with correct and incorrect answers, ensuring we have enough
|
|
incorrect_answers = example["incorrect"][
|
|
: self.config.num_choices - 1
|
|
] # Take enough incorrect answers
|
|
all_answers = [example["correct"]] + incorrect_answers
|
|
|
|
# If we don't have enough incorrect answers, pad with generic ones
|
|
while len(all_answers) < self.config.num_choices:
|
|
all_answers.append(
|
|
"I don't have enough information to answer this question."
|
|
)
|
|
|
|
random.shuffle(all_answers)
|
|
|
|
# Find where correct answer ended up
|
|
correct_index = all_answers.index(example["correct"])
|
|
correct_answer = self.choice_letters[correct_index]
|
|
|
|
user_content = self.create_judgment_prompt(example["question"], all_answers)
|
|
|
|
prompt = tuple(
|
|
[
|
|
frozenset({"role": "system", "content": system_content}.items()),
|
|
frozenset({"role": "user", "content": user_content}.items()),
|
|
]
|
|
)
|
|
|
|
return (prompt, correct_answer)
|
|
|
|
def prepare_eval_item(self, item: dict) -> Tuple[Optional[Tuple], Optional[str]]:
|
|
"""
|
|
Prepare an evaluation item from the reward-bench-2 dataset.
|
|
|
|
Dataset structure:
|
|
- chosen: list with 1 element (the best response)
|
|
- rejected: list with 3+ elements (worse responses)
|
|
- We take chosen[0] + rejected[:num_choices-1] to create judgment with configured number of choices
|
|
"""
|
|
try:
|
|
question = item.get("prompt", "")
|
|
chosen_responses = item.get("chosen", [])
|
|
rejected_responses = item.get("rejected", [])
|
|
|
|
# Validate required fields
|
|
if not question:
|
|
return None, None
|
|
|
|
# Take one chosen response and (num_choices-1) rejected responses
|
|
required_rejected = self.config.num_choices - 1
|
|
if (
|
|
len(chosen_responses) == 0
|
|
or len(rejected_responses) < required_rejected
|
|
):
|
|
return None, None
|
|
|
|
chosen = chosen_responses[0]
|
|
rejected = rejected_responses[:required_rejected]
|
|
|
|
# Validate response content
|
|
if not chosen or not all(rejected):
|
|
return None, None
|
|
|
|
# Create list with answer and whether it's correct
|
|
data = [(chosen, True)] + [(r, False) for r in rejected]
|
|
random.shuffle(data)
|
|
|
|
# Extract shuffled answers and find correct position
|
|
shuffled_answers = [item[0] for item in data]
|
|
correct_index = next(
|
|
i for i, (_, is_correct) in enumerate(data) if is_correct
|
|
)
|
|
correct_answer = self.choice_letters[correct_index]
|
|
|
|
# Create system message
|
|
system_content = self._create_system_content()
|
|
|
|
# Create user prompt
|
|
user_content = self.create_judgment_prompt(question, shuffled_answers)
|
|
|
|
prompt = tuple(
|
|
[
|
|
frozenset({"role": "system", "content": system_content}.items()),
|
|
frozenset({"role": "user", "content": user_content}.items()),
|
|
]
|
|
)
|
|
|
|
return prompt, correct_answer
|
|
|
|
except Exception as e:
|
|
print(f"Error preparing evaluation item: {e}")
|
|
return None, None
|
|
|
|
async def collect_trajectories(self, item: Item) -> Tuple[ScoredDataGroup, List]:
|
|
"""Collect and score model trajectories."""
|
|
messages, prompt_text = self._prepare_completion_input(item[0])
|
|
completion_params = self._get_train_completion_params()
|
|
|
|
completions = await self.server.completion(
|
|
prompt=prompt_text, **completion_params
|
|
)
|
|
|
|
# Build trajectories without duplicating message construction
|
|
to_score = []
|
|
for completion_choice in completions.choices:
|
|
# Add assistant response to existing messages
|
|
trajectory_messages = messages + [
|
|
{"role": "assistant", "content": completion_choice.text}
|
|
]
|
|
to_score.append((tuple(trajectory_messages), item[1]))
|
|
|
|
scored_data = await self.score(to_score)
|
|
|
|
# Add rollouts for wandb visualization
|
|
if scored_data is not None:
|
|
await self.add_rollouts_for_wandb(scored_data, item)
|
|
|
|
return scored_data, []
|
|
|
|
async def score(self, rollout_group_data: List[Tuple]) -> Optional[ScoredDataGroup]:
|
|
"""Score a group of rollout data."""
|
|
if not rollout_group_data:
|
|
return None
|
|
|
|
try:
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = []
|
|
scores["masks"] = []
|
|
scores["scores"] = []
|
|
|
|
random.shuffle(rollout_group_data)
|
|
|
|
for item in rollout_group_data:
|
|
# Simplified validation
|
|
if not item or len(item) < 2 or not item[0]:
|
|
continue
|
|
|
|
model_response = item[0][-1]["content"]
|
|
ground_truth = item[1]
|
|
|
|
predicted_answer = self.process_judgement(
|
|
model_response, track_metrics=True
|
|
)
|
|
reward = 1.0 if predicted_answer == ground_truth else 0.0
|
|
|
|
# Track correct judgments per letter
|
|
if (
|
|
predicted_answer == ground_truth
|
|
and predicted_answer != "format_error"
|
|
):
|
|
self.judgment_letter_correct[predicted_answer] += 1
|
|
|
|
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
|
tokens = out_dict["tokens"]
|
|
masks = out_dict["masks"]
|
|
|
|
# Skip obviously bad examples
|
|
if len([1 for mask in masks if mask != -100]) < 10:
|
|
continue
|
|
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["scores"].append(reward) # Use reward directly (1.0 or 0.0)
|
|
|
|
if len(scores["tokens"]) >= self.config.group_size:
|
|
break
|
|
|
|
if not scores["tokens"]:
|
|
return None
|
|
|
|
# Update percent correct buffer
|
|
for score in scores["scores"]:
|
|
self.percent_correct_buffer.append(max(score, 0))
|
|
|
|
# Return None if all scores are the same (no learning signal)
|
|
if len(set(scores["scores"])) == 1:
|
|
return None
|
|
|
|
return scores
|
|
|
|
except Exception as e:
|
|
print(f"Error in score method: {e}")
|
|
return None
|
|
|
|
async def rollout_and_score_eval(self, test_item: dict) -> dict:
|
|
"""Rollout and score evaluation with detailed sample data collection."""
|
|
try:
|
|
prompt, ground_truth = self.prepare_eval_item(test_item)
|
|
if prompt is None:
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
messages, prompt_text = self._prepare_completion_input(prompt)
|
|
completion_params = self._get_eval_completion_params()
|
|
|
|
completion = await self.server.completion(
|
|
prompt=prompt_text, **completion_params
|
|
)
|
|
|
|
if not completion.choices:
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
model_response = completion.choices[0].text
|
|
predicted_answer = self.process_judgement(
|
|
model_response, track_metrics=False
|
|
)
|
|
|
|
score = 1.0 if predicted_answer == ground_truth else 0.0
|
|
|
|
# Extract question and answer choices from the user message
|
|
user_content = messages[1]["content"]
|
|
question_match = self._question_pattern.search(user_content)
|
|
question = (
|
|
question_match.group(1).strip()
|
|
if question_match
|
|
else "Unknown question"
|
|
)
|
|
|
|
# Extract individual answer choices for all configured letters
|
|
answer_choices = {}
|
|
for letter in self.choice_letters:
|
|
match = self._answer_choice_patterns[letter].search(user_content)
|
|
if match:
|
|
answer_choices[letter] = match.group(1).strip()
|
|
|
|
# Add full conversation including model response
|
|
full_messages = messages + [
|
|
{"role": "assistant", "content": model_response}
|
|
]
|
|
|
|
sample = {
|
|
"messages": full_messages,
|
|
"question": question,
|
|
"answer_choices": answer_choices,
|
|
"ground_truth": ground_truth,
|
|
"predicted_judgment": predicted_answer,
|
|
"score": int(score),
|
|
"correct": bool(score),
|
|
"finish_reason": completion.choices[0].finish_reason,
|
|
"thinking_mode": self.config.thinking_mode,
|
|
"format_compliant": predicted_answer != "format_error",
|
|
"dataset_item_id": test_item.get("id", "unknown"),
|
|
"dataset_subset": test_item.get("subset", "unknown"),
|
|
"num_choices": self.config.num_choices,
|
|
}
|
|
|
|
# Add thinking-specific parsing info
|
|
if self.config.thinking_mode:
|
|
if "</think>" in model_response:
|
|
sample["response_after_think"] = model_response.split("</think>")[
|
|
-1
|
|
].strip()
|
|
sample["thinking_content"] = self._thinking_extract_pattern.search(
|
|
model_response
|
|
)
|
|
if sample["thinking_content"]:
|
|
sample["thinking_content"] = (
|
|
sample["thinking_content"].group(1).strip()
|
|
)
|
|
else:
|
|
sample["response_after_think"] = model_response
|
|
sample["thinking_content"] = None
|
|
|
|
return {"score": score, "sample": sample}
|
|
|
|
except Exception as e:
|
|
print(f"Error in rollout_and_score_eval: {e}")
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
def _calculate_response_metrics(
|
|
self, samples: List[dict], thinking_mode_used: bool
|
|
) -> Tuple[List[int], int, Dict[str, int]]:
|
|
"""Calculate response-related metrics from samples."""
|
|
response_lengths = []
|
|
thinking_utilization = 0
|
|
judgment_counts = {letter: 0 for letter in self.choice_letters}
|
|
judgment_counts["format_error"] = 0
|
|
|
|
for sample in samples:
|
|
if not sample:
|
|
continue
|
|
|
|
# Track response length
|
|
messages = sample.get("messages", [])
|
|
if messages:
|
|
assistant_msg = messages[-1].get("content", "")
|
|
response_lengths.append(len(assistant_msg))
|
|
|
|
# Track thinking utilization in thinking mode
|
|
if thinking_mode_used:
|
|
thinking_content = sample.get("thinking_content")
|
|
if thinking_content:
|
|
thinking_utilization += 1
|
|
|
|
# Track judgment distribution
|
|
predicted_judgment = sample.get("predicted_judgment", "format_error")
|
|
if predicted_judgment in judgment_counts:
|
|
judgment_counts[predicted_judgment] += 1
|
|
|
|
return response_lengths, thinking_utilization, judgment_counts
|
|
|
|
async def evaluate(self, *args, **kwargs) -> None:
|
|
"""Evaluate the model on the test dataset."""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
eval_tasks = [
|
|
self.rollout_and_score_eval(test_item) for test_item in self.test
|
|
]
|
|
results = await tqdm_asyncio.gather(*eval_tasks)
|
|
|
|
# Filter valid results
|
|
valid_results = [
|
|
result
|
|
for result in results
|
|
if not isinstance(result, Exception)
|
|
and result
|
|
and result.get("sample") is not None
|
|
]
|
|
|
|
if not valid_results:
|
|
print("Warning: No valid evaluation results obtained")
|
|
return
|
|
|
|
except Exception as e:
|
|
print(f"Error during evaluation: {e}")
|
|
return
|
|
|
|
# Extract scores and samples from valid results
|
|
scores = [result["score"] for result in valid_results]
|
|
samples = [result["sample"] for result in valid_results]
|
|
valid_scores = [s for s in scores if s is not None]
|
|
|
|
if not valid_scores:
|
|
print("Warning: No valid scores found during evaluation")
|
|
return
|
|
|
|
percent_correct = sum(valid_scores) / len(valid_scores)
|
|
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
|
|
|
# Track performance by subset if available
|
|
subset_scores = {}
|
|
for i, sample in enumerate(samples):
|
|
if sample and i < len(scores):
|
|
subset = sample.get("dataset_subset", "unknown")
|
|
if subset not in subset_scores:
|
|
subset_scores[subset] = []
|
|
subset_scores[subset].append(scores[i])
|
|
|
|
# Log subset-specific metrics
|
|
for subset, subset_score_list in subset_scores.items():
|
|
valid_subset_scores = [s for s in subset_score_list if s is not None]
|
|
if valid_subset_scores:
|
|
avg_score = sum(valid_subset_scores) / len(valid_subset_scores)
|
|
self.eval_metrics.append((f"eval/percent_correct_{subset}", avg_score))
|
|
|
|
# Calculate additional metrics
|
|
format_compliant = sum(
|
|
1 for sample in samples if sample.get("format_compliant", False)
|
|
)
|
|
thinking_mode_used = self.config.thinking_mode
|
|
|
|
# Get response metrics
|
|
response_lengths, thinking_utilization, judgment_counts = (
|
|
self._calculate_response_metrics(samples, thinking_mode_used)
|
|
)
|
|
|
|
# Response length metrics
|
|
if response_lengths:
|
|
avg_response_length = sum(response_lengths) / len(response_lengths)
|
|
response_length_std = (
|
|
sum((x - avg_response_length) ** 2 for x in response_lengths)
|
|
/ len(response_lengths)
|
|
) ** 0.5
|
|
self.eval_metrics.append(("eval/avg_response_length", avg_response_length))
|
|
self.eval_metrics.append(("eval/response_length_std", response_length_std))
|
|
|
|
# Thinking utilization rate
|
|
if thinking_mode_used and samples:
|
|
thinking_utilization_rate = thinking_utilization / len(samples)
|
|
self.eval_metrics.append(
|
|
("eval/thinking_utilization_rate", thinking_utilization_rate)
|
|
)
|
|
|
|
# Judgment distribution metrics
|
|
total_judgments = sum(judgment_counts.values())
|
|
if total_judgments > 0:
|
|
# Calculate entropy for judgment balance
|
|
entropy = 0.0
|
|
for count in judgment_counts.values():
|
|
if count > 0:
|
|
freq = count / total_judgments
|
|
entropy -= freq * math.log(freq)
|
|
self.eval_metrics.append(("eval/judgment_entropy", entropy))
|
|
|
|
# Most common judgment frequency (bias detection)
|
|
max_judgment_count = max(judgment_counts.values())
|
|
most_common_judgment_freq = max_judgment_count / total_judgments
|
|
self.eval_metrics.append(
|
|
("eval/most_common_judgment_freq", most_common_judgment_freq)
|
|
)
|
|
|
|
# Format error rate
|
|
format_error_rate = judgment_counts["format_error"] / total_judgments
|
|
self.eval_metrics.append(("eval/format_error_rate", format_error_rate))
|
|
|
|
# Add overall dataset statistics
|
|
self.eval_metrics.append(("eval/total_items", len(self.test)))
|
|
self.eval_metrics.append(("eval/valid_scores", len(valid_scores)))
|
|
self.eval_metrics.append(("eval/subset_count", len(subset_scores)))
|
|
self.eval_metrics.append(
|
|
(
|
|
"eval/format_compliance_rate",
|
|
format_compliant / len(samples) if samples else 0.0,
|
|
)
|
|
)
|
|
|
|
end_time = time.time()
|
|
|
|
# Build evaluation metrics dict
|
|
eval_metrics = {
|
|
"eval/percent_correct": percent_correct,
|
|
"eval/total_samples": len(samples),
|
|
"eval/correct_samples": sum(valid_scores),
|
|
"eval/format_compliance_rate": (
|
|
format_compliant / len(samples) if samples else 0.0
|
|
),
|
|
}
|
|
|
|
# Add response length metrics
|
|
if response_lengths:
|
|
eval_metrics["eval/avg_response_length"] = avg_response_length
|
|
eval_metrics["eval/response_length_std"] = response_length_std
|
|
|
|
# Add thinking utilization
|
|
if thinking_mode_used and samples:
|
|
eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate
|
|
|
|
# Add judgment distribution metrics
|
|
if total_judgments > 0:
|
|
eval_metrics["eval/judgment_entropy"] = entropy
|
|
eval_metrics["eval/most_common_judgment_freq"] = most_common_judgment_freq
|
|
eval_metrics["eval/format_error_rate"] = format_error_rate
|
|
|
|
# Add subset metrics
|
|
for subset, subset_score_list in subset_scores.items():
|
|
valid_subset_scores = [s for s in subset_score_list if s is not None]
|
|
if valid_subset_scores:
|
|
avg_score = sum(valid_subset_scores) / len(valid_subset_scores)
|
|
eval_metrics[f"eval/percent_correct_{subset}"] = avg_score
|
|
|
|
try:
|
|
await self.evaluate_log(
|
|
metrics=eval_metrics,
|
|
samples=samples,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
generation_parameters={
|
|
"temperature": self.config.eval_temperature,
|
|
"max_tokens": self.config.eval_max_tokens,
|
|
"thinking_mode": thinking_mode_used,
|
|
},
|
|
)
|
|
except Exception as e:
|
|
print(f"Error logging evaluation results: {e}")
|
|
|
|
async def add_rollouts_for_wandb(
|
|
self,
|
|
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
|
item: Item = None,
|
|
) -> None:
|
|
"""Add rollouts to wandb for visualization."""
|
|
if item is None or scored_data is None or not scored_data.get("tokens"):
|
|
return
|
|
|
|
# Extract ground truth and question info
|
|
ground_truth = item[1]
|
|
|
|
# Extract question from the item prompt
|
|
question_info = "unknown_question"
|
|
try:
|
|
# The item[0] contains the prompt tuple with system and user messages
|
|
for role_dict in item[0]:
|
|
role_dict_converted = dict(role_dict)
|
|
if role_dict_converted.get("role") == "user":
|
|
user_content = role_dict_converted.get("content", "")
|
|
# Extract question from the user message format
|
|
question_match = self._question_pattern.search(user_content)
|
|
if question_match:
|
|
question_info = question_match.group(1).strip()
|
|
break
|
|
except Exception:
|
|
# Fallback to placeholder if extraction fails
|
|
question_info = "extraction_failed"
|
|
|
|
# Keep a reasonable number of rollouts
|
|
num_keep = self.config.num_rollouts_per_group_for_logging
|
|
if num_keep == -1:
|
|
num_keep = self.config.group_size
|
|
|
|
num_keep = min(num_keep, len(scored_data["tokens"]))
|
|
|
|
current_rollouts = []
|
|
mode = "thinking" if self.config.thinking_mode else "direct"
|
|
|
|
for i in range(num_keep):
|
|
# Decode the full trajectory
|
|
full_text = self.tokenizer.decode(
|
|
scored_data["tokens"][i], skip_special_tokens=True
|
|
)
|
|
score_val = scored_data["scores"][i]
|
|
|
|
# Extract the model's judgment
|
|
predicted_judgment = "unknown"
|
|
try:
|
|
# Try to get model response from messages or decode from tokens
|
|
messages = scored_data.get("messages", [])
|
|
if i < len(messages) and isinstance(messages[i], list) and messages[i]:
|
|
model_response = messages[i][-1].get("content", "")
|
|
else:
|
|
# Fallback to decoding tokens
|
|
model_response = full_text
|
|
|
|
predicted_judgment = self.process_judgement(
|
|
model_response, track_metrics=False
|
|
)
|
|
except Exception:
|
|
predicted_judgment = "parse_error"
|
|
|
|
current_rollouts.append(
|
|
(
|
|
full_text,
|
|
score_val,
|
|
ground_truth,
|
|
predicted_judgment,
|
|
question_info,
|
|
mode,
|
|
)
|
|
)
|
|
|
|
self.rollouts_for_wandb.append(current_rollouts)
|
|
|
|
# Keep only recent rollouts
|
|
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
|
self.rollouts_for_wandb.pop(0)
|
|
|
|
async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
|
|
"""Create wandb table for rollout visualization."""
|
|
if not self.rollouts_for_wandb:
|
|
return wandb_metrics
|
|
|
|
table = wandb.Table(
|
|
columns=[
|
|
"full_text",
|
|
"score",
|
|
"ground_truth",
|
|
"predicted_judgment",
|
|
"question_info",
|
|
"mode",
|
|
]
|
|
)
|
|
|
|
for group_rollouts in self.rollouts_for_wandb:
|
|
for rollout_tuple in group_rollouts:
|
|
if len(rollout_tuple) == 6:
|
|
table.add_data(*rollout_tuple)
|
|
|
|
wandb_metrics["train/rollouts"] = table
|
|
self.rollouts_for_wandb = []
|
|
return wandb_metrics
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log metrics to wandb."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Basic accuracy metrics
|
|
if self.percent_correct_buffer:
|
|
wandb_metrics["train/percent_correct"] = sum(
|
|
self.percent_correct_buffer
|
|
) / len(self.percent_correct_buffer)
|
|
|
|
# Judgment letter distribution and accuracy
|
|
total_letters = sum(self.judgment_letter_counts.values())
|
|
if total_letters > 0:
|
|
# Calculate entropy once
|
|
entropy = 0.0
|
|
for letter in self.choice_letters:
|
|
letter_count = self.judgment_letter_counts[letter]
|
|
letter_correct = self.judgment_letter_correct[letter]
|
|
|
|
# Letter frequency and accuracy
|
|
freq = letter_count / total_letters
|
|
wandb_metrics[f"train/judgment_freq_{letter}"] = freq
|
|
wandb_metrics[f"train/judgment_acc_{letter}"] = (
|
|
letter_correct / letter_count if letter_count > 0 else 0.0
|
|
)
|
|
|
|
# Accumulate entropy
|
|
if freq > 0:
|
|
entropy -= freq * math.log(freq)
|
|
|
|
wandb_metrics["train/judgment_entropy"] = entropy
|
|
wandb_metrics["train/judgment_balance"] = entropy / math.log(
|
|
self.config.num_choices
|
|
) # Normalized entropy
|
|
|
|
# Error rate and other metrics
|
|
if self.total_judgments > 0:
|
|
wandb_metrics["train/error_rate"] = self.error_count / self.total_judgments
|
|
wandb_metrics["train/format_compliance_rate"] = 1.0 - (
|
|
self.error_count / self.total_judgments
|
|
)
|
|
|
|
# Configuration and mode metrics
|
|
wandb_metrics.update(
|
|
{
|
|
"train/thinking_mode_enabled": (
|
|
1.0 if self.config.thinking_mode else 0.0
|
|
),
|
|
"train/total_judgments": self.total_judgments,
|
|
"config/group_size": self.config.group_size,
|
|
"config/max_token_length": self.config.max_token_length,
|
|
"config/num_choices": self.config.num_choices,
|
|
}
|
|
)
|
|
|
|
# Reset training metrics
|
|
self._reset_metrics()
|
|
|
|
# Add evaluation metrics
|
|
for metric_name, metric_value in self.eval_metrics:
|
|
wandb_metrics[metric_name] = metric_value
|
|
self.eval_metrics = []
|
|
|
|
# Add rollout table
|
|
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
PairwiseJudgementEnv.cli()
|