""" MMLU Evaluation Environment for Atropos (Generative/Reasoning Mode) This environment evaluates models on the Massive Multitask Language Understanding (MMLU) benchmark using a generative approach where models can reason before answering. Dataset: lighteval/mmlu (or configurable) Paper: https://arxiv.org/abs/2009.03300 The evaluation follows the lighteval generative approach (like GPQA/MMLU-Pro): - Models are prompted to "think step by step before answering" - Models output their reasoning followed by "Answer: X" - Answer is extracted using regex patterns from the response - Simple string matching validates the extracted answer Supports optional thinking mode with tags for extended reasoning. """ import asyncio import os import random import re import time from string import ascii_uppercase from typing import Dict, List, Optional, Tuple import wandb from datasets import load_dataset from eval_helpers import ( build_mcqa_fallback_patterns, create_system_content, extract_letter_from_answer_tag, extract_thinking_content, get_default_thinking_prompt, save_eval_results, validate_thinking_format, ) from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, ) # All 57 MMLU subjects - used for dataset loading and category tracking MMLU_SUBJECTS = [ "abstract_algebra", "anatomy", "astronomy", "business_ethics", "clinical_knowledge", "college_biology", "college_chemistry", "college_computer_science", "college_mathematics", "college_medicine", "college_physics", "computer_security", "conceptual_physics", "econometrics", "electrical_engineering", "elementary_mathematics", "formal_logic", "global_facts", "high_school_biology", "high_school_chemistry", "high_school_computer_science", "high_school_european_history", "high_school_geography", "high_school_government_and_politics", "high_school_macroeconomics", "high_school_mathematics", "high_school_microeconomics", "high_school_physics", "high_school_psychology", "high_school_statistics", "high_school_us_history", "high_school_world_history", "human_aging", "human_sexuality", "international_law", "jurisprudence", "logical_fallacies", "machine_learning", "management", "marketing", "medical_genetics", "miscellaneous", "moral_disputes", "moral_scenarios", "nutrition", "philosophy", "prehistory", "professional_accounting", "professional_law", "professional_medicine", "professional_psychology", "public_relations", "security_studies", "sociology", "us_foreign_policy", "virology", "world_religions", ] # High-level category groupings for aggregate metrics SUBJECT_CATEGORIES = { "STEM": [ "abstract_algebra", "astronomy", "college_biology", "college_chemistry", "college_computer_science", "college_mathematics", "college_physics", "computer_security", "conceptual_physics", "electrical_engineering", "elementary_mathematics", "high_school_biology", "high_school_chemistry", "high_school_computer_science", "high_school_mathematics", "high_school_physics", "high_school_statistics", "machine_learning", "college_medicine", "clinical_knowledge", "medical_genetics", "professional_medicine", "anatomy", "nutrition", "virology", "human_aging", ], "Humanities": [ "formal_logic", "high_school_european_history", "high_school_us_history", "high_school_world_history", "international_law", "jurisprudence", "logical_fallacies", "moral_disputes", "moral_scenarios", "philosophy", "prehistory", "professional_law", "world_religions", ], "Social_Sciences": [ "econometrics", "high_school_geography", "high_school_government_and_politics", "high_school_macroeconomics", "high_school_microeconomics", "high_school_psychology", "human_sexuality", "professional_psychology", "public_relations", "security_studies", "sociology", "us_foreign_policy", ], "Other": [ "business_ethics", "global_facts", "management", "marketing", "miscellaneous", "professional_accounting", ], } # Generative prompt template with tag instruction # This is the USER message content - system prompt is handled separately LIGHTEVAL_PROMPT_TEMPLATE = """Answer the following multiple choice question. Think step by step before answering. Provide your final answer within tags, containing only the letter ({valid_letters}). Example format: A {question} {choices}""" class MMLUEvalConfig(BaseEnvConfig): """Configuration for MMLU evaluation environment (generative mode).""" # Thinking mode configuration (like pairwise_judgement_environment) thinking_mode: bool = Field( default=True, description="Whether to enable thinking mode with tags.", ) custom_thinking_prompt: Optional[str] = Field( default=None, description="Custom thinking prompt. If None, uses the default thinking prompt.", ) # Dataset configuration dataset_name: str = Field( default="lighteval/mmlu", description="HuggingFace dataset name for MMLU.", ) subjects: Optional[List[str]] = Field( default=None, description="List of MMLU subjects to evaluate. If None, evaluates all 57 subjects.", ) eval_split: str = Field( default="test", description="Dataset split to use for evaluation.", ) few_shot_split: str = Field( default="dev", description="Dataset split to use for few-shot examples.", ) # Few-shot configuration num_few_shot: int = Field( default=0, ge=0, le=5, description="Number of few-shot examples to include (0-5 recommended).", ) # Model generation configuration eval_temperature: float = Field( default=0.0, description="Temperature for evaluation (0.0 for deterministic).", ) eval_max_tokens: int = Field( default=0, description="Maximum tokens for evaluation responses. Set high to allow reasoning.", ) # Prompt configuration custom_system_prompt: Optional[str] = Field( default=None, description="Custom system prompt to append after thinking prompt (if thinking_mode) or use directly.", ) include_subject_in_prompt: bool = Field( default=False, description="Whether to include the subject name in the prompt for context.", ) # Retry configuration max_retries: int = Field( default=3, ge=1, description="Maximum retries for failed API calls.", ) retry_delay: float = Field( default=1.0, ge=0.0, description="Delay between retry attempts in seconds.", ) min_response_length: int = Field( default=1, ge=1, description="Minimum response length to consider valid.", ) # Debug configuration full_debug: bool = Field( default=False, description="Enable verbose debug logging.", ) class MMLUEvalEnv(BaseEnv): """ MMLU Evaluation Environment for Atropos (Generative/Reasoning Mode). Evaluates models on the Massive Multitask Language Understanding benchmark using a generative approach where models reason before answering. Key features: - Loads MMLU dataset from HuggingFace (lighteval/mmlu format) - Uses lighteval's exact prompt format for GPQA/MMLU-Pro style evaluation - Optional thinking mode with tags for extended reasoning - Extracts answer letters from patterns like "Answer: A", "The final answer is B", etc. - Tracks per-subject and per-category accuracy - Supports few-shot examples Answer extraction follows lighteval's approach with priority-ordered patterns: 1. "final answer is: X" (highest priority) 2. "answer: X" or "answer X" 3. Response starts with letter 4. Letter at start of any line 5. Any letter A/B/C/D in response (lowest priority, fallback) """ name = "mmlu_eval" env_config_cls = MMLUEvalConfig def __init__( self, config: MMLUEvalConfig, server_configs: List[APIServerConfig], slurm=True, testing=False, ): super().__init__(config, server_configs, slurm, testing) self.config: MMLUEvalConfig = config # Initialize metrics tracking self.eval_metrics = [] # Pre-compile regex patterns for thinking mode (like pairwise_judgement_environment) self._think_pattern = re.compile(r"") self._think_close_pattern = re.compile(r"") self._think_content_pattern = re.compile(r"\s*(.*)", re.DOTALL) self._thinking_extract_pattern = re.compile(r"(.*?)", re.DOTALL) # Pre-compile regex for tag extraction (primary method) self._answer_tag_pattern = re.compile( r"(.*?)", re.DOTALL | re.IGNORECASE ) # Build fallback answer extraction patterns self._build_extraction_patterns() def _get_thinking_prompt(self) -> str: """Get thinking system prompt.""" return get_default_thinking_prompt(self.config.custom_thinking_prompt) def _create_system_content(self) -> Optional[str]: """Create system message content based on thinking mode.""" return create_system_content( self.config.thinking_mode, self.config.custom_thinking_prompt, self.config.custom_system_prompt, ) def _build_extraction_patterns(self): """ Build regex patterns for extracting answer letters from model responses. Following lighteval's IndicesExtractionConfig approach, patterns are ordered by priority (lower number = higher priority). """ # Valid answer letters (default to A-D for standard MMLU) letters = "ABCD" # Build the letter matching pattern - matches A, B, C, D or (A), (B), etc. letter_pattern = rf"([{letters}]|\([{letters}]\))" # Patterns ordered by priority (most specific first) # Priority 0: "final answer is: X" with "I hope" (very specific, highest confidence) self._pattern_final_answer_hope = re.compile( rf"(?i:final\s+answer\s+is)\s*:?\s*{letter_pattern}\.?\s*I\s*hope", re.IGNORECASE, ) # Priority 50: "final answer ... is X" (allows text between) self._pattern_final_answer_is = re.compile( rf"(?i:final\s+answer).{{0,100}}?\s+is\s*:?\s*{letter_pattern}", re.IGNORECASE | re.DOTALL, ) # Priority 75: "the answer is X" self._pattern_the_answer_is = re.compile( rf"(?i:the\s+answer\s+is)\s*:?\s*{letter_pattern}", re.IGNORECASE ) # Priority 100: "answer: X" or "Answer: X" (with colon) self._pattern_answer_colon = re.compile( rf"(?i:answer)\s*:\s*.{{0,50}}?{letter_pattern}", re.IGNORECASE | re.DOTALL ) # Priority 150: "answer X" or "Answer X" (without colon) self._pattern_answer_space = re.compile( rf"(?i:answer)\s+{letter_pattern}", re.IGNORECASE ) # Priority 200: Response starts with answer letter (with optional punctuation) self._pattern_start = re.compile( rf"^\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE ) # Priority 210: Letter at start of any line (for multi-line responses) self._pattern_line_start = re.compile( rf"\n\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE ) # Priority 250: Standalone letter with word boundaries self._pattern_standalone = re.compile(rf"\b{letter_pattern}\b", re.IGNORECASE) # Store patterns in priority order self._extraction_patterns = [ (0, self._pattern_final_answer_hope, "final_answer_hope"), (50, self._pattern_final_answer_is, "final_answer_is"), (75, self._pattern_the_answer_is, "the_answer_is"), (100, self._pattern_answer_colon, "answer_colon"), (150, self._pattern_answer_space, "answer_space"), (200, self._pattern_start, "start"), (210, self._pattern_line_start, "line_start"), (250, self._pattern_standalone, "standalone"), ] @classmethod def config_init(cls) -> Tuple[MMLUEvalConfig, List[APIServerConfig]]: """Initialize default configuration for the environment.""" env_config = MMLUEvalConfig( tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B", group_size=1, # Eval only, no training groups needed use_wandb=True, max_num_workers_per_node=8, rollout_server_url="http://localhost:8000", total_steps=1, # Eval-only environment batch_size=1, steps_per_eval=1, inference_weight=1.0, wandb_name="mmlu_eval", eval_handling=EvalHandlingEnum.STOP_TRAIN, # MMLU-specific defaults dataset_name="lighteval/mmlu", subjects=None, # All subjects by default num_few_shot=0, # 0-shot by default for generative eval_temperature=0.6, eval_max_tokens=0, # Use model default include_subject_in_prompt=False, # Match lighteval default # Thinking mode defaults thinking_mode=True, ) server_configs = [ APIServerConfig( model_name="Hermes-3-Llama-3.1-8B", base_url="http://localhost:9000/v1", api_key=os.getenv("OPENAI_API_KEY", "none"), num_max_requests_at_once=32, num_requests_for_eval=256, ), ] return env_config, server_configs async def setup(self) -> None: """Load the MMLU dataset and prepare for evaluation.""" # Determine which subjects to evaluate self.subjects = self.config.subjects or MMLU_SUBJECTS # Validate subjects invalid_subjects = [s for s in self.subjects if s not in MMLU_SUBJECTS] if invalid_subjects: print(f"Warning: Invalid subjects will be skipped: {invalid_subjects}") self.subjects = [s for s in self.subjects if s in MMLU_SUBJECTS] if not self.subjects: raise ValueError("No valid MMLU subjects specified for evaluation.") print(f"\nMMLU Evaluation Setup (Generative Mode):") print(f" Dataset: {self.config.dataset_name}") print(f" Subjects: {len(self.subjects)} subjects") print(f" Few-shot examples: {self.config.num_few_shot}") print(f" Max tokens for reasoning: {self.config.eval_max_tokens}") print(f" Evaluation split: {self.config.eval_split}") print(f" Thinking mode: {self.config.thinking_mode}") if self.config.thinking_mode: print(f" Thinking prompt: {self._get_thinking_prompt()[:100]}...") # Load datasets for each subject self.eval_data = {} # subject -> list of eval items self.few_shot_data = {} # subject -> list of few-shot items total_eval_items = 0 for subject in self.subjects: try: # Load evaluation data dataset = load_dataset( self.config.dataset_name, subject, split=self.config.eval_split, trust_remote_code=True, ) self.eval_data[subject] = list(dataset) total_eval_items += len(self.eval_data[subject]) # Load few-shot data if needed if self.config.num_few_shot > 0: few_shot_dataset = load_dataset( self.config.dataset_name, subject, split=self.config.few_shot_split, trust_remote_code=True, ) self.few_shot_data[subject] = list(few_shot_dataset) if self.config.full_debug: print( f" Loaded {subject}: {len(self.eval_data[subject])} eval items" ) except Exception as e: print(f" Warning: Failed to load subject '{subject}': {e}") continue print(f" Total evaluation items: {total_eval_items}") # Flatten all eval items with subject metadata for iteration self.all_eval_items = [] for subject, items in self.eval_data.items(): for item in items: item["subject"] = subject # Ensure subject is in each item self.all_eval_items.append(item) self.iter = 0 def _format_choices(self, choices: List[str]) -> str: """Format choices as A) choice1, B) choice2, etc.""" lines = [] for idx, choice in enumerate(choices): letter = ascii_uppercase[idx] lines.append(f"{letter}) {choice}") return "\n".join(lines) def _format_mmlu_prompt( self, question: str, choices: List[str], subject: str, few_shot_examples: Optional[List[Dict]] = None, ) -> str: """ Format a question using the lighteval MMLU template. Uses the exact GPQA/MMLU-Pro style prompt from lighteval that instructs the model to think step by step and provide the answer in a specific format. Args: question: The question text choices: List of answer choices subject: The subject name (for context in prompt) few_shot_examples: Optional list of few-shot example dicts Returns: Formatted prompt string (user message content) """ num_choices = len(choices) valid_letters = "".join(ascii_uppercase[:num_choices]) # Format choices formatted_choices = self._format_choices(choices) # Build the question - optionally include subject if self.config.include_subject_in_prompt: subject_display = subject.replace("_", " ") question_with_context = f"[{subject_display}]\n\n{question}" else: question_with_context = question # Use lighteval's exact prompt template prompt = LIGHTEVAL_PROMPT_TEMPLATE.format( question=question_with_context, choices=formatted_choices, valid_letters=valid_letters, ) # Add few-shot examples if provided (prepended) if few_shot_examples: few_shot_text = self._format_few_shot_examples(few_shot_examples) prompt = few_shot_text + "\n\n---\n\n" + prompt return prompt def _format_few_shot_examples(self, examples: List[Dict]) -> str: """Format few-shot examples with answers for context.""" formatted = [] for example in examples: question = example.get("question", "") choices = example.get("choices", []) answer = example.get("answer", 0) # Get the answer letter if isinstance(answer, int): answer_letter = ascii_uppercase[answer] else: answer_letter = answer.upper() formatted_choices = self._format_choices(choices) example_text = ( f"Question: {question}\n{formatted_choices}\n\nAnswer: {answer_letter}" ) formatted.append(example_text) return "\n\n---\n\n".join(formatted) def _validate_thinking_format(self, response: str) -> Tuple[bool, str]: """ Validate thinking format and extract content after tags. In thinking mode, we require exactly one pair of tags. Returns the content after for answer extraction. Args: response: The model's full response Returns: Tuple of (is_valid, content_for_extraction) """ if not self.config.thinking_mode: return True, response # Check for exactly one pair of think tags think_open_count = len(self._think_pattern.findall(response)) think_close_count = len(self._think_close_pattern.findall(response)) if think_open_count != 1 or think_close_count != 1: return False, response # Extract content after tags for answer extraction match = self._think_content_pattern.search(response) if match: return True, match.group(1).strip() else: return False, response def _extract_thinking_content(self, response: str) -> Optional[str]: """Extract the content inside tags.""" match = self._thinking_extract_pattern.search(response) if match: return match.group(1).strip() return None def _extract_answer( self, response: str, num_choices: int = 4, choices: Optional[List[str]] = None ) -> Tuple[Optional[str], str]: """ Extract the answer letter from the model's response. Uses shared helpers from eval_helpers.py. Primary method: Look for tags with exactly ONE valid letter, or match against the exact choice texts. Fallback: Use priority-ordered regex patterns. Args: response: The model's response string (content after in thinking mode) num_choices: Number of valid choices (determines valid letters) choices: Optional list of choice texts for exact matching Returns: Tuple of (extracted_letter or None, extraction_method used) """ if not response: return None, "empty_response" valid_letters = set(ascii_uppercase[:num_choices]) # PRIMARY: Try tags first # Also matches against choice texts if provided letter, method = extract_letter_from_answer_tag( response, valid_letters, debug=self.config.full_debug, choices=choices ) if letter: return letter, method # FALLBACK: Use regex patterns for priority, pattern, method_name in self._extraction_patterns: matches = pattern.findall(response) if matches: match = ( matches[-1] if method_name in [ "final_answer_is", "the_answer_is", "answer_colon", "answer_space", ] else matches[0] ) if isinstance(match, tuple): match = match[0] letter = match.strip("()").upper() if letter in valid_letters: if self.config.full_debug: print( f" Extracted '{letter}' using fallback method '{method_name}' (priority {priority})" ) return letter, f"fallback_{method_name}" # Last resort: find any valid letter (take the last one) for letter in reversed(list(valid_letters)): if letter in response.upper(): if self.config.full_debug: print( f" Extracted '{letter}' using fallback 'last_valid_letter'" ) return letter, "fallback_last_valid_letter" return None, "no_match" async def get_next_item(self): """Get next item for training (not used in eval-only environment).""" self.iter += 1 if self.all_eval_items: item = self.all_eval_items[self.iter % len(self.all_eval_items)] return item return None async def collect_trajectories(self, item): """Collect trajectories (not used in eval-only environment).""" return None, [] async def score(self, rollout_group_data): """Score rollouts (not used in eval-only environment).""" return None async def rollout_and_score_eval(self, eval_item: Dict) -> Dict: """ Evaluate a single MMLU question using generative mode. The model generates a response with reasoning, then we extract the final answer from patterns like "Answer: A". In thinking mode, validates tags and extracts the answer from content after the closing tag. Args: eval_item: Dictionary with question, choices, answer, and subject Returns: Dictionary with is_correct, extracted_answer, and sample details """ try: subject = eval_item.get("subject", "unknown") question = eval_item.get("question", "") choices = eval_item.get("choices", []) num_choices = len(choices) # Get the correct answer (handle both int index and string letter) gold_answer = eval_item.get("answer", 0) if isinstance(gold_answer, int): gold_letter = ascii_uppercase[gold_answer] else: gold_letter = gold_answer.upper() if not question or num_choices < 2: return {"is_correct": None, "sample": None} # Get few-shot examples for this subject few_shot_examples = None if self.config.num_few_shot > 0 and subject in self.few_shot_data: available_examples = self.few_shot_data[subject] num_examples = min(self.config.num_few_shot, len(available_examples)) few_shot_examples = available_examples[:num_examples] # Format the prompt (lighteval style - user message content) formatted_prompt = self._format_mmlu_prompt( question=question, choices=choices, subject=subject, few_shot_examples=few_shot_examples, ) # Build messages messages = [] system_content = self._create_system_content() if system_content: messages.append({"role": "system", "content": system_content}) messages.append({"role": "user", "content": formatted_prompt}) # Get model completion with retry logic model_response = None finish_reason = None for attempt in range(self.config.max_retries): try: completion = await self.server.chat_completion( messages=messages, n=1, temperature=self.config.eval_temperature, max_tokens=self.config.eval_max_tokens, split="eval", ) if completion.choices and completion.choices[0].message.content: model_response = completion.choices[0].message.content finish_reason = getattr( completion.choices[0], "finish_reason", None ) # Check minimum response length if ( len(model_response.strip()) >= self.config.min_response_length ): break elif attempt < self.config.max_retries - 1: if self.config.full_debug: print( f" Response too short ({len(model_response)} chars), retrying..." ) await asyncio.sleep(self.config.retry_delay) continue except Exception as e: # Always log API errors to help diagnose issues print( f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}" ) if hasattr(e, "response"): try: print( f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}" ) except: pass if attempt < self.config.max_retries - 1: await asyncio.sleep(self.config.retry_delay) else: print(f" Failed after {self.config.max_retries} attempts") return {"is_correct": None, "sample": None} if not model_response: return {"is_correct": None, "sample": None} # Validate thinking format if in thinking mode format_valid, content_for_extraction = self._validate_thinking_format( model_response ) # Extract thinking content for logging (if in thinking mode) thinking_content = None if self.config.thinking_mode: thinking_content = self._extract_thinking_content(model_response) # Extract the answer from the response (or content after ) # Pass choices for exact text matching support extracted_answer, extraction_method = self._extract_answer( content_for_extraction, num_choices, choices=choices ) # Check if correct is_correct = extracted_answer == gold_letter if extracted_answer else False # Build sample record for logging sample = { "subject": subject, "question": question, "choices": choices, "gold_answer": gold_letter, "model_response": model_response, "extracted_answer": extracted_answer, "extraction_method": extraction_method, "is_correct": is_correct, "num_few_shot": self.config.num_few_shot, "finish_reason": finish_reason, "response_length": len(model_response), "thinking_mode": self.config.thinking_mode, "format_valid": format_valid, } # Add thinking-specific info if self.config.thinking_mode: sample["thinking_content"] = thinking_content sample["response_after_think"] = ( content_for_extraction if format_valid else None ) if self.config.full_debug: status = "✓" if is_correct else "✗" format_status = "✓" if format_valid else "✗" print( f" [{status}] {subject}: gold={gold_letter}, extracted={extracted_answer} ({extraction_method}), format={format_status}" ) return {"is_correct": is_correct, "sample": sample} except Exception as e: if self.config.full_debug: print(f"Error in rollout_and_score_eval: {e}") import traceback traceback.print_exc() return {"is_correct": None, "sample": None} async def evaluate(self, *args, **kwargs) -> None: """ Run MMLU evaluation across all configured subjects. Calculates: - Overall accuracy - Per-subject accuracy - Per-category accuracy (STEM, Humanities, Social Sciences, Other) - Extraction method statistics - Format compliance (for thinking mode) - Thinking utilization metrics """ start_time = time.time() print(f"\n{'='*60}") print(f"Starting MMLU Evaluation (Generative/Reasoning Mode)") print(f"{'='*60}") print(f" Subjects: {len(self.subjects)}") print(f" Total questions: {len(self.all_eval_items)}") print(f" Few-shot examples: {self.config.num_few_shot}") print(f" Max tokens (for reasoning): {self.config.eval_max_tokens}") print(f" Thinking mode: {self.config.thinking_mode}") print(f"{'='*60}\n") try: # Run evaluation for all items eval_tasks = [ self.rollout_and_score_eval(item) for item in self.all_eval_items ] results = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating MMLU") # Filter valid results valid_results = [ r for r in results if r and r.get("sample") is not None and r.get("is_correct") is not None ] if not valid_results: print("Warning: No valid evaluation results obtained") return except Exception as e: print(f"Error during evaluation: {e}") import traceback traceback.print_exc() return end_time = time.time() # Compute metrics samples = [r["sample"] for r in valid_results] # Overall accuracy total_correct = sum(1 for r in valid_results if r["is_correct"]) total_count = len(valid_results) overall_accuracy = total_correct / total_count if total_count > 0 else 0.0 # Per-subject accuracy subject_results = {} for sample in samples: subject = sample["subject"] if subject not in subject_results: subject_results[subject] = {"correct": 0, "total": 0} subject_results[subject]["total"] += 1 if sample["is_correct"]: subject_results[subject]["correct"] += 1 # Per-category accuracy category_results = { cat: {"correct": 0, "total": 0} for cat in SUBJECT_CATEGORIES } for subject, stats in subject_results.items(): for category, subjects_in_cat in SUBJECT_CATEGORIES.items(): if subject in subjects_in_cat: category_results[category]["correct"] += stats["correct"] category_results[category]["total"] += stats["total"] break # Extraction method statistics extraction_methods = {} for sample in samples: method = sample.get("extraction_method", "unknown") if method not in extraction_methods: extraction_methods[method] = {"count": 0, "correct": 0} extraction_methods[method]["count"] += 1 if sample["is_correct"]: extraction_methods[method]["correct"] += 1 # Average response length response_lengths = [s.get("response_length", 0) for s in samples] avg_response_length = ( sum(response_lengths) / len(response_lengths) if response_lengths else 0 ) # Format compliance (for thinking mode) format_compliant = sum(1 for s in samples if s.get("format_valid", True)) format_compliance_rate = format_compliant / len(samples) if samples else 0.0 # Thinking utilization (how many responses had thinking content) thinking_utilization = 0 if self.config.thinking_mode: thinking_utilization = sum(1 for s in samples if s.get("thinking_content")) # Build metrics dictionary eval_metrics = { "eval/overall_accuracy": overall_accuracy, "eval/total_questions": total_count, "eval/total_correct": total_correct, "eval/num_subjects": len(subject_results), "eval/num_few_shot": self.config.num_few_shot, "eval/evaluation_time_seconds": end_time - start_time, "eval/avg_response_length": avg_response_length, "eval/format_compliance_rate": format_compliance_rate, "eval/thinking_mode_enabled": 1.0 if self.config.thinking_mode else 0.0, } # Add thinking utilization if in thinking mode if self.config.thinking_mode: thinking_utilization_rate = ( thinking_utilization / len(samples) if samples else 0.0 ) eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate # Add category metrics for category, stats in category_results.items(): if stats["total"] > 0: cat_accuracy = stats["correct"] / stats["total"] eval_metrics[f"eval/category_{category.lower()}_accuracy"] = ( cat_accuracy ) eval_metrics[f"eval/category_{category.lower()}_total"] = stats["total"] # Add extraction method metrics for method, stats in extraction_methods.items(): if stats["count"] > 0: method_accuracy = stats["correct"] / stats["count"] eval_metrics[f"eval/extraction_{method}_count"] = stats["count"] eval_metrics[f"eval/extraction_{method}_accuracy"] = method_accuracy # Add per-subject metrics for subject, stats in sorted(subject_results.items()): if stats["total"] > 0: subj_accuracy = stats["correct"] / stats["total"] # Sanitize subject name for metric key subj_key = subject.replace(" ", "_").replace("-", "_") eval_metrics[f"eval/subject_{subj_key}_accuracy"] = subj_accuracy # Store metrics for wandb logging self.eval_metrics = [(k, v) for k, v in eval_metrics.items()] # Print summary print(f"\n{'='*60}") print(f"MMLU Evaluation Results") print(f"{'='*60}") print( f"Overall Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_count})" ) print(f"Evaluation Time: {end_time - start_time:.1f} seconds") print(f"Avg Response Length: {avg_response_length:.0f} chars") if self.config.thinking_mode: print(f"Format Compliance: {format_compliance_rate:.4f}") print(f"Thinking Utilization: {thinking_utilization}/{total_count}") print(f"\nCategory Breakdown:") for category, stats in category_results.items(): if stats["total"] > 0: cat_acc = stats["correct"] / stats["total"] print( f" {category}: {cat_acc:.4f} ({stats['correct']}/{stats['total']})" ) print(f"\nExtraction Method Statistics:") for method, stats in sorted( extraction_methods.items(), key=lambda x: -x[1]["count"] ): if stats["count"] > 0: method_acc = stats["correct"] / stats["count"] print(f" {method}: {stats['count']} uses, {method_acc:.4f} accuracy") print(f"{'='*60}\n") # Log evaluation results try: await self.evaluate_log( metrics=eval_metrics, samples=samples, start_time=start_time, end_time=end_time, generation_parameters={ "temperature": self.config.eval_temperature, "max_tokens": self.config.eval_max_tokens, "num_few_shot": self.config.num_few_shot, "thinking_mode": self.config.thinking_mode, "mode": "generative", }, ) except Exception as e: print(f"Error logging evaluation results: {e}") async def wandb_log(self, wandb_metrics: Optional[Dict] = None): """Log metrics to wandb.""" if wandb_metrics is None: wandb_metrics = {} # Add evaluation metrics for metric_name, metric_value in self.eval_metrics: wandb_metrics[metric_name] = metric_value self.eval_metrics = [] # Add config metrics wandb_metrics["config/thinking_mode"] = ( 1.0 if self.config.thinking_mode else 0.0 ) wandb_metrics["config/num_few_shot"] = self.config.num_few_shot wandb_metrics["config/eval_max_tokens"] = self.config.eval_max_tokens await super().wandb_log(wandb_metrics) if __name__ == "__main__": MMLUEvalEnv.cli()