more linter nonsense

This commit is contained in:
teknium 2025-12-24 11:04:33 +00:00
parent f18d46549d
commit abdda3978a
29 changed files with 113 additions and 151 deletions

View file

@ -36,16 +36,11 @@ import time
from string import ascii_uppercase
from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from eval_helpers import (
build_mcqa_fallback_patterns,
create_system_content,
extract_letter_from_answer_tag,
extract_thinking_content,
get_default_thinking_prompt,
save_eval_results,
validate_thinking_format,
)
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
@ -333,12 +328,13 @@ class AGIEvalEnv(BaseEnv):
async def setup(self) -> None:
"""Load the AGIEval dataset and prepare for evaluation."""
print(f"\nAGIEval Evaluation Setup (Generative Mode):")
print("\nAGIEval Evaluation Setup (Generative Mode):")
print(f" Max tokens for reasoning: {self.config.eval_max_tokens}")
print(f" Evaluation split: {self.config.eval_split}")
print(f" Thinking mode: {self.config.thinking_mode}")
if self.config.thinking_mode:
print(f" Thinking prompt: {self._get_thinking_prompt()[:100]}...")
prompt_preview = self._get_thinking_prompt()[:100]
print(f" Thinking prompt: {prompt_preview}...")
# Determine which subsets to use
if self.config.subsets:
@ -379,7 +375,7 @@ class AGIEvalEnv(BaseEnv):
print(f"\n Total evaluation items: {len(self.eval_data)}")
# Print subset distribution
print(f"\n Subset distribution:")
print("\n Subset distribution:")
for subset, count in sorted(subset_counts.items()):
print(f" {subset}: {count} questions")
@ -584,7 +580,7 @@ class AGIEvalEnv(BaseEnv):
break
elif attempt < self.config.max_retries - 1:
if self.config.full_debug:
print(f" Response too short, retrying...")
print(" Response too short, retrying...")
await asyncio.sleep(self.config.retry_delay)
except Exception as e:
@ -594,15 +590,15 @@ class AGIEvalEnv(BaseEnv):
)
if hasattr(e, "response"):
try:
print(
f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}"
)
except:
resp_text = e.response.text[:500] if hasattr(e.response, "text") else str(e.response)
print(f" Response: {resp_text}")
except Exception:
pass
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay)
else:
print(f" Failed after {self.config.max_retries} attempts")
retries = self.config.max_retries
print(f" Failed after {retries} attempts")
return {"is_correct": None, "sample": None}
if not model_response:
@ -669,9 +665,9 @@ class AGIEvalEnv(BaseEnv):
"""Run AGIEval evaluation."""
start_time = time.time()
print(f"\n{'='*60}")
print(f"Starting AGIEval Evaluation (Generative/Reasoning Mode)")
print(f"{'='*60}")
print("\n" + "=" * 60)
print("Starting AGIEval Evaluation (Generative/Reasoning Mode)")
print("=" * 60)
print(f" Total questions: {len(self.all_eval_items)}")
print(f" Max tokens (for reasoning): {self.config.eval_max_tokens}")
print(f" Thinking mode: {self.config.thinking_mode}")
@ -782,9 +778,9 @@ class AGIEvalEnv(BaseEnv):
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
# Print summary
print(f"\n{'='*60}")
print(f"AGIEval Evaluation Results")
print(f"{'='*60}")
print("\n" + "=" * 60)
print("AGIEval Evaluation Results")
print("=" * 60)
print(
f"Overall Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_count})"
)
@ -794,7 +790,7 @@ class AGIEvalEnv(BaseEnv):
print(f"Format Compliance: {format_compliance_rate:.4f}")
print(f"Thinking Utilization: {thinking_utilization}/{total_count}")
print(f"\nSubset Breakdown:")
print("\nSubset Breakdown:")
for subset, stats in sorted(subset_results.items()):
if stats["total"] > 0:
subset_acc = stats["correct"] / stats["total"]
@ -802,7 +798,7 @@ class AGIEvalEnv(BaseEnv):
f" {subset}: {subset_acc:.4f} ({stats['correct']}/{stats['total']})"
)
print(f"\nExtraction Method Statistics:")
print("\nExtraction Method Statistics:")
for method, stats in sorted(
extraction_methods.items(), key=lambda x: -x[1]["count"]
):
@ -810,7 +806,7 @@ class AGIEvalEnv(BaseEnv):
method_acc = stats["correct"] / stats["count"]
print(f" {method}: {stats['count']} uses, {method_acc:.4f} accuracy")
print(f"{'='*60}\n")
print("=" * 60 + "\n")
# Log evaluation results
try: