[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-12-24 10:48:20 +00:00
parent ef9c0c3699
commit afab28dfa9
37 changed files with 4868 additions and 4052 deletions

View file

@ -25,7 +25,7 @@ Supports optional thinking mode with <think></think> tags for extended reasoning
Available subsets:
- English: aqua-rat, logiqa-en, lsat-ar, lsat-lr, lsat-rc, sat-en, sat-en-without-passage, sat-math
- Chinese: gaokao-biology, gaokao-chemistry, gaokao-chinese, gaokao-english,
- Chinese: gaokao-biology, gaokao-chemistry, gaokao-chinese, gaokao-english,
gaokao-geography, gaokao-history, gaokao-mathqa, gaokao-physics, logiqa-zh
"""
@ -38,6 +38,15 @@ from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from eval_helpers import (
build_mcqa_fallback_patterns,
create_system_content,
extract_letter_from_answer_tag,
extract_thinking_content,
get_default_thinking_prompt,
save_eval_results,
validate_thinking_format,
)
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
@ -47,16 +56,6 @@ from atroposlib.envs.base import (
BaseEnvConfig,
EvalHandlingEnum,
)
from eval_helpers import (
extract_letter_from_answer_tag,
validate_thinking_format,
extract_thinking_content,
get_default_thinking_prompt,
create_system_content,
save_eval_results,
build_mcqa_fallback_patterns,
)
# AGIEval generative prompt template with <answer> tag instruction
AGIEVAL_PROMPT_TEMPLATE = """Answer the following multiple choice question. Think step by step before answering.
@ -138,14 +137,14 @@ class AGIEvalConfig(BaseEnvConfig):
subsets: Optional[List[str]] = Field(
default=None,
description="List of AGIEval subsets to evaluate. If None, evaluates all English subsets. "
"Available: aqua-rat, logiqa-en, lsat-ar, lsat-lr, lsat-rc, sat-en, "
"sat-en-without-passage, sat-math, gaokao-biology, gaokao-chemistry, etc.",
"Available: aqua-rat, logiqa-en, lsat-ar, lsat-lr, lsat-rc, sat-en, "
"sat-en-without-passage, sat-math, gaokao-biology, gaokao-chemistry, etc.",
)
english_only: bool = Field(
default=True,
description="If True and subsets is None, only evaluate English subsets. "
"If False and subsets is None, evaluate all subsets including Chinese.",
"If False and subsets is None, evaluate all subsets including Chinese.",
)
eval_split: str = Field(
@ -199,10 +198,10 @@ class AGIEvalConfig(BaseEnvConfig):
class AGIEvalEnv(BaseEnv):
"""
AGIEval Evaluation Environment for Atropos (Generative/Reasoning Mode).
Evaluates models on the AGIEval benchmark using a generative approach where
models reason before answering multiple-choice questions from standardized exams.
Key features:
- Loads multiple AGIEval subsets from HuggingFace (dmayhem93/agieval-*)
- Uses generative prompt format with "think step by step"
@ -210,7 +209,7 @@ class AGIEvalEnv(BaseEnv):
- Tracks per-subset accuracy
- Supports English and Chinese subsets
"""
name = "agieval_eval"
env_config_cls = AGIEvalConfig
@ -226,16 +225,18 @@ class AGIEvalEnv(BaseEnv):
# Initialize metrics tracking
self.eval_metrics = []
# Pre-compile regex patterns for thinking mode
self._think_pattern = re.compile(r"<think>")
self._think_close_pattern = re.compile(r"</think>")
self._think_content_pattern = re.compile(r"</think>\s*(.*)", re.DOTALL)
self._thinking_extract_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
# Pre-compile regex for <answer></answer> tag extraction (primary method)
self._answer_tag_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
self._answer_tag_pattern = re.compile(
r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE
)
# Build fallback answer extraction patterns (supports A-E for up to 5 choices)
self._build_extraction_patterns()
@ -248,7 +249,7 @@ class AGIEvalEnv(BaseEnv):
return create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt
self.config.custom_system_prompt,
)
def _build_extraction_patterns(self):
@ -256,40 +257,32 @@ class AGIEvalEnv(BaseEnv):
# AGIEval typically has 4-5 choices
letters = "ABCDE"
letter_pattern = rf"([{letters}]|\([{letters}]\))"
self._pattern_final_answer_hope = re.compile(
rf"(?i:final\s+answer\s+is)\s*:?\s*{letter_pattern}\.?\s*I\s*hope",
re.IGNORECASE
re.IGNORECASE,
)
self._pattern_final_answer_is = re.compile(
rf"(?i:final\s+answer).{{0,100}}?\s+is\s*:?\s*{letter_pattern}",
re.IGNORECASE | re.DOTALL
re.IGNORECASE | re.DOTALL,
)
self._pattern_the_answer_is = re.compile(
rf"(?i:the\s+answer\s+is)\s*:?\s*{letter_pattern}",
re.IGNORECASE
rf"(?i:the\s+answer\s+is)\s*:?\s*{letter_pattern}", re.IGNORECASE
)
self._pattern_answer_colon = re.compile(
rf"(?i:answer)\s*:\s*.{{0,50}}?{letter_pattern}",
re.IGNORECASE | re.DOTALL
rf"(?i:answer)\s*:\s*.{{0,50}}?{letter_pattern}", re.IGNORECASE | re.DOTALL
)
self._pattern_answer_space = re.compile(
rf"(?i:answer)\s+{letter_pattern}",
re.IGNORECASE
rf"(?i:answer)\s+{letter_pattern}", re.IGNORECASE
)
self._pattern_start = re.compile(
rf"^\s*\**{letter_pattern}\**[\s\.\)\:]",
re.IGNORECASE
rf"^\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE
)
self._pattern_line_start = re.compile(
rf"\n\s*\**{letter_pattern}\**[\s\.\)\:]",
re.IGNORECASE
rf"\n\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE
)
self._pattern_standalone = re.compile(
rf"\b{letter_pattern}\b",
re.IGNORECASE
)
self._pattern_standalone = re.compile(rf"\b{letter_pattern}\b", re.IGNORECASE)
self._extraction_patterns = [
(0, self._pattern_final_answer_hope, "final_answer_hope"),
(50, self._pattern_final_answer_is, "final_answer_is"),
@ -325,7 +318,7 @@ class AGIEvalEnv(BaseEnv):
eval_max_tokens=0, # Use model default
thinking_mode=True,
)
server_configs = [
APIServerConfig(
model_name="Hermes-3-Llama-3.1-8B",
@ -335,7 +328,7 @@ class AGIEvalEnv(BaseEnv):
num_requests_for_eval=1024,
),
]
return env_config, server_configs
async def setup(self) -> None:
@ -346,7 +339,7 @@ class AGIEvalEnv(BaseEnv):
print(f" Thinking mode: {self.config.thinking_mode}")
if self.config.thinking_mode:
print(f" Thinking prompt: {self._get_thinking_prompt()[:100]}...")
# Determine which subsets to use
if self.config.subsets:
subsets_to_load = self.config.subsets
@ -354,42 +347,42 @@ class AGIEvalEnv(BaseEnv):
subsets_to_load = AGIEVAL_ENGLISH_SUBSETS
else:
subsets_to_load = list(AGIEVAL_SUBSETS.keys())
print(f" Subsets to evaluate: {subsets_to_load}")
# Load all subsets
self.eval_data = []
subset_counts = {}
for subset_name in subsets_to_load:
if subset_name not in AGIEVAL_SUBSETS:
print(f" Warning: Unknown subset '{subset_name}', skipping.")
continue
repo_name = AGIEVAL_SUBSETS[subset_name]
try:
dataset = load_dataset(repo_name, split=self.config.eval_split)
items = list(dataset)
# Add subset info to each item
for item in items:
item["_subset"] = subset_name
self.eval_data.extend(items)
subset_counts[subset_name] = len(items)
print(f" Loaded {len(items)} items from {subset_name}")
except Exception as e:
print(f" Error loading {subset_name}: {e}")
print(f"\n Total evaluation items: {len(self.eval_data)}")
# Print subset distribution
print(f"\n Subset distribution:")
for subset, count in sorted(subset_counts.items()):
print(f" {subset}: {count} questions")
self.all_eval_items = self.eval_data
self.iter = 0
@ -409,8 +402,8 @@ class AGIEvalEnv(BaseEnv):
return "\n".join(lines)
def _format_agieval_prompt(
self,
query: str,
self,
query: str,
choices: List[str],
) -> str:
"""
@ -418,30 +411,30 @@ class AGIEvalEnv(BaseEnv):
"""
num_choices = len(choices)
valid_letters = "".join(ascii_uppercase[:num_choices])
# Format choices
formatted_choices = self._format_choices(choices)
# Use generative template (like GPQA)
prompt = AGIEVAL_PROMPT_TEMPLATE.format(
question=query,
choices=formatted_choices,
valid_letters=valid_letters,
)
return prompt
def _validate_thinking_format(self, response: str) -> Tuple[bool, str]:
"""Validate thinking format and extract content after </think> tags."""
if not self.config.thinking_mode:
return True, response
think_open_count = len(self._think_pattern.findall(response))
think_close_count = len(self._think_close_pattern.findall(response))
if think_open_count != 1 or think_close_count != 1:
return False, response
match = self._think_content_pattern.search(response)
if match:
return True, match.group(1).strip()
@ -456,22 +449,19 @@ class AGIEvalEnv(BaseEnv):
return None
def _extract_answer(
self,
response: str,
num_choices: int = 4,
choices: Optional[List[str]] = None
self, response: str, num_choices: int = 4, choices: Optional[List[str]] = None
) -> Tuple[Optional[str], str]:
"""
Extract the answer letter from the model's response.
Primary method: Look for <answer></answer> tags, or match against choice texts.
Fallback: Use priority-ordered regex patterns.
"""
if not response:
return None, "empty_response"
valid_letters = set(ascii_uppercase[:num_choices])
# PRIMARY: Try <answer></answer> tags first
# Also matches against choice texts if provided
letter, method = extract_letter_from_answer_tag(
@ -479,27 +469,41 @@ class AGIEvalEnv(BaseEnv):
)
if letter:
return letter, method
# FALLBACK: Try each pattern in priority order
for priority, pattern, method_name in self._extraction_patterns:
matches = pattern.findall(response)
if matches:
match = matches[-1] if method_name in ["final_answer_is", "the_answer_is", "answer_colon", "answer_space"] else matches[0]
match = (
matches[-1]
if method_name
in [
"final_answer_is",
"the_answer_is",
"answer_colon",
"answer_space",
]
else matches[0]
)
if isinstance(match, tuple):
match = match[0]
letter = match.strip("()").upper()
if letter in valid_letters:
if self.config.full_debug:
print(f" Extracted '{letter}' using fallback method '{method_name}'")
print(
f" Extracted '{letter}' using fallback method '{method_name}'"
)
return letter, f"fallback_{method_name}"
for letter in reversed(list(valid_letters)):
if letter in response.upper():
if self.config.full_debug:
print(f" Extracted '{letter}' using fallback 'last_valid_letter'")
print(
f" Extracted '{letter}' using fallback 'last_valid_letter'"
)
return letter, "fallback_last_valid_letter"
return None, "no_match"
async def get_next_item(self):
@ -521,37 +525,39 @@ class AGIEvalEnv(BaseEnv):
async def rollout_and_score_eval(self, eval_item: Dict) -> Dict:
"""Evaluate a single AGIEval question using generative mode."""
try:
query = eval_item.get('query', '')
choices = eval_item.get('choices', [])
gold_indices = eval_item.get('gold', []) # Note: gold is a list
subset = eval_item.get('_subset', 'unknown')
query = eval_item.get("query", "")
choices = eval_item.get("choices", [])
gold_indices = eval_item.get("gold", []) # Note: gold is a list
subset = eval_item.get("_subset", "unknown")
num_choices = len(choices)
# Handle gold index (can be a list)
if isinstance(gold_indices, list) and len(gold_indices) > 0:
gold_index = gold_indices[0]
else:
gold_index = gold_indices
gold_letter = ascii_uppercase[gold_index] if isinstance(gold_index, int) else None
gold_letter = (
ascii_uppercase[gold_index] if isinstance(gold_index, int) else None
)
if not query or num_choices < 2 or gold_letter is None:
return {"is_correct": None, "sample": None}
# Format the prompt (generative style like GPQA)
formatted_prompt = self._format_agieval_prompt(
query=query,
choices=choices,
)
# Build messages
messages = []
system_content = self._create_system_content()
if system_content:
messages.append({"role": "system", "content": system_content})
messages.append({"role": "user", "content": formatted_prompt})
# Get model completion with retry logic
model_response = None
finish_reason = None
@ -564,24 +570,33 @@ class AGIEvalEnv(BaseEnv):
max_tokens=self.config.eval_max_tokens,
split="eval",
)
if completion.choices and completion.choices[0].message.content:
model_response = completion.choices[0].message.content
finish_reason = getattr(completion.choices[0], 'finish_reason', None)
if len(model_response.strip()) >= self.config.min_response_length:
finish_reason = getattr(
completion.choices[0], "finish_reason", None
)
if (
len(model_response.strip())
>= self.config.min_response_length
):
break
elif attempt < self.config.max_retries - 1:
if self.config.full_debug:
print(f" Response too short, retrying...")
await asyncio.sleep(self.config.retry_delay)
except Exception as e:
# Always log API errors to help diagnose issues
print(f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}")
if hasattr(e, 'response'):
print(
f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}"
)
if hasattr(e, "response"):
try:
print(f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}")
print(
f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}"
)
except:
pass
if attempt < self.config.max_retries - 1:
@ -589,26 +604,28 @@ class AGIEvalEnv(BaseEnv):
else:
print(f" Failed after {self.config.max_retries} attempts")
return {"is_correct": None, "sample": None}
if not model_response:
return {"is_correct": None, "sample": None}
# Validate thinking format if enabled
format_valid, content_for_extraction = self._validate_thinking_format(model_response)
format_valid, content_for_extraction = self._validate_thinking_format(
model_response
)
# Extract thinking content for logging
thinking_content = None
if self.config.thinking_mode:
thinking_content = self._extract_thinking_content(model_response)
# Extract the answer (pass choices for exact text matching)
extracted_answer, extraction_method = self._extract_answer(
content_for_extraction, num_choices, choices=choices
)
# Check if correct
is_correct = extracted_answer == gold_letter if extracted_answer else False
# Build sample record
sample = {
"query": query,
@ -625,28 +642,33 @@ class AGIEvalEnv(BaseEnv):
"thinking_mode": self.config.thinking_mode,
"format_valid": format_valid,
}
if self.config.thinking_mode:
sample["thinking_content"] = thinking_content
sample["response_after_think"] = content_for_extraction if format_valid else None
sample["response_after_think"] = (
content_for_extraction if format_valid else None
)
if self.config.full_debug:
status = "" if is_correct else ""
print(f" [{status}] {subset}: gold={gold_letter}, extracted={extracted_answer}")
print(
f" [{status}] {subset}: gold={gold_letter}, extracted={extracted_answer}"
)
return {"is_correct": is_correct, "sample": sample}
except Exception as e:
if self.config.full_debug:
print(f"Error in rollout_and_score_eval: {e}")
import traceback
traceback.print_exc()
return {"is_correct": None, "sample": None}
async def evaluate(self, *args, **kwargs) -> None:
"""Run AGIEval evaluation."""
start_time = time.time()
print(f"\n{'='*60}")
print(f"Starting AGIEval Evaluation (Generative/Reasoning Mode)")
print(f"{'='*60}")
@ -654,38 +676,40 @@ class AGIEvalEnv(BaseEnv):
print(f" Max tokens (for reasoning): {self.config.eval_max_tokens}")
print(f" Thinking mode: {self.config.thinking_mode}")
print(f"{'='*60}\n")
try:
eval_tasks = [
self.rollout_and_score_eval(item) for item in self.all_eval_items
]
results = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating AGIEval")
valid_results = [
r for r in results
r
for r in results
if r and r.get("sample") is not None and r.get("is_correct") is not None
]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return
except Exception as e:
print(f"Error during evaluation: {e}")
import traceback
traceback.print_exc()
return
end_time = time.time()
# Compute metrics
samples = [r["sample"] for r in valid_results]
# Overall accuracy
total_correct = sum(1 for r in valid_results if r["is_correct"])
total_count = len(valid_results)
overall_accuracy = total_correct / total_count if total_count > 0 else 0.0
# Per-subset accuracy
subset_results = {}
for sample in samples:
@ -695,7 +719,7 @@ class AGIEvalEnv(BaseEnv):
subset_results[subset]["total"] += 1
if sample["is_correct"]:
subset_results[subset]["correct"] += 1
# Extraction method statistics
extraction_methods = {}
for sample in samples:
@ -705,20 +729,22 @@ class AGIEvalEnv(BaseEnv):
extraction_methods[method]["count"] += 1
if sample["is_correct"]:
extraction_methods[method]["correct"] += 1
# Average response length
response_lengths = [s.get("response_length", 0) for s in samples]
avg_response_length = sum(response_lengths) / len(response_lengths) if response_lengths else 0
avg_response_length = (
sum(response_lengths) / len(response_lengths) if response_lengths else 0
)
# Format compliance
format_compliant = sum(1 for s in samples if s.get("format_valid", True))
format_compliance_rate = format_compliant / len(samples) if samples else 0.0
# Thinking utilization
thinking_utilization = 0
if self.config.thinking_mode:
thinking_utilization = sum(1 for s in samples if s.get("thinking_content"))
# Build metrics dictionary
eval_metrics = {
"eval/overall_accuracy": overall_accuracy,
@ -730,11 +756,13 @@ class AGIEvalEnv(BaseEnv):
"eval/format_compliance_rate": format_compliance_rate,
"eval/thinking_mode_enabled": 1.0 if self.config.thinking_mode else 0.0,
}
if self.config.thinking_mode:
thinking_utilization_rate = thinking_utilization / len(samples) if samples else 0.0
thinking_utilization_rate = (
thinking_utilization / len(samples) if samples else 0.0
)
eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate
# Add subset metrics
for subset, stats in subset_results.items():
if stats["total"] > 0:
@ -742,42 +770,48 @@ class AGIEvalEnv(BaseEnv):
subset_key = subset.replace("-", "_").replace(" ", "_").lower()
eval_metrics[f"eval/subset_{subset_key}_accuracy"] = subset_accuracy
eval_metrics[f"eval/subset_{subset_key}_total"] = stats["total"]
# Add extraction method metrics
for method, stats in extraction_methods.items():
if stats["count"] > 0:
method_accuracy = stats["correct"] / stats["count"]
eval_metrics[f"eval/extraction_{method}_count"] = stats["count"]
eval_metrics[f"eval/extraction_{method}_accuracy"] = method_accuracy
# Store metrics for wandb logging
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
# Print summary
print(f"\n{'='*60}")
print(f"AGIEval Evaluation Results")
print(f"{'='*60}")
print(f"Overall Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_count})")
print(
f"Overall Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_count})"
)
print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
print(f"Avg Response Length: {avg_response_length:.0f} chars")
if self.config.thinking_mode:
print(f"Format Compliance: {format_compliance_rate:.4f}")
print(f"Thinking Utilization: {thinking_utilization}/{total_count}")
print(f"\nSubset Breakdown:")
for subset, stats in sorted(subset_results.items()):
if stats["total"] > 0:
subset_acc = stats["correct"] / stats["total"]
print(f" {subset}: {subset_acc:.4f} ({stats['correct']}/{stats['total']})")
print(
f" {subset}: {subset_acc:.4f} ({stats['correct']}/{stats['total']})"
)
print(f"\nExtraction Method Statistics:")
for method, stats in sorted(extraction_methods.items(), key=lambda x: -x[1]["count"]):
for method, stats in sorted(
extraction_methods.items(), key=lambda x: -x[1]["count"]
):
if stats["count"] > 0:
method_acc = stats["correct"] / stats["count"]
print(f" {method}: {stats['count']} uses, {method_acc:.4f} accuracy")
print(f"{'='*60}\n")
# Log evaluation results
try:
await self.evaluate_log(
@ -799,17 +833,18 @@ class AGIEvalEnv(BaseEnv):
"""Log metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
for metric_name, metric_value in self.eval_metrics:
wandb_metrics[metric_name] = metric_value
self.eval_metrics = []
wandb_metrics["config/thinking_mode"] = 1.0 if self.config.thinking_mode else 0.0
wandb_metrics["config/thinking_mode"] = (
1.0 if self.config.thinking_mode else 0.0
)
wandb_metrics["config/eval_max_tokens"] = self.config.eval_max_tokens
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
AGIEvalEnv.cli()