mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
563 lines
20 KiB
Python
563 lines
20 KiB
Python
"""
|
|
AIME Evaluation Environment for Atropos (Generative Mode)
|
|
|
|
This environment evaluates models on AIME (American Invitational Mathematics Examination) -
|
|
a prestigious, invite-only mathematics competition for high-school students who perform
|
|
in the top 5% of the AMC 12 mathematics exam.
|
|
|
|
Datasets:
|
|
- AIME 2024: HuggingFaceH4/aime_2024
|
|
- AIME 2025: yentinglin/aime_2025
|
|
|
|
AIME consists of 15 questions of increasing difficulty per year, with answers being
|
|
single integers from 0 to 999. The median score is historically 4-6 questions correct.
|
|
|
|
The evaluation follows a generative approach:
|
|
- Models receive competition math problems
|
|
- Expected to provide step-by-step reasoning
|
|
- Final answer in \\boxed{} format
|
|
- Uses math_verify for robust answer verification
|
|
- Falls back to string/integer normalization if symbolic comparison fails
|
|
|
|
Supports thinking mode with <think></think> tags for extended reasoning.
|
|
"""
|
|
|
|
import asyncio
|
|
import random
|
|
import re
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from typing import Dict, List, Optional
|
|
|
|
import wandb
|
|
from datasets import load_dataset
|
|
from eval_helpers import (
|
|
THINK_CONTENT_AFTER_PATTERN,
|
|
create_system_content,
|
|
extract_boxed_answers,
|
|
extract_thinking_content,
|
|
get_default_thinking_prompt,
|
|
get_math_executor,
|
|
save_eval_results,
|
|
score_math_answer_async,
|
|
validate_thinking_format,
|
|
)
|
|
from pydantic import Field
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
)
|
|
|
|
# Available AIME years
|
|
AIME_DATASETS = {
|
|
"2024": "HuggingFaceH4/aime_2024",
|
|
"2025": "yentinglin/aime_2025",
|
|
}
|
|
|
|
|
|
# Prompt template following lighteval's AIME structure
|
|
# Important: Uses the "I hope it is correct" format for math-verify
|
|
AIME_PROMPT_TEMPLATE = """Solve the following math problem efficiently and clearly.
|
|
|
|
The last line of your response should be of the following format:
|
|
'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes)
|
|
where ANSWER is just the final number or expression that solves the problem.
|
|
|
|
Think step by step before answering.
|
|
|
|
Note: AIME answers are always integers from 0 to 999.
|
|
|
|
{problem}"""
|
|
|
|
|
|
class AIMEEvalConfig(BaseEnvConfig):
|
|
"""Configuration for AIME evaluation environment."""
|
|
|
|
# Dataset configuration
|
|
years: List[str] = Field(
|
|
default=["2024", "2025"],
|
|
description="List of AIME years to evaluate ('2024', '2025', or both)",
|
|
)
|
|
eval_split: str = Field(
|
|
default="train", description="Split to evaluate on (AIME uses train split)"
|
|
)
|
|
shuffle_seed: int = Field(default=42, description="Random seed for shuffling")
|
|
|
|
# Generation parameters
|
|
eval_temperature: float = Field(
|
|
default=0.6, description="Temperature for evaluation generation"
|
|
)
|
|
eval_max_tokens: int = Field(
|
|
default=0, description="Max tokens for evaluation (0 = use model default)"
|
|
)
|
|
|
|
# System prompt configuration
|
|
custom_system_prompt: Optional[str] = Field(
|
|
default=None, description="Optional custom system prompt"
|
|
)
|
|
|
|
# Thinking mode configuration
|
|
thinking_mode: bool = Field(
|
|
default=True,
|
|
description="Whether to use thinking mode with <think></think> tags",
|
|
)
|
|
custom_thinking_prompt: Optional[str] = Field(
|
|
default=None, description="Optional custom thinking prompt"
|
|
)
|
|
|
|
# Math verification configuration
|
|
max_math_workers: int = Field(
|
|
default=64,
|
|
description="Maximum workers for math verification ProcessPoolExecutor",
|
|
)
|
|
|
|
# Retry and debug configuration
|
|
max_retries: int = Field(
|
|
default=3, description="Maximum retries for failed API calls"
|
|
)
|
|
retry_delay: float = Field(
|
|
default=1.0, description="Delay between retries in seconds"
|
|
)
|
|
min_response_length: int = Field(
|
|
default=1, description="Minimum response length to consider valid"
|
|
)
|
|
full_debug: bool = Field(default=False, description="Enable full debug output")
|
|
|
|
# Override defaults
|
|
group_size: int = 1
|
|
max_num_workers: int = 1024
|
|
max_eval_workers: int = 256
|
|
max_num_workers_per_node: int = 128
|
|
use_wandb: bool = True
|
|
rollout_server_url: str = "http://localhost:8000"
|
|
total_steps: int = 1
|
|
wandb_name: str = "aime_eval"
|
|
steps_per_eval: int = 1
|
|
|
|
|
|
class AIMEEvalEnv(BaseEnv):
|
|
"""
|
|
AIME Evaluation Environment.
|
|
|
|
Evaluates competition-level math problem solving using AIME problems.
|
|
AIME answers are always integers from 0 to 999.
|
|
Uses math_verify for answer verification with integer fallback.
|
|
"""
|
|
|
|
name = "aime_eval"
|
|
|
|
def __init__(
|
|
self,
|
|
config: AIMEEvalConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm_job_id: Optional[str] = None,
|
|
testing: bool = False,
|
|
):
|
|
super().__init__(config, server_configs, slurm_job_id, testing)
|
|
self.config: AIMEEvalConfig = config
|
|
self.eval_items: List[Dict] = []
|
|
self._dataset_loaded = False
|
|
self._math_executor: Optional[ProcessPoolExecutor] = None
|
|
|
|
@classmethod
|
|
def config_cls(cls) -> type:
|
|
return AIMEEvalConfig
|
|
|
|
async def setup(self) -> None:
|
|
"""Initialize the environment and load the dataset."""
|
|
await super().setup()
|
|
|
|
# Initialize math executor
|
|
self._math_executor = get_math_executor(self.config.max_math_workers)
|
|
|
|
if not self._dataset_loaded:
|
|
await self._load_dataset()
|
|
|
|
print("\nAIME Evaluation Setup (Generative Mode):")
|
|
print(f" Years: {self.config.years}")
|
|
print(f" Evaluation split: {self.config.eval_split}")
|
|
print(f" Thinking mode: {self.config.thinking_mode}")
|
|
if self.config.thinking_mode:
|
|
thinking_prompt = get_default_thinking_prompt(
|
|
self.config.custom_thinking_prompt
|
|
)
|
|
print(f" Thinking prompt: {thinking_prompt[:80]}...")
|
|
print(f" Loaded {len(self.eval_items)} evaluation items")
|
|
|
|
async def _load_dataset(self) -> None:
|
|
"""Load and process AIME datasets."""
|
|
self.eval_items = []
|
|
|
|
for year in self.config.years:
|
|
if year not in AIME_DATASETS:
|
|
print(
|
|
f"Warning: Unknown AIME year '{year}'. Available: {list(AIME_DATASETS.keys())}"
|
|
)
|
|
continue
|
|
|
|
dataset_name = AIME_DATASETS[year]
|
|
print(f"Loading AIME {year}: {dataset_name}...")
|
|
|
|
try:
|
|
dataset = load_dataset(dataset_name, trust_remote_code=True)
|
|
except Exception as e:
|
|
print(f" Error loading AIME {year}: {e}")
|
|
continue
|
|
|
|
if self.config.eval_split not in dataset:
|
|
available_splits = list(dataset.keys())
|
|
print(
|
|
f" Split '{self.config.eval_split}' not found. Available: {available_splits}"
|
|
)
|
|
# AIME typically uses train split
|
|
if "train" in available_splits:
|
|
split_key = "train"
|
|
else:
|
|
split_key = available_splits[0]
|
|
print(f" Using '{split_key}' instead")
|
|
else:
|
|
split_key = self.config.eval_split
|
|
|
|
split_data = dataset[split_key]
|
|
|
|
# Process items
|
|
for idx, item in enumerate(split_data):
|
|
problem = item.get("problem", "")
|
|
answer = str(item.get("answer", "")).strip()
|
|
|
|
# AIME answers should be integers 0-999
|
|
try:
|
|
answer_int = int(answer)
|
|
if not (0 <= answer_int <= 999):
|
|
print(
|
|
f" Warning: Answer {answer_int} outside 0-999 range for item {idx}"
|
|
)
|
|
except ValueError:
|
|
print(f" Warning: Non-integer answer '{answer}' for item {idx}")
|
|
|
|
self.eval_items.append(
|
|
{
|
|
"id": f"aime{year}_{idx}",
|
|
"year": year,
|
|
"problem": problem,
|
|
"answer": answer,
|
|
"problem_idx": idx,
|
|
}
|
|
)
|
|
|
|
print(
|
|
f" Loaded {len([i for i in self.eval_items if i['year'] == year])} items from AIME {year}"
|
|
)
|
|
|
|
# Shuffle with seed (optional for AIME since it's ordered by difficulty)
|
|
random.seed(self.config.shuffle_seed)
|
|
random.shuffle(self.eval_items)
|
|
|
|
self._dataset_loaded = True
|
|
print(f"Total: Loaded {len(self.eval_items)} AIME problems")
|
|
|
|
def _format_prompt(self, item: Dict) -> str:
|
|
"""Format the problem into a prompt."""
|
|
return AIME_PROMPT_TEMPLATE.format(problem=item["problem"])
|
|
|
|
def _create_system_content(self) -> str:
|
|
"""Create system message content based on thinking mode."""
|
|
return (
|
|
create_system_content(
|
|
self.config.thinking_mode,
|
|
self.config.custom_thinking_prompt,
|
|
self.config.custom_system_prompt,
|
|
)
|
|
or ""
|
|
)
|
|
|
|
def _extract_integer_answer(self, text: str) -> Optional[int]:
|
|
"""
|
|
Extract integer answer from text.
|
|
|
|
AIME answers are always integers 0-999.
|
|
Tries multiple strategies to extract the integer.
|
|
"""
|
|
if not text:
|
|
return None
|
|
|
|
text = text.strip()
|
|
|
|
# Try direct integer parse
|
|
try:
|
|
val = int(float(text.replace(",", "")))
|
|
if 0 <= val <= 999:
|
|
return val
|
|
except ValueError:
|
|
pass
|
|
|
|
# Look for standalone integers in the text
|
|
integers_found = re.findall(r"\b(\d{1,3})\b", text)
|
|
if integers_found:
|
|
# Take the last one that's in valid range
|
|
for num_str in reversed(integers_found):
|
|
try:
|
|
val = int(num_str)
|
|
if 0 <= val <= 999:
|
|
return val
|
|
except ValueError:
|
|
pass
|
|
|
|
return None
|
|
|
|
async def rollout_and_score_eval(
|
|
self,
|
|
item: Dict,
|
|
server: APIServerConfig,
|
|
) -> Optional[Dict]:
|
|
"""Run evaluation on a single item and return the result."""
|
|
prompt = self._format_prompt(item)
|
|
system_content = self._create_system_content()
|
|
|
|
messages = []
|
|
if system_content:
|
|
messages.append({"role": "system", "content": system_content})
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
# Build API call parameters
|
|
kwargs = {
|
|
"model": server.model_name,
|
|
"messages": messages,
|
|
"temperature": self.config.eval_temperature,
|
|
}
|
|
if self.config.eval_max_tokens > 0:
|
|
kwargs["max_tokens"] = self.config.eval_max_tokens
|
|
|
|
response_text = ""
|
|
for attempt in range(self.config.max_retries):
|
|
try:
|
|
response = await self.server.chat_completion(**kwargs)
|
|
response_text = response.choices[0].message.content or ""
|
|
|
|
if len(response_text) >= self.config.min_response_length:
|
|
break
|
|
|
|
except Exception as e:
|
|
if self.config.full_debug:
|
|
print(f" API error (attempt {attempt + 1}): {e}")
|
|
if attempt < self.config.max_retries - 1:
|
|
await asyncio.sleep(self.config.retry_delay)
|
|
continue
|
|
|
|
if not response_text:
|
|
return None
|
|
|
|
# Validate thinking format
|
|
is_valid_format, content_for_extraction = validate_thinking_format(
|
|
response_text, self.config.thinking_mode
|
|
)
|
|
|
|
# Extract thinking content if present
|
|
thinking_content = (
|
|
extract_thinking_content(response_text)
|
|
if self.config.thinking_mode
|
|
else None
|
|
)
|
|
|
|
# Score using math_verify with string fallback
|
|
gold_answer = item["answer"]
|
|
is_correct, method, has_multiple_boxed = await score_math_answer_async(
|
|
gold=gold_answer,
|
|
response=response_text,
|
|
after_think=self.config.thinking_mode,
|
|
wrap_gold_boxed=True,
|
|
executor=self._math_executor,
|
|
debug=self.config.full_debug,
|
|
)
|
|
|
|
# Extract the boxed answer for logging
|
|
if self.config.thinking_mode:
|
|
match = THINK_CONTENT_AFTER_PATTERN.search(response_text)
|
|
score_content = match.group(1) if match else response_text
|
|
else:
|
|
score_content = response_text
|
|
|
|
boxed_answers = extract_boxed_answers(score_content)
|
|
extracted_answer = boxed_answers[0] if boxed_answers else None
|
|
|
|
# Try integer extraction if boxed extraction worked
|
|
extracted_int = None
|
|
if extracted_answer:
|
|
extracted_int = self._extract_integer_answer(extracted_answer)
|
|
|
|
# If math_verify failed but we have integer match, count as correct
|
|
if is_correct is None and extracted_int is not None:
|
|
try:
|
|
gold_int = int(gold_answer)
|
|
if extracted_int == gold_int:
|
|
is_correct = True
|
|
method = "integer_fallback"
|
|
except ValueError:
|
|
pass
|
|
|
|
if self.config.full_debug:
|
|
print(f"\n--- Item: {item['id']} ---")
|
|
print(f"Year: {item['year']}, Problem #{item.get('problem_idx', 'N/A')}")
|
|
print(f"Problem: {item['problem'][:100]}...")
|
|
print(f"Gold answer: {gold_answer}")
|
|
print(f"Extracted: {extracted_answer} -> {extracted_int}")
|
|
print(f"Correct: {is_correct} (method: {method})")
|
|
|
|
return {
|
|
"item_id": item["id"],
|
|
"year": item["year"],
|
|
"problem_idx": item.get("problem_idx", -1),
|
|
"problem": item["problem"][:200],
|
|
"gold_answer": gold_answer,
|
|
"extracted_answer": extracted_answer,
|
|
"extracted_int": extracted_int,
|
|
"verification_method": method,
|
|
"is_correct": is_correct if is_correct is not None else False,
|
|
"has_multiple_boxed": has_multiple_boxed,
|
|
"format_valid": is_valid_format,
|
|
"response": response_text,
|
|
"thinking_content": thinking_content,
|
|
"has_thinking": thinking_content is not None,
|
|
}
|
|
|
|
async def evaluate(self, *args, **kwargs) -> Dict:
|
|
"""Run the full AIME evaluation."""
|
|
print(f"\n{'='*60}")
|
|
print("Starting AIME Evaluation (Generative Mode)")
|
|
print(f"{'='*60}")
|
|
print(f" Years: {self.config.years}")
|
|
print(f" Total problems: {len(self.eval_items)}")
|
|
print(f" Thinking mode: {self.config.thinking_mode}")
|
|
print(f"{'='*60}\n")
|
|
|
|
# Create evaluation tasks
|
|
async def eval_task(item):
|
|
return await self.rollout_and_score_eval(item, self.server_configs[0])
|
|
|
|
tasks = [eval_task(item) for item in self.eval_items]
|
|
|
|
# Run with progress bar
|
|
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating AIME")
|
|
|
|
# Filter out failed results
|
|
valid_results = [r for r in results if r is not None]
|
|
|
|
if not valid_results:
|
|
print("Warning: No valid evaluation results obtained")
|
|
return {"error": "No valid results", "accuracy": 0.0}
|
|
|
|
# Calculate overall metrics
|
|
total = len(valid_results)
|
|
correct = sum(1 for r in valid_results if r["is_correct"])
|
|
accuracy = correct / total if total > 0 else 0.0
|
|
|
|
# Calculate per-year metrics
|
|
year_metrics = {}
|
|
for r in valid_results:
|
|
year = r.get("year", "unknown")
|
|
if year not in year_metrics:
|
|
year_metrics[year] = {"total": 0, "correct": 0}
|
|
year_metrics[year]["total"] += 1
|
|
if r["is_correct"]:
|
|
year_metrics[year]["correct"] += 1
|
|
|
|
for year in year_metrics:
|
|
y_total = year_metrics[year]["total"]
|
|
y_correct = year_metrics[year]["correct"]
|
|
year_metrics[year]["accuracy"] = y_correct / y_total if y_total > 0 else 0.0
|
|
|
|
# Count verification methods and other stats
|
|
method_counts = {}
|
|
for r in valid_results:
|
|
method = r.get("verification_method", "unknown")
|
|
method_counts[method] = method_counts.get(method, 0) + 1
|
|
|
|
multiple_boxed = sum(
|
|
1 for r in valid_results if r.get("has_multiple_boxed", False)
|
|
)
|
|
format_valid = sum(1 for r in valid_results if r.get("format_valid", True))
|
|
has_thinking = sum(1 for r in valid_results if r.get("has_thinking", False))
|
|
has_boxed = sum(
|
|
1 for r in valid_results if r.get("extracted_answer") is not None
|
|
)
|
|
|
|
metrics = {
|
|
"accuracy": accuracy,
|
|
"total_evaluated": total,
|
|
"total_correct": correct,
|
|
"num_years": len(year_metrics),
|
|
"has_boxed_rate": has_boxed / total if total > 0 else 0.0,
|
|
"multiple_boxed_rate": multiple_boxed / total if total > 0 else 0.0,
|
|
"format_compliance_rate": format_valid / total if total > 0 else 0.0,
|
|
"thinking_utilization_rate": has_thinking / total if total > 0 else 0.0,
|
|
"year_metrics": year_metrics,
|
|
"verification_methods": method_counts,
|
|
}
|
|
|
|
print(f"\n{'='*60}")
|
|
print("AIME Evaluation Results")
|
|
print(f"{'='*60}")
|
|
print(f" Overall Accuracy: {accuracy:.2%} ({correct}/{total})")
|
|
print(f" Has \\boxed{{}} Rate: {has_boxed / total:.2%}")
|
|
print(f" Format Compliance: {format_valid / total:.2%}")
|
|
if self.config.thinking_mode:
|
|
print(f" Thinking Utilization: {has_thinking / total:.2%}")
|
|
print("\n Per-Year Breakdown:")
|
|
for year, data in sorted(year_metrics.items()):
|
|
print(
|
|
f" AIME {year}: {data['accuracy']:.2%} ({data['correct']}/{data['total']})"
|
|
)
|
|
print("\n Verification Methods:")
|
|
for method, count in sorted(method_counts.items(), key=lambda x: -x[1]):
|
|
print(f" {method}: {count} ({count/total:.1%})")
|
|
print(f"{'='*60}\n")
|
|
|
|
# Save results
|
|
if self.config.data_dir_to_save_evals:
|
|
self._save_results(metrics, valid_results)
|
|
|
|
return metrics
|
|
|
|
def _save_results(self, metrics: Dict, results: List[Dict]) -> None:
|
|
"""Save evaluation results to disk."""
|
|
save_eval_results(self.config.data_dir_to_save_evals, metrics, results)
|
|
|
|
async def wandb_log(self, metrics: Dict, step: int = 0) -> None:
|
|
"""Log metrics to Weights & Biases."""
|
|
if not self.config.use_wandb:
|
|
return
|
|
|
|
log_dict = {
|
|
"aime/accuracy": metrics.get("accuracy", 0),
|
|
"aime/total_evaluated": metrics.get("total_evaluated", 0),
|
|
"aime/has_boxed_rate": metrics.get("has_boxed_rate", 0),
|
|
"aime/format_compliance_rate": metrics.get("format_compliance_rate", 0),
|
|
"aime/thinking_utilization_rate": metrics.get(
|
|
"thinking_utilization_rate", 0
|
|
),
|
|
}
|
|
|
|
# Log per-year accuracies
|
|
for year, data in metrics.get("year_metrics", {}).items():
|
|
log_dict[f"aime/accuracy_{year}"] = data.get("accuracy", 0)
|
|
|
|
wandb.log(log_dict, step=step)
|
|
|
|
# Required abstract method implementations
|
|
async def get_next_item(self) -> Optional[Dict]:
|
|
"""Not used in evaluation mode."""
|
|
return None
|
|
|
|
async def collect_trajectories(self, item) -> List:
|
|
"""Not used in evaluation mode."""
|
|
return []
|
|
|
|
async def score(self, rollout_group_data) -> Optional[List]:
|
|
"""Not used in evaluation mode."""
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
AIMEEvalEnv.cli()
|