[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-12-24 10:48:20 +00:00
parent ef9c0c3699
commit afab28dfa9
37 changed files with 4868 additions and 4052 deletions

View file

@ -25,6 +25,14 @@ from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from eval_helpers import (
ANSWER_TAG_PATTERN,
create_system_content,
extract_thinking_content,
get_default_thinking_prompt,
save_eval_results,
validate_thinking_format,
)
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
@ -34,15 +42,6 @@ from atroposlib.envs.base import (
BaseEnvConfig,
EvalHandlingEnum,
)
from eval_helpers import (
ANSWER_TAG_PATTERN,
validate_thinking_format,
extract_thinking_content,
get_default_thinking_prompt,
create_system_content,
save_eval_results,
)
# Prompt template for HLE with answer tag instruction
HLE_PROMPT_TEMPLATE = """Answer the following challenging question. Think step by step and reason carefully before providing your answer.
@ -57,79 +56,57 @@ Question: {question}"""
class HLEEvalConfig(BaseEnvConfig):
"""Configuration for HLE evaluation environment."""
# Dataset configuration
dataset_name: str = Field(
default="cais/hle",
description="HuggingFace dataset name"
default="cais/hle", description="HuggingFace dataset name"
)
subset: str = Field(
default="default",
description="Dataset subset"
)
eval_split: str = Field(
default="test",
description="Split to evaluate on"
)
shuffle_seed: int = Field(
default=42,
description="Random seed for shuffling"
)
subset: str = Field(default="default", description="Dataset subset")
eval_split: str = Field(default="test", description="Split to evaluate on")
shuffle_seed: int = Field(default=42, description="Random seed for shuffling")
# Generation parameters
eval_temperature: float = Field(
default=0.6,
description="Temperature for evaluation generation"
default=0.6, description="Temperature for evaluation generation"
)
eval_max_tokens: int = Field(
default=0,
description="Max tokens for evaluation (0 = use model default)"
default=0, description="Max tokens for evaluation (0 = use model default)"
)
# System prompt configuration
custom_system_prompt: Optional[str] = Field(
default=None,
description="Optional custom system prompt"
default=None, description="Optional custom system prompt"
)
# Thinking mode configuration
thinking_mode: bool = Field(
default=True,
description="Whether to use thinking mode with <think></think> tags"
description="Whether to use thinking mode with <think></think> tags",
)
custom_thinking_prompt: Optional[str] = Field(
default=None,
description="Optional custom thinking prompt"
default=None, description="Optional custom thinking prompt"
)
# Matching configuration
fuzzy_match: bool = Field(
default=True,
description="Allow fuzzy matching (substring containment)"
default=True, description="Allow fuzzy matching (substring containment)"
)
case_sensitive: bool = Field(
default=False,
description="Whether matching should be case-sensitive"
default=False, description="Whether matching should be case-sensitive"
)
# Retry and debug configuration
max_retries: int = Field(
default=3,
description="Maximum retries for failed API calls"
default=3, description="Maximum retries for failed API calls"
)
retry_delay: float = Field(
default=1.0,
description="Delay between retries in seconds"
default=1.0, description="Delay between retries in seconds"
)
min_response_length: int = Field(
default=1,
description="Minimum response length to consider valid"
default=1, description="Minimum response length to consider valid"
)
full_debug: bool = Field(
default=False,
description="Enable full debug output"
)
full_debug: bool = Field(default=False, description="Enable full debug output")
# Override defaults
group_size: int = 1
max_num_workers: int = 1024
@ -145,7 +122,7 @@ class HLEEvalConfig(BaseEnvConfig):
class HLEEvalEnv(BaseEnv):
"""
HLE (Humanity's Last Exam) Evaluation Environment.
Evaluates models on extremely challenging questions from expert contributors.
Uses generative evaluation with flexible string matching.
"""
@ -171,73 +148,75 @@ class HLEEvalEnv(BaseEnv):
async def setup(self) -> None:
"""Initialize the environment and load the dataset."""
await super().setup()
if not self._dataset_loaded:
await self._load_dataset()
print(f"\nHLE Evaluation Setup (Generative Mode):")
print(f" Dataset: {self.config.dataset_name}")
print(f" Evaluation split: {self.config.eval_split}")
print(f" Thinking mode: {self.config.thinking_mode}")
print(f" Fuzzy matching: {self.config.fuzzy_match}")
if self.config.thinking_mode:
thinking_prompt = get_default_thinking_prompt(self.config.custom_thinking_prompt)
thinking_prompt = get_default_thinking_prompt(
self.config.custom_thinking_prompt
)
print(f" Thinking prompt: {thinking_prompt[:80]}...")
print(f" Loaded {len(self.eval_items)} text-only evaluation items")
async def _load_dataset(self) -> None:
"""Load and process the HLE dataset."""
print(f"Loading HLE dataset: {self.config.dataset_name}...")
try:
dataset = load_dataset(
self.config.dataset_name,
self.config.subset,
trust_remote_code=True
self.config.dataset_name, self.config.subset, trust_remote_code=True
)
except Exception as e:
print(f"Error loading dataset: {e}")
raise
if self.config.eval_split not in dataset:
available_splits = list(dataset.keys())
raise ValueError(
f"Split '{self.config.eval_split}' not found. Available: {available_splits}"
)
split_data = dataset[self.config.eval_split]
# Process items - filter to text-only questions
self.eval_items = []
skipped_image = 0
for idx, item in enumerate(split_data):
# Filter out image questions
image = item.get("image")
if image is not None and image != "":
skipped_image += 1
continue
question = item.get("question", "")
answer = item.get("answer", "")
if not question or not answer:
continue
self.eval_items.append({
"id": str(idx),
"question": question,
"answer": answer,
"category": item.get("category", ""),
"source": item.get("source", ""),
})
self.eval_items.append(
{
"id": str(idx),
"question": question,
"answer": answer,
"category": item.get("category", ""),
"source": item.get("source", ""),
}
)
print(f"Filtered out {skipped_image} image questions")
# Shuffle with seed
random.seed(self.config.shuffle_seed)
random.shuffle(self.eval_items)
self._dataset_loaded = True
print(f"Loaded {len(self.eval_items)} text-only evaluation items")
@ -247,57 +226,60 @@ class HLEEvalEnv(BaseEnv):
def _create_system_content(self) -> str:
"""Create system message content based on thinking mode."""
return create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt
) or ""
return (
create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt,
)
or ""
)
def _normalize_answer(self, answer: str) -> str:
"""Normalize an answer for comparison."""
if not answer:
return ""
normalized = answer.strip()
if not self.config.case_sensitive:
normalized = normalized.lower()
# Remove common punctuation at ends
normalized = normalized.strip(".,;:!?\"'")
# Normalize whitespace
normalized = " ".join(normalized.split())
return normalized
def _check_match(self, predicted: str, gold: str) -> Tuple[bool, str]:
"""
Check if the predicted answer matches the gold answer.
Returns:
Tuple of (is_match, match_type)
"""
pred_norm = self._normalize_answer(predicted)
gold_norm = self._normalize_answer(gold)
if not pred_norm:
return False, "empty_prediction"
# Exact match
if pred_norm == gold_norm:
return True, "exact"
# Fuzzy matching if enabled
if self.config.fuzzy_match:
# Check if gold is contained in prediction
if gold_norm in pred_norm:
return True, "gold_in_pred"
# Check if prediction is contained in gold
if pred_norm in gold_norm:
return True, "pred_in_gold"
# Check for numeric equivalence (e.g., "42" vs "42.0")
try:
pred_num = float(pred_norm.replace(",", ""))
@ -306,40 +288,46 @@ class HLEEvalEnv(BaseEnv):
return True, "numeric_equiv"
except (ValueError, TypeError):
pass
return False, "no_match"
def _extract_answer(self, response: str, debug: bool = False) -> Tuple[Optional[str], str]:
def _extract_answer(
self, response: str, debug: bool = False
) -> Tuple[Optional[str], str]:
"""
Extract the answer from the response.
Args:
response: The model's response (content after </think> in thinking mode)
debug: Whether to print debug information
Returns:
Tuple of (extracted_answer or None, extraction_method)
"""
if not response:
return None, "empty_response"
# Try <answer></answer> tags first
answer_tag_match = ANSWER_TAG_PATTERN.search(response)
if answer_tag_match:
answer_content = answer_tag_match.group(1).strip()
if answer_content:
if debug:
preview = answer_content[:50] + "..." if len(answer_content) > 50 else answer_content
preview = (
answer_content[:50] + "..."
if len(answer_content) > 50
else answer_content
)
print(f" Extracted '{preview}' using method 'answer_tag'")
return answer_content, "answer_tag"
# Fallback: Look for "Answer: X" pattern
patterns = [
r"(?:the\s+)?(?:final\s+)?answer\s+is\s*:?\s*(.+?)(?:\n|$)",
r"(?:my\s+)?answer\s*:\s*(.+?)(?:\n|$)",
r"(?:so\s+)?the\s+answer\s+is\s*:?\s*(.+?)(?:\n|$)",
]
for pattern in patterns:
match = re.search(pattern, response, re.IGNORECASE)
if match:
@ -349,7 +337,7 @@ class HLEEvalEnv(BaseEnv):
preview = answer[:50] + "..." if len(answer) > 50 else answer
print(f" Extracted '{preview}' using fallback pattern")
return answer, "fallback_pattern"
# Last resort: take the last line/sentence
lines = [l.strip() for l in response.strip().split("\n") if l.strip()]
if lines:
@ -357,13 +345,13 @@ class HLEEvalEnv(BaseEnv):
# Clean up common prefixes
for prefix in ["Therefore,", "Thus,", "So,", "Hence,"]:
if last_line.startswith(prefix):
last_line = last_line[len(prefix):].strip()
last_line = last_line[len(prefix) :].strip()
if debug:
preview = last_line[:50] + "..." if len(last_line) > 50 else last_line
print(f" Extracted '{preview}' using fallback last_line")
return last_line, "fallback_last_line"
return None, "no_match"
async def rollout_and_score_eval(
@ -376,12 +364,12 @@ class HLEEvalEnv(BaseEnv):
"""
prompt = self._format_prompt(item)
system_content = self._create_system_content()
messages = []
if system_content:
messages.append({"role": "system", "content": system_content})
messages.append({"role": "user", "content": prompt})
# Build API call parameters
kwargs = {
"model": server.model_name,
@ -390,52 +378,60 @@ class HLEEvalEnv(BaseEnv):
}
if self.config.eval_max_tokens > 0:
kwargs["max_tokens"] = self.config.eval_max_tokens
response_text = ""
for attempt in range(self.config.max_retries):
try:
response = await self.server.chat_completion(**kwargs)
response_text = response.choices[0].message.content or ""
if len(response_text) >= self.config.min_response_length:
break
except Exception as e:
if self.config.full_debug:
print(f" API error (attempt {attempt + 1}): {e}")
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay)
continue
if not response_text:
return None
# Validate thinking format and extract content after </think>
is_valid_format, content_for_extraction = validate_thinking_format(
response_text,
self.config.thinking_mode
response_text, self.config.thinking_mode
)
# Extract thinking content if present
thinking_content = extract_thinking_content(response_text) if self.config.thinking_mode else None
thinking_content = (
extract_thinking_content(response_text)
if self.config.thinking_mode
else None
)
# Extract answer from appropriate content
extracted_answer, extraction_method = self._extract_answer(
content_for_extraction,
debug=self.config.full_debug
content_for_extraction, debug=self.config.full_debug
)
# Check match
gold_answer = item["answer"]
is_correct, match_type = self._check_match(extracted_answer, gold_answer) if extracted_answer else (False, "no_extraction")
is_correct, match_type = (
self._check_match(extracted_answer, gold_answer)
if extracted_answer
else (False, "no_extraction")
)
if self.config.full_debug:
print(f"\n--- Item: {item['id']} ---")
print(f"Question: {item['question'][:100]}...")
print(f"Gold answer: {gold_answer[:100]}...")
print(f"Extracted: {extracted_answer[:100] if extracted_answer else 'None'}...")
print(
f"Extracted: {extracted_answer[:100] if extracted_answer else 'None'}..."
)
print(f"Match: {is_correct} ({match_type})")
return {
"item_id": item["id"],
"question": item["question"],
@ -462,31 +458,28 @@ class HLEEvalEnv(BaseEnv):
print(f" Thinking mode: {self.config.thinking_mode}")
print(f" Fuzzy matching: {self.config.fuzzy_match}")
print(f"{'='*60}\n")
# Create evaluation tasks
async def eval_task(item):
return await self.rollout_and_score_eval(item, self.server_configs[0])
tasks = [eval_task(item) for item in self.eval_items]
# Run with progress bar
results = await tqdm_asyncio.gather(
*tasks,
desc="Evaluating HLE"
)
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating HLE")
# Filter out failed results
valid_results = [r for r in results if r is not None]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return {"error": "No valid results", "accuracy": 0.0}
# Calculate metrics
total = len(valid_results)
correct = sum(1 for r in valid_results if r["is_correct"])
accuracy = correct / total if total > 0 else 0.0
# Calculate per-category metrics
category_metrics = {}
for r in valid_results:
@ -496,22 +489,24 @@ class HLEEvalEnv(BaseEnv):
category_metrics[cat]["total"] += 1
if r["is_correct"]:
category_metrics[cat]["correct"] += 1
for cat in category_metrics:
cat_total = category_metrics[cat]["total"]
cat_correct = category_metrics[cat]["correct"]
category_metrics[cat]["accuracy"] = cat_correct / cat_total if cat_total > 0 else 0.0
category_metrics[cat]["accuracy"] = (
cat_correct / cat_total if cat_total > 0 else 0.0
)
# Format compliance and thinking utilization
format_valid = sum(1 for r in valid_results if r.get("format_valid", True))
has_thinking = sum(1 for r in valid_results if r.get("has_thinking", False))
# Match type breakdown
match_counts = {}
for r in valid_results:
match_type = r.get("match_type", "unknown")
match_counts[match_type] = match_counts.get(match_type, 0) + 1
metrics = {
"accuracy": accuracy,
"total_evaluated": total,
@ -521,7 +516,7 @@ class HLEEvalEnv(BaseEnv):
"category_metrics": category_metrics,
"match_types": match_counts,
}
print(f"\n{'='*60}")
print("HLE Evaluation Results")
print(f"{'='*60}")
@ -531,14 +526,18 @@ class HLEEvalEnv(BaseEnv):
print(f" Thinking Utilization: {has_thinking / total:.2%}")
if category_metrics:
print(f"\n Per-Category Breakdown:")
for cat, data in sorted(category_metrics.items(), key=lambda x: -x[1]["accuracy"]):
print(f" {cat}: {data['accuracy']:.2%} ({data['correct']}/{data['total']})")
for cat, data in sorted(
category_metrics.items(), key=lambda x: -x[1]["accuracy"]
):
print(
f" {cat}: {data['accuracy']:.2%} ({data['correct']}/{data['total']})"
)
print(f"{'='*60}\n")
# Save results
if self.config.data_dir_to_save_evals:
self._save_results(metrics, valid_results)
return metrics
def _save_results(self, metrics: Dict, results: List[Dict]) -> None:
@ -549,19 +548,21 @@ class HLEEvalEnv(BaseEnv):
"""Log metrics to Weights & Biases."""
if not self.config.use_wandb:
return
log_dict = {
"hle/accuracy": metrics.get("accuracy", 0),
"hle/total_evaluated": metrics.get("total_evaluated", 0),
"hle/format_compliance_rate": metrics.get("format_compliance_rate", 0),
"hle/thinking_utilization_rate": metrics.get("thinking_utilization_rate", 0),
"hle/thinking_utilization_rate": metrics.get(
"thinking_utilization_rate", 0
),
}
# Log per-category accuracies (top categories)
for cat, data in metrics.get("category_metrics", {}).items():
safe_cat = cat.replace("/", "_").replace(" ", "_")[:30]
log_dict[f"hle/accuracy_{safe_cat}"] = data.get("accuracy", 0)
wandb.log(log_dict, step=step)
# Required abstract method implementations
@ -580,4 +581,3 @@ class HLEEvalEnv(BaseEnv):
if __name__ == "__main__":
HLEEvalEnv.cli()