mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ef9c0c3699
commit
afab28dfa9
37 changed files with 4868 additions and 4052 deletions
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
MMLU Evaluation Environment for Atropos (Generative/Reasoning Mode)
|
||||
|
||||
This environment evaluates models on the Massive Multitask Language Understanding (MMLU)
|
||||
This environment evaluates models on the Massive Multitask Language Understanding (MMLU)
|
||||
benchmark using a generative approach where models can reason before answering.
|
||||
|
||||
Dataset: lighteval/mmlu (or configurable)
|
||||
|
|
@ -9,7 +9,7 @@ Paper: https://arxiv.org/abs/2009.03300
|
|||
|
||||
The evaluation follows the lighteval generative approach (like GPQA/MMLU-Pro):
|
||||
- Models are prompted to "think step by step before answering"
|
||||
- Models output their reasoning followed by "Answer: X"
|
||||
- Models output their reasoning followed by "Answer: X"
|
||||
- Answer is extracted using regex patterns from the response
|
||||
- Simple string matching validates the extracted answer
|
||||
|
||||
|
|
@ -26,6 +26,15 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from eval_helpers import (
|
||||
build_mcqa_fallback_patterns,
|
||||
create_system_content,
|
||||
extract_letter_from_answer_tag,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
save_eval_results,
|
||||
validate_thinking_format,
|
||||
)
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
|
|
@ -35,16 +44,6 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
)
|
||||
from eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
validate_thinking_format,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
create_system_content,
|
||||
save_eval_results,
|
||||
build_mcqa_fallback_patterns,
|
||||
)
|
||||
|
||||
|
||||
# All 57 MMLU subjects - used for dataset loading and category tracking
|
||||
MMLU_SUBJECTS = [
|
||||
|
|
@ -110,30 +109,69 @@ MMLU_SUBJECTS = [
|
|||
# High-level category groupings for aggregate metrics
|
||||
SUBJECT_CATEGORIES = {
|
||||
"STEM": [
|
||||
"abstract_algebra", "astronomy", "college_biology", "college_chemistry",
|
||||
"college_computer_science", "college_mathematics", "college_physics",
|
||||
"computer_security", "conceptual_physics", "electrical_engineering",
|
||||
"elementary_mathematics", "high_school_biology", "high_school_chemistry",
|
||||
"high_school_computer_science", "high_school_mathematics", "high_school_physics",
|
||||
"high_school_statistics", "machine_learning", "college_medicine",
|
||||
"clinical_knowledge", "medical_genetics", "professional_medicine", "anatomy",
|
||||
"nutrition", "virology", "human_aging",
|
||||
"abstract_algebra",
|
||||
"astronomy",
|
||||
"college_biology",
|
||||
"college_chemistry",
|
||||
"college_computer_science",
|
||||
"college_mathematics",
|
||||
"college_physics",
|
||||
"computer_security",
|
||||
"conceptual_physics",
|
||||
"electrical_engineering",
|
||||
"elementary_mathematics",
|
||||
"high_school_biology",
|
||||
"high_school_chemistry",
|
||||
"high_school_computer_science",
|
||||
"high_school_mathematics",
|
||||
"high_school_physics",
|
||||
"high_school_statistics",
|
||||
"machine_learning",
|
||||
"college_medicine",
|
||||
"clinical_knowledge",
|
||||
"medical_genetics",
|
||||
"professional_medicine",
|
||||
"anatomy",
|
||||
"nutrition",
|
||||
"virology",
|
||||
"human_aging",
|
||||
],
|
||||
"Humanities": [
|
||||
"formal_logic", "high_school_european_history", "high_school_us_history",
|
||||
"high_school_world_history", "international_law", "jurisprudence",
|
||||
"logical_fallacies", "moral_disputes", "moral_scenarios", "philosophy",
|
||||
"prehistory", "professional_law", "world_religions",
|
||||
"formal_logic",
|
||||
"high_school_european_history",
|
||||
"high_school_us_history",
|
||||
"high_school_world_history",
|
||||
"international_law",
|
||||
"jurisprudence",
|
||||
"logical_fallacies",
|
||||
"moral_disputes",
|
||||
"moral_scenarios",
|
||||
"philosophy",
|
||||
"prehistory",
|
||||
"professional_law",
|
||||
"world_religions",
|
||||
],
|
||||
"Social_Sciences": [
|
||||
"econometrics", "high_school_geography", "high_school_government_and_politics",
|
||||
"high_school_macroeconomics", "high_school_microeconomics",
|
||||
"high_school_psychology", "human_sexuality", "professional_psychology",
|
||||
"public_relations", "security_studies", "sociology", "us_foreign_policy",
|
||||
"econometrics",
|
||||
"high_school_geography",
|
||||
"high_school_government_and_politics",
|
||||
"high_school_macroeconomics",
|
||||
"high_school_microeconomics",
|
||||
"high_school_psychology",
|
||||
"human_sexuality",
|
||||
"professional_psychology",
|
||||
"public_relations",
|
||||
"security_studies",
|
||||
"sociology",
|
||||
"us_foreign_policy",
|
||||
],
|
||||
"Other": [
|
||||
"business_ethics", "global_facts", "management", "marketing",
|
||||
"miscellaneous", "professional_accounting",
|
||||
"business_ethics",
|
||||
"global_facts",
|
||||
"management",
|
||||
"marketing",
|
||||
"miscellaneous",
|
||||
"professional_accounting",
|
||||
],
|
||||
}
|
||||
|
||||
|
|
@ -206,7 +244,7 @@ class MMLUEvalConfig(BaseEnvConfig):
|
|||
description="Maximum tokens for evaluation responses. Set high to allow reasoning.",
|
||||
)
|
||||
|
||||
# Prompt configuration
|
||||
# Prompt configuration
|
||||
custom_system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Custom system prompt to append after thinking prompt (if thinking_mode) or use directly.",
|
||||
|
|
@ -246,10 +284,10 @@ class MMLUEvalConfig(BaseEnvConfig):
|
|||
class MMLUEvalEnv(BaseEnv):
|
||||
"""
|
||||
MMLU Evaluation Environment for Atropos (Generative/Reasoning Mode).
|
||||
|
||||
|
||||
Evaluates models on the Massive Multitask Language Understanding benchmark
|
||||
using a generative approach where models reason before answering.
|
||||
|
||||
|
||||
Key features:
|
||||
- Loads MMLU dataset from HuggingFace (lighteval/mmlu format)
|
||||
- Uses lighteval's exact prompt format for GPQA/MMLU-Pro style evaluation
|
||||
|
|
@ -257,7 +295,7 @@ class MMLUEvalEnv(BaseEnv):
|
|||
- Extracts answer letters from patterns like "Answer: A", "The final answer is B", etc.
|
||||
- Tracks per-subject and per-category accuracy
|
||||
- Supports few-shot examples
|
||||
|
||||
|
||||
Answer extraction follows lighteval's approach with priority-ordered patterns:
|
||||
1. "final answer is: X" (highest priority)
|
||||
2. "answer: X" or "answer X"
|
||||
|
|
@ -265,7 +303,7 @@ class MMLUEvalEnv(BaseEnv):
|
|||
4. Letter at start of any line
|
||||
5. Any letter A/B/C/D in response (lowest priority, fallback)
|
||||
"""
|
||||
|
||||
|
||||
name = "mmlu_eval"
|
||||
env_config_cls = MMLUEvalConfig
|
||||
|
||||
|
|
@ -281,16 +319,18 @@ class MMLUEvalEnv(BaseEnv):
|
|||
|
||||
# Initialize metrics tracking
|
||||
self.eval_metrics = []
|
||||
|
||||
|
||||
# Pre-compile regex patterns for thinking mode (like pairwise_judgement_environment)
|
||||
self._think_pattern = re.compile(r"<think>")
|
||||
self._think_close_pattern = re.compile(r"</think>")
|
||||
self._think_content_pattern = re.compile(r"</think>\s*(.*)", re.DOTALL)
|
||||
self._thinking_extract_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
|
||||
|
||||
# Pre-compile regex for <answer></answer> tag extraction (primary method)
|
||||
self._answer_tag_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
|
||||
|
||||
self._answer_tag_pattern = re.compile(
|
||||
r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
# Build fallback answer extraction patterns
|
||||
self._build_extraction_patterns()
|
||||
|
||||
|
|
@ -303,71 +343,63 @@ class MMLUEvalEnv(BaseEnv):
|
|||
return create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt
|
||||
self.config.custom_system_prompt,
|
||||
)
|
||||
|
||||
def _build_extraction_patterns(self):
|
||||
"""
|
||||
Build regex patterns for extracting answer letters from model responses.
|
||||
|
||||
Following lighteval's IndicesExtractionConfig approach, patterns are
|
||||
|
||||
Following lighteval's IndicesExtractionConfig approach, patterns are
|
||||
ordered by priority (lower number = higher priority).
|
||||
"""
|
||||
# Valid answer letters (default to A-D for standard MMLU)
|
||||
letters = "ABCD"
|
||||
|
||||
|
||||
# Build the letter matching pattern - matches A, B, C, D or (A), (B), etc.
|
||||
letter_pattern = rf"([{letters}]|\([{letters}]\))"
|
||||
|
||||
|
||||
# Patterns ordered by priority (most specific first)
|
||||
# Priority 0: "final answer is: X" with "I hope" (very specific, highest confidence)
|
||||
self._pattern_final_answer_hope = re.compile(
|
||||
rf"(?i:final\s+answer\s+is)\s*:?\s*{letter_pattern}\.?\s*I\s*hope",
|
||||
re.IGNORECASE
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
# Priority 50: "final answer ... is X" (allows text between)
|
||||
self._pattern_final_answer_is = re.compile(
|
||||
rf"(?i:final\s+answer).{{0,100}}?\s+is\s*:?\s*{letter_pattern}",
|
||||
re.IGNORECASE | re.DOTALL
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
# Priority 75: "the answer is X"
|
||||
self._pattern_the_answer_is = re.compile(
|
||||
rf"(?i:the\s+answer\s+is)\s*:?\s*{letter_pattern}",
|
||||
re.IGNORECASE
|
||||
rf"(?i:the\s+answer\s+is)\s*:?\s*{letter_pattern}", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
# Priority 100: "answer: X" or "Answer: X" (with colon)
|
||||
self._pattern_answer_colon = re.compile(
|
||||
rf"(?i:answer)\s*:\s*.{{0,50}}?{letter_pattern}",
|
||||
re.IGNORECASE | re.DOTALL
|
||||
rf"(?i:answer)\s*:\s*.{{0,50}}?{letter_pattern}", re.IGNORECASE | re.DOTALL
|
||||
)
|
||||
|
||||
|
||||
# Priority 150: "answer X" or "Answer X" (without colon)
|
||||
self._pattern_answer_space = re.compile(
|
||||
rf"(?i:answer)\s+{letter_pattern}",
|
||||
re.IGNORECASE
|
||||
rf"(?i:answer)\s+{letter_pattern}", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
# Priority 200: Response starts with answer letter (with optional punctuation)
|
||||
self._pattern_start = re.compile(
|
||||
rf"^\s*\**{letter_pattern}\**[\s\.\)\:]",
|
||||
re.IGNORECASE
|
||||
rf"^\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
# Priority 210: Letter at start of any line (for multi-line responses)
|
||||
self._pattern_line_start = re.compile(
|
||||
rf"\n\s*\**{letter_pattern}\**[\s\.\)\:]",
|
||||
re.IGNORECASE
|
||||
rf"\n\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
# Priority 250: Standalone letter with word boundaries
|
||||
self._pattern_standalone = re.compile(
|
||||
rf"\b{letter_pattern}\b",
|
||||
re.IGNORECASE
|
||||
)
|
||||
|
||||
self._pattern_standalone = re.compile(rf"\b{letter_pattern}\b", re.IGNORECASE)
|
||||
|
||||
# Store patterns in priority order
|
||||
self._extraction_patterns = [
|
||||
(0, self._pattern_final_answer_hope, "final_answer_hope"),
|
||||
|
|
@ -405,7 +437,7 @@ class MMLUEvalEnv(BaseEnv):
|
|||
# Thinking mode defaults
|
||||
thinking_mode=True,
|
||||
)
|
||||
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="Hermes-3-Llama-3.1-8B",
|
||||
|
|
@ -415,23 +447,23 @@ class MMLUEvalEnv(BaseEnv):
|
|||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Load the MMLU dataset and prepare for evaluation."""
|
||||
# Determine which subjects to evaluate
|
||||
self.subjects = self.config.subjects or MMLU_SUBJECTS
|
||||
|
||||
|
||||
# Validate subjects
|
||||
invalid_subjects = [s for s in self.subjects if s not in MMLU_SUBJECTS]
|
||||
if invalid_subjects:
|
||||
print(f"Warning: Invalid subjects will be skipped: {invalid_subjects}")
|
||||
self.subjects = [s for s in self.subjects if s in MMLU_SUBJECTS]
|
||||
|
||||
|
||||
if not self.subjects:
|
||||
raise ValueError("No valid MMLU subjects specified for evaluation.")
|
||||
|
||||
|
||||
print(f"\nMMLU Evaluation Setup (Generative Mode):")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Subjects: {len(self.subjects)} subjects")
|
||||
|
|
@ -441,11 +473,11 @@ class MMLUEvalEnv(BaseEnv):
|
|||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
if self.config.thinking_mode:
|
||||
print(f" Thinking prompt: {self._get_thinking_prompt()[:100]}...")
|
||||
|
||||
|
||||
# Load datasets for each subject
|
||||
self.eval_data = {} # subject -> list of eval items
|
||||
self.few_shot_data = {} # subject -> list of few-shot items
|
||||
|
||||
|
||||
total_eval_items = 0
|
||||
for subject in self.subjects:
|
||||
try:
|
||||
|
|
@ -458,7 +490,7 @@ class MMLUEvalEnv(BaseEnv):
|
|||
)
|
||||
self.eval_data[subject] = list(dataset)
|
||||
total_eval_items += len(self.eval_data[subject])
|
||||
|
||||
|
||||
# Load few-shot data if needed
|
||||
if self.config.num_few_shot > 0:
|
||||
few_shot_dataset = load_dataset(
|
||||
|
|
@ -468,23 +500,25 @@ class MMLUEvalEnv(BaseEnv):
|
|||
trust_remote_code=True,
|
||||
)
|
||||
self.few_shot_data[subject] = list(few_shot_dataset)
|
||||
|
||||
|
||||
if self.config.full_debug:
|
||||
print(f" Loaded {subject}: {len(self.eval_data[subject])} eval items")
|
||||
|
||||
print(
|
||||
f" Loaded {subject}: {len(self.eval_data[subject])} eval items"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to load subject '{subject}': {e}")
|
||||
continue
|
||||
|
||||
|
||||
print(f" Total evaluation items: {total_eval_items}")
|
||||
|
||||
|
||||
# Flatten all eval items with subject metadata for iteration
|
||||
self.all_eval_items = []
|
||||
for subject, items in self.eval_data.items():
|
||||
for item in items:
|
||||
item['subject'] = subject # Ensure subject is in each item
|
||||
item["subject"] = subject # Ensure subject is in each item
|
||||
self.all_eval_items.append(item)
|
||||
|
||||
|
||||
self.iter = 0
|
||||
|
||||
def _format_choices(self, choices: List[str]) -> str:
|
||||
|
|
@ -496,98 +530,100 @@ class MMLUEvalEnv(BaseEnv):
|
|||
return "\n".join(lines)
|
||||
|
||||
def _format_mmlu_prompt(
|
||||
self,
|
||||
question: str,
|
||||
choices: List[str],
|
||||
self,
|
||||
question: str,
|
||||
choices: List[str],
|
||||
subject: str,
|
||||
few_shot_examples: Optional[List[Dict]] = None
|
||||
few_shot_examples: Optional[List[Dict]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a question using the lighteval MMLU template.
|
||||
|
||||
Uses the exact GPQA/MMLU-Pro style prompt from lighteval that instructs
|
||||
|
||||
Uses the exact GPQA/MMLU-Pro style prompt from lighteval that instructs
|
||||
the model to think step by step and provide the answer in a specific format.
|
||||
|
||||
|
||||
Args:
|
||||
question: The question text
|
||||
choices: List of answer choices
|
||||
subject: The subject name (for context in prompt)
|
||||
few_shot_examples: Optional list of few-shot example dicts
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted prompt string (user message content)
|
||||
"""
|
||||
num_choices = len(choices)
|
||||
valid_letters = "".join(ascii_uppercase[:num_choices])
|
||||
|
||||
|
||||
# Format choices
|
||||
formatted_choices = self._format_choices(choices)
|
||||
|
||||
|
||||
# Build the question - optionally include subject
|
||||
if self.config.include_subject_in_prompt:
|
||||
subject_display = subject.replace('_', ' ')
|
||||
subject_display = subject.replace("_", " ")
|
||||
question_with_context = f"[{subject_display}]\n\n{question}"
|
||||
else:
|
||||
question_with_context = question
|
||||
|
||||
|
||||
# Use lighteval's exact prompt template
|
||||
prompt = LIGHTEVAL_PROMPT_TEMPLATE.format(
|
||||
question=question_with_context,
|
||||
choices=formatted_choices,
|
||||
valid_letters=valid_letters,
|
||||
)
|
||||
|
||||
|
||||
# Add few-shot examples if provided (prepended)
|
||||
if few_shot_examples:
|
||||
few_shot_text = self._format_few_shot_examples(few_shot_examples)
|
||||
prompt = few_shot_text + "\n\n---\n\n" + prompt
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
def _format_few_shot_examples(self, examples: List[Dict]) -> str:
|
||||
"""Format few-shot examples with answers for context."""
|
||||
formatted = []
|
||||
for example in examples:
|
||||
question = example.get('question', '')
|
||||
choices = example.get('choices', [])
|
||||
answer = example.get('answer', 0)
|
||||
|
||||
question = example.get("question", "")
|
||||
choices = example.get("choices", [])
|
||||
answer = example.get("answer", 0)
|
||||
|
||||
# Get the answer letter
|
||||
if isinstance(answer, int):
|
||||
answer_letter = ascii_uppercase[answer]
|
||||
else:
|
||||
answer_letter = answer.upper()
|
||||
|
||||
|
||||
formatted_choices = self._format_choices(choices)
|
||||
|
||||
example_text = f"Question: {question}\n{formatted_choices}\n\nAnswer: {answer_letter}"
|
||||
|
||||
example_text = (
|
||||
f"Question: {question}\n{formatted_choices}\n\nAnswer: {answer_letter}"
|
||||
)
|
||||
formatted.append(example_text)
|
||||
|
||||
|
||||
return "\n\n---\n\n".join(formatted)
|
||||
|
||||
def _validate_thinking_format(self, response: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate thinking format and extract content after </think> tags.
|
||||
|
||||
|
||||
In thinking mode, we require exactly one pair of <think></think> tags.
|
||||
Returns the content after </think> for answer extraction.
|
||||
|
||||
|
||||
Args:
|
||||
response: The model's full response
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, content_for_extraction)
|
||||
"""
|
||||
if not self.config.thinking_mode:
|
||||
return True, response
|
||||
|
||||
|
||||
# Check for exactly one pair of think tags
|
||||
think_open_count = len(self._think_pattern.findall(response))
|
||||
think_close_count = len(self._think_close_pattern.findall(response))
|
||||
|
||||
|
||||
if think_open_count != 1 or think_close_count != 1:
|
||||
return False, response
|
||||
|
||||
|
||||
# Extract content after </think> tags for answer extraction
|
||||
match = self._think_content_pattern.search(response)
|
||||
if match:
|
||||
|
|
@ -603,33 +639,30 @@ class MMLUEvalEnv(BaseEnv):
|
|||
return None
|
||||
|
||||
def _extract_answer(
|
||||
self,
|
||||
response: str,
|
||||
num_choices: int = 4,
|
||||
choices: Optional[List[str]] = None
|
||||
self, response: str, num_choices: int = 4, choices: Optional[List[str]] = None
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Extract the answer letter from the model's response.
|
||||
|
||||
|
||||
Uses shared helpers from eval_helpers.py.
|
||||
|
||||
|
||||
Primary method: Look for <answer></answer> tags with exactly ONE valid letter,
|
||||
or match against the exact choice texts.
|
||||
Fallback: Use priority-ordered regex patterns.
|
||||
|
||||
|
||||
Args:
|
||||
response: The model's response string (content after </think> in thinking mode)
|
||||
num_choices: Number of valid choices (determines valid letters)
|
||||
choices: Optional list of choice texts for exact matching
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_letter or None, extraction_method used)
|
||||
"""
|
||||
if not response:
|
||||
return None, "empty_response"
|
||||
|
||||
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
|
||||
# PRIMARY: Try <answer></answer> tags first
|
||||
# Also matches against choice texts if provided
|
||||
letter, method = extract_letter_from_answer_tag(
|
||||
|
|
@ -637,28 +670,42 @@ class MMLUEvalEnv(BaseEnv):
|
|||
)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
|
||||
# FALLBACK: Use regex patterns
|
||||
for priority, pattern, method_name in self._extraction_patterns:
|
||||
matches = pattern.findall(response)
|
||||
if matches:
|
||||
match = matches[-1] if method_name in ["final_answer_is", "the_answer_is", "answer_colon", "answer_space"] else matches[0]
|
||||
match = (
|
||||
matches[-1]
|
||||
if method_name
|
||||
in [
|
||||
"final_answer_is",
|
||||
"the_answer_is",
|
||||
"answer_colon",
|
||||
"answer_space",
|
||||
]
|
||||
else matches[0]
|
||||
)
|
||||
if isinstance(match, tuple):
|
||||
match = match[0]
|
||||
letter = match.strip("()").upper()
|
||||
|
||||
|
||||
if letter in valid_letters:
|
||||
if self.config.full_debug:
|
||||
print(f" Extracted '{letter}' using fallback method '{method_name}' (priority {priority})")
|
||||
print(
|
||||
f" Extracted '{letter}' using fallback method '{method_name}' (priority {priority})"
|
||||
)
|
||||
return letter, f"fallback_{method_name}"
|
||||
|
||||
|
||||
# Last resort: find any valid letter (take the last one)
|
||||
for letter in reversed(list(valid_letters)):
|
||||
if letter in response.upper():
|
||||
if self.config.full_debug:
|
||||
print(f" Extracted '{letter}' using fallback 'last_valid_letter'")
|
||||
print(
|
||||
f" Extracted '{letter}' using fallback 'last_valid_letter'"
|
||||
)
|
||||
return letter, "fallback_last_valid_letter"
|
||||
|
||||
|
||||
return None, "no_match"
|
||||
|
||||
async def get_next_item(self):
|
||||
|
|
@ -680,42 +727,42 @@ class MMLUEvalEnv(BaseEnv):
|
|||
async def rollout_and_score_eval(self, eval_item: Dict) -> Dict:
|
||||
"""
|
||||
Evaluate a single MMLU question using generative mode.
|
||||
|
||||
|
||||
The model generates a response with reasoning, then we extract
|
||||
the final answer from patterns like "Answer: A".
|
||||
|
||||
|
||||
In thinking mode, validates <think></think> tags and extracts
|
||||
the answer from content after the closing tag.
|
||||
|
||||
|
||||
Args:
|
||||
eval_item: Dictionary with question, choices, answer, and subject
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with is_correct, extracted_answer, and sample details
|
||||
"""
|
||||
try:
|
||||
subject = eval_item.get('subject', 'unknown')
|
||||
question = eval_item.get('question', '')
|
||||
choices = eval_item.get('choices', [])
|
||||
subject = eval_item.get("subject", "unknown")
|
||||
question = eval_item.get("question", "")
|
||||
choices = eval_item.get("choices", [])
|
||||
num_choices = len(choices)
|
||||
|
||||
|
||||
# Get the correct answer (handle both int index and string letter)
|
||||
gold_answer = eval_item.get('answer', 0)
|
||||
gold_answer = eval_item.get("answer", 0)
|
||||
if isinstance(gold_answer, int):
|
||||
gold_letter = ascii_uppercase[gold_answer]
|
||||
else:
|
||||
gold_letter = gold_answer.upper()
|
||||
|
||||
|
||||
if not question or num_choices < 2:
|
||||
return {"is_correct": None, "sample": None}
|
||||
|
||||
|
||||
# Get few-shot examples for this subject
|
||||
few_shot_examples = None
|
||||
if self.config.num_few_shot > 0 and subject in self.few_shot_data:
|
||||
available_examples = self.few_shot_data[subject]
|
||||
num_examples = min(self.config.num_few_shot, len(available_examples))
|
||||
few_shot_examples = available_examples[:num_examples]
|
||||
|
||||
|
||||
# Format the prompt (lighteval style - user message content)
|
||||
formatted_prompt = self._format_mmlu_prompt(
|
||||
question=question,
|
||||
|
|
@ -723,14 +770,14 @@ class MMLUEvalEnv(BaseEnv):
|
|||
subject=subject,
|
||||
few_shot_examples=few_shot_examples,
|
||||
)
|
||||
|
||||
|
||||
# Build messages
|
||||
messages = []
|
||||
system_content = self._create_system_content()
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
messages.append({"role": "user", "content": formatted_prompt})
|
||||
|
||||
|
||||
# Get model completion with retry logic
|
||||
model_response = None
|
||||
finish_reason = None
|
||||
|
|
@ -743,26 +790,37 @@ class MMLUEvalEnv(BaseEnv):
|
|||
max_tokens=self.config.eval_max_tokens,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
|
||||
if completion.choices and completion.choices[0].message.content:
|
||||
model_response = completion.choices[0].message.content
|
||||
finish_reason = getattr(completion.choices[0], 'finish_reason', None)
|
||||
|
||||
finish_reason = getattr(
|
||||
completion.choices[0], "finish_reason", None
|
||||
)
|
||||
|
||||
# Check minimum response length
|
||||
if len(model_response.strip()) >= self.config.min_response_length:
|
||||
if (
|
||||
len(model_response.strip())
|
||||
>= self.config.min_response_length
|
||||
):
|
||||
break
|
||||
elif attempt < self.config.max_retries - 1:
|
||||
if self.config.full_debug:
|
||||
print(f" Response too short ({len(model_response)} chars), retrying...")
|
||||
print(
|
||||
f" Response too short ({len(model_response)} chars), retrying..."
|
||||
)
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Always log API errors to help diagnose issues
|
||||
print(f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}")
|
||||
if hasattr(e, 'response'):
|
||||
print(
|
||||
f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}"
|
||||
)
|
||||
if hasattr(e, "response"):
|
||||
try:
|
||||
print(f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}")
|
||||
print(
|
||||
f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}"
|
||||
)
|
||||
except:
|
||||
pass
|
||||
if attempt < self.config.max_retries - 1:
|
||||
|
|
@ -770,27 +828,29 @@ class MMLUEvalEnv(BaseEnv):
|
|||
else:
|
||||
print(f" Failed after {self.config.max_retries} attempts")
|
||||
return {"is_correct": None, "sample": None}
|
||||
|
||||
|
||||
if not model_response:
|
||||
return {"is_correct": None, "sample": None}
|
||||
|
||||
|
||||
# Validate thinking format if in thinking mode
|
||||
format_valid, content_for_extraction = self._validate_thinking_format(model_response)
|
||||
|
||||
format_valid, content_for_extraction = self._validate_thinking_format(
|
||||
model_response
|
||||
)
|
||||
|
||||
# Extract thinking content for logging (if in thinking mode)
|
||||
thinking_content = None
|
||||
if self.config.thinking_mode:
|
||||
thinking_content = self._extract_thinking_content(model_response)
|
||||
|
||||
|
||||
# Extract the answer from the response (or content after </think>)
|
||||
# Pass choices for exact text matching support
|
||||
extracted_answer, extraction_method = self._extract_answer(
|
||||
content_for_extraction, num_choices, choices=choices
|
||||
)
|
||||
|
||||
|
||||
# Check if correct
|
||||
is_correct = extracted_answer == gold_letter if extracted_answer else False
|
||||
|
||||
|
||||
# Build sample record for logging
|
||||
sample = {
|
||||
"subject": subject,
|
||||
|
|
@ -807,30 +867,35 @@ class MMLUEvalEnv(BaseEnv):
|
|||
"thinking_mode": self.config.thinking_mode,
|
||||
"format_valid": format_valid,
|
||||
}
|
||||
|
||||
|
||||
# Add thinking-specific info
|
||||
if self.config.thinking_mode:
|
||||
sample["thinking_content"] = thinking_content
|
||||
sample["response_after_think"] = content_for_extraction if format_valid else None
|
||||
|
||||
sample["response_after_think"] = (
|
||||
content_for_extraction if format_valid else None
|
||||
)
|
||||
|
||||
if self.config.full_debug:
|
||||
status = "✓" if is_correct else "✗"
|
||||
format_status = "✓" if format_valid else "✗"
|
||||
print(f" [{status}] {subject}: gold={gold_letter}, extracted={extracted_answer} ({extraction_method}), format={format_status}")
|
||||
|
||||
print(
|
||||
f" [{status}] {subject}: gold={gold_letter}, extracted={extracted_answer} ({extraction_method}), format={format_status}"
|
||||
)
|
||||
|
||||
return {"is_correct": is_correct, "sample": sample}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f"Error in rollout_and_score_eval: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return {"is_correct": None, "sample": None}
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run MMLU evaluation across all configured subjects.
|
||||
|
||||
|
||||
Calculates:
|
||||
- Overall accuracy
|
||||
- Per-subject accuracy
|
||||
|
|
@ -840,7 +905,7 @@ class MMLUEvalEnv(BaseEnv):
|
|||
- Thinking utilization metrics
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Starting MMLU Evaluation (Generative/Reasoning Mode)")
|
||||
print(f"{'='*60}")
|
||||
|
|
@ -850,40 +915,42 @@ class MMLUEvalEnv(BaseEnv):
|
|||
print(f" Max tokens (for reasoning): {self.config.eval_max_tokens}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
try:
|
||||
# Run evaluation for all items
|
||||
eval_tasks = [
|
||||
self.rollout_and_score_eval(item) for item in self.all_eval_items
|
||||
]
|
||||
results = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating MMLU")
|
||||
|
||||
|
||||
# Filter valid results
|
||||
valid_results = [
|
||||
r for r in results
|
||||
r
|
||||
for r in results
|
||||
if r and r.get("sample") is not None and r.get("is_correct") is not None
|
||||
]
|
||||
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during evaluation: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
# Compute metrics
|
||||
samples = [r["sample"] for r in valid_results]
|
||||
|
||||
|
||||
# Overall accuracy
|
||||
total_correct = sum(1 for r in valid_results if r["is_correct"])
|
||||
total_count = len(valid_results)
|
||||
overall_accuracy = total_correct / total_count if total_count > 0 else 0.0
|
||||
|
||||
|
||||
# Per-subject accuracy
|
||||
subject_results = {}
|
||||
for sample in samples:
|
||||
|
|
@ -893,16 +960,18 @@ class MMLUEvalEnv(BaseEnv):
|
|||
subject_results[subject]["total"] += 1
|
||||
if sample["is_correct"]:
|
||||
subject_results[subject]["correct"] += 1
|
||||
|
||||
|
||||
# Per-category accuracy
|
||||
category_results = {cat: {"correct": 0, "total": 0} for cat in SUBJECT_CATEGORIES}
|
||||
category_results = {
|
||||
cat: {"correct": 0, "total": 0} for cat in SUBJECT_CATEGORIES
|
||||
}
|
||||
for subject, stats in subject_results.items():
|
||||
for category, subjects_in_cat in SUBJECT_CATEGORIES.items():
|
||||
if subject in subjects_in_cat:
|
||||
category_results[category]["correct"] += stats["correct"]
|
||||
category_results[category]["total"] += stats["total"]
|
||||
break
|
||||
|
||||
|
||||
# Extraction method statistics
|
||||
extraction_methods = {}
|
||||
for sample in samples:
|
||||
|
|
@ -912,20 +981,22 @@ class MMLUEvalEnv(BaseEnv):
|
|||
extraction_methods[method]["count"] += 1
|
||||
if sample["is_correct"]:
|
||||
extraction_methods[method]["correct"] += 1
|
||||
|
||||
|
||||
# Average response length
|
||||
response_lengths = [s.get("response_length", 0) for s in samples]
|
||||
avg_response_length = sum(response_lengths) / len(response_lengths) if response_lengths else 0
|
||||
|
||||
avg_response_length = (
|
||||
sum(response_lengths) / len(response_lengths) if response_lengths else 0
|
||||
)
|
||||
|
||||
# Format compliance (for thinking mode)
|
||||
format_compliant = sum(1 for s in samples if s.get("format_valid", True))
|
||||
format_compliance_rate = format_compliant / len(samples) if samples else 0.0
|
||||
|
||||
|
||||
# Thinking utilization (how many responses had thinking content)
|
||||
thinking_utilization = 0
|
||||
if self.config.thinking_mode:
|
||||
thinking_utilization = sum(1 for s in samples if s.get("thinking_content"))
|
||||
|
||||
|
||||
# Build metrics dictionary
|
||||
eval_metrics = {
|
||||
"eval/overall_accuracy": overall_accuracy,
|
||||
|
|
@ -938,26 +1009,30 @@ class MMLUEvalEnv(BaseEnv):
|
|||
"eval/format_compliance_rate": format_compliance_rate,
|
||||
"eval/thinking_mode_enabled": 1.0 if self.config.thinking_mode else 0.0,
|
||||
}
|
||||
|
||||
|
||||
# Add thinking utilization if in thinking mode
|
||||
if self.config.thinking_mode:
|
||||
thinking_utilization_rate = thinking_utilization / len(samples) if samples else 0.0
|
||||
thinking_utilization_rate = (
|
||||
thinking_utilization / len(samples) if samples else 0.0
|
||||
)
|
||||
eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate
|
||||
|
||||
|
||||
# Add category metrics
|
||||
for category, stats in category_results.items():
|
||||
if stats["total"] > 0:
|
||||
cat_accuracy = stats["correct"] / stats["total"]
|
||||
eval_metrics[f"eval/category_{category.lower()}_accuracy"] = cat_accuracy
|
||||
eval_metrics[f"eval/category_{category.lower()}_accuracy"] = (
|
||||
cat_accuracy
|
||||
)
|
||||
eval_metrics[f"eval/category_{category.lower()}_total"] = stats["total"]
|
||||
|
||||
|
||||
# Add extraction method metrics
|
||||
for method, stats in extraction_methods.items():
|
||||
if stats["count"] > 0:
|
||||
method_accuracy = stats["correct"] / stats["count"]
|
||||
eval_metrics[f"eval/extraction_{method}_count"] = stats["count"]
|
||||
eval_metrics[f"eval/extraction_{method}_accuracy"] = method_accuracy
|
||||
|
||||
|
||||
# Add per-subject metrics
|
||||
for subject, stats in sorted(subject_results.items()):
|
||||
if stats["total"] > 0:
|
||||
|
|
@ -965,35 +1040,41 @@ class MMLUEvalEnv(BaseEnv):
|
|||
# Sanitize subject name for metric key
|
||||
subj_key = subject.replace(" ", "_").replace("-", "_")
|
||||
eval_metrics[f"eval/subject_{subj_key}_accuracy"] = subj_accuracy
|
||||
|
||||
|
||||
# Store metrics for wandb logging
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
print(f"MMLU Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(f"Overall Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_count})")
|
||||
print(
|
||||
f"Overall Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_count})"
|
||||
)
|
||||
print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
|
||||
print(f"Avg Response Length: {avg_response_length:.0f} chars")
|
||||
if self.config.thinking_mode:
|
||||
print(f"Format Compliance: {format_compliance_rate:.4f}")
|
||||
print(f"Thinking Utilization: {thinking_utilization}/{total_count}")
|
||||
|
||||
|
||||
print(f"\nCategory Breakdown:")
|
||||
for category, stats in category_results.items():
|
||||
if stats["total"] > 0:
|
||||
cat_acc = stats["correct"] / stats["total"]
|
||||
print(f" {category}: {cat_acc:.4f} ({stats['correct']}/{stats['total']})")
|
||||
|
||||
print(
|
||||
f" {category}: {cat_acc:.4f} ({stats['correct']}/{stats['total']})"
|
||||
)
|
||||
|
||||
print(f"\nExtraction Method Statistics:")
|
||||
for method, stats in sorted(extraction_methods.items(), key=lambda x: -x[1]["count"]):
|
||||
for method, stats in sorted(
|
||||
extraction_methods.items(), key=lambda x: -x[1]["count"]
|
||||
):
|
||||
if stats["count"] > 0:
|
||||
method_acc = stats["correct"] / stats["count"]
|
||||
print(f" {method}: {stats['count']} uses, {method_acc:.4f} accuracy")
|
||||
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
# Log evaluation results
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
|
|
@ -1016,17 +1097,19 @@ class MMLUEvalEnv(BaseEnv):
|
|||
"""Log metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
|
||||
# Add evaluation metrics
|
||||
for metric_name, metric_value in self.eval_metrics:
|
||||
wandb_metrics[metric_name] = metric_value
|
||||
self.eval_metrics = []
|
||||
|
||||
|
||||
# Add config metrics
|
||||
wandb_metrics["config/thinking_mode"] = 1.0 if self.config.thinking_mode else 0.0
|
||||
wandb_metrics["config/thinking_mode"] = (
|
||||
1.0 if self.config.thinking_mode else 0.0
|
||||
)
|
||||
wandb_metrics["config/num_few_shot"] = self.config.num_few_shot
|
||||
wandb_metrics["config/eval_max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue