mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ef9c0c3699
commit
afab28dfa9
37 changed files with 4868 additions and 4052 deletions
|
|
@ -37,6 +37,14 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from eval_helpers import (
|
||||
THINK_CONTENT_AFTER_PATTERN,
|
||||
create_system_content,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
save_eval_results,
|
||||
validate_thinking_format,
|
||||
)
|
||||
from pydantic import Field
|
||||
from tenacity import (
|
||||
retry,
|
||||
|
|
@ -52,15 +60,6 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
)
|
||||
from eval_helpers import (
|
||||
validate_thinking_format,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
create_system_content,
|
||||
save_eval_results,
|
||||
THINK_CONTENT_AFTER_PATTERN,
|
||||
)
|
||||
|
||||
|
||||
# Prompt construction helpers
|
||||
MULTI_CHOICE_PROMPT = "Answer with the option letter from the given choices directly."
|
||||
|
|
@ -73,7 +72,9 @@ FREE_FORM_PROMPT_MATH = "Answer the question. \nLet's think step by step."
|
|||
def parse_options(options: List[str]) -> str:
|
||||
"""Format options as A. option, B. option, etc."""
|
||||
option_letters = [chr(ord("A") + i) for i in range(len(options))]
|
||||
return "\n".join([f"{letter}. {opt}" for letter, opt in zip(option_letters, options)])
|
||||
return "\n".join(
|
||||
[f"{letter}. {opt}" for letter, opt in zip(option_letters, options)]
|
||||
)
|
||||
|
||||
|
||||
def construct_prompt_multichoice(entry: Dict) -> str:
|
||||
|
|
@ -81,9 +82,9 @@ def construct_prompt_multichoice(entry: Dict) -> str:
|
|||
prompt = entry.get("prompt", "")
|
||||
options = entry.get("options", [])
|
||||
context = entry.get("context", "")
|
||||
|
||||
|
||||
parsed_options = parse_options(options)
|
||||
|
||||
|
||||
if context and str(context).lower() not in ["none", "null", ""]:
|
||||
return f"{context}\n{prompt}\n{parsed_options}\n{MULTI_CHOICE_PROMPT}"
|
||||
else:
|
||||
|
|
@ -95,7 +96,7 @@ def construct_prompt_freeform(entry: Dict) -> str:
|
|||
prompt = entry.get("prompt", "")
|
||||
context = entry.get("context", "")
|
||||
benchmark_name = entry.get("benchmark_name", "")
|
||||
|
||||
|
||||
# Select prompt suffix based on benchmark
|
||||
if benchmark_name == "BBH":
|
||||
suffix = FREE_FORM_PROMPT_BBH
|
||||
|
|
@ -105,9 +106,9 @@ def construct_prompt_freeform(entry: Dict) -> str:
|
|||
suffix = FREE_FORM_PROMPT_MATH
|
||||
else:
|
||||
suffix = FREE_FORM_PROMPT
|
||||
|
||||
|
||||
prompt_with_suffix = f"{prompt}\n{suffix}"
|
||||
|
||||
|
||||
if context and str(context).lower() not in ["none", "null", ""]:
|
||||
return f"Question: {context}\n{prompt_with_suffix}"
|
||||
else:
|
||||
|
|
@ -121,7 +122,7 @@ def judge_freeform_prompt(question: str, answer: str, gold: str) -> List[Dict]:
|
|||
{"role": "system", "content": "In this task, I want you to act as a judge."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""You will be provided with a question, its golden answer(s), and the model's answer, while the context of the question is not given here. Your task is to judge how correct the model's answer is based on the golden answer(s), without seeing the context of the question, and then give a correctness score. The correctness score should be one of the below numbers: 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right). Your should first briefly give your reasoning process regarding how the model's answer conforms to or contradicts the golden answer(s), and then give the correctness score. The correctness score must strictly follow this format: "[[score]]", e.g., "The correctness score: [[0.5]]".
|
||||
"content": f"""You will be provided with a question, its golden answer(s), and the model's answer, while the context of the question is not given here. Your task is to judge how correct the model's answer is based on the golden answer(s), without seeing the context of the question, and then give a correctness score. The correctness score should be one of the below numbers: 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right). Your should first briefly give your reasoning process regarding how the model's answer conforms to or contradicts the golden answer(s), and then give the correctness score. The correctness score must strictly follow this format: "[[score]]", e.g., "The correctness score: [[0.5]]".
|
||||
|
||||
Note that each one of the golden answers is considered correct. Thus if the model's answer matches any one of the golden answers, it should be considered correct. Judge the below case, give the brief reasoning process and the correctness score.
|
||||
|
||||
|
|
@ -134,11 +135,16 @@ Your Judgment:
|
|||
]
|
||||
|
||||
|
||||
def judge_multichoice_prompt(question: str, options: List[str], answer: str, gold: str) -> List[Dict]:
|
||||
def judge_multichoice_prompt(
|
||||
question: str, options: List[str], answer: str, gold: str
|
||||
) -> List[Dict]:
|
||||
"""Create judge prompt for multiple choice questions (scores 0 or 1)."""
|
||||
parsed_options = parse_options(options)
|
||||
return [
|
||||
{"role": "system", "content": "In this task, I want you to act as an option extractor."},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "In this task, I want you to act as an option extractor.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""You will be provided with a multiple-choice question, its options, the gold answer, and the model's answer, while the context of the question is not given here. Your task is to extract or judge which option is chosen by the model based on its response, and to determine whether or not the model answered correctly. The model scores can either be 0 (incorrect) or 1 (correct). The correctness score must strictly follow this format: "[[score]]", e.g., "The correctness score: [[1]]".
|
||||
|
|
@ -156,126 +162,105 @@ Your Judgment:
|
|||
|
||||
class MixEvalConfig(BaseEnvConfig):
|
||||
"""Configuration for MixEval evaluation environment with LLM judge."""
|
||||
|
||||
|
||||
# Dataset configuration
|
||||
dataset_name: str = Field(
|
||||
default="MixEval/MixEval",
|
||||
description="HuggingFace dataset name"
|
||||
default="MixEval/MixEval", description="HuggingFace dataset name"
|
||||
)
|
||||
difficulty: str = Field(
|
||||
default="easy",
|
||||
description="Difficulty level: 'easy' (MixEval) or 'hard' (MixEval_Hard)"
|
||||
description="Difficulty level: 'easy' (MixEval) or 'hard' (MixEval_Hard)",
|
||||
)
|
||||
question_types: List[str] = Field(
|
||||
default=["freeform", "multichoice"],
|
||||
description="Question types to evaluate: 'freeform', 'multichoice', or both"
|
||||
description="Question types to evaluate: 'freeform', 'multichoice', or both",
|
||||
)
|
||||
shuffle_seed: int = Field(
|
||||
default=42,
|
||||
description="Random seed for shuffling"
|
||||
)
|
||||
|
||||
shuffle_seed: int = Field(default=42, description="Random seed for shuffling")
|
||||
|
||||
# Model generation parameters
|
||||
eval_temperature: float = Field(
|
||||
default=0.6,
|
||||
description="Temperature for model evaluation completions"
|
||||
default=0.6, description="Temperature for model evaluation completions"
|
||||
)
|
||||
eval_max_tokens: int = Field(
|
||||
default=0,
|
||||
description="Max tokens for model evaluation (0 = use model default)"
|
||||
default=0, description="Max tokens for model evaluation (0 = use model default)"
|
||||
)
|
||||
|
||||
|
||||
# System prompt configuration
|
||||
custom_system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional custom system prompt"
|
||||
default=None, description="Optional custom system prompt"
|
||||
)
|
||||
|
||||
|
||||
# Thinking mode configuration
|
||||
thinking_mode: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use thinking mode with <think></think> tags"
|
||||
description="Whether to use thinking mode with <think></think> tags",
|
||||
)
|
||||
custom_thinking_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional custom thinking prompt"
|
||||
default=None, description="Optional custom thinking prompt"
|
||||
)
|
||||
|
||||
|
||||
# Judge model configuration (following refusalbench pattern)
|
||||
judge_model_name: str = Field(
|
||||
default="gpt-4o",
|
||||
description="Model name for the judge"
|
||||
default="gpt-4o", description="Model name for the judge"
|
||||
)
|
||||
judge_base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Base URL for the judge model API"
|
||||
description="Base URL for the judge model API",
|
||||
)
|
||||
judge_api_key_env: str = Field(
|
||||
default="OPENAI_API_KEY",
|
||||
description="Environment variable name containing the API key for the judge model"
|
||||
description="Environment variable name containing the API key for the judge model",
|
||||
)
|
||||
judge_temperature: float = Field(
|
||||
default=0.2,
|
||||
description="Temperature for judge completions"
|
||||
default=0.2, description="Temperature for judge completions"
|
||||
)
|
||||
judge_max_tokens: int = Field(
|
||||
default=0,
|
||||
description="Maximum tokens for judge completions (0 = use model default)"
|
||||
description="Maximum tokens for judge completions (0 = use model default)",
|
||||
)
|
||||
|
||||
|
||||
# Judge retry configuration
|
||||
judge_max_retries: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
description="Maximum number of retries for failed judge API calls"
|
||||
description="Maximum number of retries for failed judge API calls",
|
||||
)
|
||||
judge_retry_multiplier: float = Field(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
description="Exponential backoff multiplier for judge retries"
|
||||
description="Exponential backoff multiplier for judge retries",
|
||||
)
|
||||
judge_retry_max_wait: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
description="Maximum wait time in seconds for judge retries"
|
||||
default=10, ge=1, description="Maximum wait time in seconds for judge retries"
|
||||
)
|
||||
|
||||
|
||||
# Rate limiting configuration
|
||||
judge_max_concurrent_calls: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
description="Maximum number of concurrent judge API calls"
|
||||
default=10, ge=1, description="Maximum number of concurrent judge API calls"
|
||||
)
|
||||
judge_rate_limit_delay: float = Field(
|
||||
default=0.2,
|
||||
ge=0.0,
|
||||
description="Minimum delay in seconds between judge API calls"
|
||||
description="Minimum delay in seconds between judge API calls",
|
||||
)
|
||||
|
||||
|
||||
# Fallback configuration
|
||||
use_fallback_scoring: bool = Field(
|
||||
default=True,
|
||||
description="Use fallback scoring (score=0) when judge API fails"
|
||||
default=True, description="Use fallback scoring (score=0) when judge API fails"
|
||||
)
|
||||
|
||||
|
||||
# Retry and debug configuration
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Maximum retries for failed model API calls"
|
||||
default=3, description="Maximum retries for failed model API calls"
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0,
|
||||
description="Delay between retries in seconds"
|
||||
default=1.0, description="Delay between retries in seconds"
|
||||
)
|
||||
min_response_length: int = Field(
|
||||
default=1,
|
||||
description="Minimum response length to consider valid"
|
||||
default=1, description="Minimum response length to consider valid"
|
||||
)
|
||||
full_debug: bool = Field(
|
||||
default=False,
|
||||
description="Enable full debug output"
|
||||
)
|
||||
|
||||
full_debug: bool = Field(default=False, description="Enable full debug output")
|
||||
|
||||
# Override defaults
|
||||
group_size: int = 1
|
||||
max_num_workers: int = 512
|
||||
|
|
@ -291,7 +276,7 @@ class MixEvalConfig(BaseEnvConfig):
|
|||
class MixEvalEnv(BaseEnv):
|
||||
"""
|
||||
MixEval Evaluation Environment with LLM Judge.
|
||||
|
||||
|
||||
Evaluates general knowledge and reasoning using MixEval benchmark.
|
||||
Uses an LLM judge to score responses.
|
||||
"""
|
||||
|
|
@ -309,18 +294,18 @@ class MixEvalEnv(BaseEnv):
|
|||
self.config: MixEvalConfig = config
|
||||
self.eval_items: List[Dict] = []
|
||||
self._dataset_loaded = False
|
||||
|
||||
|
||||
# Setup judge client
|
||||
self.judge_client = None
|
||||
self._setup_judge_client()
|
||||
|
||||
|
||||
# Rate limiting for judge calls
|
||||
self.judge_semaphore = asyncio.Semaphore(self.config.judge_max_concurrent_calls)
|
||||
|
||||
|
||||
# Pre-compile regex for score extraction
|
||||
self._score_pattern_float = re.compile(r"\[\[(\d\.?\d?)\]\]")
|
||||
self._score_pattern_int = re.compile(r"\[\[([01])\]\]")
|
||||
|
||||
|
||||
# Thread-safe metrics tracking
|
||||
self._metrics_lock = asyncio.Lock()
|
||||
self.judge_error_count = 0
|
||||
|
|
@ -334,18 +319,18 @@ class MixEvalEnv(BaseEnv):
|
|||
"""Setup the judge API client (following refusalbench pattern)."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
|
||||
api_key = os.getenv(self.config.judge_api_key_env)
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
f"API key not found in environment variable: {self.config.judge_api_key_env}"
|
||||
)
|
||||
|
||||
|
||||
self.judge_client = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=self.config.judge_base_url,
|
||||
)
|
||||
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OpenAI package is required for judge functionality. Install with: pip install openai"
|
||||
|
|
@ -354,10 +339,10 @@ class MixEvalEnv(BaseEnv):
|
|||
async def setup(self) -> None:
|
||||
"""Initialize the environment and load the dataset."""
|
||||
await super().setup()
|
||||
|
||||
|
||||
if not self._dataset_loaded:
|
||||
await self._load_dataset()
|
||||
|
||||
|
||||
print(f"\nMixEval Evaluation Setup (with LLM Judge):")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Difficulty: {self.config.difficulty}")
|
||||
|
|
@ -366,7 +351,9 @@ class MixEvalEnv(BaseEnv):
|
|||
print(f" Judge model: {self.config.judge_model_name}")
|
||||
print(f" Judge endpoint: {self.config.judge_base_url}")
|
||||
if self.config.thinking_mode:
|
||||
thinking_prompt = get_default_thinking_prompt(self.config.custom_thinking_prompt)
|
||||
thinking_prompt = get_default_thinking_prompt(
|
||||
self.config.custom_thinking_prompt
|
||||
)
|
||||
print(f" Thinking prompt: {thinking_prompt[:80]}...")
|
||||
print(f" Loaded {len(self.eval_items)} evaluation items")
|
||||
|
||||
|
|
@ -374,44 +361,48 @@ class MixEvalEnv(BaseEnv):
|
|||
"""Load and process the MixEval dataset."""
|
||||
# Determine subset based on difficulty
|
||||
subset = "MixEval" if self.config.difficulty == "easy" else "MixEval_Hard"
|
||||
|
||||
|
||||
print(f"Loading MixEval dataset: {self.config.dataset_name} ({subset})...")
|
||||
|
||||
|
||||
self.eval_items = []
|
||||
|
||||
|
||||
try:
|
||||
dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
subset,
|
||||
trust_remote_code=True
|
||||
self.config.dataset_name, subset, trust_remote_code=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading dataset: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Load freeform questions
|
||||
if "freeform" in self.config.question_types:
|
||||
if "free_form" in dataset:
|
||||
for idx, item in enumerate(dataset["free_form"]):
|
||||
prompt = construct_prompt_freeform(item)
|
||||
targets = item.get("target", [])
|
||||
gold = "; ".join([f"<answer {i+1}> {t}" for i, t in enumerate(targets)])
|
||||
|
||||
self.eval_items.append({
|
||||
"id": f"freeform_{idx}",
|
||||
"type": "freeform",
|
||||
"prompt": prompt,
|
||||
"raw_prompt": item.get("prompt", ""),
|
||||
"target": targets,
|
||||
"gold_formatted": gold,
|
||||
"options": None,
|
||||
"benchmark_name": item.get("benchmark_name", ""),
|
||||
"problem_type": item.get("problem_type", ""),
|
||||
})
|
||||
print(f" Loaded {len([i for i in self.eval_items if i['type'] == 'freeform'])} freeform items")
|
||||
gold = "; ".join(
|
||||
[f"<answer {i+1}> {t}" for i, t in enumerate(targets)]
|
||||
)
|
||||
|
||||
self.eval_items.append(
|
||||
{
|
||||
"id": f"freeform_{idx}",
|
||||
"type": "freeform",
|
||||
"prompt": prompt,
|
||||
"raw_prompt": item.get("prompt", ""),
|
||||
"target": targets,
|
||||
"gold_formatted": gold,
|
||||
"options": None,
|
||||
"benchmark_name": item.get("benchmark_name", ""),
|
||||
"problem_type": item.get("problem_type", ""),
|
||||
}
|
||||
)
|
||||
print(
|
||||
f" Loaded {len([i for i in self.eval_items if i['type'] == 'freeform'])} freeform items"
|
||||
)
|
||||
else:
|
||||
print(" Warning: free_form split not found")
|
||||
|
||||
|
||||
# Load multiple choice questions
|
||||
if "multichoice" in self.config.question_types:
|
||||
if "multiple_choice" in dataset:
|
||||
|
|
@ -419,44 +410,59 @@ class MixEvalEnv(BaseEnv):
|
|||
prompt = construct_prompt_multichoice(item)
|
||||
options = item.get("options", [])
|
||||
targets = item.get("target", [])
|
||||
|
||||
|
||||
# Convert target indices to letters
|
||||
gold_letters = [ascii_uppercase[int(t)] for t in targets if int(t) < len(options)]
|
||||
gold = f"{gold_letters[0]}. {options[int(targets[0])]}" if targets and int(targets[0]) < len(options) else ""
|
||||
|
||||
self.eval_items.append({
|
||||
"id": f"multichoice_{idx}",
|
||||
"type": "multichoice",
|
||||
"prompt": prompt,
|
||||
"raw_prompt": item.get("prompt", ""),
|
||||
"target": targets,
|
||||
"gold_formatted": gold,
|
||||
"options": options,
|
||||
"benchmark_name": item.get("benchmark_name", ""),
|
||||
"problem_type": item.get("problem_type", ""),
|
||||
})
|
||||
print(f" Loaded {len([i for i in self.eval_items if i['type'] == 'multichoice'])} multichoice items")
|
||||
gold_letters = [
|
||||
ascii_uppercase[int(t)]
|
||||
for t in targets
|
||||
if int(t) < len(options)
|
||||
]
|
||||
gold = (
|
||||
f"{gold_letters[0]}. {options[int(targets[0])]}"
|
||||
if targets and int(targets[0]) < len(options)
|
||||
else ""
|
||||
)
|
||||
|
||||
self.eval_items.append(
|
||||
{
|
||||
"id": f"multichoice_{idx}",
|
||||
"type": "multichoice",
|
||||
"prompt": prompt,
|
||||
"raw_prompt": item.get("prompt", ""),
|
||||
"target": targets,
|
||||
"gold_formatted": gold,
|
||||
"options": options,
|
||||
"benchmark_name": item.get("benchmark_name", ""),
|
||||
"problem_type": item.get("problem_type", ""),
|
||||
}
|
||||
)
|
||||
print(
|
||||
f" Loaded {len([i for i in self.eval_items if i['type'] == 'multichoice'])} multichoice items"
|
||||
)
|
||||
else:
|
||||
print(" Warning: multiple_choice split not found")
|
||||
|
||||
|
||||
# Shuffle with seed
|
||||
random.seed(self.config.shuffle_seed)
|
||||
random.shuffle(self.eval_items)
|
||||
|
||||
|
||||
self._dataset_loaded = True
|
||||
print(f"Total: Loaded {len(self.eval_items)} evaluation items")
|
||||
|
||||
def _create_system_content(self) -> str:
|
||||
"""Create system message content based on thinking mode."""
|
||||
return create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt
|
||||
) or "You are a helpful assistant."
|
||||
return (
|
||||
create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt,
|
||||
)
|
||||
or "You are a helpful assistant."
|
||||
)
|
||||
|
||||
async def _rate_limited_judge_call(self, messages: List[Dict]) -> Optional[str]:
|
||||
"""Make a rate-limited API call to the judge model with retry logic."""
|
||||
|
||||
|
||||
retry_decorator = retry(
|
||||
stop=stop_after_attempt(self.config.judge_max_retries),
|
||||
wait=wait_random_exponential(
|
||||
|
|
@ -465,12 +471,12 @@ class MixEvalEnv(BaseEnv):
|
|||
),
|
||||
retry=retry_if_exception_type((Exception,)),
|
||||
)
|
||||
|
||||
|
||||
async def _inner_judge_call():
|
||||
async with self.judge_semaphore:
|
||||
await asyncio.sleep(self.config.judge_rate_limit_delay)
|
||||
return await self._judge_api_call_raw(messages)
|
||||
|
||||
|
||||
retrying_call = retry_decorator(_inner_judge_call)
|
||||
return await retrying_call()
|
||||
|
||||
|
|
@ -484,7 +490,7 @@ class MixEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.judge_max_tokens > 0:
|
||||
kwargs["max_tokens"] = self.config.judge_max_tokens
|
||||
|
||||
|
||||
result = await self.judge_client.chat.completions.create(**kwargs)
|
||||
if result.choices and result.choices[0].message.content:
|
||||
return result.choices[0].message.content
|
||||
|
|
@ -518,27 +524,22 @@ class MixEvalEnv(BaseEnv):
|
|||
) -> Tuple[float, str]:
|
||||
"""
|
||||
Judge a model response using the LLM judge.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (score: float, judgment: str)
|
||||
"""
|
||||
if item["type"] == "freeform":
|
||||
messages = judge_freeform_prompt(
|
||||
item["raw_prompt"],
|
||||
answer,
|
||||
item["gold_formatted"]
|
||||
item["raw_prompt"], answer, item["gold_formatted"]
|
||||
)
|
||||
else: # multichoice
|
||||
messages = judge_multichoice_prompt(
|
||||
item["raw_prompt"],
|
||||
item["options"],
|
||||
answer,
|
||||
item["gold_formatted"]
|
||||
item["raw_prompt"], item["options"], answer, item["gold_formatted"]
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
judgment = await self._rate_limited_judge_call(messages)
|
||||
|
||||
|
||||
if not judgment:
|
||||
async with self._metrics_lock:
|
||||
self.judge_error_count += 1
|
||||
|
|
@ -547,14 +548,14 @@ class MixEvalEnv(BaseEnv):
|
|||
self.fallback_count += 1
|
||||
return 0.0, "JUDGE_ERROR_EMPTY_RESPONSE"
|
||||
return 0.0, "JUDGE_ERROR_EMPTY_RESPONSE"
|
||||
|
||||
|
||||
if item["type"] == "freeform":
|
||||
score = self._parse_freeform_score(judgment)
|
||||
else:
|
||||
score = float(self._parse_multichoice_score(judgment))
|
||||
|
||||
|
||||
return score, judgment
|
||||
|
||||
|
||||
except Exception as e:
|
||||
async with self._metrics_lock:
|
||||
self.judge_error_count += 1
|
||||
|
|
@ -571,12 +572,12 @@ class MixEvalEnv(BaseEnv):
|
|||
) -> Optional[Dict]:
|
||||
"""Run evaluation on a single item and return the result."""
|
||||
system_content = self._create_system_content()
|
||||
|
||||
|
||||
messages = []
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
messages.append({"role": "user", "content": item["prompt"]})
|
||||
|
||||
|
||||
# Build API call parameters
|
||||
kwargs = {
|
||||
"model": server.model_name,
|
||||
|
|
@ -585,42 +586,45 @@ class MixEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
response_text = ""
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
response = await self.server.chat_completion(**kwargs)
|
||||
response_text = response.choices[0].message.content or ""
|
||||
|
||||
|
||||
if len(response_text) >= self.config.min_response_length:
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f" API error (attempt {attempt + 1}): {e}")
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
continue
|
||||
|
||||
|
||||
if not response_text:
|
||||
return None
|
||||
|
||||
|
||||
# Validate thinking format and extract actual response
|
||||
is_valid_format, extracted_response = validate_thinking_format(
|
||||
response_text,
|
||||
self.config.thinking_mode
|
||||
response_text, self.config.thinking_mode
|
||||
)
|
||||
|
||||
|
||||
# Extract thinking content if present
|
||||
thinking_content = extract_thinking_content(response_text) if self.config.thinking_mode else None
|
||||
|
||||
thinking_content = (
|
||||
extract_thinking_content(response_text)
|
||||
if self.config.thinking_mode
|
||||
else None
|
||||
)
|
||||
|
||||
# Judge the response
|
||||
score, judgment = await self._judge_response(item, extracted_response)
|
||||
|
||||
|
||||
if self.config.full_debug:
|
||||
print(f"\n--- Item: {item['id']} ({item['type']}) ---")
|
||||
print(f"Score: {score}")
|
||||
|
||||
|
||||
return {
|
||||
"item_id": item["id"],
|
||||
"type": item["type"],
|
||||
|
|
@ -647,41 +651,46 @@ class MixEvalEnv(BaseEnv):
|
|||
print(f" Judge model: {self.config.judge_model_name}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
# Reset metrics
|
||||
self.judge_error_count = 0
|
||||
self.fallback_count = 0
|
||||
|
||||
|
||||
# Create evaluation tasks
|
||||
async def eval_task(item):
|
||||
return await self.rollout_and_score_eval(item, self.server_configs[0])
|
||||
|
||||
|
||||
tasks = [eval_task(item) for item in self.eval_items]
|
||||
|
||||
|
||||
# Run with progress bar
|
||||
results = await tqdm_asyncio.gather(
|
||||
*tasks,
|
||||
desc="Evaluating MixEval"
|
||||
)
|
||||
|
||||
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating MixEval")
|
||||
|
||||
# Filter out failed results
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return {"error": "No valid results", "avg_score": 0.0}
|
||||
|
||||
|
||||
# Calculate overall metrics
|
||||
total = len(valid_results)
|
||||
avg_score = sum(r["score"] for r in valid_results) / total
|
||||
|
||||
|
||||
# Calculate per-type metrics
|
||||
freeform_results = [r for r in valid_results if r["type"] == "freeform"]
|
||||
multichoice_results = [r for r in valid_results if r["type"] == "multichoice"]
|
||||
|
||||
freeform_score = sum(r["score"] for r in freeform_results) / len(freeform_results) if freeform_results else 0
|
||||
multichoice_score = sum(r["score"] for r in multichoice_results) / len(multichoice_results) if multichoice_results else 0
|
||||
|
||||
|
||||
freeform_score = (
|
||||
sum(r["score"] for r in freeform_results) / len(freeform_results)
|
||||
if freeform_results
|
||||
else 0
|
||||
)
|
||||
multichoice_score = (
|
||||
sum(r["score"] for r in multichoice_results) / len(multichoice_results)
|
||||
if multichoice_results
|
||||
else 0
|
||||
)
|
||||
|
||||
# Calculate per-benchmark metrics
|
||||
benchmark_metrics = {}
|
||||
for r in valid_results:
|
||||
|
|
@ -690,15 +699,17 @@ class MixEvalEnv(BaseEnv):
|
|||
benchmark_metrics[bench] = {"scores": [], "count": 0}
|
||||
benchmark_metrics[bench]["scores"].append(r["score"])
|
||||
benchmark_metrics[bench]["count"] += 1
|
||||
|
||||
|
||||
for bench in benchmark_metrics:
|
||||
scores = benchmark_metrics[bench]["scores"]
|
||||
benchmark_metrics[bench]["avg_score"] = sum(scores) / len(scores) if scores else 0
|
||||
|
||||
benchmark_metrics[bench]["avg_score"] = (
|
||||
sum(scores) / len(scores) if scores else 0
|
||||
)
|
||||
|
||||
# Format compliance
|
||||
format_valid = sum(1 for r in valid_results if r.get("format_valid", True))
|
||||
has_thinking = sum(1 for r in valid_results if r.get("has_thinking", False))
|
||||
|
||||
|
||||
metrics = {
|
||||
"avg_score": avg_score,
|
||||
"freeform_score": freeform_score,
|
||||
|
|
@ -712,27 +723,31 @@ class MixEvalEnv(BaseEnv):
|
|||
"fallback_rate": self.fallback_count / total if total > 0 else 0.0,
|
||||
"benchmark_metrics": benchmark_metrics,
|
||||
}
|
||||
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("MixEval Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(f" Overall Average Score: {avg_score:.2%}")
|
||||
print(f" Freeform Score: {freeform_score:.2%} ({len(freeform_results)} items)")
|
||||
print(f" Multiple Choice Accuracy: {multichoice_score:.2%} ({len(multichoice_results)} items)")
|
||||
print(
|
||||
f" Multiple Choice Accuracy: {multichoice_score:.2%} ({len(multichoice_results)} items)"
|
||||
)
|
||||
print(f" Total Evaluated: {total}")
|
||||
if self.config.thinking_mode:
|
||||
print(f" Format Compliance: {format_valid / total:.2%}")
|
||||
print(f" Thinking Utilization: {has_thinking / total:.2%}")
|
||||
print(f" Judge Error Rate: {self.judge_error_count / total:.2%}")
|
||||
print(f"\n Per-Benchmark Breakdown:")
|
||||
for bench, data in sorted(benchmark_metrics.items(), key=lambda x: -x[1]["avg_score"]):
|
||||
for bench, data in sorted(
|
||||
benchmark_metrics.items(), key=lambda x: -x[1]["avg_score"]
|
||||
):
|
||||
print(f" {bench}: {data['avg_score']:.2%} [{data['count']} items]")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
# Save results
|
||||
if self.config.data_dir_to_save_evals:
|
||||
self._save_results(metrics, valid_results)
|
||||
|
||||
|
||||
return metrics
|
||||
|
||||
def _save_results(self, metrics: Dict, results: List[Dict]) -> None:
|
||||
|
|
@ -743,22 +758,24 @@ class MixEvalEnv(BaseEnv):
|
|||
"""Log metrics to Weights & Biases."""
|
||||
if not self.config.use_wandb:
|
||||
return
|
||||
|
||||
|
||||
log_dict = {
|
||||
"mixeval/avg_score": metrics.get("avg_score", 0),
|
||||
"mixeval/freeform_score": metrics.get("freeform_score", 0),
|
||||
"mixeval/multichoice_accuracy": metrics.get("multichoice_accuracy", 0),
|
||||
"mixeval/total_evaluated": metrics.get("total_evaluated", 0),
|
||||
"mixeval/format_compliance_rate": metrics.get("format_compliance_rate", 0),
|
||||
"mixeval/thinking_utilization_rate": metrics.get("thinking_utilization_rate", 0),
|
||||
"mixeval/thinking_utilization_rate": metrics.get(
|
||||
"thinking_utilization_rate", 0
|
||||
),
|
||||
"mixeval/judge_error_rate": metrics.get("judge_error_rate", 0),
|
||||
}
|
||||
|
||||
|
||||
# Log per-benchmark scores
|
||||
for bench, data in metrics.get("benchmark_metrics", {}).items():
|
||||
safe_name = bench.replace(" ", "_")[:20]
|
||||
log_dict[f"mixeval/score_{safe_name}"] = data.get("avg_score", 0)
|
||||
|
||||
|
||||
wandb.log(log_dict, step=step)
|
||||
|
||||
# Required abstract method implementations
|
||||
|
|
@ -777,4 +794,3 @@ class MixEvalEnv(BaseEnv):
|
|||
|
||||
if __name__ == "__main__":
|
||||
MixEvalEnv.cli()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue