atropos/environments/pairwise_judgement_environment.py
2025-07-14 09:43:20 +00:00

1043 lines
41 KiB
Python

import json
import math
import random
import re
import time
from typing import Dict, List, Optional, Tuple, Union
import wandb
from datasets import load_dataset
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
Item,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class PairwiseJudgementConfig(BaseEnvConfig):
"""Configuration for PairwiseJudgementEnv with thinking mode and configurable options."""
thinking_mode: bool = Field(
default=False,
description="Whether to enable thinking mode with <think></think> tags.",
)
num_choices: int = Field(
default=4,
ge=2,
le=26,
description="Number of choices for pairwise judgment (2-26, corresponding to A-Z).",
)
custom_thinking_prompt: Optional[str] = Field(
default=None,
description="Custom thinking prompt. If None, uses the default thinking prompt.",
)
custom_judgment_prompt: Optional[str] = Field(
default=None,
description="Custom judgment prompt. If None, uses the default judgment prompt.",
)
eval_temperature: float = Field(
default=0.6,
description="Temperature for evaluation completions.",
)
rollout_temperature: float = Field(
default=0.8,
description="Temperature for training rollout completions.",
)
eval_max_tokens: int = Field(
default=1024 * 16,
description="Maximum tokens for evaluation completions.",
)
train_max_tokens: int = Field(
default=1024 * 16,
description="Maximum tokens for training completions.",
)
class PairwiseJudgementEnv(BaseEnv):
name = "pairwise_judgement"
env_config_cls = PairwiseJudgementConfig
def __init__(
self,
config: PairwiseJudgementConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.config: PairwiseJudgementConfig = config
self.percent_correct_buffer = []
self.eval_metrics = []
# Generate choice letters based on num_choices (A, B, C, D... up to Z)
self.choice_letters = [chr(65 + i) for i in range(self.config.num_choices)]
# Initialize detailed metrics tracking for all choice letters
self.judgment_letter_counts = {letter: 0 for letter in self.choice_letters}
self.judgment_letter_correct = {letter: 0 for letter in self.choice_letters}
self.error_count = 0 # Failed to follow format
self.total_judgments = 0
self.rollouts_for_wandb = []
# Pre-compile regex patterns for performance
self._think_pattern = re.compile(r"<think>")
self._think_close_pattern = re.compile(r"</think>")
self._think_content_pattern = re.compile(r"</think>\s*(.*)", re.DOTALL)
self._question_pattern = re.compile(
r"\[User Question\]\s*(.*?)\s*\[The Start of Assistant A", re.DOTALL
)
self._thinking_extract_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
# Pre-compile choice patterns for each letter
self._choice_patterns = {
letter: re.compile(rf"\[\[{letter}\]\]") for letter in self.choice_letters
}
self._answer_choice_patterns = {
letter: re.compile(
rf"\[The Start of Assistant {letter}\'s Answer\]\s*(.*?)\s*\[The End of Assistant {letter}\'s Answer\]",
re.DOTALL,
)
for letter in self.choice_letters
}
# System prompts (use custom ones if provided, otherwise defaults)
self.thinking_system_prompt = self._get_thinking_prompt()
self.judgment_system_prompt = self._get_judgment_prompt()
def _get_thinking_prompt(self) -> str:
"""Get thinking system prompt."""
return (
self.config.custom_thinking_prompt
if self.config.custom_thinking_prompt
else "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
"</think> tags, and then provide your solution or response to the problem."
)
def _get_judgment_prompt(self) -> str:
"""Get judgment system prompt."""
if self.config.custom_judgment_prompt:
return self.config.custom_judgment_prompt
choice_format_examples = ", ".join(
[
f'"[[{letter}]]" if assistant {letter} is best'
for letter in self.choice_letters
]
)
return (
f"Please act as an impartial judge and evaluate the quality of the responses provided by {self.config.num_choices} AI assistants to the user question displayed below. "
f"You should choose the assistant that follows the user's instructions and answers the user's question best. "
f"Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. "
f"Begin your evaluation by comparing the {self.config.num_choices} responses and provide a short explanation. "
f"Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. "
f"Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. "
f"Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: "
f"{choice_format_examples}."
)
def _reset_metrics(self) -> None:
"""Reset training metrics."""
self.percent_correct_buffer = []
self.judgment_letter_counts = {letter: 0 for letter in self.choice_letters}
self.judgment_letter_correct = {letter: 0 for letter in self.choice_letters}
self.error_count = 0
self.total_judgments = 0
def _convert_messages_to_list(self, prompt_tuple: Tuple) -> List[Dict]:
"""Convert frozenset message format to list format."""
messages = []
for role_dict in prompt_tuple:
messages.append(dict(role_dict))
return messages
def _create_system_content(self) -> str:
"""Create system message content based on thinking mode."""
if self.config.thinking_mode:
return f"{self.thinking_system_prompt}\n\n{self.judgment_system_prompt}"
return self.judgment_system_prompt
def _prepare_completion_input(self, prompt_tuple: Tuple) -> Tuple[List[Dict], str]:
"""Convert prompt tuple to messages and formatted prompt text."""
messages = self._convert_messages_to_list(prompt_tuple)
prompt_text = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
return messages, prompt_text
def _get_train_completion_params(self) -> Dict:
"""Get completion parameters for training rollouts."""
return {
"n": self.config.group_size,
"max_tokens": self.config.train_max_tokens,
"temperature": self.config.rollout_temperature,
}
def _get_eval_completion_params(self) -> Dict:
"""Get completion parameters for evaluation."""
return {
"n": 1,
"max_tokens": self.config.eval_max_tokens,
"temperature": self.config.eval_temperature,
"split": "eval",
}
@classmethod
def config_init(cls) -> Tuple[PairwiseJudgementConfig, List[APIServerConfig]]:
env_config = PairwiseJudgementConfig(
tokenizer_name="NousResearch/Hermes-4-Qwen3-14B-1-e3",
group_size=16,
use_wandb=True,
max_num_workers_per_node=16,
rollout_server_url="http://localhost:8000",
total_steps=2000,
batch_size=1024,
steps_per_eval=25,
max_token_length=1024 * 16,
inference_weight=1.0,
wandb_name="pairwise_judgment",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
min_batch_allocation=0.1,
thinking_mode=True,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_max_requests_at_once=32,
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def setup(self) -> None:
"""Set up the environment by loading datasets."""
# Load placeholder train dataset (not actually used since we generate synthetic examples)
try:
self.train = load_dataset("example/train", split="train")
print(f"Loaded placeholder train dataset with {len(self.train)} examples")
except Exception as e:
# Create minimal placeholder data if dataset doesn't exist
# Note: This isn't actually used since get_next_item() generates synthetic examples
self.train = [{"question": "What is 2+2?", "answer": "4"}] * 100
print(f"Using synthetic placeholder training data due to error: {e}")
# Load evaluation dataset - reward-bench-2 (MUST WORK OR CRASH)
self.test = load_dataset(
"allenai/reward-bench-2", split="test", trust_remote_code=True
)
print(f"Loaded reward-bench-2 eval dataset with {len(self.test)} examples")
# Debug: Show sample evaluation item structure
if len(self.test) > 0:
try:
sample_item = self.test[0]
print(f"\nSample eval item structure:")
print(f"- Available keys: {list(sample_item.keys())}")
# Handle different dataset structures
if "prompt" in sample_item:
print(f"- Prompt: {sample_item['prompt'][:100]}...")
elif "chosen" in sample_item and isinstance(sample_item["chosen"], str):
print(f"- Chosen (string): {sample_item['chosen'][:100]}...")
elif "rejected" in sample_item and isinstance(
sample_item["rejected"], str
):
print(f"- Rejected (string): {sample_item['rejected'][:100]}...")
if "chosen" in sample_item:
if isinstance(sample_item["chosen"], list):
print(f"- Chosen responses: {len(sample_item['chosen'])}")
if sample_item["chosen"]:
print(
f"- First chosen (truncated): {sample_item['chosen'][0][:200]}..."
)
else:
print(f"- Chosen (string): {sample_item['chosen'][:200]}...")
if "rejected" in sample_item:
if isinstance(sample_item["rejected"], list):
print(f"- Rejected responses: {len(sample_item['rejected'])}")
if sample_item["rejected"]:
print(
f"- First rejected (truncated): {sample_item['rejected'][0][:200]}..."
)
else:
print(
f"- Rejected (string): {sample_item['rejected'][:200]}..."
)
except Exception as e:
print(f"Warning: Could not display sample item structure: {e}")
self.iter = 0
def save_checkpoint(self, step: int, data: Optional[Dict] = None) -> None:
"""Save checkpoint including iteration state."""
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
def process_judgement(self, judgment: str, track_metrics: bool = True) -> str:
"""Extract judgment from model response."""
if self.config.thinking_mode:
# Check for exactly one pair of think tags using pre-compiled patterns
think_open_count = len(self._think_pattern.findall(judgment))
think_close_count = len(self._think_close_pattern.findall(judgment))
if think_open_count != 1 or think_close_count != 1:
if track_metrics:
self.error_count += 1
self.total_judgments += 1
return "format_error"
# Parse only content after </think> tags
match = self._think_content_pattern.search(judgment)
if match:
judgment = match.group(1)
else:
if track_metrics:
self.error_count += 1
self.total_judgments += 1
return "format_error"
if track_metrics:
self.total_judgments += 1
# Check for each possible choice letter using pre-compiled patterns
for letter in self.choice_letters:
if self._choice_patterns[letter].search(judgment):
if track_metrics:
self.judgment_letter_counts[letter] += 1
return letter
# No valid judgment found
if track_metrics:
self.error_count += 1
return "format_error"
def create_judgment_prompt(self, question: str, answers: List[str]) -> str:
"""Create the user prompt for judgment task."""
if len(answers) != self.config.num_choices:
raise ValueError(
f"Need exactly {self.config.num_choices} answers for judgment, got {len(answers)}"
)
prompt = f"[User Question]\n{question}\n\n"
for i, answer in enumerate(answers):
letter = self.choice_letters[i]
prompt += f"[The Start of Assistant {letter}'s Answer]\n{answer}\n[The End of Assistant {letter}'s Answer]\n\n"
return prompt.strip()
async def get_next_item(self) -> Item:
"""Generate next training item with synthetic data."""
self.iter += 1
# Create system message
system_content = self._create_system_content()
# Create varied placeholder judgment tasks
examples = [
{
"question": "What is the capital of France?",
"correct": "The capital of France is Paris, which has been the capital since 987 AD and serves as the political, economic, and cultural center of the country.",
"incorrect": [
"The capital of France is London.",
"France's capital is Berlin, located in central Europe.",
"I don't know the answer to this question.",
"France doesn't have a capital city.",
"The capital changes every year in France.",
"Paris is just a city, not a capital.",
],
},
{
"question": "How do you fix a memory leak in Python?",
"correct": "To fix memory leaks in Python: 1) Use memory profilers like tracemalloc or memory_profiler to identify leaks, 2) Ensure proper cleanup of resources with context managers, 3) Break circular references, 4) Close files and database connections explicitly, and 5) Use weak references when appropriate.",
"incorrect": [
"Just restart your computer and the memory leak will be fixed.",
"Python automatically handles all memory management, so memory leaks are impossible.",
"You need to reinstall Python to fix memory leaks.",
"Memory leaks don't exist in Python because it's interpreted.",
"Use more RAM to solve memory leaks.",
"Delete the Python installation and use a different language.",
],
},
{
"question": "Explain the difference between machine learning and artificial intelligence.",
"correct": "Artificial Intelligence (AI) is the broader field focused on creating systems that can perform tasks typically requiring human intelligence. Machine Learning (ML) is a subset of AI that uses algorithms to learn patterns from data without being explicitly programmed for each task. So ML is one approach to achieving AI.",
"incorrect": [
"Machine learning and artificial intelligence are exactly the same thing with different names.",
"Machine learning is much broader than AI and includes all computer science.",
"AI is only about robots, while machine learning is only about statistics.",
"Machine learning came before AI historically.",
"AI is a subset of machine learning, not the other way around.",
"There is no difference; they are marketing terms for the same technology.",
],
},
]
# Select random example
example = random.choice(examples)
# Create list with correct and incorrect answers, ensuring we have enough
incorrect_answers = example["incorrect"][
: self.config.num_choices - 1
] # Take enough incorrect answers
all_answers = [example["correct"]] + incorrect_answers
# If we don't have enough incorrect answers, pad with generic ones
while len(all_answers) < self.config.num_choices:
all_answers.append(
"I don't have enough information to answer this question."
)
random.shuffle(all_answers)
# Find where correct answer ended up
correct_index = all_answers.index(example["correct"])
correct_answer = self.choice_letters[correct_index]
user_content = self.create_judgment_prompt(example["question"], all_answers)
prompt = tuple(
[
frozenset({"role": "system", "content": system_content}.items()),
frozenset({"role": "user", "content": user_content}.items()),
]
)
return (prompt, correct_answer)
def prepare_eval_item(self, item: dict) -> Tuple[Optional[Tuple], Optional[str]]:
"""
Prepare an evaluation item from the reward-bench-2 dataset.
Dataset structure:
- chosen: list with 1 element (the best response)
- rejected: list with 3+ elements (worse responses)
- We take chosen[0] + rejected[:num_choices-1] to create judgment with configured number of choices
"""
try:
question = item.get("prompt", "")
chosen_responses = item.get("chosen", [])
rejected_responses = item.get("rejected", [])
# Validate required fields
if not question:
return None, None
# Take one chosen response and (num_choices-1) rejected responses
required_rejected = self.config.num_choices - 1
if (
len(chosen_responses) == 0
or len(rejected_responses) < required_rejected
):
return None, None
chosen = chosen_responses[0]
rejected = rejected_responses[:required_rejected]
# Validate response content
if not chosen or not all(rejected):
return None, None
# Create list with answer and whether it's correct
data = [(chosen, True)] + [(r, False) for r in rejected]
random.shuffle(data)
# Extract shuffled answers and find correct position
shuffled_answers = [item[0] for item in data]
correct_index = next(
i for i, (_, is_correct) in enumerate(data) if is_correct
)
correct_answer = self.choice_letters[correct_index]
# Create system message
system_content = self._create_system_content()
# Create user prompt
user_content = self.create_judgment_prompt(question, shuffled_answers)
prompt = tuple(
[
frozenset({"role": "system", "content": system_content}.items()),
frozenset({"role": "user", "content": user_content}.items()),
]
)
return prompt, correct_answer
except Exception as e:
print(f"Error preparing evaluation item: {e}")
return None, None
async def collect_trajectories(self, item: Item) -> Tuple[ScoredDataGroup, List]:
"""Collect and score model trajectories."""
messages, prompt_text = self._prepare_completion_input(item[0])
completion_params = self._get_train_completion_params()
completions = await self.server.completion(
prompt=prompt_text, **completion_params
)
# Build trajectories without duplicating message construction
to_score = []
for completion_choice in completions.choices:
# Add assistant response to existing messages
trajectory_messages = messages + [
{"role": "assistant", "content": completion_choice.text}
]
to_score.append((tuple(trajectory_messages), item[1]))
scored_data = await self.score(to_score)
# Add rollouts for wandb visualization
if scored_data is not None:
await self.add_rollouts_for_wandb(scored_data, item)
return scored_data, []
async def score(self, rollout_group_data: List[Tuple]) -> Optional[ScoredDataGroup]:
"""Score a group of rollout data."""
if not rollout_group_data:
return None
try:
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
random.shuffle(rollout_group_data)
for item in rollout_group_data:
# Simplified validation
if not item or len(item) < 2 or not item[0]:
continue
model_response = item[0][-1]["content"]
ground_truth = item[1]
predicted_answer = self.process_judgement(
model_response, track_metrics=True
)
reward = 1.0 if predicted_answer == ground_truth else 0.0
# Track correct judgments per letter
if (
predicted_answer == ground_truth
and predicted_answer != "format_error"
):
self.judgment_letter_correct[predicted_answer] += 1
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Skip obviously bad examples
if len([1 for mask in masks if mask != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(reward) # Use reward directly (1.0 or 0.0)
if len(scores["tokens"]) >= self.config.group_size:
break
if not scores["tokens"]:
return None
# Update percent correct buffer
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
# Return None if all scores are the same (no learning signal)
if len(set(scores["scores"])) == 1:
return None
return scores
except Exception as e:
print(f"Error in score method: {e}")
return None
async def rollout_and_score_eval(self, test_item: dict) -> dict:
"""Rollout and score evaluation with detailed sample data collection."""
try:
prompt, ground_truth = self.prepare_eval_item(test_item)
if prompt is None:
return {"score": 0.0, "sample": None}
messages, prompt_text = self._prepare_completion_input(prompt)
completion_params = self._get_eval_completion_params()
completion = await self.server.completion(
prompt=prompt_text, **completion_params
)
if not completion.choices:
return {"score": 0.0, "sample": None}
model_response = completion.choices[0].text
predicted_answer = self.process_judgement(
model_response, track_metrics=False
)
score = 1.0 if predicted_answer == ground_truth else 0.0
# Extract question and answer choices from the user message
user_content = messages[1]["content"]
question_match = self._question_pattern.search(user_content)
question = (
question_match.group(1).strip()
if question_match
else "Unknown question"
)
# Extract individual answer choices for all configured letters
answer_choices = {}
for letter in self.choice_letters:
match = self._answer_choice_patterns[letter].search(user_content)
if match:
answer_choices[letter] = match.group(1).strip()
# Add full conversation including model response
full_messages = messages + [
{"role": "assistant", "content": model_response}
]
sample = {
"messages": full_messages,
"question": question,
"answer_choices": answer_choices,
"ground_truth": ground_truth,
"predicted_judgment": predicted_answer,
"score": int(score),
"correct": bool(score),
"finish_reason": completion.choices[0].finish_reason,
"thinking_mode": self.config.thinking_mode,
"format_compliant": predicted_answer != "format_error",
"dataset_item_id": test_item.get("id", "unknown"),
"dataset_subset": test_item.get("subset", "unknown"),
"num_choices": self.config.num_choices,
}
# Add thinking-specific parsing info
if self.config.thinking_mode:
if "</think>" in model_response:
sample["response_after_think"] = model_response.split("</think>")[
-1
].strip()
sample["thinking_content"] = self._thinking_extract_pattern.search(
model_response
)
if sample["thinking_content"]:
sample["thinking_content"] = (
sample["thinking_content"].group(1).strip()
)
else:
sample["response_after_think"] = model_response
sample["thinking_content"] = None
return {"score": score, "sample": sample}
except Exception as e:
print(f"Error in rollout_and_score_eval: {e}")
return {"score": 0.0, "sample": None}
def _calculate_response_metrics(
self, samples: List[dict], thinking_mode_used: bool
) -> Tuple[List[int], int, Dict[str, int]]:
"""Calculate response-related metrics from samples."""
response_lengths = []
thinking_utilization = 0
judgment_counts = {letter: 0 for letter in self.choice_letters}
judgment_counts["format_error"] = 0
for sample in samples:
if not sample:
continue
# Track response length
messages = sample.get("messages", [])
if messages:
assistant_msg = messages[-1].get("content", "")
response_lengths.append(len(assistant_msg))
# Track thinking utilization in thinking mode
if thinking_mode_used:
thinking_content = sample.get("thinking_content")
if thinking_content:
thinking_utilization += 1
# Track judgment distribution
predicted_judgment = sample.get("predicted_judgment", "format_error")
if predicted_judgment in judgment_counts:
judgment_counts[predicted_judgment] += 1
return response_lengths, thinking_utilization, judgment_counts
async def evaluate(self, *args, **kwargs) -> None:
"""Evaluate the model on the test dataset."""
start_time = time.time()
try:
eval_tasks = [
self.rollout_and_score_eval(test_item) for test_item in self.test
]
results = await tqdm_asyncio.gather(*eval_tasks)
# Filter valid results
valid_results = [
result
for result in results
if not isinstance(result, Exception)
and result
and result.get("sample") is not None
]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return
except Exception as e:
print(f"Error during evaluation: {e}")
return
# Extract scores and samples from valid results
scores = [result["score"] for result in valid_results]
samples = [result["sample"] for result in valid_results]
valid_scores = [s for s in scores if s is not None]
if not valid_scores:
print("Warning: No valid scores found during evaluation")
return
percent_correct = sum(valid_scores) / len(valid_scores)
self.eval_metrics.append(("eval/percent_correct", percent_correct))
# Track performance by subset if available
subset_scores = {}
for i, sample in enumerate(samples):
if sample and i < len(scores):
subset = sample.get("dataset_subset", "unknown")
if subset not in subset_scores:
subset_scores[subset] = []
subset_scores[subset].append(scores[i])
# Log subset-specific metrics
for subset, subset_score_list in subset_scores.items():
valid_subset_scores = [s for s in subset_score_list if s is not None]
if valid_subset_scores:
avg_score = sum(valid_subset_scores) / len(valid_subset_scores)
self.eval_metrics.append((f"eval/percent_correct_{subset}", avg_score))
# Calculate additional metrics
format_compliant = sum(
1 for sample in samples if sample.get("format_compliant", False)
)
thinking_mode_used = self.config.thinking_mode
# Get response metrics
response_lengths, thinking_utilization, judgment_counts = (
self._calculate_response_metrics(samples, thinking_mode_used)
)
# Response length metrics
if response_lengths:
avg_response_length = sum(response_lengths) / len(response_lengths)
response_length_std = (
sum((x - avg_response_length) ** 2 for x in response_lengths)
/ len(response_lengths)
) ** 0.5
self.eval_metrics.append(("eval/avg_response_length", avg_response_length))
self.eval_metrics.append(("eval/response_length_std", response_length_std))
# Thinking utilization rate
if thinking_mode_used and samples:
thinking_utilization_rate = thinking_utilization / len(samples)
self.eval_metrics.append(
("eval/thinking_utilization_rate", thinking_utilization_rate)
)
# Judgment distribution metrics
total_judgments = sum(judgment_counts.values())
if total_judgments > 0:
# Calculate entropy for judgment balance
entropy = 0.0
for count in judgment_counts.values():
if count > 0:
freq = count / total_judgments
entropy -= freq * math.log(freq)
self.eval_metrics.append(("eval/judgment_entropy", entropy))
# Most common judgment frequency (bias detection)
max_judgment_count = max(judgment_counts.values())
most_common_judgment_freq = max_judgment_count / total_judgments
self.eval_metrics.append(
("eval/most_common_judgment_freq", most_common_judgment_freq)
)
# Format error rate
format_error_rate = judgment_counts["format_error"] / total_judgments
self.eval_metrics.append(("eval/format_error_rate", format_error_rate))
# Add overall dataset statistics
self.eval_metrics.append(("eval/total_items", len(self.test)))
self.eval_metrics.append(("eval/valid_scores", len(valid_scores)))
self.eval_metrics.append(("eval/subset_count", len(subset_scores)))
self.eval_metrics.append(
(
"eval/format_compliance_rate",
format_compliant / len(samples) if samples else 0.0,
)
)
end_time = time.time()
# Build evaluation metrics dict
eval_metrics = {
"eval/percent_correct": percent_correct,
"eval/total_samples": len(samples),
"eval/correct_samples": sum(valid_scores),
"eval/format_compliance_rate": (
format_compliant / len(samples) if samples else 0.0
),
}
# Add response length metrics
if response_lengths:
eval_metrics["eval/avg_response_length"] = avg_response_length
eval_metrics["eval/response_length_std"] = response_length_std
# Add thinking utilization
if thinking_mode_used and samples:
eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate
# Add judgment distribution metrics
if total_judgments > 0:
eval_metrics["eval/judgment_entropy"] = entropy
eval_metrics["eval/most_common_judgment_freq"] = most_common_judgment_freq
eval_metrics["eval/format_error_rate"] = format_error_rate
# Add subset metrics
for subset, subset_score_list in subset_scores.items():
valid_subset_scores = [s for s in subset_score_list if s is not None]
if valid_subset_scores:
avg_score = sum(valid_subset_scores) / len(valid_subset_scores)
eval_metrics[f"eval/percent_correct_{subset}"] = avg_score
try:
await self.evaluate_log(
metrics=eval_metrics,
samples=samples,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": self.config.eval_temperature,
"max_tokens": self.config.eval_max_tokens,
"thinking_mode": thinking_mode_used,
},
)
except Exception as e:
print(f"Error logging evaluation results: {e}")
async def add_rollouts_for_wandb(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
) -> None:
"""Add rollouts to wandb for visualization."""
if item is None or scored_data is None or not scored_data.get("tokens"):
return
# Extract ground truth and question info
ground_truth = item[1]
# Extract question from the item prompt
question_info = "unknown_question"
try:
# The item[0] contains the prompt tuple with system and user messages
for role_dict in item[0]:
role_dict_converted = dict(role_dict)
if role_dict_converted.get("role") == "user":
user_content = role_dict_converted.get("content", "")
# Extract question from the user message format
question_match = self._question_pattern.search(user_content)
if question_match:
question_info = question_match.group(1).strip()
break
except Exception:
# Fallback to placeholder if extraction fails
question_info = "extraction_failed"
# Keep a reasonable number of rollouts
num_keep = self.config.num_rollouts_per_group_for_logging
if num_keep == -1:
num_keep = self.config.group_size
num_keep = min(num_keep, len(scored_data["tokens"]))
current_rollouts = []
mode = "thinking" if self.config.thinking_mode else "direct"
for i in range(num_keep):
# Decode the full trajectory
full_text = self.tokenizer.decode(
scored_data["tokens"][i], skip_special_tokens=True
)
score_val = scored_data["scores"][i]
# Extract the model's judgment
predicted_judgment = "unknown"
try:
# Try to get model response from messages or decode from tokens
messages = scored_data.get("messages", [])
if i < len(messages) and isinstance(messages[i], list) and messages[i]:
model_response = messages[i][-1].get("content", "")
else:
# Fallback to decoding tokens
model_response = full_text
predicted_judgment = self.process_judgement(
model_response, track_metrics=False
)
except Exception:
predicted_judgment = "parse_error"
current_rollouts.append(
(
full_text,
score_val,
ground_truth,
predicted_judgment,
question_info,
mode,
)
)
self.rollouts_for_wandb.append(current_rollouts)
# Keep only recent rollouts
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
"""Create wandb table for rollout visualization."""
if not self.rollouts_for_wandb:
return wandb_metrics
table = wandb.Table(
columns=[
"full_text",
"score",
"ground_truth",
"predicted_judgment",
"question_info",
"mode",
]
)
for group_rollouts in self.rollouts_for_wandb:
for rollout_tuple in group_rollouts:
if len(rollout_tuple) == 6:
table.add_data(*rollout_tuple)
wandb_metrics["train/rollouts"] = table
self.rollouts_for_wandb = []
return wandb_metrics
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
# Basic accuracy metrics
if self.percent_correct_buffer:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
# Judgment letter distribution and accuracy
total_letters = sum(self.judgment_letter_counts.values())
if total_letters > 0:
# Calculate entropy once
entropy = 0.0
for letter in self.choice_letters:
letter_count = self.judgment_letter_counts[letter]
letter_correct = self.judgment_letter_correct[letter]
# Letter frequency and accuracy
freq = letter_count / total_letters
wandb_metrics[f"train/judgment_freq_{letter}"] = freq
wandb_metrics[f"train/judgment_acc_{letter}"] = (
letter_correct / letter_count if letter_count > 0 else 0.0
)
# Accumulate entropy
if freq > 0:
entropy -= freq * math.log(freq)
wandb_metrics["train/judgment_entropy"] = entropy
wandb_metrics["train/judgment_balance"] = entropy / math.log(
self.config.num_choices
) # Normalized entropy
# Error rate and other metrics
if self.total_judgments > 0:
wandb_metrics["train/error_rate"] = self.error_count / self.total_judgments
wandb_metrics["train/format_compliance_rate"] = 1.0 - (
self.error_count / self.total_judgments
)
# Configuration and mode metrics
wandb_metrics.update(
{
"train/thinking_mode_enabled": (
1.0 if self.config.thinking_mode else 0.0
),
"train/total_judgments": self.total_judgments,
"config/group_size": self.config.group_size,
"config/max_token_length": self.config.max_token_length,
"config/num_choices": self.config.num_choices,
}
)
# Reset training metrics
self._reset_metrics()
# Add evaluation metrics
for metric_name, metric_value in self.eval_metrics:
wandb_metrics[metric_name] = metric_value
self.eval_metrics = []
# Add rollout table
wandb_metrics = await self.create_rollout_table(wandb_metrics)
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
PairwiseJudgementEnv.cli()