mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ef9c0c3699
commit
afab28dfa9
37 changed files with 4868 additions and 4052 deletions
|
|
@ -31,6 +31,14 @@ from typing import Dict, List, Optional, Tuple
|
|||
import openai
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from eval_helpers import (
|
||||
ANSWER_TAG_PATTERN,
|
||||
create_system_content,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
save_eval_results,
|
||||
validate_thinking_format,
|
||||
)
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
|
|
@ -40,15 +48,6 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
)
|
||||
from eval_helpers import (
|
||||
ANSWER_TAG_PATTERN,
|
||||
validate_thinking_format,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
create_system_content,
|
||||
save_eval_results,
|
||||
)
|
||||
|
||||
|
||||
# SimpleQA prompt template - Nous style with <answer> tags
|
||||
SIMPLEQA_PROMPT_TEMPLATE = """Please provide your answer within <answer></answer> tags. Give a concise, accurate answer.
|
||||
|
|
@ -243,20 +242,20 @@ class SimpleQAEvalConfig(BaseEnvConfig):
|
|||
class SimpleQAEvalEnv(BaseEnv):
|
||||
"""
|
||||
SimpleQA Evaluation Environment for Atropos.
|
||||
|
||||
|
||||
Evaluates models on the SimpleQA factuality benchmark.
|
||||
|
||||
|
||||
Two scoring modes:
|
||||
1. String Matching (default, Nous style): Uses exact/fuzzy match - fast, no LLM needed
|
||||
2. LLM Judge (optional, OpenAI style): Uses GPT-4o judge - more nuanced but slower
|
||||
|
||||
|
||||
Key features:
|
||||
- Loads SimpleQA dataset from HuggingFace (lighteval/SimpleQA)
|
||||
- Open-ended question answering (not multiple choice)
|
||||
- Optional thinking mode with <think></think> tags
|
||||
- Tracks exact match, fuzzy match, and combined accuracy
|
||||
"""
|
||||
|
||||
|
||||
name = "simpleqa_eval"
|
||||
env_config_cls = SimpleQAEvalConfig
|
||||
|
||||
|
|
@ -272,16 +271,18 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
|
||||
# Initialize metrics tracking
|
||||
self.eval_metrics = []
|
||||
|
||||
|
||||
# Pre-compile regex patterns for thinking mode
|
||||
self._think_pattern = re.compile(r"<think>")
|
||||
self._think_close_pattern = re.compile(r"</think>")
|
||||
self._think_content_pattern = re.compile(r"</think>\s*(.*)", re.DOTALL)
|
||||
self._thinking_extract_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
|
||||
|
||||
# Pre-compile regex patterns for <answer></answer> tag extraction
|
||||
self._answer_tag_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
|
||||
|
||||
self._answer_tag_pattern = re.compile(
|
||||
r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
# Initialize judge client only if using LLM judge mode
|
||||
self.judge_client = None
|
||||
if self.config.use_llm_judge:
|
||||
|
|
@ -305,11 +306,11 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
return create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt
|
||||
self.config.custom_system_prompt,
|
||||
)
|
||||
|
||||
# ==================== String Matching Functions (Nous style) ====================
|
||||
|
||||
|
||||
def _exact_match(self, gold: str, prediction: str) -> bool:
|
||||
"""
|
||||
Evaluate open-ended answer using exact match (case-insensitive).
|
||||
|
|
@ -317,7 +318,7 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
"""
|
||||
if not prediction:
|
||||
return False
|
||||
|
||||
|
||||
return prediction.lower().strip() == gold.lower().strip()
|
||||
|
||||
def _fuzzy_match(self, gold: str, prediction: str) -> bool:
|
||||
|
|
@ -328,24 +329,24 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
"""
|
||||
if not prediction:
|
||||
return False
|
||||
|
||||
|
||||
pred_lower = prediction.lower().strip()
|
||||
truth_lower = gold.lower().strip()
|
||||
|
||||
|
||||
# Check if either string contains the other
|
||||
return truth_lower in pred_lower or pred_lower in truth_lower
|
||||
|
||||
def _score_string_match(self, gold: str, prediction: str) -> Dict:
|
||||
"""
|
||||
Score a prediction using Nous-style string matching methods.
|
||||
|
||||
|
||||
Returns dict with:
|
||||
- exact_match: bool (case-insensitive exact match)
|
||||
- fuzzy_match: bool (containment in either direction)
|
||||
"""
|
||||
exact = self._exact_match(gold, prediction)
|
||||
fuzzy = self._fuzzy_match(gold, prediction)
|
||||
|
||||
|
||||
return {
|
||||
"exact_match": exact,
|
||||
"fuzzy_match": fuzzy,
|
||||
|
|
@ -354,10 +355,7 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
# ==================== LLM Judge Functions (OpenAI style) ====================
|
||||
|
||||
def _format_judge_prompt(
|
||||
self,
|
||||
question: str,
|
||||
gold_answer: str,
|
||||
predicted_answer: str
|
||||
self, question: str, gold_answer: str, predicted_answer: str
|
||||
) -> str:
|
||||
"""Format the judge prompt using the lighteval grading template."""
|
||||
return SIMPLEQA_GRADER_TEMPLATE.format(
|
||||
|
|
@ -369,7 +367,7 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
def _parse_judge_grade(self, judge_response: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Parse the judge's grade from their response.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (grade_string, score)
|
||||
- CORRECT: 1.0
|
||||
|
|
@ -377,7 +375,7 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
- NOT_ATTEMPTED: 0.0 (but tracked separately)
|
||||
"""
|
||||
response = judge_response.strip().upper()
|
||||
|
||||
|
||||
# Direct match
|
||||
if response == "A":
|
||||
return "CORRECT", 1.0
|
||||
|
|
@ -385,7 +383,7 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
return "INCORRECT", 0.0
|
||||
elif response == "C":
|
||||
return "NOT_ATTEMPTED", 0.0
|
||||
|
||||
|
||||
# Try to find A, B, or C in the response
|
||||
if "A" in response and "B" not in response and "C" not in response:
|
||||
return "CORRECT", 1.0
|
||||
|
|
@ -393,15 +391,19 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
return "INCORRECT", 0.0
|
||||
elif "C" in response and "A" not in response and "B" not in response:
|
||||
return "NOT_ATTEMPTED", 0.0
|
||||
|
||||
|
||||
# Check for text matches
|
||||
if "CORRECT" in response and "INCORRECT" not in response and "NOT" not in response:
|
||||
if (
|
||||
"CORRECT" in response
|
||||
and "INCORRECT" not in response
|
||||
and "NOT" not in response
|
||||
):
|
||||
return "CORRECT", 1.0
|
||||
elif "INCORRECT" in response:
|
||||
return "INCORRECT", 0.0
|
||||
elif "NOT_ATTEMPTED" in response or "NOT ATTEMPTED" in response:
|
||||
return "NOT_ATTEMPTED", 0.0
|
||||
|
||||
|
||||
# Unable to parse - treat as incorrect
|
||||
return "PARSE_ERROR", 0.0
|
||||
|
||||
|
|
@ -432,7 +434,7 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
# Default to string matching (Nous style)
|
||||
use_llm_judge=False,
|
||||
)
|
||||
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="Hermes-3-Llama-3.1-8B",
|
||||
|
|
@ -442,13 +444,17 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
num_requests_for_eval=1024,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Load the SimpleQA dataset and prepare for evaluation."""
|
||||
scoring_mode = "LLM Judge (GPT-4o)" if self.config.use_llm_judge else "String Matching (Nous)"
|
||||
|
||||
scoring_mode = (
|
||||
"LLM Judge (GPT-4o)"
|
||||
if self.config.use_llm_judge
|
||||
else "String Matching (Nous)"
|
||||
)
|
||||
|
||||
print(f"\nSimpleQA Evaluation Setup:")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Scoring mode: {scoring_mode}")
|
||||
|
|
@ -456,10 +462,12 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
print(f" Evaluation split: {self.config.eval_split}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
if self.config.use_llm_judge:
|
||||
print(f" Judge model: {self.config.judge_model_name} @ {self.config.judge_base_url}")
|
||||
print(
|
||||
f" Judge model: {self.config.judge_model_name} @ {self.config.judge_base_url}"
|
||||
)
|
||||
if self.config.thinking_mode:
|
||||
print(f" Thinking prompt: {self._get_thinking_prompt()[:100]}...")
|
||||
|
||||
|
||||
# Load SimpleQA dataset
|
||||
try:
|
||||
dataset = load_dataset(
|
||||
|
|
@ -468,16 +476,16 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
)
|
||||
self.eval_data = list(dataset)
|
||||
print(f" Loaded {len(self.eval_data)} evaluation items")
|
||||
|
||||
|
||||
# Show sample structure
|
||||
if self.eval_data and self.config.full_debug:
|
||||
sample = self.eval_data[0]
|
||||
print(f" Sample fields: {list(sample.keys())}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading SimpleQA dataset: {e}")
|
||||
raise
|
||||
|
||||
|
||||
self.all_eval_items = self.eval_data
|
||||
self.iter = 0
|
||||
|
||||
|
|
@ -485,13 +493,13 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
"""Validate thinking format. Returns (is_valid, content_after_think)."""
|
||||
if not self.config.thinking_mode:
|
||||
return True, response
|
||||
|
||||
|
||||
think_open_count = len(self._think_pattern.findall(response))
|
||||
think_close_count = len(self._think_close_pattern.findall(response))
|
||||
|
||||
|
||||
if think_open_count != 1 or think_close_count != 1:
|
||||
return False, response
|
||||
|
||||
|
||||
match = self._think_content_pattern.search(response)
|
||||
if match:
|
||||
return True, match.group(1).strip()
|
||||
|
|
@ -518,27 +526,29 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
def _extract_answer_for_scoring(self, response: str) -> Tuple[str, bool, bool]:
|
||||
"""
|
||||
Extract the answer to use for scoring from the model response.
|
||||
|
||||
|
||||
Handles both thinking mode and answer tags:
|
||||
- If thinking mode: extract content after </think>
|
||||
- Then extract content from <answer></answer> tags
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (answer_text, thinking_format_valid, answer_tag_found)
|
||||
"""
|
||||
# First, handle thinking mode
|
||||
thinking_format_valid = True
|
||||
content_after_think = response
|
||||
|
||||
|
||||
if self.config.thinking_mode:
|
||||
thinking_format_valid, content_after_think = self._validate_thinking_format(response)
|
||||
|
||||
thinking_format_valid, content_after_think = self._validate_thinking_format(
|
||||
response
|
||||
)
|
||||
|
||||
# Now extract from <answer> tags
|
||||
answer_content = self._extract_answer_tag(content_after_think)
|
||||
|
||||
|
||||
if answer_content is not None:
|
||||
return answer_content, thinking_format_valid, True
|
||||
|
||||
|
||||
# Fallback: if no answer tags, use content after think (or full response)
|
||||
return content_after_think, thinking_format_valid, False
|
||||
|
||||
|
|
@ -566,24 +576,28 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
"""Evaluate a single SimpleQA question."""
|
||||
try:
|
||||
# SimpleQA uses 'problem' for question and 'answer' for gold
|
||||
question = eval_item.get('problem', '')
|
||||
gold_answer = eval_item.get('answer', '')
|
||||
metadata = eval_item.get('metadata', {})
|
||||
topic = metadata.get('topic', 'unknown') if isinstance(metadata, dict) else 'unknown'
|
||||
|
||||
question = eval_item.get("problem", "")
|
||||
gold_answer = eval_item.get("answer", "")
|
||||
metadata = eval_item.get("metadata", {})
|
||||
topic = (
|
||||
metadata.get("topic", "unknown")
|
||||
if isinstance(metadata, dict)
|
||||
else "unknown"
|
||||
)
|
||||
|
||||
if not question or not gold_answer:
|
||||
return {"score": None, "sample": None}
|
||||
|
||||
|
||||
# Format the prompt
|
||||
formatted_prompt = self._format_simpleqa_prompt(question)
|
||||
|
||||
|
||||
# Build messages for model
|
||||
messages = []
|
||||
system_content = self._create_system_content()
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
messages.append({"role": "user", "content": formatted_prompt})
|
||||
|
||||
|
||||
# Get model answer with retry logic
|
||||
model_response = None
|
||||
finish_reason = None
|
||||
|
|
@ -597,25 +611,34 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
completion = await self.server.chat_completion(**completion_kwargs)
|
||||
|
||||
|
||||
if completion.choices and completion.choices[0].message.content:
|
||||
model_response = completion.choices[0].message.content
|
||||
finish_reason = getattr(completion.choices[0], 'finish_reason', None)
|
||||
|
||||
if len(model_response.strip()) >= self.config.min_response_length:
|
||||
finish_reason = getattr(
|
||||
completion.choices[0], "finish_reason", None
|
||||
)
|
||||
|
||||
if (
|
||||
len(model_response.strip())
|
||||
>= self.config.min_response_length
|
||||
):
|
||||
break
|
||||
elif attempt < self.config.max_retries - 1:
|
||||
if self.config.full_debug:
|
||||
print(f" Response too short, retrying...")
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}")
|
||||
if hasattr(e, 'response'):
|
||||
print(
|
||||
f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}"
|
||||
)
|
||||
if hasattr(e, "response"):
|
||||
try:
|
||||
print(f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}")
|
||||
print(
|
||||
f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}"
|
||||
)
|
||||
except:
|
||||
pass
|
||||
if attempt < self.config.max_retries - 1:
|
||||
|
|
@ -623,25 +646,29 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
else:
|
||||
print(f" Failed after {self.config.max_retries} attempts")
|
||||
return {"score": None, "sample": None}
|
||||
|
||||
|
||||
if not model_response:
|
||||
return {"score": None, "sample": None}
|
||||
|
||||
|
||||
# Extract answer using the combined thinking + answer tag extraction
|
||||
answer_for_scoring, thinking_format_valid, answer_tag_found = self._extract_answer_for_scoring(model_response)
|
||||
|
||||
answer_for_scoring, thinking_format_valid, answer_tag_found = (
|
||||
self._extract_answer_for_scoring(model_response)
|
||||
)
|
||||
|
||||
# Extract thinking content for logging
|
||||
thinking_content = None
|
||||
if self.config.thinking_mode:
|
||||
thinking_content = self._extract_thinking_content(model_response)
|
||||
|
||||
|
||||
# Score the response based on mode
|
||||
if self.config.use_llm_judge:
|
||||
# LLM Judge mode
|
||||
result = await self._score_with_judge(question, gold_answer, answer_for_scoring)
|
||||
result = await self._score_with_judge(
|
||||
question, gold_answer, answer_for_scoring
|
||||
)
|
||||
if result is None:
|
||||
return {"score": None, "sample": None}
|
||||
|
||||
|
||||
sample = {
|
||||
"question": question,
|
||||
"gold_answer": gold_answer,
|
||||
|
|
@ -658,23 +685,31 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
"answer_tag_found": answer_tag_found,
|
||||
"scoring_mode": "llm_judge",
|
||||
}
|
||||
|
||||
|
||||
if self.config.thinking_mode:
|
||||
sample["thinking_content"] = thinking_content
|
||||
|
||||
|
||||
if self.config.full_debug:
|
||||
status = "✓" if result["grade"] == "CORRECT" else ("○" if result["grade"] == "NOT_ATTEMPTED" else "✗")
|
||||
status = (
|
||||
"✓"
|
||||
if result["grade"] == "CORRECT"
|
||||
else ("○" if result["grade"] == "NOT_ATTEMPTED" else "✗")
|
||||
)
|
||||
print(f" [{status}] {topic[:20]}: {result['grade']}")
|
||||
|
||||
|
||||
return {"score": result["score"], "sample": sample}
|
||||
|
||||
|
||||
else:
|
||||
# String matching mode (Nous style)
|
||||
match_results = self._score_string_match(gold_answer, answer_for_scoring)
|
||||
|
||||
match_results = self._score_string_match(
|
||||
gold_answer, answer_for_scoring
|
||||
)
|
||||
|
||||
# Score is 1.0 if either exact or fuzzy match
|
||||
is_correct = match_results["exact_match"] or match_results["fuzzy_match"]
|
||||
|
||||
is_correct = (
|
||||
match_results["exact_match"] or match_results["fuzzy_match"]
|
||||
)
|
||||
|
||||
sample = {
|
||||
"question": question,
|
||||
"gold_answer": gold_answer,
|
||||
|
|
@ -692,33 +727,38 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
"answer_tag_found": answer_tag_found,
|
||||
"scoring_mode": "string_match",
|
||||
}
|
||||
|
||||
|
||||
if self.config.thinking_mode:
|
||||
sample["thinking_content"] = thinking_content
|
||||
|
||||
|
||||
if self.config.full_debug:
|
||||
status = "✓" if is_correct else "✗"
|
||||
print(f" [{status}] {topic[:20]}: exact={match_results['exact_match']}, fuzzy={match_results['fuzzy_match']}")
|
||||
|
||||
print(
|
||||
f" [{status}] {topic[:20]}: exact={match_results['exact_match']}, fuzzy={match_results['fuzzy_match']}"
|
||||
)
|
||||
|
||||
return {"score": 1.0 if is_correct else 0.0, "sample": sample}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f"Error in rollout_and_score_eval: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return {"score": None, "sample": None}
|
||||
|
||||
async def _score_with_judge(self, question: str, gold_answer: str, prediction: str) -> Optional[Dict]:
|
||||
async def _score_with_judge(
|
||||
self, question: str, gold_answer: str, prediction: str
|
||||
) -> Optional[Dict]:
|
||||
"""Score using LLM judge."""
|
||||
judge_prompt = self._format_judge_prompt(
|
||||
question=question,
|
||||
gold_answer=gold_answer,
|
||||
predicted_answer=prediction,
|
||||
)
|
||||
|
||||
|
||||
judge_messages = [{"role": "user", "content": judge_prompt}]
|
||||
|
||||
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
kwargs = {
|
||||
|
|
@ -728,10 +768,15 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.judge_max_tokens > 0:
|
||||
kwargs["max_tokens"] = self.config.judge_max_tokens
|
||||
|
||||
judge_completion = await self.judge_client.chat.completions.create(**kwargs)
|
||||
|
||||
if judge_completion.choices and judge_completion.choices[0].message.content:
|
||||
|
||||
judge_completion = await self.judge_client.chat.completions.create(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if (
|
||||
judge_completion.choices
|
||||
and judge_completion.choices[0].message.content
|
||||
):
|
||||
judge_response = judge_completion.choices[0].message.content
|
||||
grade, score = self._parse_judge_grade(judge_response)
|
||||
return {
|
||||
|
|
@ -739,23 +784,29 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
"grade": grade,
|
||||
"score": score,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f" Judge Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}")
|
||||
print(
|
||||
f" Judge Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}"
|
||||
)
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
else:
|
||||
print(f" Judge failed after {self.config.max_retries} attempts")
|
||||
return None
|
||||
|
||||
|
||||
return None
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Run SimpleQA evaluation."""
|
||||
start_time = time.time()
|
||||
|
||||
scoring_mode = "LLM Judge (GPT-4o)" if self.config.use_llm_judge else "String Matching (Nous)"
|
||||
|
||||
|
||||
scoring_mode = (
|
||||
"LLM Judge (GPT-4o)"
|
||||
if self.config.use_llm_judge
|
||||
else "String Matching (Nous)"
|
||||
)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Starting SimpleQA Evaluation")
|
||||
print(f"{'='*60}")
|
||||
|
|
@ -764,82 +815,102 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
print(f" Max tokens (for answer): {self.config.eval_max_tokens}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
try:
|
||||
eval_tasks = [
|
||||
self.rollout_and_score_eval(item) for item in self.all_eval_items
|
||||
]
|
||||
results = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating SimpleQA")
|
||||
|
||||
|
||||
valid_results = [
|
||||
r for r in results
|
||||
r
|
||||
for r in results
|
||||
if r and r.get("sample") is not None and r.get("score") is not None
|
||||
]
|
||||
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during evaluation: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
# Compute metrics
|
||||
samples = [r["sample"] for r in valid_results]
|
||||
total_count = len(valid_results)
|
||||
|
||||
|
||||
# Build metrics based on scoring mode
|
||||
eval_metrics = {
|
||||
"eval/total_questions": total_count,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
"eval/thinking_mode_enabled": 1.0 if self.config.thinking_mode else 0.0,
|
||||
"eval/scoring_mode": 1.0 if self.config.use_llm_judge else 0.0, # 1=judge, 0=string
|
||||
"eval/scoring_mode": (
|
||||
1.0 if self.config.use_llm_judge else 0.0
|
||||
), # 1=judge, 0=string
|
||||
}
|
||||
|
||||
|
||||
if self.config.use_llm_judge:
|
||||
# LLM Judge metrics
|
||||
correct_count = sum(1 for s in samples if s.get("grade") == "CORRECT")
|
||||
incorrect_count = sum(1 for s in samples if s.get("grade") == "INCORRECT")
|
||||
not_attempted_count = sum(1 for s in samples if s.get("grade") == "NOT_ATTEMPTED")
|
||||
parse_error_count = sum(1 for s in samples if s.get("grade") == "PARSE_ERROR")
|
||||
|
||||
not_attempted_count = sum(
|
||||
1 for s in samples if s.get("grade") == "NOT_ATTEMPTED"
|
||||
)
|
||||
parse_error_count = sum(
|
||||
1 for s in samples if s.get("grade") == "PARSE_ERROR"
|
||||
)
|
||||
|
||||
accuracy = correct_count / total_count if total_count > 0 else 0.0
|
||||
attempted_count = correct_count + incorrect_count
|
||||
accuracy_if_attempted = correct_count / attempted_count if attempted_count > 0 else 0.0
|
||||
not_attempted_rate = not_attempted_count / total_count if total_count > 0 else 0.0
|
||||
|
||||
eval_metrics.update({
|
||||
"eval/accuracy": accuracy,
|
||||
"eval/accuracy_if_attempted": accuracy_if_attempted,
|
||||
"eval/not_attempted_rate": not_attempted_rate,
|
||||
"eval/correct_count": correct_count,
|
||||
"eval/incorrect_count": incorrect_count,
|
||||
"eval/not_attempted_count": not_attempted_count,
|
||||
"eval/parse_error_count": parse_error_count,
|
||||
})
|
||||
accuracy_if_attempted = (
|
||||
correct_count / attempted_count if attempted_count > 0 else 0.0
|
||||
)
|
||||
not_attempted_rate = (
|
||||
not_attempted_count / total_count if total_count > 0 else 0.0
|
||||
)
|
||||
|
||||
eval_metrics.update(
|
||||
{
|
||||
"eval/accuracy": accuracy,
|
||||
"eval/accuracy_if_attempted": accuracy_if_attempted,
|
||||
"eval/not_attempted_rate": not_attempted_rate,
|
||||
"eval/correct_count": correct_count,
|
||||
"eval/incorrect_count": incorrect_count,
|
||||
"eval/not_attempted_count": not_attempted_count,
|
||||
"eval/parse_error_count": parse_error_count,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# String matching metrics (Nous style)
|
||||
exact_match_count = sum(1 for s in samples if s.get("exact_match", False))
|
||||
fuzzy_match_count = sum(1 for s in samples if s.get("fuzzy_match", False))
|
||||
correct_count = sum(1 for s in samples if s.get("is_correct", False))
|
||||
|
||||
exact_match_rate = exact_match_count / total_count if total_count > 0 else 0.0
|
||||
fuzzy_match_rate = fuzzy_match_count / total_count if total_count > 0 else 0.0
|
||||
|
||||
exact_match_rate = (
|
||||
exact_match_count / total_count if total_count > 0 else 0.0
|
||||
)
|
||||
fuzzy_match_rate = (
|
||||
fuzzy_match_count / total_count if total_count > 0 else 0.0
|
||||
)
|
||||
accuracy = correct_count / total_count if total_count > 0 else 0.0
|
||||
|
||||
eval_metrics.update({
|
||||
"eval/accuracy": accuracy,
|
||||
"eval/exact_match_accuracy": exact_match_rate,
|
||||
"eval/fuzzy_match_accuracy": fuzzy_match_rate,
|
||||
"eval/correct_count": correct_count,
|
||||
"eval/exact_match_count": exact_match_count,
|
||||
"eval/fuzzy_match_count": fuzzy_match_count,
|
||||
})
|
||||
|
||||
|
||||
eval_metrics.update(
|
||||
{
|
||||
"eval/accuracy": accuracy,
|
||||
"eval/exact_match_accuracy": exact_match_rate,
|
||||
"eval/fuzzy_match_accuracy": fuzzy_match_rate,
|
||||
"eval/correct_count": correct_count,
|
||||
"eval/exact_match_count": exact_match_count,
|
||||
"eval/fuzzy_match_count": fuzzy_match_count,
|
||||
}
|
||||
)
|
||||
|
||||
# Per-topic accuracy
|
||||
topic_results = {}
|
||||
for sample in samples:
|
||||
|
|
@ -847,35 +918,47 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
if topic not in topic_results:
|
||||
topic_results[topic] = {"correct": 0, "total": 0}
|
||||
topic_results[topic]["total"] += 1
|
||||
|
||||
|
||||
if self.config.use_llm_judge:
|
||||
if sample.get("grade") == "CORRECT":
|
||||
topic_results[topic]["correct"] += 1
|
||||
else:
|
||||
if sample.get("is_correct", False):
|
||||
topic_results[topic]["correct"] += 1
|
||||
|
||||
|
||||
# Average response length
|
||||
response_lengths = [s.get("response_length", 0) for s in samples]
|
||||
avg_response_length = sum(response_lengths) / len(response_lengths) if response_lengths else 0
|
||||
avg_response_length = (
|
||||
sum(response_lengths) / len(response_lengths) if response_lengths else 0
|
||||
)
|
||||
eval_metrics["eval/avg_response_length"] = avg_response_length
|
||||
|
||||
|
||||
# Answer tag usage (primary format indicator for SimpleQA Nous)
|
||||
answer_tag_found_count = sum(1 for s in samples if s.get("answer_tag_found", False))
|
||||
answer_tag_found_count = sum(
|
||||
1 for s in samples if s.get("answer_tag_found", False)
|
||||
)
|
||||
answer_tag_rate = answer_tag_found_count / len(samples) if samples else 0.0
|
||||
eval_metrics["eval/answer_tag_rate"] = answer_tag_rate
|
||||
eval_metrics["eval/answer_tag_found_count"] = answer_tag_found_count
|
||||
|
||||
|
||||
# Thinking format compliance (for thinking mode)
|
||||
if self.config.thinking_mode:
|
||||
thinking_format_compliant = sum(1 for s in samples if s.get("thinking_format_valid", True))
|
||||
thinking_format_compliance_rate = thinking_format_compliant / len(samples) if samples else 0.0
|
||||
eval_metrics["eval/thinking_format_compliance_rate"] = thinking_format_compliance_rate
|
||||
|
||||
thinking_format_compliant = sum(
|
||||
1 for s in samples if s.get("thinking_format_valid", True)
|
||||
)
|
||||
thinking_format_compliance_rate = (
|
||||
thinking_format_compliant / len(samples) if samples else 0.0
|
||||
)
|
||||
eval_metrics["eval/thinking_format_compliance_rate"] = (
|
||||
thinking_format_compliance_rate
|
||||
)
|
||||
|
||||
thinking_utilization = sum(1 for s in samples if s.get("thinking_content"))
|
||||
thinking_utilization_rate = thinking_utilization / len(samples) if samples else 0.0
|
||||
thinking_utilization_rate = (
|
||||
thinking_utilization / len(samples) if samples else 0.0
|
||||
)
|
||||
eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate
|
||||
|
||||
|
||||
# Add top topic metrics
|
||||
sorted_topics = sorted(topic_results.items(), key=lambda x: -x[1]["total"])[:20]
|
||||
for topic, stats in sorted_topics:
|
||||
|
|
@ -883,44 +966,62 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
topic_accuracy = stats["correct"] / stats["total"]
|
||||
topic_key = topic.replace(" ", "_").replace("-", "_").lower()[:30]
|
||||
eval_metrics[f"eval/topic_{topic_key}_accuracy"] = topic_accuracy
|
||||
|
||||
|
||||
# Store metrics for wandb logging
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
print(f"SimpleQA Evaluation Results ({scoring_mode})")
|
||||
print(f"{'='*60}")
|
||||
|
||||
|
||||
if self.config.use_llm_judge:
|
||||
print(f"Overall Accuracy: {eval_metrics['eval/accuracy']:.4f} ({correct_count}/{total_count})")
|
||||
print(f"Accuracy (if attempted): {eval_metrics['eval/accuracy_if_attempted']:.4f}")
|
||||
print(
|
||||
f"Overall Accuracy: {eval_metrics['eval/accuracy']:.4f} ({correct_count}/{total_count})"
|
||||
)
|
||||
print(
|
||||
f"Accuracy (if attempted): {eval_metrics['eval/accuracy_if_attempted']:.4f}"
|
||||
)
|
||||
print(f"Not Attempted Rate: {eval_metrics['eval/not_attempted_rate']:.4f}")
|
||||
print(f"\nGrade Distribution:")
|
||||
print(f" CORRECT: {correct_count} ({100*correct_count/total_count:.1f}%)")
|
||||
print(f" INCORRECT: {incorrect_count} ({100*incorrect_count/total_count:.1f}%)")
|
||||
print(f" NOT_ATTEMPTED: {not_attempted_count} ({100*not_attempted_count/total_count:.1f}%)")
|
||||
print(
|
||||
f" INCORRECT: {incorrect_count} ({100*incorrect_count/total_count:.1f}%)"
|
||||
)
|
||||
print(
|
||||
f" NOT_ATTEMPTED: {not_attempted_count} ({100*not_attempted_count/total_count:.1f}%)"
|
||||
)
|
||||
else:
|
||||
print(f"Overall Accuracy: {eval_metrics['eval/accuracy']:.4f} ({correct_count}/{total_count})")
|
||||
print(f"Exact Match Accuracy: {eval_metrics['eval/exact_match_accuracy']:.4f} ({exact_match_count}/{total_count})")
|
||||
print(f"Fuzzy Match Accuracy: {eval_metrics['eval/fuzzy_match_accuracy']:.4f} ({fuzzy_match_count}/{total_count})")
|
||||
|
||||
print(
|
||||
f"Overall Accuracy: {eval_metrics['eval/accuracy']:.4f} ({correct_count}/{total_count})"
|
||||
)
|
||||
print(
|
||||
f"Exact Match Accuracy: {eval_metrics['eval/exact_match_accuracy']:.4f} ({exact_match_count}/{total_count})"
|
||||
)
|
||||
print(
|
||||
f"Fuzzy Match Accuracy: {eval_metrics['eval/fuzzy_match_accuracy']:.4f} ({fuzzy_match_count}/{total_count})"
|
||||
)
|
||||
|
||||
print(f"\nEvaluation Time: {end_time - start_time:.1f} seconds")
|
||||
print(f"Avg Response Length: {avg_response_length:.0f} chars")
|
||||
print(f"Answer Tag Rate: {answer_tag_rate:.4f} ({answer_tag_found_count}/{total_count})")
|
||||
print(
|
||||
f"Answer Tag Rate: {answer_tag_rate:.4f} ({answer_tag_found_count}/{total_count})"
|
||||
)
|
||||
if self.config.thinking_mode:
|
||||
print(f"Thinking Format Compliance: {thinking_format_compliance_rate:.4f}")
|
||||
print(f"Thinking Utilization: {thinking_utilization}/{total_count}")
|
||||
|
||||
|
||||
if len(sorted_topics) > 0:
|
||||
print(f"\nTop Topics (by count):")
|
||||
for topic, stats in sorted_topics[:10]:
|
||||
if stats["total"] > 0:
|
||||
topic_acc = stats["correct"] / stats["total"]
|
||||
print(f" {topic}: {topic_acc:.4f} ({stats['correct']}/{stats['total']})")
|
||||
|
||||
print(
|
||||
f" {topic}: {topic_acc:.4f} ({stats['correct']}/{stats['total']})"
|
||||
)
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
# Log evaluation results
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
|
|
@ -932,7 +1033,9 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
"temperature": self.config.eval_temperature,
|
||||
"max_tokens": self.config.eval_max_tokens,
|
||||
"thinking_mode": self.config.thinking_mode,
|
||||
"scoring_mode": "llm_judge" if self.config.use_llm_judge else "string_match",
|
||||
"scoring_mode": (
|
||||
"llm_judge" if self.config.use_llm_judge else "string_match"
|
||||
),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -942,18 +1045,21 @@ class SimpleQAEvalEnv(BaseEnv):
|
|||
"""Log metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
|
||||
for metric_name, metric_value in self.eval_metrics:
|
||||
wandb_metrics[metric_name] = metric_value
|
||||
self.eval_metrics = []
|
||||
|
||||
wandb_metrics["config/thinking_mode"] = 1.0 if self.config.thinking_mode else 0.0
|
||||
|
||||
wandb_metrics["config/thinking_mode"] = (
|
||||
1.0 if self.config.thinking_mode else 0.0
|
||||
)
|
||||
wandb_metrics["config/eval_max_tokens"] = self.config.eval_max_tokens
|
||||
wandb_metrics["config/use_llm_judge"] = 1.0 if self.config.use_llm_judge else 0.0
|
||||
|
||||
wandb_metrics["config/use_llm_judge"] = (
|
||||
1.0 if self.config.use_llm_judge else 0.0
|
||||
)
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SimpleQAEvalEnv.cli()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue