atropos/environments/eval_environments/mmlu_pro_eval.py
2025-12-25 09:49:33 +00:00

857 lines
31 KiB
Python

"""
MMLU-Pro Evaluation Environment for Atropos (Generative/Reasoning Mode)
This environment evaluates models on the MMLU-Pro benchmark using a generative
approach where models can reason before answering.
Dataset: TIGER-Lab/MMLU-Pro
Paper: https://arxiv.org/abs/2406.01574
MMLU-Pro is a more robust and challenging massive multi-task understanding
dataset tailored to more rigorously benchmark large language models' capabilities.
This dataset contains 12K complex questions across various disciplines with
10 answer choices instead of 4.
The evaluation follows the lighteval generative approach:
- Models are prompted to "think step by step before answering"
- Models output their reasoning followed by "Answer: X"
- Answer is extracted using regex patterns from the response
- Simple string matching validates the extracted answer
Supports optional thinking mode with <think></think> tags for extended reasoning.
"""
import asyncio
import os
import re
import time
from string import ascii_uppercase
from typing import Dict, List, Optional, Tuple
from datasets import load_dataset
from eval_helpers import (
create_system_content,
extract_letter_from_answer_tag,
get_default_thinking_prompt,
)
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
)
# MMLU-Pro prompt template with <answer> tag instruction
# Note: MMLU-Pro has up to 10 choices (A-J), not just 4
MMLU_PRO_PROMPT_TEMPLATE = """Answer the following multiple choice question. Think step by step before answering.
Provide your final answer within <answer></answer> tags, containing only the letter ({valid_letters}).
Example format:
<answer>A</answer>
{question}
{choices}"""
# MMLU-Pro categories for aggregate metrics
MMLU_PRO_CATEGORIES = [
"biology",
"business",
"chemistry",
"computer science",
"economics",
"engineering",
"health",
"history",
"law",
"math",
"philosophy",
"physics",
"psychology",
"other",
]
class MMLUProEvalConfig(BaseEnvConfig):
"""Configuration for MMLU-Pro evaluation environment (generative mode)."""
# Thinking mode configuration
thinking_mode: bool = Field(
default=True,
description="Whether to enable thinking mode with <think></think> tags.",
)
custom_thinking_prompt: Optional[str] = Field(
default=None,
description="Custom thinking prompt. If None, uses the default thinking prompt.",
)
# Dataset configuration
dataset_name: str = Field(
default="TIGER-Lab/MMLU-Pro",
description="HuggingFace dataset name for MMLU-Pro.",
)
eval_split: str = Field(
default="test",
description="Dataset split to use for evaluation.",
)
few_shot_split: str = Field(
default="validation",
description="Dataset split to use for few-shot examples.",
)
# Few-shot configuration
num_few_shot: int = Field(
default=0,
ge=0,
le=5,
description="Number of few-shot examples to include (0-5 recommended).",
)
# Category filtering
categories: Optional[List[str]] = Field(
default=None,
description="List of categories to evaluate. If None, evaluates all categories.",
)
# Model generation configuration
eval_temperature: float = Field(
default=0.6,
description="Temperature for evaluation.",
)
eval_max_tokens: int = Field(
default=0,
description="Maximum tokens for evaluation responses. Set high to allow reasoning.",
)
# Prompt configuration
custom_system_prompt: Optional[str] = Field(
default=None,
description="Custom system prompt to append after thinking prompt (if thinking_mode) or use directly.",
)
# Retry configuration
max_retries: int = Field(
default=3,
ge=1,
description="Maximum retries for failed API calls.",
)
retry_delay: float = Field(
default=1.0,
ge=0.0,
description="Delay between retry attempts in seconds.",
)
min_response_length: int = Field(
default=1,
ge=1,
description="Minimum response length to consider valid.",
)
# Debug configuration
full_debug: bool = Field(
default=False,
description="Enable verbose debug logging.",
)
class MMLUProEvalEnv(BaseEnv):
"""
MMLU-Pro Evaluation Environment for Atropos (Generative/Reasoning Mode).
Evaluates models on the MMLU-Pro benchmark using a generative approach where
models reason before answering complex multi-choice questions.
Key features:
- Loads MMLU-Pro dataset from HuggingFace (TIGER-Lab/MMLU-Pro)
- Uses lighteval's exact prompt format
- Handles 10-choice questions (A-J)
- Optional thinking mode with <think></think> tags
- Tracks per-category accuracy
- Supports few-shot examples
"""
name = "mmlu_pro_eval"
env_config_cls = MMLUProEvalConfig
def __init__(
self,
config: MMLUProEvalConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.config: MMLUProEvalConfig = config
# Initialize metrics tracking
self.eval_metrics = []
# Pre-compile regex patterns for thinking mode
self._think_pattern = re.compile(r"<think>")
self._think_close_pattern = re.compile(r"</think>")
self._think_content_pattern = re.compile(r"</think>\s*(.*)", re.DOTALL)
self._thinking_extract_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
# Pre-compile regex for <answer></answer> tag extraction (primary method)
self._answer_tag_pattern = re.compile(
r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE
)
# Build fallback answer extraction patterns (supports A-J for 10 choices)
self._build_extraction_patterns()
def _get_thinking_prompt(self) -> str:
"""Get thinking system prompt."""
return get_default_thinking_prompt(self.config.custom_thinking_prompt)
def _create_system_content(self) -> Optional[str]:
"""Create system message content based on thinking mode."""
return create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt,
)
def _build_extraction_patterns(self):
"""Build regex patterns for extracting answer letters (A-J for 10 choices)."""
# MMLU-Pro has up to 10 choices (A-J)
letters = "ABCDEFGHIJ"
letter_pattern = rf"([{letters}]|\([{letters}]\))"
self._pattern_final_answer_hope = re.compile(
rf"(?i:final\s+answer\s+is)\s*:?\s*{letter_pattern}\.?\s*I\s*hope",
re.IGNORECASE,
)
self._pattern_final_answer_is = re.compile(
rf"(?i:final\s+answer).{{0,100}}?\s+is\s*:?\s*{letter_pattern}",
re.IGNORECASE | re.DOTALL,
)
self._pattern_the_answer_is = re.compile(
rf"(?i:the\s+answer\s+is)\s*:?\s*{letter_pattern}", re.IGNORECASE
)
self._pattern_answer_colon = re.compile(
rf"(?i:answer)\s*:\s*.{{0,50}}?{letter_pattern}", re.IGNORECASE | re.DOTALL
)
self._pattern_answer_space = re.compile(
rf"(?i:answer)\s+{letter_pattern}", re.IGNORECASE
)
self._pattern_start = re.compile(
rf"^\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE
)
self._pattern_line_start = re.compile(
rf"\n\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE
)
self._pattern_standalone = re.compile(rf"\b{letter_pattern}\b", re.IGNORECASE)
self._extraction_patterns = [
(0, self._pattern_final_answer_hope, "final_answer_hope"),
(50, self._pattern_final_answer_is, "final_answer_is"),
(75, self._pattern_the_answer_is, "the_answer_is"),
(100, self._pattern_answer_colon, "answer_colon"),
(150, self._pattern_answer_space, "answer_space"),
(200, self._pattern_start, "start"),
(210, self._pattern_line_start, "line_start"),
(250, self._pattern_standalone, "standalone"),
]
@classmethod
def config_init(cls) -> Tuple[MMLUProEvalConfig, List[APIServerConfig]]:
"""Initialize default configuration for the environment."""
env_config = MMLUProEvalConfig(
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
group_size=1,
use_wandb=True,
max_num_workers_per_node=8,
rollout_server_url="http://localhost:8000",
total_steps=1,
batch_size=1,
steps_per_eval=1,
inference_weight=1.0,
wandb_name="mmlu_pro_eval",
eval_handling=EvalHandlingEnum.STOP_TRAIN,
# MMLU-Pro specific defaults
dataset_name="TIGER-Lab/MMLU-Pro",
num_few_shot=0,
eval_temperature=0.6,
eval_max_tokens=0, # Use model default
thinking_mode=True,
)
server_configs = [
APIServerConfig(
model_name="Hermes-3-Llama-3.1-8B",
base_url="http://localhost:9000/v1",
api_key=os.getenv("OPENAI_API_KEY", "none"),
num_max_requests_at_once=32,
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def setup(self) -> None:
"""Load the MMLU-Pro dataset and prepare for evaluation."""
print("\nMMLU-Pro Evaluation Setup (Generative Mode):")
print(f" Dataset: {self.config.dataset_name}")
print(f" Few-shot examples: {self.config.num_few_shot}")
print(f" Max tokens for reasoning: {self.config.eval_max_tokens}")
print(f" Evaluation split: {self.config.eval_split}")
print(f" Thinking mode: {self.config.thinking_mode}")
if self.config.thinking_mode:
print(f" Thinking prompt: {self._get_thinking_prompt()[:100]}...")
# Load MMLU-Pro dataset
try:
dataset = load_dataset(
self.config.dataset_name,
split=self.config.eval_split,
)
self.eval_data = list(dataset)
print(f" Loaded {len(self.eval_data)} evaluation items")
# Load few-shot data if needed
if self.config.num_few_shot > 0:
few_shot_dataset = load_dataset(
self.config.dataset_name,
split=self.config.few_shot_split,
)
self.few_shot_data = list(few_shot_dataset)
print(f" Loaded {len(self.few_shot_data)} few-shot examples")
else:
self.few_shot_data = []
except Exception as e:
print(f"Error loading MMLU-Pro dataset: {e}")
raise
# Filter by categories if specified
if self.config.categories:
self.eval_data = [
item
for item in self.eval_data
if item.get("category", "").lower()
in [c.lower() for c in self.config.categories]
]
print(
f" Filtered to {len(self.eval_data)} items in categories: {self.config.categories}"
)
# Analyze category distribution
category_counts = {}
for item in self.eval_data:
cat = item.get("category", "unknown")
category_counts[cat] = category_counts.get(cat, 0) + 1
print("\n Category distribution:")
for cat, count in sorted(category_counts.items()):
print(f" {cat}: {count} questions")
self.all_eval_items = self.eval_data
self.iter = 0
def _format_choices(self, options: List[str]) -> str:
"""Format choices as A: choice1, B: choice2, etc. (MMLU-Pro format)."""
lines = []
for idx, option in enumerate(options):
letter = ascii_uppercase[idx]
lines.append(f"{letter}: {option}")
return "\n".join(lines)
def _format_mmlu_pro_prompt(
self,
question: str,
options: List[str],
few_shot_examples: Optional[List[Dict]] = None,
) -> str:
"""
Format a question using the lighteval MMLU-Pro template.
Uses the exact prompt format from lighteval's mmlu_pro_prompt_function.
"""
num_choices = len(options)
valid_letters = "".join(ascii_uppercase[:num_choices])
# Format choices
formatted_choices = self._format_choices(options)
# Use lighteval's exact template
prompt = MMLU_PRO_PROMPT_TEMPLATE.format(
question=question,
choices=formatted_choices,
valid_letters=valid_letters,
)
# Add few-shot examples if provided
if few_shot_examples:
few_shot_text = self._format_few_shot_examples(few_shot_examples)
prompt = few_shot_text + "\n\n---\n\n" + prompt
return prompt
def _format_few_shot_examples(self, examples: List[Dict]) -> str:
"""Format few-shot examples with answers for context."""
formatted = []
for example in examples:
question = example.get("question", "")
options = example.get("options", [])
answer_index = example.get("answer_index", 0)
answer_letter = ascii_uppercase[answer_index]
formatted_choices = self._format_choices(options)
example_text = (
f"Question: {question}\n{formatted_choices}\n\nAnswer: {answer_letter}"
)
formatted.append(example_text)
return "\n\n---\n\n".join(formatted)
def _validate_thinking_format(self, response: str) -> Tuple[bool, str]:
"""Validate thinking format and extract content after </think> tags."""
if not self.config.thinking_mode:
return True, response
think_open_count = len(self._think_pattern.findall(response))
think_close_count = len(self._think_close_pattern.findall(response))
if think_open_count != 1 or think_close_count != 1:
return False, response
match = self._think_content_pattern.search(response)
if match:
return True, match.group(1).strip()
else:
return False, response
def _extract_thinking_content(self, response: str) -> Optional[str]:
"""Extract the content inside <think></think> tags."""
match = self._thinking_extract_pattern.search(response)
if match:
return match.group(1).strip()
return None
def _extract_answer(
self, response: str, num_choices: int = 10, choices: Optional[List[str]] = None
) -> Tuple[Optional[str], str]:
"""
Extract the answer letter from the model's response.
Primary method: Look for <answer></answer> tags, or match against choice texts.
Fallback: Use priority-ordered regex patterns.
"""
if not response:
return None, "empty_response"
valid_letters = set(ascii_uppercase[:num_choices])
# PRIMARY: Try <answer></answer> tags first
# Also matches against choice texts if provided
letter, method = extract_letter_from_answer_tag(
response, valid_letters, debug=self.config.full_debug, choices=choices
)
if letter:
return letter, method
# FALLBACK: Try each pattern in priority order
for priority, pattern, method_name in self._extraction_patterns:
matches = pattern.findall(response)
if matches:
match = (
matches[-1]
if method_name
in [
"final_answer_is",
"the_answer_is",
"answer_colon",
"answer_space",
]
else matches[0]
)
if isinstance(match, tuple):
match = match[0]
letter = match.strip("()").upper()
if letter in valid_letters:
if self.config.full_debug:
print(
f" Extracted '{letter}' using fallback method '{method_name}'"
)
return letter, f"fallback_{method_name}"
for letter in reversed(list(valid_letters)):
if letter in response.upper():
if self.config.full_debug:
print(
f" Extracted '{letter}' using fallback 'last_valid_letter'"
)
return letter, "fallback_last_valid_letter"
return None, "no_match"
async def get_next_item(self):
"""Get next item for training (not used in eval-only environment)."""
self.iter += 1
if self.all_eval_items:
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
return item
return None
async def collect_trajectories(self, item):
"""Collect trajectories (not used in eval-only environment)."""
return None, []
async def score(self, rollout_group_data):
"""Score rollouts (not used in eval-only environment)."""
return None
async def rollout_and_score_eval(self, eval_item: Dict) -> Dict:
"""Evaluate a single MMLU-Pro question using generative mode."""
try:
question = eval_item.get("question", "")
options = eval_item.get("options", [])
answer_index = eval_item.get("answer_index", 0)
category = eval_item.get("category", "unknown")
num_choices = len(options)
gold_letter = ascii_uppercase[answer_index]
if not question or num_choices < 2:
return {"is_correct": None, "sample": None}
# Get few-shot examples
few_shot_examples = None
if self.config.num_few_shot > 0 and self.few_shot_data:
# Get examples from the same category if possible
same_cat_examples = [
ex for ex in self.few_shot_data if ex.get("category") == category
]
if len(same_cat_examples) >= self.config.num_few_shot:
few_shot_examples = same_cat_examples[: self.config.num_few_shot]
else:
few_shot_examples = self.few_shot_data[: self.config.num_few_shot]
# Format the prompt (lighteval style)
formatted_prompt = self._format_mmlu_pro_prompt(
question=question,
options=options,
few_shot_examples=few_shot_examples,
)
# Build messages
messages = []
system_content = self._create_system_content()
if system_content:
messages.append({"role": "system", "content": system_content})
messages.append({"role": "user", "content": formatted_prompt})
# Get model completion with retry logic
model_response = None
finish_reason = None
# Build completion kwargs - only include max_tokens if > 0
# (0 means "use model default", so we don't pass the parameter)
completion_kwargs = {
"messages": messages,
"n": 1,
"temperature": self.config.eval_temperature,
"split": "eval",
}
if self.config.eval_max_tokens > 0:
completion_kwargs["max_tokens"] = self.config.eval_max_tokens
for attempt in range(self.config.max_retries):
try:
completion = await self.server.chat_completion(**completion_kwargs)
if completion.choices and completion.choices[0].message.content:
model_response = completion.choices[0].message.content
finish_reason = getattr(
completion.choices[0], "finish_reason", None
)
if (
len(model_response.strip())
>= self.config.min_response_length
):
break
elif attempt < self.config.max_retries - 1:
if self.config.full_debug:
print(" Response too short, retrying...")
await asyncio.sleep(self.config.retry_delay)
except Exception as e:
# Always log API errors to help diagnose issues
print(
f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}"
)
if hasattr(e, "response"):
try:
print(
f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}"
)
except Exception:
pass
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay)
else:
print(f" Failed after {self.config.max_retries} attempts")
return {"is_correct": None, "sample": None}
if not model_response:
return {"is_correct": None, "sample": None}
# Validate thinking format if enabled
format_valid, content_for_extraction = self._validate_thinking_format(
model_response
)
# Extract thinking content for logging
thinking_content = None
if self.config.thinking_mode:
thinking_content = self._extract_thinking_content(model_response)
# Extract the answer (pass choices for exact text matching)
extracted_answer, extraction_method = self._extract_answer(
content_for_extraction, num_choices, choices=options
)
# Check if correct
is_correct = extracted_answer == gold_letter if extracted_answer else False
# Build sample record
sample = {
"question": question,
"options": options,
"gold_answer": gold_letter,
"model_response": model_response,
"extracted_answer": extracted_answer,
"extraction_method": extraction_method,
"is_correct": is_correct,
"category": category,
"num_choices": num_choices,
"num_few_shot": self.config.num_few_shot,
"finish_reason": finish_reason,
"response_length": len(model_response),
"thinking_mode": self.config.thinking_mode,
"format_valid": format_valid,
}
if self.config.thinking_mode:
sample["thinking_content"] = thinking_content
sample["response_after_think"] = (
content_for_extraction if format_valid else None
)
if self.config.full_debug:
status = "" if is_correct else ""
print(
f" [{status}] {category}: gold={gold_letter}, extracted={extracted_answer}"
)
return {"is_correct": is_correct, "sample": sample}
except Exception as e:
if self.config.full_debug:
print(f"Error in rollout_and_score_eval: {e}")
import traceback
traceback.print_exc()
return {"is_correct": None, "sample": None}
async def evaluate(self, *args, **kwargs) -> None:
"""Run MMLU-Pro evaluation."""
start_time = time.time()
print(f"\n{'='*60}")
print("Starting MMLU-Pro Evaluation (Generative/Reasoning Mode)")
print(f"{'='*60}")
print(f" Total questions: {len(self.all_eval_items)}")
print(f" Few-shot examples: {self.config.num_few_shot}")
print(f" Max tokens (for reasoning): {self.config.eval_max_tokens}")
print(f" Thinking mode: {self.config.thinking_mode}")
print(f"{'='*60}\n")
try:
eval_tasks = [
self.rollout_and_score_eval(item) for item in self.all_eval_items
]
results = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating MMLU-Pro")
valid_results = [
r
for r in results
if r and r.get("sample") is not None and r.get("is_correct") is not None
]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return
except Exception as e:
print(f"Error during evaluation: {e}")
import traceback
traceback.print_exc()
return
end_time = time.time()
# Compute metrics
samples = [r["sample"] for r in valid_results]
# Overall accuracy
total_correct = sum(1 for r in valid_results if r["is_correct"])
total_count = len(valid_results)
overall_accuracy = total_correct / total_count if total_count > 0 else 0.0
# Per-category accuracy
category_results = {}
for sample in samples:
category = sample.get("category", "unknown")
if category not in category_results:
category_results[category] = {"correct": 0, "total": 0}
category_results[category]["total"] += 1
if sample["is_correct"]:
category_results[category]["correct"] += 1
# Extraction method statistics
extraction_methods = {}
for sample in samples:
method = sample.get("extraction_method", "unknown")
if method not in extraction_methods:
extraction_methods[method] = {"count": 0, "correct": 0}
extraction_methods[method]["count"] += 1
if sample["is_correct"]:
extraction_methods[method]["correct"] += 1
# Average response length
response_lengths = [s.get("response_length", 0) for s in samples]
avg_response_length = (
sum(response_lengths) / len(response_lengths) if response_lengths else 0
)
# Format compliance
format_compliant = sum(1 for s in samples if s.get("format_valid", True))
format_compliance_rate = format_compliant / len(samples) if samples else 0.0
# Thinking utilization
thinking_utilization = 0
if self.config.thinking_mode:
thinking_utilization = sum(1 for s in samples if s.get("thinking_content"))
# Build metrics dictionary
eval_metrics = {
"eval/overall_accuracy": overall_accuracy,
"eval/total_questions": total_count,
"eval/total_correct": total_correct,
"eval/num_categories": len(category_results),
"eval/num_few_shot": self.config.num_few_shot,
"eval/evaluation_time_seconds": end_time - start_time,
"eval/avg_response_length": avg_response_length,
"eval/format_compliance_rate": format_compliance_rate,
"eval/thinking_mode_enabled": 1.0 if self.config.thinking_mode else 0.0,
}
if self.config.thinking_mode:
thinking_utilization_rate = (
thinking_utilization / len(samples) if samples else 0.0
)
eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate
# Add category metrics
for category, stats in category_results.items():
if stats["total"] > 0:
cat_accuracy = stats["correct"] / stats["total"]
cat_key = category.replace(" ", "_").replace("-", "_").lower()
eval_metrics[f"eval/category_{cat_key}_accuracy"] = cat_accuracy
eval_metrics[f"eval/category_{cat_key}_total"] = stats["total"]
# Add extraction method metrics
for method, stats in extraction_methods.items():
if stats["count"] > 0:
method_accuracy = stats["correct"] / stats["count"]
eval_metrics[f"eval/extraction_{method}_count"] = stats["count"]
eval_metrics[f"eval/extraction_{method}_accuracy"] = method_accuracy
# Store metrics for wandb logging
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
# Print summary
print(f"\n{'='*60}")
print("MMLU-Pro Evaluation Results")
print(f"{'='*60}")
print(
f"Overall Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_count})"
)
print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
print(f"Avg Response Length: {avg_response_length:.0f} chars")
if self.config.thinking_mode:
print(f"Format Compliance: {format_compliance_rate:.4f}")
print(f"Thinking Utilization: {thinking_utilization}/{total_count}")
print("\nCategory Breakdown:")
for category, stats in sorted(category_results.items()):
if stats["total"] > 0:
cat_acc = stats["correct"] / stats["total"]
print(
f" {category}: {cat_acc:.4f} ({stats['correct']}/{stats['total']})"
)
print("\nExtraction Method Statistics:")
for method, stats in sorted(
extraction_methods.items(), key=lambda x: -x[1]["count"]
):
if stats["count"] > 0:
method_acc = stats["correct"] / stats["count"]
print(f" {method}: {stats['count']} uses, {method_acc:.4f} accuracy")
print(f"{'='*60}\n")
# Log evaluation results
try:
await self.evaluate_log(
metrics=eval_metrics,
samples=samples,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": self.config.eval_temperature,
"max_tokens": self.config.eval_max_tokens,
"num_few_shot": self.config.num_few_shot,
"thinking_mode": self.config.thinking_mode,
"mode": "generative",
},
)
except Exception as e:
print(f"Error logging evaluation results: {e}")
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
for metric_name, metric_value in self.eval_metrics:
wandb_metrics[metric_name] = metric_value
self.eval_metrics = []
wandb_metrics["config/thinking_mode"] = (
1.0 if self.config.thinking_mode else 0.0
)
wandb_metrics["config/num_few_shot"] = self.config.num_few_shot
wandb_metrics["config/eval_max_tokens"] = self.config.eval_max_tokens
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
MMLUProEvalEnv.cli()