mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
1980 lines
80 KiB
Python
1980 lines
80 KiB
Python
import asyncio
|
|
import math
|
|
import random
|
|
import re
|
|
import time
|
|
from enum import Enum
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import wandb
|
|
from datasets import load_dataset
|
|
from pydantic import Field
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
EvalHandlingEnum,
|
|
Item,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
|
|
class RewardBenchCategory(str, Enum):
|
|
"""Enumeration of RewardBench-2 dataset categories."""
|
|
|
|
FACTUALITY = "Factuality"
|
|
FOCUS = "Focus"
|
|
MATH = "Math"
|
|
PRECISE_IF = "Precise IF"
|
|
SAFETY = "Safety"
|
|
TIES = "Ties"
|
|
|
|
|
|
class PairwiseJudgementConfig(BaseEnvConfig):
|
|
"""Configuration for PairwiseJudgementEnv with thinking mode and configurable options."""
|
|
|
|
thinking_mode: bool = Field(
|
|
default=False,
|
|
description="Whether to enable thinking mode with <think></think> tags.",
|
|
)
|
|
|
|
num_choices: int = Field(
|
|
default=4,
|
|
ge=2,
|
|
le=26,
|
|
description="Number of choices for pairwise judgment (2-26, corresponding to A-Z).",
|
|
)
|
|
|
|
custom_thinking_prompt: Optional[str] = Field(
|
|
default=None,
|
|
description="Custom thinking prompt. If None, uses the default thinking prompt.",
|
|
)
|
|
|
|
custom_judgment_prompt: Optional[str] = Field(
|
|
default=None,
|
|
description="Custom judgment prompt. If None, uses the default judgment prompt.",
|
|
)
|
|
|
|
eval_temperature: float = Field(
|
|
default=0.6,
|
|
description="Temperature for evaluation completions.",
|
|
)
|
|
|
|
rollout_temperature: float = Field(
|
|
default=0.8,
|
|
description="Temperature for training rollout completions.",
|
|
)
|
|
|
|
eval_max_tokens: int = Field(
|
|
default=1024 * 16,
|
|
description="Maximum tokens for evaluation completions.",
|
|
)
|
|
|
|
train_max_tokens: int = Field(
|
|
default=1024 * 16,
|
|
description="Maximum tokens for training completions.",
|
|
)
|
|
|
|
# Ties-specific configuration
|
|
max_ties_responses: int = Field(
|
|
default=100,
|
|
description="Maximum number of responses to evaluate in ties mode to control API costs.",
|
|
)
|
|
|
|
# Category filtering configuration
|
|
eval_categories: Optional[List[RewardBenchCategory]] = Field(
|
|
default=None,
|
|
description="List of categories to evaluate. If None, evaluates all categories. Categories not in this list will be skipped.", # noqa
|
|
)
|
|
|
|
# Retry configuration
|
|
max_retries: int = Field(
|
|
default=3,
|
|
ge=1,
|
|
description="Maximum number of retries for failed API calls.",
|
|
)
|
|
|
|
retry_delay: float = Field(
|
|
default=1.0,
|
|
ge=0.0,
|
|
description="Delay in seconds between retry attempts.",
|
|
)
|
|
|
|
min_response_length: int = Field(
|
|
default=5,
|
|
ge=1,
|
|
description="Minimum response length to consider valid (filters out EOS-only responses).",
|
|
)
|
|
|
|
# Debug configuration
|
|
full_debug: bool = Field(
|
|
default=False,
|
|
description="Enable full debug mode - logs every API request and response with truncated content.",
|
|
)
|
|
|
|
|
|
class PairwiseJudgementEnv(BaseEnv):
|
|
name = "pairwise_judgement"
|
|
env_config_cls = PairwiseJudgementConfig
|
|
|
|
def __init__(
|
|
self,
|
|
config: PairwiseJudgementConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.config: PairwiseJudgementConfig = config
|
|
self.percent_correct_buffer = []
|
|
self.eval_metrics = []
|
|
|
|
# Generate choice letters based on num_choices (A, B, C, D... up to Z)
|
|
self.choice_letters = [chr(65 + i) for i in range(self.config.num_choices)]
|
|
|
|
# Initialize detailed metrics tracking for all choice letters
|
|
self.judgment_letter_counts = {letter: 0 for letter in self.choice_letters}
|
|
self.judgment_letter_correct = {letter: 0 for letter in self.choice_letters}
|
|
self.error_count = 0 # Failed to follow format
|
|
self.total_judgments = 0
|
|
self.rollouts_for_wandb = []
|
|
|
|
# Pre-compile regex patterns for performance
|
|
self._think_pattern = re.compile(r"<think>")
|
|
self._think_close_pattern = re.compile(r"</think>")
|
|
self._think_content_pattern = re.compile(r"</think>\s*(.*)", re.DOTALL)
|
|
self._question_pattern = re.compile(
|
|
r"\[User Question\]\s*(.*?)\s*\[The Start of Assistant A", re.DOTALL
|
|
)
|
|
self._thinking_extract_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
|
|
|
# Pre-compile choice patterns for each letter
|
|
self._choice_patterns = {
|
|
letter: re.compile(rf"\[\[{letter}\]\]") for letter in self.choice_letters
|
|
}
|
|
self._answer_choice_patterns = {
|
|
letter: re.compile(
|
|
rf"\[The Start of Assistant {letter}\'s Answer\]\s*(.*?)\s*\[The End of Assistant {letter}\'s Answer\]", # noqa
|
|
re.DOTALL,
|
|
)
|
|
for letter in self.choice_letters
|
|
}
|
|
|
|
# Pre-compile rating pattern for ties evaluation
|
|
self._rating_pattern = re.compile(r"\b([1-9]|10)\b\s*$")
|
|
|
|
# System prompts (use custom ones if provided, otherwise defaults)
|
|
self.thinking_system_prompt = self._get_thinking_prompt()
|
|
self.judgment_system_prompt = self._get_judgment_prompt()
|
|
|
|
def _get_thinking_prompt(self) -> str:
|
|
"""Get thinking system prompt."""
|
|
return (
|
|
self.config.custom_thinking_prompt
|
|
if self.config.custom_thinking_prompt
|
|
else "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
|
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
|
|
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
|
|
"</think> tags, and then provide your solution or response to the problem."
|
|
)
|
|
|
|
def _get_judgment_prompt(self) -> str:
|
|
"""Get judgment system prompt."""
|
|
if self.config.custom_judgment_prompt:
|
|
return self.config.custom_judgment_prompt
|
|
|
|
choice_format_examples = ", ".join(
|
|
[
|
|
f'"[[{letter}]]" if assistant {letter} is best'
|
|
for letter in self.choice_letters
|
|
]
|
|
)
|
|
|
|
return (
|
|
f"Please act as an impartial judge and evaluate the quality of the responses provided by {self.config.num_choices} AI assistants to the user question displayed below. " # noqa
|
|
f"You should choose the assistant that follows the user's instructions and answers the user's question best. " # noqa
|
|
f"Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. " # noqa
|
|
f"Begin your evaluation by comparing the {self.config.num_choices} responses and provide a short explanation. " # noqa
|
|
f"Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. " # noqa
|
|
f"Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. " # noqa
|
|
f"Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: " # noqa
|
|
f"{choice_format_examples}."
|
|
)
|
|
|
|
def _is_ties_sample(self, item: dict) -> bool:
|
|
"""Detect if this is a ties evaluation sample."""
|
|
return item.get("subset") == "Ties"
|
|
|
|
def _should_evaluate_category(self, item: dict) -> bool:
|
|
"""Check if this item's category should be evaluated based on config."""
|
|
if self.config.eval_categories is None:
|
|
return True # Evaluate all categories if none specified
|
|
|
|
item_category = item.get("subset", "")
|
|
# Try to match against enum values
|
|
for category in self.config.eval_categories:
|
|
if category.value == item_category:
|
|
return True
|
|
|
|
return False
|
|
|
|
def _format_debug_text(self, text: str, label: str) -> str:
|
|
"""Format text for debug output (first 100 + last 100 chars)."""
|
|
if not text:
|
|
return f"{label}: <empty>"
|
|
|
|
text_clean = text.strip()
|
|
if len(text_clean) <= 200:
|
|
return f"{label}: '{text_clean}'"
|
|
|
|
first_100 = text_clean[:100]
|
|
last_100 = text_clean[-100:]
|
|
return f"{label}: '{first_100}...{last_100}' (total {len(text_clean)} chars)"
|
|
|
|
def _log_full_debug_request(
|
|
self,
|
|
messages: List[Dict],
|
|
params: Dict,
|
|
category: str = "unknown",
|
|
item_id: str = "unknown",
|
|
context: str = "",
|
|
):
|
|
"""Log full debug information for API requests."""
|
|
if not self.config.full_debug:
|
|
return
|
|
|
|
print(f"\n🔍 FULL DEBUG - API REQUEST [{context}]")
|
|
print(f" Category: {category}")
|
|
print(f" Item ID: {item_id}")
|
|
print(f" Parameters: {params}")
|
|
|
|
for i, message in enumerate(messages):
|
|
role = message.get("role", "unknown")
|
|
content = message.get("content", "")
|
|
print(
|
|
f" Message {i+1} ({role}): {self._format_debug_text(content, 'Content')}"
|
|
)
|
|
|
|
def _log_full_debug_response(self, completion, context: str = ""):
|
|
"""Log full debug information for API responses."""
|
|
if not self.config.full_debug:
|
|
return
|
|
|
|
print(f"\n🔍 FULL DEBUG - API RESPONSE [{context}]")
|
|
|
|
if hasattr(completion, "usage"):
|
|
print(f" Usage: {completion.usage}")
|
|
|
|
if hasattr(completion, "choices") and completion.choices:
|
|
for i, choice in enumerate(completion.choices):
|
|
content = choice.message.content if hasattr(choice, "message") else ""
|
|
finish_reason = (
|
|
choice.finish_reason
|
|
if hasattr(choice, "finish_reason")
|
|
else "unknown"
|
|
)
|
|
print(
|
|
f" Choice {i+1}: {self._format_debug_text(content, 'Response')}"
|
|
)
|
|
print(f" Finish reason: {finish_reason}")
|
|
else:
|
|
print(f" No choices in response")
|
|
print(f" Completion object: {completion}")
|
|
|
|
def _reset_metrics(self) -> None:
|
|
"""Reset training metrics."""
|
|
self.percent_correct_buffer = []
|
|
self.judgment_letter_counts = {letter: 0 for letter in self.choice_letters}
|
|
self.judgment_letter_correct = {letter: 0 for letter in self.choice_letters}
|
|
self.error_count = 0
|
|
self.total_judgments = 0
|
|
|
|
def _convert_messages_to_list(self, prompt_tuple: Tuple) -> List[Dict]:
|
|
"""Convert frozenset message format to list format."""
|
|
messages = []
|
|
for role_dict in prompt_tuple:
|
|
messages.append(dict(role_dict))
|
|
return messages
|
|
|
|
def _create_system_content(self) -> str:
|
|
"""Create system message content based on thinking mode."""
|
|
if self.config.thinking_mode:
|
|
return f"{self.thinking_system_prompt}\n\n{self.judgment_system_prompt}"
|
|
return self.judgment_system_prompt
|
|
|
|
def _prepare_completion_input(self, prompt_tuple: Tuple) -> List[Dict]:
|
|
"""Convert prompt tuple to messages format."""
|
|
messages = self._convert_messages_to_list(prompt_tuple)
|
|
return messages
|
|
|
|
def _get_train_completion_params(self) -> Dict:
|
|
"""Get completion parameters for training rollouts."""
|
|
return {
|
|
"n": self.config.group_size,
|
|
"max_tokens": self.config.train_max_tokens,
|
|
"temperature": self.config.rollout_temperature,
|
|
}
|
|
|
|
def _get_eval_completion_params(self) -> Dict:
|
|
"""Get completion parameters for evaluation."""
|
|
return {
|
|
"n": 1,
|
|
"max_tokens": self.config.eval_max_tokens,
|
|
"temperature": self.config.eval_temperature,
|
|
"split": "eval",
|
|
}
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[PairwiseJudgementConfig, List[APIServerConfig]]:
|
|
env_config = PairwiseJudgementConfig(
|
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
group_size=16,
|
|
use_wandb=True,
|
|
max_num_workers_per_node=16,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=2000,
|
|
batch_size=1024,
|
|
steps_per_eval=25,
|
|
max_token_length=1024 * 32,
|
|
inference_weight=1.0,
|
|
wandb_name="pairwise_judgment",
|
|
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
|
eval_limit_ratio=0.1,
|
|
min_batch_allocation=0.1,
|
|
thinking_mode=False,
|
|
# List specific categories to evaluate, or None for all
|
|
eval_categories=[
|
|
RewardBenchCategory.FACTUALITY,
|
|
RewardBenchCategory.FOCUS,
|
|
RewardBenchCategory.MATH,
|
|
RewardBenchCategory.PRECISE_IF,
|
|
RewardBenchCategory.SAFETY,
|
|
RewardBenchCategory.TIES,
|
|
],
|
|
# Debug and retry configuration
|
|
full_debug=False, # Set to True to enable detailed API request/response logging
|
|
max_retries=3,
|
|
retry_delay=1.0,
|
|
min_response_length=5,
|
|
)
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
|
base_url="http://localhost:9004/v1",
|
|
api_key="x",
|
|
num_max_requests_at_once=32,
|
|
num_requests_for_eval=256,
|
|
),
|
|
]
|
|
return env_config, server_configs
|
|
|
|
async def setup(self) -> None:
|
|
"""Set up the environment by loading datasets."""
|
|
# Load placeholder train dataset (not actually used since we generate synthetic examples)
|
|
try:
|
|
self.train = load_dataset("example/train", split="train")
|
|
print(f"Loaded placeholder train dataset with {len(self.train)} examples")
|
|
except Exception as e:
|
|
# Create minimal placeholder data if dataset doesn't exist
|
|
# Note: This isn't actually used since get_next_item() generates synthetic examples
|
|
self.train = [{"question": "What is 2+2?", "answer": "4"}] * 100
|
|
print(f"Using synthetic placeholder training data due to error: {e}")
|
|
|
|
# Load evaluation dataset - reward-bench-2 (MUST WORK OR CRASH)
|
|
self.test = load_dataset(
|
|
"allenai/reward-bench-2", split="test", trust_remote_code=True
|
|
)
|
|
print(f"Loaded reward-bench-2 eval dataset with {len(self.test)} examples")
|
|
|
|
# Analyze dataset composition
|
|
category_counts = {}
|
|
for item in self.test:
|
|
category = item.get("subset", "Unknown")
|
|
category_counts[category] = category_counts.get(category, 0) + 1
|
|
|
|
print("Dataset categories found:")
|
|
for category, count in sorted(category_counts.items()):
|
|
print(f" - {category}: {count} samples")
|
|
|
|
# Count ties vs choice samples
|
|
ties_samples = sum(1 for item in self.test if self._is_ties_sample(item))
|
|
choice_samples = len(self.test) - ties_samples
|
|
print(
|
|
f"\nEvaluation modes: {choice_samples} choice samples, {ties_samples} ties samples"
|
|
)
|
|
|
|
# Show category filtering info
|
|
if self.config.eval_categories is not None:
|
|
selected_categories = [cat.value for cat in self.config.eval_categories]
|
|
print(
|
|
f"\nCategory filtering enabled. Selected categories: {selected_categories}"
|
|
)
|
|
filtered_count = sum(
|
|
1 for item in self.test if self._should_evaluate_category(item)
|
|
)
|
|
print(f"Will evaluate {filtered_count} out of {len(self.test)} samples")
|
|
else:
|
|
print(
|
|
f"\nNo category filtering. Will evaluate all {len(self.test)} samples"
|
|
)
|
|
|
|
# Show debug mode status
|
|
if self.config.full_debug:
|
|
print(
|
|
f"\n🔍 FULL DEBUG MODE ENABLED - Will log all API requests and responses"
|
|
)
|
|
print(
|
|
f" 📊 Will show: category, item ID, first/last 100 chars of prompts and responses"
|
|
)
|
|
print(
|
|
f" ⚙️ Retry settings: max_retries={self.config.max_retries}, retry_delay={self.config.retry_delay}s"
|
|
)
|
|
print(f" 📏 Min response length: {self.config.min_response_length} chars")
|
|
else:
|
|
print(
|
|
f"\n🔍 Full debug mode disabled - Use full_debug=True to enable detailed logging"
|
|
)
|
|
|
|
# Debug: Show sample evaluation item structure
|
|
if len(self.test) > 0:
|
|
try:
|
|
sample_item = self.test[0]
|
|
print("\nSample eval item structure:")
|
|
print(f"- Available keys: {list(sample_item.keys())}")
|
|
|
|
# Handle different dataset structures
|
|
if "prompt" in sample_item:
|
|
print(f"- Prompt: {sample_item['prompt'][:100]}...")
|
|
elif "chosen" in sample_item and isinstance(sample_item["chosen"], str):
|
|
print(f"- Chosen (string): {sample_item['chosen'][:100]}...")
|
|
elif "rejected" in sample_item and isinstance(
|
|
sample_item["rejected"], str
|
|
):
|
|
print(f"- Rejected (string): {sample_item['rejected'][:100]}...")
|
|
|
|
if "chosen" in sample_item:
|
|
if isinstance(sample_item["chosen"], list):
|
|
print(f"- Chosen responses: {len(sample_item['chosen'])}")
|
|
if sample_item["chosen"]:
|
|
print(
|
|
f"- First chosen (truncated): {sample_item['chosen'][0][:200]}..."
|
|
)
|
|
else:
|
|
print(f"- Chosen (string): {sample_item['chosen'][:200]}...")
|
|
|
|
if "rejected" in sample_item:
|
|
if isinstance(sample_item["rejected"], list):
|
|
print(f"- Rejected responses: {len(sample_item['rejected'])}")
|
|
if sample_item["rejected"]:
|
|
print(
|
|
f"- First rejected (truncated): {sample_item['rejected'][0][:200]}..."
|
|
)
|
|
else:
|
|
print(
|
|
f"- Rejected (string): {sample_item['rejected'][:200]}..."
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Could not display sample item structure: {e}")
|
|
|
|
self.iter = 0
|
|
|
|
def save_checkpoint(self, step: int, data: Optional[Dict] = None) -> None:
|
|
"""Save checkpoint including iteration state."""
|
|
if data is None:
|
|
data = {}
|
|
data["iter"] = self.iter
|
|
super().save_checkpoint(step, data)
|
|
|
|
def process_judgement(self, judgment: str, track_metrics: bool = True) -> str:
|
|
"""Extract judgment from model response."""
|
|
# Debug: Check judgment type and content
|
|
if judgment is None:
|
|
print(f"DEBUG: judgment is None in process_judgement")
|
|
if track_metrics:
|
|
self.error_count += 1
|
|
self.total_judgments += 1
|
|
return "format_error"
|
|
elif not isinstance(judgment, str):
|
|
print(
|
|
f"DEBUG: judgment is not a string in process_judgement. Type: {type(judgment)}, Value: {judgment}"
|
|
)
|
|
if track_metrics:
|
|
self.error_count += 1
|
|
self.total_judgments += 1
|
|
return "format_error"
|
|
|
|
if self.config.thinking_mode:
|
|
# Check for exactly one pair of think tags using pre-compiled patterns
|
|
think_open_count = len(self._think_pattern.findall(judgment))
|
|
think_close_count = len(self._think_close_pattern.findall(judgment))
|
|
|
|
if think_open_count != 1 or think_close_count != 1:
|
|
if track_metrics:
|
|
self.error_count += 1
|
|
self.total_judgments += 1
|
|
return "format_error"
|
|
|
|
# Parse only content after </think> tags
|
|
match = self._think_content_pattern.search(judgment)
|
|
if match:
|
|
judgment = match.group(1)
|
|
else:
|
|
if track_metrics:
|
|
self.error_count += 1
|
|
self.total_judgments += 1
|
|
return "format_error"
|
|
|
|
if track_metrics:
|
|
self.total_judgments += 1
|
|
|
|
# Check for each possible choice letter using pre-compiled patterns
|
|
for letter in self.choice_letters:
|
|
if self._choice_patterns[letter].search(judgment):
|
|
if track_metrics:
|
|
self.judgment_letter_counts[letter] += 1
|
|
return letter
|
|
|
|
# No valid judgment found
|
|
if track_metrics:
|
|
self.error_count += 1
|
|
return "format_error"
|
|
|
|
def create_judgment_prompt(self, question: str, answers: List[str]) -> str:
|
|
"""Create the user prompt for judgment task."""
|
|
if len(answers) != self.config.num_choices:
|
|
raise ValueError(
|
|
f"Need exactly {self.config.num_choices} answers for judgment, got {len(answers)}"
|
|
)
|
|
|
|
prompt = f"[User Question]\n{question}\n\n"
|
|
|
|
for i, answer in enumerate(answers):
|
|
letter = self.choice_letters[i]
|
|
prompt += f"[The Start of Assistant {letter}'s Answer]\n{answer}\n[The End of Assistant {letter}'s Answer]\n\n" # noqa
|
|
|
|
return prompt.strip()
|
|
|
|
async def get_next_item(self) -> Item:
|
|
"""Generate next training item with synthetic data."""
|
|
self.iter += 1
|
|
|
|
# Create system message
|
|
system_content = self._create_system_content()
|
|
|
|
# Create varied placeholder judgment tasks
|
|
examples = [
|
|
{
|
|
"question": "What is the capital of France?",
|
|
"correct": "The capital of France is Paris, which has been the capital since 987 AD and serves as the political, economic, and cultural center of the country.", # noqa
|
|
"incorrect": [
|
|
"The capital of France is London.",
|
|
"France's capital is Berlin, located in central Europe.",
|
|
"I don't know the answer to this question.",
|
|
"France doesn't have a capital city.",
|
|
"The capital changes every year in France.",
|
|
"Paris is just a city, not a capital.",
|
|
],
|
|
},
|
|
{
|
|
"question": "How do you fix a memory leak in Python?",
|
|
"correct": "To fix memory leaks in Python: 1) Use memory profilers like tracemalloc or memory_profiler to identify leaks, 2) Ensure proper cleanup of resources with context managers, 3) Break circular references, 4) Close files and database connections explicitly, and 5) Use weak references when appropriate.", # noqa
|
|
"incorrect": [
|
|
"Just restart your computer and the memory leak will be fixed.",
|
|
"Python automatically handles all memory management, so memory leaks are impossible.",
|
|
"You need to reinstall Python to fix memory leaks.",
|
|
"Memory leaks don't exist in Python because it's interpreted.",
|
|
"Use more RAM to solve memory leaks.",
|
|
"Delete the Python installation and use a different language.",
|
|
],
|
|
},
|
|
{
|
|
"question": "Explain the difference between machine learning and artificial intelligence.",
|
|
"correct": "Artificial Intelligence (AI) is the broader field focused on creating systems that can perform tasks typically requiring human intelligence. Machine Learning (ML) is a subset of AI that uses algorithms to learn patterns from data without being explicitly programmed for each task. So ML is one approach to achieving AI.", # noqa
|
|
"incorrect": [
|
|
"Machine learning and artificial intelligence are exactly the same thing with different names.",
|
|
"Machine learning is much broader than AI and includes all computer science.",
|
|
"AI is only about robots, while machine learning is only about statistics.",
|
|
"Machine learning came before AI historically.",
|
|
"AI is a subset of machine learning, not the other way around.",
|
|
"There is no difference; they are marketing terms for the same technology.",
|
|
],
|
|
},
|
|
]
|
|
|
|
# Select random example
|
|
example = random.choice(examples)
|
|
|
|
# Create list with correct and incorrect answers, ensuring we have enough
|
|
incorrect_answers = example["incorrect"][
|
|
: self.config.num_choices - 1
|
|
] # Take enough incorrect answers
|
|
all_answers = [example["correct"]] + incorrect_answers
|
|
|
|
# If we don't have enough incorrect answers, pad with generic ones
|
|
while len(all_answers) < self.config.num_choices:
|
|
all_answers.append(
|
|
"I don't have enough information to answer this question."
|
|
)
|
|
|
|
random.shuffle(all_answers)
|
|
|
|
# Find where correct answer ended up
|
|
correct_index = all_answers.index(example["correct"])
|
|
correct_answer = self.choice_letters[correct_index]
|
|
|
|
user_content = self.create_judgment_prompt(example["question"], all_answers)
|
|
|
|
prompt = tuple(
|
|
[
|
|
frozenset({"role": "system", "content": system_content}.items()),
|
|
frozenset({"role": "user", "content": user_content}.items()),
|
|
]
|
|
)
|
|
|
|
return (prompt, correct_answer)
|
|
|
|
def prepare_eval_item(self, item: dict) -> Tuple[Optional[Tuple], Optional[str]]:
|
|
"""
|
|
Prepare an evaluation item from the reward-bench-2 dataset.
|
|
|
|
Dataset structure:
|
|
- chosen: list with 1 element (the best response)
|
|
- rejected: list with 3+ elements (worse responses)
|
|
- We take chosen[0] + rejected[:num_choices-1] to create judgment with configured number of choices
|
|
"""
|
|
try:
|
|
question = item.get("prompt", "")
|
|
chosen_responses = item.get("chosen", [])
|
|
rejected_responses = item.get("rejected", [])
|
|
|
|
# Validate required fields
|
|
if not question:
|
|
return None, None
|
|
|
|
# Take one chosen response and (num_choices-1) rejected responses
|
|
required_rejected = self.config.num_choices - 1
|
|
if (
|
|
len(chosen_responses) == 0
|
|
or len(rejected_responses) < required_rejected
|
|
):
|
|
return None, None
|
|
|
|
chosen = chosen_responses[0]
|
|
rejected = rejected_responses[:required_rejected]
|
|
|
|
# Validate response content
|
|
if not chosen or not all(rejected):
|
|
return None, None
|
|
|
|
# Create list with answer and whether it's correct
|
|
data = [(chosen, True)] + [(r, False) for r in rejected]
|
|
random.shuffle(data)
|
|
|
|
# Extract shuffled answers and find correct position
|
|
shuffled_answers = [item[0] for item in data]
|
|
correct_index = next(
|
|
i for i, (_, is_correct) in enumerate(data) if is_correct
|
|
)
|
|
correct_answer = self.choice_letters[correct_index]
|
|
|
|
# Create system message
|
|
system_content = self._create_system_content()
|
|
|
|
# Create user prompt
|
|
user_content = self.create_judgment_prompt(question, shuffled_answers)
|
|
|
|
prompt = tuple(
|
|
[
|
|
frozenset({"role": "system", "content": system_content}.items()),
|
|
frozenset({"role": "user", "content": user_content}.items()),
|
|
]
|
|
)
|
|
|
|
return prompt, correct_answer
|
|
|
|
except Exception as e:
|
|
print(f"Error preparing evaluation item: {e}")
|
|
print(f"DEBUG: Exception type: {type(e)}")
|
|
print(f"DEBUG: item keys: {list(item.keys()) if item else 'item is None'}")
|
|
print(f"DEBUG: item id: {item.get('id', 'no_id') if item else 'no_item'}")
|
|
return None, None
|
|
|
|
async def collect_trajectories(self, item: Item) -> Tuple[ScoredDataGroup, List]:
|
|
"""Collect and score model trajectories."""
|
|
messages = self._prepare_completion_input(item[0])
|
|
completion_params = self._get_train_completion_params()
|
|
|
|
# Retry logic for training trajectories
|
|
max_retries = self.config.max_retries
|
|
retry_delay = self.config.retry_delay
|
|
|
|
# Get category info for debug logging (this is synthetic training data)
|
|
category = "synthetic_training"
|
|
item_id = f"train_{self.iter if hasattr(self, 'iter') else 'unknown'}"
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
# Log full debug request
|
|
self._log_full_debug_request(
|
|
messages,
|
|
completion_params,
|
|
category,
|
|
item_id,
|
|
f"TRAINING attempt {attempt + 1}/{max_retries}",
|
|
)
|
|
|
|
completions = await self.server.chat_completion(
|
|
messages=messages, **completion_params
|
|
)
|
|
|
|
# Log full debug response
|
|
self._log_full_debug_response(
|
|
completions, f"TRAINING attempt {attempt + 1}/{max_retries}"
|
|
)
|
|
|
|
# Check if we got valid completions
|
|
if not completions.choices:
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: No choices in collect_trajectories (attempt {attempt + 1}/{max_retries})"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
print(
|
|
f"DEBUG: No choices in collect_trajectories after {max_retries} attempts"
|
|
)
|
|
return None, []
|
|
|
|
# Check if any completion has None content
|
|
valid_completions = []
|
|
for completion_choice in completions.choices:
|
|
if (
|
|
completion_choice.message.content is not None
|
|
and isinstance(completion_choice.message.content, str)
|
|
and len(completion_choice.message.content.strip())
|
|
>= self.config.min_response_length
|
|
):
|
|
valid_completions.append(completion_choice)
|
|
|
|
# If we don't have enough valid completions, retry
|
|
if (
|
|
len(valid_completions) < len(completions.choices) // 2
|
|
): # If less than half are valid
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: Only {len(valid_completions)}/{len(completions.choices)} valid completions (attempt {attempt + 1}/{max_retries})"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
print(
|
|
f"DEBUG: Only {len(valid_completions)}/{len(completions.choices)} valid completions after {max_retries} attempts"
|
|
)
|
|
# Continue with what we have
|
|
|
|
# Build trajectories using valid completions
|
|
to_score = []
|
|
for completion_choice in valid_completions:
|
|
# Add assistant response to existing messages
|
|
trajectory_messages = messages + [
|
|
{
|
|
"role": "assistant",
|
|
"content": completion_choice.message.content,
|
|
}
|
|
]
|
|
to_score.append((tuple(trajectory_messages), item[1]))
|
|
|
|
# Success - we got at least some valid trajectories
|
|
break
|
|
|
|
except Exception as e:
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: collect_trajectories API call failed (attempt {attempt + 1}/{max_retries}): {e}"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
print(
|
|
f"DEBUG: collect_trajectories API call failed after {max_retries} attempts: {e}"
|
|
)
|
|
return None, []
|
|
|
|
scored_data = await self.score(to_score)
|
|
|
|
# Add rollouts for wandb visualization
|
|
if scored_data is not None:
|
|
await self.add_rollouts_for_wandb(scored_data, item)
|
|
|
|
return scored_data, []
|
|
|
|
async def score(self, rollout_group_data: List[Tuple]) -> Optional[ScoredDataGroup]:
|
|
"""Score a group of rollout data."""
|
|
if not rollout_group_data:
|
|
return None
|
|
|
|
try:
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = []
|
|
scores["masks"] = []
|
|
scores["scores"] = []
|
|
|
|
random.shuffle(rollout_group_data)
|
|
|
|
for item in rollout_group_data:
|
|
# Simplified validation
|
|
if not item or len(item) < 2 or not item[0]:
|
|
continue
|
|
|
|
model_response = item[0][-1]["content"]
|
|
ground_truth = item[1]
|
|
|
|
predicted_answer = self.process_judgement(
|
|
model_response, track_metrics=True
|
|
)
|
|
reward = 1.0 if predicted_answer == ground_truth else 0.0
|
|
|
|
# Track correct judgments per letter
|
|
if (
|
|
predicted_answer == ground_truth
|
|
and predicted_answer != "format_error"
|
|
):
|
|
self.judgment_letter_correct[predicted_answer] += 1
|
|
|
|
out_dict = tokenize_for_trainer(self.tokenizer, item[0])
|
|
tokens = out_dict["tokens"]
|
|
masks = out_dict["masks"]
|
|
|
|
# Skip obviously bad examples
|
|
if len([1 for mask in masks if mask != -100]) < 10:
|
|
continue
|
|
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["scores"].append(reward) # Use reward directly (1.0 or 0.0)
|
|
|
|
if len(scores["tokens"]) >= self.config.group_size:
|
|
break
|
|
|
|
if not scores["tokens"]:
|
|
return None
|
|
|
|
# Update percent correct buffer
|
|
for score in scores["scores"]:
|
|
self.percent_correct_buffer.append(max(score, 0))
|
|
|
|
# Return None if all scores are the same (no learning signal)
|
|
if len(set(scores["scores"])) == 1:
|
|
return None
|
|
|
|
return scores
|
|
|
|
except Exception as e:
|
|
print(f"Error in score method: {e}")
|
|
print(f"DEBUG: Exception type: {type(e)}")
|
|
print(
|
|
f"DEBUG: rollout_group_data length: {len(rollout_group_data) if rollout_group_data else 'None'}"
|
|
)
|
|
if rollout_group_data:
|
|
print(f"DEBUG: first item type: {type(rollout_group_data[0])}")
|
|
print(
|
|
f"DEBUG: first item length: {len(rollout_group_data[0]) if rollout_group_data[0] else 'None'}"
|
|
)
|
|
return None
|
|
|
|
async def rollout_and_score_eval(self, test_item: dict) -> dict:
|
|
"""Rollout and score evaluation with automatic ties detection."""
|
|
try:
|
|
# Detect evaluation mode
|
|
if self._is_ties_sample(test_item):
|
|
return await self._rollout_and_score_ties(test_item)
|
|
else:
|
|
return await self._rollout_and_score_choice(test_item)
|
|
except Exception as e:
|
|
print(f"Error in rollout_and_score_eval: {e}")
|
|
print(f"DEBUG: Exception type: {type(e)}")
|
|
print(
|
|
f"DEBUG: test_item keys: {list(test_item.keys()) if test_item else 'test_item is None'}"
|
|
)
|
|
print(
|
|
f"DEBUG: test_item id: {test_item.get('id', 'no_id') if test_item else 'no_test_item'}"
|
|
)
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
async def _rollout_and_score_choice(self, test_item: dict) -> dict:
|
|
"""Original choice-based evaluation logic."""
|
|
try:
|
|
prompt, ground_truth = self.prepare_eval_item(test_item)
|
|
if prompt is None:
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
messages = self._prepare_completion_input(prompt)
|
|
completion_params = self._get_eval_completion_params()
|
|
|
|
# Retry logic for failed API calls
|
|
max_retries = self.config.max_retries
|
|
retry_delay = self.config.retry_delay
|
|
|
|
# Get category and item info for debug logging
|
|
category = test_item.get("subset", "unknown")
|
|
item_id = test_item.get("id", "unknown")
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
# Log full debug request
|
|
self._log_full_debug_request(
|
|
messages,
|
|
completion_params,
|
|
category,
|
|
item_id,
|
|
f"CHOICE_EVAL attempt {attempt + 1}/{max_retries}",
|
|
)
|
|
|
|
completion = await self.server.chat_completion(
|
|
messages=messages, **completion_params
|
|
)
|
|
|
|
# Log full debug response
|
|
self._log_full_debug_response(
|
|
completion, f"CHOICE_EVAL attempt {attempt + 1}/{max_retries}"
|
|
)
|
|
|
|
if not completion.choices:
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: No choices in completion (attempt {attempt + 1}/{max_retries})"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
print(f"DEBUG: No choices after {max_retries} attempts")
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
model_response = completion.choices[0].message.content
|
|
|
|
# Check for None content or very short responses (likely just EOS token)
|
|
if model_response is None:
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: model_response is None (attempt {attempt + 1}/{max_retries})"
|
|
)
|
|
print(f"DEBUG: Completion: {completion}")
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
print(
|
|
f"DEBUG: model_response is None after {max_retries} attempts"
|
|
)
|
|
print(f"DEBUG: Final completion: {completion}")
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
if not isinstance(model_response, str):
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: model_response is not a string. Type: {type(model_response)}, Value: {model_response} (attempt {attempt + 1}/{max_retries})"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
print(
|
|
f"DEBUG: model_response is not a string after {max_retries} attempts. Type: {type(model_response)}, Value: {model_response}"
|
|
)
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
# Check for very short responses (likely just EOS token)
|
|
if len(model_response.strip()) < self.config.min_response_length:
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: Very short response (likely EOS token only): '{model_response}' (attempt {attempt + 1}/{max_retries})"
|
|
)
|
|
print(
|
|
f"DEBUG: Completion tokens: {completion.usage.completion_tokens if hasattr(completion, 'usage') else 'unknown'}"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
print(
|
|
f"DEBUG: Very short response after {max_retries} attempts: '{model_response}'"
|
|
)
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
# Success - we got a valid response
|
|
break
|
|
|
|
except Exception as e:
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: API call failed (attempt {attempt + 1}/{max_retries}): {e}"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
print(
|
|
f"DEBUG: API call failed after {max_retries} attempts: {e}"
|
|
)
|
|
raise
|
|
|
|
predicted_answer = self.process_judgement(
|
|
model_response, track_metrics=False
|
|
)
|
|
|
|
score = 1.0 if predicted_answer == ground_truth else 0.0
|
|
|
|
# Extract question and answer choices from the user message
|
|
user_content = messages[1]["content"]
|
|
|
|
# Debug: Check user_content type and content
|
|
if user_content is None:
|
|
print(
|
|
f"DEBUG: user_content is None for test_item: {test_item.get('id', 'unknown')}"
|
|
)
|
|
return {"score": 0.0, "sample": None}
|
|
elif not isinstance(user_content, str):
|
|
print(
|
|
f"DEBUG: user_content is not a string. Type: {type(user_content)}, Value: {user_content}"
|
|
)
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
question_match = self._question_pattern.search(user_content)
|
|
question = (
|
|
question_match.group(1).strip()
|
|
if question_match
|
|
else "Unknown question"
|
|
)
|
|
|
|
# Extract individual answer choices for all configured letters
|
|
answer_choices = {}
|
|
for letter in self.choice_letters:
|
|
match = self._answer_choice_patterns[letter].search(user_content)
|
|
if match:
|
|
answer_choices[letter] = match.group(1).strip()
|
|
|
|
# Add full conversation including model response
|
|
full_messages = messages + [
|
|
{"role": "assistant", "content": model_response}
|
|
]
|
|
|
|
sample = {
|
|
"evaluation_mode": "choice",
|
|
"messages": full_messages,
|
|
"question": question,
|
|
"answer_choices": answer_choices,
|
|
"ground_truth": ground_truth,
|
|
"predicted_judgment": predicted_answer,
|
|
"score": int(score),
|
|
"correct": bool(score),
|
|
"finish_reason": completion.choices[0].finish_reason,
|
|
"thinking_mode": self.config.thinking_mode,
|
|
"format_compliant": predicted_answer != "format_error",
|
|
"dataset_item_id": test_item.get("id", "unknown"),
|
|
"dataset_subset": test_item.get("subset", "unknown"),
|
|
"num_choices": self.config.num_choices,
|
|
}
|
|
|
|
# Add thinking-specific parsing info
|
|
if self.config.thinking_mode:
|
|
if "</think>" in model_response:
|
|
sample["response_after_think"] = model_response.split("</think>")[
|
|
-1
|
|
].strip()
|
|
sample["thinking_content"] = self._thinking_extract_pattern.search(
|
|
model_response
|
|
)
|
|
if sample["thinking_content"]:
|
|
sample["thinking_content"] = (
|
|
sample["thinking_content"].group(1).strip()
|
|
)
|
|
else:
|
|
sample["response_after_think"] = model_response
|
|
sample["thinking_content"] = None
|
|
|
|
return {"score": score, "sample": sample}
|
|
|
|
except Exception as e:
|
|
print(f"Error in choice evaluation: {e}")
|
|
print(f"DEBUG: Exception type: {type(e)}")
|
|
print(
|
|
f"DEBUG: test_item keys: {list(test_item.keys()) if test_item else 'test_item is None'}"
|
|
)
|
|
print(
|
|
f"DEBUG: test_item id: {test_item.get('id', 'no_id') if test_item else 'no_test_item'}"
|
|
)
|
|
|
|
# Try to get more context about what variables exist
|
|
try:
|
|
print(f"DEBUG: completion exists: {completion is not None}")
|
|
if completion and hasattr(completion, "choices"):
|
|
print(
|
|
f"DEBUG: completion.choices length: {len(completion.choices)}"
|
|
)
|
|
if completion.choices:
|
|
print(
|
|
f"DEBUG: completion.choices[0].message exists: {hasattr(completion.choices[0], 'message')}"
|
|
)
|
|
if hasattr(completion.choices[0], "message"):
|
|
print(
|
|
f"DEBUG: completion.choices[0].message.content type: {type(completion.choices[0].message.content)}"
|
|
)
|
|
print(
|
|
f"DEBUG: completion.choices[0].message.content value: {completion.choices[0].message.content}"
|
|
)
|
|
except Exception as debug_e:
|
|
print(f"DEBUG: Error in debug info: {debug_e}")
|
|
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
async def _rollout_and_score_ties(self, test_item: dict) -> dict:
|
|
"""Ties-based evaluation logic using rating approach."""
|
|
try:
|
|
prompts_and_responses = self._prepare_ties_eval_item(test_item)
|
|
if not prompts_and_responses:
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
# Rate each response individually
|
|
ratings = []
|
|
response_data = []
|
|
|
|
# Get category and item info for debug logging
|
|
category = test_item.get("subset", "unknown")
|
|
item_id = test_item.get("id", "unknown")
|
|
|
|
for prompt, response_text, is_correct in prompts_and_responses:
|
|
messages = self._prepare_completion_input(prompt)
|
|
completion_params = self._get_eval_completion_params()
|
|
|
|
# Retry logic for ties evaluation
|
|
max_retries = self.config.max_retries
|
|
retry_delay = self.config.retry_delay
|
|
success = False
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
# Log full debug request
|
|
self._log_full_debug_request(
|
|
messages,
|
|
completion_params,
|
|
category,
|
|
item_id,
|
|
f"TIES_EVAL attempt {attempt + 1}/{max_retries}",
|
|
)
|
|
|
|
completion = await self.server.chat_completion(
|
|
messages=messages, **completion_params
|
|
)
|
|
|
|
# Log full debug response
|
|
self._log_full_debug_response(
|
|
completion, f"TIES_EVAL attempt {attempt + 1}/{max_retries}"
|
|
)
|
|
|
|
if not completion.choices:
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: No choices in ties completion (attempt {attempt + 1}/{max_retries})"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
break # Failed after all retries
|
|
|
|
model_response = completion.choices[0].message.content
|
|
|
|
# Check for None content or very short responses
|
|
if model_response is None:
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: ties model_response is None (attempt {attempt + 1}/{max_retries})"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
break # Failed after all retries
|
|
|
|
if not isinstance(model_response, str):
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: ties model_response is not a string. Type: {type(model_response)} (attempt {attempt + 1}/{max_retries})"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
break # Failed after all retries
|
|
|
|
# For ties evaluation, don't check response format - invalid ratings are part of normal evaluation
|
|
# Only retry for technical failures (None content, API errors, etc.)
|
|
|
|
# Success - process the rating
|
|
rating = self._process_rating_judgment(model_response)
|
|
ratings.append(rating)
|
|
response_data.append(
|
|
{
|
|
"response": response_text,
|
|
"is_correct": is_correct,
|
|
"rating": rating,
|
|
"model_judgment": model_response,
|
|
"finish_reason": completion.choices[0].finish_reason,
|
|
}
|
|
)
|
|
success = True
|
|
break
|
|
|
|
except Exception as e:
|
|
if attempt < max_retries - 1:
|
|
print(
|
|
f"DEBUG: ties API call failed (attempt {attempt + 1}/{max_retries}): {e}"
|
|
)
|
|
await asyncio.sleep(retry_delay)
|
|
continue
|
|
else:
|
|
print(
|
|
f"DEBUG: ties API call failed after {max_retries} attempts: {e}"
|
|
)
|
|
break
|
|
|
|
# If we failed after all retries, add error rating
|
|
if not success:
|
|
ratings.append(-1) # Error rating
|
|
response_data.append(
|
|
{
|
|
"response": response_text,
|
|
"is_correct": is_correct,
|
|
"rating": -1,
|
|
"model_judgment": "API_ERROR_RETRIES_EXHAUSTED",
|
|
"finish_reason": "error",
|
|
}
|
|
)
|
|
|
|
# Calculate success score
|
|
score = self._calculate_ties_score(ratings, test_item, response_data)
|
|
|
|
# Calculate format compliance for ties (valid ratings != -1)
|
|
valid_ratings = [r for r in ratings if r != -1]
|
|
format_compliant = len(valid_ratings) > 0 # True if any valid ratings found
|
|
|
|
# Create sample data
|
|
sample = {
|
|
"evaluation_mode": "ties",
|
|
"question": test_item.get("prompt", ""),
|
|
"response_data": response_data,
|
|
"ratings": ratings,
|
|
"num_correct": test_item.get("num_correct", 0),
|
|
"num_incorrect": test_item.get("num_incorrect", 0),
|
|
"total_responses": len(response_data),
|
|
"score": int(score),
|
|
"correct": bool(score),
|
|
"thinking_mode": self.config.thinking_mode,
|
|
"dataset_item_id": test_item.get("id", "unknown"),
|
|
"dataset_subset": test_item.get("subset", "unknown"),
|
|
"format_compliant": format_compliant,
|
|
}
|
|
|
|
return {"score": score, "sample": sample}
|
|
|
|
except Exception as e:
|
|
print(f"Error in ties evaluation: {e}")
|
|
print(f"DEBUG: Exception type: {type(e)}")
|
|
print(
|
|
f"DEBUG: test_item keys: {list(test_item.keys()) if test_item else 'test_item is None'}"
|
|
)
|
|
print(
|
|
f"DEBUG: test_item id: {test_item.get('id', 'no_id') if test_item else 'no_test_item'}"
|
|
)
|
|
return {"score": 0.0, "sample": None}
|
|
|
|
def _prepare_ties_eval_item(self, item: dict) -> List[Tuple[Tuple, str, bool]]:
|
|
"""Prepare ties evaluation item - returns list of (prompt, response, is_correct) tuples."""
|
|
try:
|
|
question = item.get("prompt", "")
|
|
chosen_responses = item.get("chosen", [])
|
|
rejected_responses = item.get("rejected", [])
|
|
|
|
if not question or not chosen_responses:
|
|
return []
|
|
|
|
# Create rating prompts for each response
|
|
prompts_and_responses = []
|
|
|
|
# Add chosen responses (correct)
|
|
for response in chosen_responses:
|
|
rating_prompt = self._create_rating_prompt(question, response)
|
|
prompts_and_responses.append((rating_prompt, response, True))
|
|
|
|
# Add rejected responses (incorrect) - limit to control API costs
|
|
max_rejected = min(
|
|
len(rejected_responses),
|
|
self.config.max_ties_responses - len(chosen_responses),
|
|
)
|
|
for response in rejected_responses[:max_rejected]:
|
|
rating_prompt = self._create_rating_prompt(question, response)
|
|
prompts_and_responses.append((rating_prompt, response, False))
|
|
|
|
return prompts_and_responses
|
|
|
|
except Exception as e:
|
|
print(f"Error preparing ties eval item: {e}")
|
|
print(f"DEBUG: Exception type: {type(e)}")
|
|
print(f"DEBUG: item keys: {list(item.keys()) if item else 'item is None'}")
|
|
print(f"DEBUG: item id: {item.get('id', 'no_id') if item else 'no_item'}")
|
|
return []
|
|
|
|
def _create_rating_prompt(self, question: str, response: str) -> Tuple:
|
|
"""Create rating prompt for a single response."""
|
|
# Use the original RewardBench ties-specific prompt
|
|
rating_prompt_template = """### Task Description
|
|
Please act as an impartial judge and evaluate the quality of the response provided by an
|
|
AI assistant to the user query displayed below.
|
|
|
|
Notes:
|
|
1- Your evaluation should consider factors such as the helpfulness, relevance, and accuracy of the response, but need not consider depth or level of detail of the response.
|
|
2- Begin your evaluation by providing a short explanation.
|
|
3- Be as objective as possible. After providing your explanation, please rate the response on a scale of 1 to 10. For your rating, only give a number between 1 and 10 (inclusive), do not use any markdown, and do not put any text after your final rating.
|
|
|
|
[Query]
|
|
{question}
|
|
|
|
[Response]
|
|
{response}
|
|
|
|
[Your judgement]""" # noqa
|
|
|
|
# Fill in the question and response
|
|
user_content = rating_prompt_template.format(
|
|
question=question, response=response
|
|
)
|
|
|
|
if self.config.thinking_mode:
|
|
system_content = self.thinking_system_prompt
|
|
else:
|
|
system_content = ""
|
|
|
|
return tuple(
|
|
[
|
|
frozenset({"role": "system", "content": system_content}.items()),
|
|
frozenset({"role": "user", "content": user_content}.items()),
|
|
]
|
|
)
|
|
|
|
def _process_rating_judgment(self, judgment: str) -> int:
|
|
"""
|
|
Extract 1-10 rating from model response.
|
|
|
|
Looks for a number 1-10 at the END of the response.
|
|
Examples that work:
|
|
- "This is good. I rate it 8"
|
|
- "Rating: 7"
|
|
Examples that don't work:
|
|
- "I give this a 5 out of 10" (doesn't end with just the number)
|
|
"""
|
|
# Debug: Check judgment type and content
|
|
if judgment is None:
|
|
print(f"DEBUG: judgment is None in _process_rating_judgment")
|
|
return -1
|
|
elif not isinstance(judgment, str):
|
|
print(
|
|
f"DEBUG: judgment is not a string in _process_rating_judgment. Type: {type(judgment)}, Value: {judgment}"
|
|
)
|
|
return -1
|
|
|
|
if self.config.thinking_mode:
|
|
# Extract content after </think> tags
|
|
match = self._think_content_pattern.search(judgment)
|
|
if match:
|
|
judgment = match.group(1)
|
|
else:
|
|
return -1
|
|
|
|
# Look for trailing number 1-10 using regex: \b([1-9]|10)\b\s*$
|
|
match = self._rating_pattern.search(judgment.strip())
|
|
if match:
|
|
rating = int(match.group(1))
|
|
if 1 <= rating <= 10:
|
|
return rating
|
|
|
|
return -1 # Error/invalid rating
|
|
|
|
def _calculate_ties_score(
|
|
self, ratings: List[int], test_item: dict, response_data: List[dict]
|
|
) -> float:
|
|
"""Calculate success score for ties evaluation using RewardBench's exact approach."""
|
|
# Get all valid ratings (not -1)
|
|
valid_ratings = [r for r in ratings if r != -1]
|
|
if not valid_ratings:
|
|
return 0.0
|
|
|
|
# Find the maximum rating among all valid ratings
|
|
max_rating = max(valid_ratings)
|
|
|
|
# Find all response indices that achieved this maximum rating
|
|
winner_indices = [i for i, r in enumerate(ratings) if r == max_rating]
|
|
|
|
# Check if any of the winners are correct responses
|
|
for idx in winner_indices:
|
|
if idx < len(response_data) and response_data[idx]["is_correct"]:
|
|
return 1.0
|
|
|
|
return 0.0
|
|
|
|
def _calculate_response_metrics(
|
|
self, samples: List[dict], thinking_mode_used: bool
|
|
) -> Tuple[List[int], int, Dict[str, int], Dict[str, int]]:
|
|
"""Calculate response-related metrics from samples."""
|
|
response_lengths = []
|
|
thinking_utilization = 0
|
|
judgment_counts = {letter: 0 for letter in self.choice_letters}
|
|
judgment_counts["format_error"] = 0
|
|
|
|
# Ties-specific metrics
|
|
ties_rating_counts = {i: 0 for i in range(1, 11)} # Ratings 1-10
|
|
ties_rating_counts["error"] = 0
|
|
|
|
for sample in samples:
|
|
if not sample:
|
|
continue
|
|
|
|
evaluation_mode = sample.get("evaluation_mode", "choice")
|
|
|
|
if evaluation_mode == "choice":
|
|
# Track response length for choice mode
|
|
messages = sample.get("messages", [])
|
|
if messages:
|
|
assistant_msg = messages[-1].get("content", "")
|
|
response_lengths.append(len(assistant_msg))
|
|
|
|
# Track judgment distribution for choice mode
|
|
predicted_judgment = sample.get("predicted_judgment", "format_error")
|
|
if predicted_judgment in judgment_counts:
|
|
judgment_counts[predicted_judgment] += 1
|
|
|
|
elif evaluation_mode == "ties":
|
|
# Track ratings for ties mode
|
|
ratings = sample.get("ratings", [])
|
|
for rating in ratings:
|
|
if rating == -1:
|
|
ties_rating_counts["error"] += 1
|
|
elif 1 <= rating <= 10:
|
|
ties_rating_counts[rating] += 1
|
|
|
|
# Track thinking utilization in thinking mode (both modes)
|
|
if thinking_mode_used:
|
|
thinking_content = sample.get("thinking_content")
|
|
if thinking_content:
|
|
thinking_utilization += 1
|
|
|
|
return (
|
|
response_lengths,
|
|
thinking_utilization,
|
|
judgment_counts,
|
|
ties_rating_counts,
|
|
)
|
|
|
|
async def evaluate(self, *args, **kwargs) -> None:
|
|
"""Evaluate the model on the test dataset."""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Filter test items based on selected categories
|
|
if self.config.eval_categories is not None:
|
|
filtered_test_items = [
|
|
test_item
|
|
for test_item in self.test
|
|
if self._should_evaluate_category(test_item)
|
|
]
|
|
print(
|
|
f"Filtered to {len(filtered_test_items)} samples from {len(self.test)} total"
|
|
)
|
|
else:
|
|
filtered_test_items = self.test
|
|
|
|
if not filtered_test_items:
|
|
print("Warning: No samples match the selected categories")
|
|
return
|
|
|
|
eval_tasks = [
|
|
self.rollout_and_score_eval(test_item)
|
|
for test_item in filtered_test_items
|
|
]
|
|
results = await tqdm_asyncio.gather(*eval_tasks)
|
|
|
|
# Filter valid results
|
|
valid_results = [
|
|
result
|
|
for result in results
|
|
if not isinstance(result, Exception)
|
|
and result
|
|
and result.get("sample") is not None
|
|
]
|
|
|
|
if not valid_results:
|
|
print("Warning: No valid evaluation results obtained")
|
|
return
|
|
|
|
except Exception as e:
|
|
print(f"Error during evaluation: {e}")
|
|
return
|
|
|
|
# Extract scores and samples from valid results
|
|
scores = [result["score"] for result in valid_results]
|
|
samples = [result["sample"] for result in valid_results]
|
|
valid_scores = [s for s in scores if s is not None]
|
|
|
|
if not valid_scores:
|
|
print("Warning: No valid scores found during evaluation")
|
|
return
|
|
|
|
percent_correct = sum(valid_scores) / len(valid_scores)
|
|
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
|
|
|
# Track performance by subset and evaluation mode
|
|
subset_scores = {}
|
|
choice_count = 0
|
|
ties_count = 0
|
|
|
|
for i, sample in enumerate(samples):
|
|
if sample and i < len(scores):
|
|
subset = sample.get("dataset_subset", "unknown")
|
|
if subset not in subset_scores:
|
|
subset_scores[subset] = []
|
|
subset_scores[subset].append(scores[i])
|
|
|
|
# Count evaluation modes
|
|
if sample.get("evaluation_mode") == "choice":
|
|
choice_count += 1
|
|
elif sample.get("evaluation_mode") == "ties":
|
|
ties_count += 1
|
|
|
|
print(
|
|
f"Evaluation completed: {choice_count} choice samples, {ties_count} ties samples"
|
|
)
|
|
|
|
# Log subset-specific metrics
|
|
for subset, subset_score_list in subset_scores.items():
|
|
valid_subset_scores = [s for s in subset_score_list if s is not None]
|
|
if valid_subset_scores:
|
|
avg_score = sum(valid_subset_scores) / len(valid_subset_scores)
|
|
self.eval_metrics.append((f"eval/percent_correct_{subset}", avg_score))
|
|
|
|
# Calculate additional metrics
|
|
# Format compliance means:
|
|
# - Choice mode: Proper thinking tags (if enabled) + valid choice like [[A]]
|
|
# - Ties mode: At least one valid rating (1-10) was extracted from responses
|
|
format_compliant = sum(
|
|
1 for sample in samples if sample.get("format_compliant", False)
|
|
)
|
|
|
|
# Separate compliance by evaluation mode
|
|
choice_format_compliant = sum(
|
|
1
|
|
for sample in samples
|
|
if sample.get("evaluation_mode") == "choice"
|
|
and sample.get("format_compliant", False)
|
|
)
|
|
ties_format_compliant = sum(
|
|
1
|
|
for sample in samples
|
|
if sample.get("evaluation_mode") == "ties"
|
|
and sample.get("format_compliant", False)
|
|
)
|
|
|
|
# Track "A" bias in wrong answers (choice mode only)
|
|
# This detects if the model defaults to choosing "A" when uncertain
|
|
# Expected: ~25% for 4 choices if no bias, higher indicates A bias
|
|
wrong_choice_samples = [
|
|
sample
|
|
for sample in samples
|
|
if sample.get("evaluation_mode") == "choice"
|
|
and not sample.get("correct", False)
|
|
]
|
|
wrong_a_choices = sum(
|
|
1
|
|
for sample in wrong_choice_samples
|
|
if sample.get("predicted_judgment") == "A"
|
|
)
|
|
|
|
# Calculate A bias rate for wrong answers
|
|
a_bias_rate = (
|
|
wrong_a_choices / len(wrong_choice_samples) if wrong_choice_samples else 0.0
|
|
)
|
|
|
|
thinking_mode_used = self.config.thinking_mode
|
|
|
|
# Get response metrics
|
|
response_lengths, thinking_utilization, judgment_counts, ties_rating_counts = (
|
|
self._calculate_response_metrics(samples, thinking_mode_used)
|
|
)
|
|
|
|
# Response length metrics
|
|
if response_lengths:
|
|
avg_response_length = sum(response_lengths) / len(response_lengths)
|
|
response_length_std = (
|
|
sum((x - avg_response_length) ** 2 for x in response_lengths)
|
|
/ len(response_lengths)
|
|
) ** 0.5
|
|
self.eval_metrics.append(("eval/avg_response_length", avg_response_length))
|
|
self.eval_metrics.append(("eval/response_length_std", response_length_std))
|
|
|
|
# Thinking utilization rate
|
|
if thinking_mode_used and samples:
|
|
thinking_utilization_rate = thinking_utilization / len(samples)
|
|
self.eval_metrics.append(
|
|
("eval/thinking_utilization_rate", thinking_utilization_rate)
|
|
)
|
|
|
|
# Judgment distribution metrics
|
|
total_judgments = sum(judgment_counts.values())
|
|
if total_judgments > 0:
|
|
# Calculate entropy for judgment balance
|
|
entropy = 0.0
|
|
for count in judgment_counts.values():
|
|
if count > 0:
|
|
freq = count / total_judgments
|
|
entropy -= freq * math.log(freq)
|
|
self.eval_metrics.append(("eval/judgment_entropy", entropy))
|
|
|
|
# Most common judgment frequency (bias detection)
|
|
max_judgment_count = max(judgment_counts.values())
|
|
most_common_judgment_freq = max_judgment_count / total_judgments
|
|
self.eval_metrics.append(
|
|
("eval/most_common_judgment_freq", most_common_judgment_freq)
|
|
)
|
|
|
|
# Format error rate
|
|
format_error_rate = judgment_counts["format_error"] / total_judgments
|
|
self.eval_metrics.append(("eval/format_error_rate", format_error_rate))
|
|
|
|
# Ties-specific metrics
|
|
total_ties_ratings = sum(ties_rating_counts.values())
|
|
if total_ties_ratings > 0:
|
|
# Average rating for ties mode
|
|
total_rating_sum = sum(
|
|
rating * count
|
|
for rating, count in ties_rating_counts.items()
|
|
if isinstance(rating, int)
|
|
)
|
|
valid_ratings_count = sum(
|
|
count
|
|
for rating, count in ties_rating_counts.items()
|
|
if isinstance(rating, int)
|
|
)
|
|
|
|
if valid_ratings_count > 0:
|
|
avg_ties_rating = total_rating_sum / valid_ratings_count
|
|
self.eval_metrics.append(("eval/avg_ties_rating", avg_ties_rating))
|
|
|
|
# Ties rating distribution
|
|
for rating in range(1, 11):
|
|
rating_freq = ties_rating_counts[rating] / total_ties_ratings
|
|
self.eval_metrics.append(
|
|
(f"eval/ties_rating_freq_{rating}", rating_freq)
|
|
)
|
|
|
|
# Ties error rate (proportion of rating attempts that failed to parse)
|
|
# This is different from percent correct:
|
|
# - Error rate: How often we couldn't extract a 1-10 rating from responses
|
|
# - Percent correct: How often the model chose the right responses as winners
|
|
ties_error_rate = ties_rating_counts["error"] / total_ties_ratings
|
|
self.eval_metrics.append(("eval/ties_error_rate", ties_error_rate))
|
|
|
|
# Add overall dataset statistics
|
|
total_dataset_items = len(self.test) if hasattr(self, "test") else 0
|
|
evaluated_items = len(samples)
|
|
self.eval_metrics.append(("eval/total_dataset_items", total_dataset_items))
|
|
self.eval_metrics.append(("eval/evaluated_items", evaluated_items))
|
|
self.eval_metrics.append(("eval/valid_scores", len(valid_scores)))
|
|
self.eval_metrics.append(("eval/subset_count", len(subset_scores)))
|
|
self.eval_metrics.append(
|
|
(
|
|
"eval/format_compliance_rate",
|
|
format_compliant / len(samples) if samples else 0.0,
|
|
)
|
|
)
|
|
|
|
# Add mode-specific compliance rates
|
|
if choice_count > 0:
|
|
self.eval_metrics.append(
|
|
(
|
|
"eval/choice_format_compliance_rate",
|
|
choice_format_compliant / choice_count,
|
|
)
|
|
)
|
|
if ties_count > 0:
|
|
self.eval_metrics.append(
|
|
("eval/ties_format_compliance_rate", ties_format_compliant / ties_count)
|
|
)
|
|
|
|
# Add A bias metric for wrong answers
|
|
self.eval_metrics.append(("eval/wrong_answer_a_bias_rate", a_bias_rate))
|
|
self.eval_metrics.append(
|
|
("eval/wrong_answer_total_count", len(wrong_choice_samples))
|
|
)
|
|
self.eval_metrics.append(("eval/wrong_answer_a_count", wrong_a_choices))
|
|
|
|
end_time = time.time()
|
|
|
|
# Build evaluation metrics dict
|
|
eval_metrics = {
|
|
"eval/percent_correct": percent_correct,
|
|
"eval/total_samples": len(samples),
|
|
"eval/correct_samples": sum(valid_scores),
|
|
"eval/format_compliance_rate": (
|
|
format_compliant / len(samples) if samples else 0.0
|
|
),
|
|
}
|
|
|
|
# Add response length metrics
|
|
if response_lengths:
|
|
eval_metrics["eval/avg_response_length"] = avg_response_length
|
|
eval_metrics["eval/response_length_std"] = response_length_std
|
|
|
|
# Add thinking utilization
|
|
if thinking_mode_used and samples:
|
|
eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate
|
|
|
|
# Add judgment distribution metrics
|
|
if total_judgments > 0:
|
|
eval_metrics["eval/judgment_entropy"] = entropy
|
|
eval_metrics["eval/most_common_judgment_freq"] = most_common_judgment_freq
|
|
eval_metrics["eval/format_error_rate"] = format_error_rate
|
|
|
|
# Add ties-specific metrics
|
|
if total_ties_ratings > 0:
|
|
if valid_ratings_count > 0:
|
|
eval_metrics["eval/avg_ties_rating"] = avg_ties_rating
|
|
eval_metrics["eval/ties_error_rate"] = ties_error_rate
|
|
|
|
# Add subset metrics
|
|
for subset, subset_score_list in subset_scores.items():
|
|
valid_subset_scores = [s for s in subset_score_list if s is not None]
|
|
if valid_subset_scores:
|
|
avg_score = sum(valid_subset_scores) / len(valid_subset_scores)
|
|
eval_metrics[f"eval/percent_correct_{subset}"] = avg_score
|
|
|
|
# Add evaluation mode counts and compliance rates
|
|
choice_samples = sum(
|
|
1 for sample in samples if sample.get("evaluation_mode") == "choice"
|
|
)
|
|
ties_samples = sum(
|
|
1 for sample in samples if sample.get("evaluation_mode") == "ties"
|
|
)
|
|
eval_metrics["eval/choice_samples"] = choice_samples
|
|
eval_metrics["eval/ties_samples"] = ties_samples
|
|
|
|
# Add mode-specific compliance rates
|
|
if choice_samples > 0:
|
|
eval_metrics["eval/choice_format_compliance_rate"] = (
|
|
choice_format_compliant / choice_samples
|
|
)
|
|
if ties_samples > 0:
|
|
eval_metrics["eval/ties_format_compliance_rate"] = (
|
|
ties_format_compliant / ties_samples
|
|
)
|
|
|
|
# Add A bias metrics for wrong answers
|
|
eval_metrics["eval/wrong_answer_a_bias_rate"] = a_bias_rate
|
|
eval_metrics["eval/wrong_answer_total_count"] = len(wrong_choice_samples)
|
|
eval_metrics["eval/wrong_answer_a_count"] = wrong_a_choices
|
|
|
|
try:
|
|
await self.evaluate_log(
|
|
metrics=eval_metrics,
|
|
samples=samples,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
generation_parameters={
|
|
"temperature": self.config.eval_temperature,
|
|
"max_tokens": self.config.eval_max_tokens,
|
|
"thinking_mode": thinking_mode_used,
|
|
},
|
|
)
|
|
except Exception as e:
|
|
print(f"Error logging evaluation results: {e}")
|
|
|
|
async def add_rollouts_for_wandb(
|
|
self,
|
|
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
|
item: Item = None,
|
|
) -> None:
|
|
"""Add rollouts to wandb for visualization."""
|
|
if item is None or scored_data is None or not scored_data.get("tokens"):
|
|
return
|
|
|
|
# Extract ground truth and question info
|
|
ground_truth = item[1]
|
|
|
|
# Extract question from the item prompt
|
|
question_info = "unknown_question"
|
|
try:
|
|
# The item[0] contains the prompt tuple with system and user messages
|
|
for role_dict in item[0]:
|
|
role_dict_converted = dict(role_dict)
|
|
if role_dict_converted.get("role") == "user":
|
|
user_content = role_dict_converted.get("content", "")
|
|
# Extract question from the user message format
|
|
question_match = self._question_pattern.search(user_content)
|
|
if question_match:
|
|
question_info = question_match.group(1).strip()
|
|
break
|
|
except Exception as e:
|
|
# Fallback to placeholder if extraction fails
|
|
print(
|
|
f"DEBUG: Exception in add_rollouts_for_wandb question extraction: {e}"
|
|
)
|
|
print(f"DEBUG: Exception type: {type(e)}")
|
|
question_info = "extraction_failed"
|
|
|
|
# Keep a reasonable number of rollouts
|
|
num_keep = self.config.num_rollouts_per_group_for_logging
|
|
if num_keep == -1:
|
|
num_keep = self.config.group_size
|
|
|
|
num_keep = min(num_keep, len(scored_data["tokens"]))
|
|
|
|
current_rollouts = []
|
|
mode = "thinking" if self.config.thinking_mode else "direct"
|
|
|
|
for i in range(num_keep):
|
|
# Decode the full trajectory
|
|
full_text = self.tokenizer.decode(
|
|
scored_data["tokens"][i], skip_special_tokens=True
|
|
)
|
|
score_val = scored_data["scores"][i]
|
|
|
|
# Extract the model's judgment
|
|
predicted_judgment = "unknown"
|
|
try:
|
|
# Try to get model response from messages or decode from tokens
|
|
messages = scored_data.get("messages", [])
|
|
if i < len(messages) and isinstance(messages[i], list) and messages[i]:
|
|
model_response = messages[i][-1].get("content", "")
|
|
else:
|
|
# Fallback to decoding tokens
|
|
model_response = full_text
|
|
|
|
predicted_judgment = self.process_judgement(
|
|
model_response, track_metrics=False
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"DEBUG: Exception in add_rollouts_for_wandb judgment parsing: {e}"
|
|
)
|
|
print(f"DEBUG: Exception type: {type(e)}")
|
|
predicted_judgment = "parse_error"
|
|
|
|
current_rollouts.append(
|
|
(
|
|
full_text,
|
|
score_val,
|
|
ground_truth,
|
|
predicted_judgment,
|
|
question_info,
|
|
mode,
|
|
)
|
|
)
|
|
|
|
self.rollouts_for_wandb.append(current_rollouts)
|
|
|
|
# Keep only recent rollouts
|
|
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
|
self.rollouts_for_wandb.pop(0)
|
|
|
|
async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
|
|
"""Create wandb table for rollout visualization."""
|
|
if not self.rollouts_for_wandb:
|
|
return wandb_metrics
|
|
|
|
table = wandb.Table(
|
|
columns=[
|
|
"full_text",
|
|
"score",
|
|
"ground_truth",
|
|
"predicted_judgment",
|
|
"question_info",
|
|
"mode",
|
|
]
|
|
)
|
|
|
|
for group_rollouts in self.rollouts_for_wandb:
|
|
for rollout_tuple in group_rollouts:
|
|
if len(rollout_tuple) == 6:
|
|
table.add_data(*rollout_tuple)
|
|
|
|
wandb_metrics["train/rollouts"] = table
|
|
self.rollouts_for_wandb = []
|
|
return wandb_metrics
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log metrics to wandb."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Basic accuracy metrics
|
|
if self.percent_correct_buffer:
|
|
wandb_metrics["train/percent_correct"] = sum(
|
|
self.percent_correct_buffer
|
|
) / len(self.percent_correct_buffer)
|
|
|
|
# Judgment letter distribution and accuracy
|
|
total_letters = sum(self.judgment_letter_counts.values())
|
|
if total_letters > 0:
|
|
# Calculate entropy once
|
|
entropy = 0.0
|
|
for letter in self.choice_letters:
|
|
letter_count = self.judgment_letter_counts[letter]
|
|
letter_correct = self.judgment_letter_correct[letter]
|
|
|
|
# Letter frequency and accuracy
|
|
freq = letter_count / total_letters
|
|
wandb_metrics[f"train/judgment_freq_{letter}"] = freq
|
|
wandb_metrics[f"train/judgment_acc_{letter}"] = (
|
|
letter_correct / letter_count if letter_count > 0 else 0.0
|
|
)
|
|
|
|
# Accumulate entropy
|
|
if freq > 0:
|
|
entropy -= freq * math.log(freq)
|
|
|
|
wandb_metrics["train/judgment_entropy"] = entropy
|
|
wandb_metrics["train/judgment_balance"] = entropy / math.log(
|
|
self.config.num_choices
|
|
) # Normalized entropy
|
|
|
|
# Error rate and other metrics
|
|
if self.total_judgments > 0:
|
|
wandb_metrics["train/error_rate"] = self.error_count / self.total_judgments
|
|
wandb_metrics["train/format_compliance_rate"] = 1.0 - (
|
|
self.error_count / self.total_judgments
|
|
)
|
|
|
|
# Configuration and mode metrics
|
|
wandb_metrics.update(
|
|
{
|
|
"train/thinking_mode_enabled": (
|
|
1.0 if self.config.thinking_mode else 0.0
|
|
),
|
|
"train/total_judgments": self.total_judgments,
|
|
"config/group_size": self.config.group_size,
|
|
"config/max_token_length": self.config.max_token_length,
|
|
"config/num_choices": self.config.num_choices,
|
|
}
|
|
)
|
|
|
|
# Reset training metrics
|
|
self._reset_metrics()
|
|
|
|
# Add evaluation metrics
|
|
for metric_name, metric_value in self.eval_metrics:
|
|
wandb_metrics[metric_name] = metric_value
|
|
self.eval_metrics = []
|
|
|
|
# Add rollout table
|
|
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
PairwiseJudgementEnv.cli()
|