mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ef9c0c3699
commit
afab28dfa9
37 changed files with 4868 additions and 4052 deletions
|
|
@ -27,6 +27,18 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from eval_helpers import (
|
||||
THINK_CONTENT_AFTER_PATTERN,
|
||||
create_system_content,
|
||||
extract_boxed_answers,
|
||||
extract_thinking_content,
|
||||
format_math_answer_instruction,
|
||||
get_default_thinking_prompt,
|
||||
get_math_executor,
|
||||
save_eval_results,
|
||||
score_math_answer_async,
|
||||
validate_thinking_format,
|
||||
)
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
|
|
@ -36,19 +48,6 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
)
|
||||
from eval_helpers import (
|
||||
validate_thinking_format,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
create_system_content,
|
||||
save_eval_results,
|
||||
score_math_answer_async,
|
||||
get_math_executor,
|
||||
extract_boxed_answers,
|
||||
format_math_answer_instruction,
|
||||
THINK_CONTENT_AFTER_PATTERN,
|
||||
)
|
||||
|
||||
|
||||
# Prompt template following lighteval's MATH-500 structure
|
||||
MATH500_PROMPT_TEMPLATE = """Solve the following problem. The final line of your response MUST be of the following format:
|
||||
|
|
@ -67,83 +66,63 @@ MATH500_BOXED_PROMPT_TEMPLATE = """Solve the following math problem. {answer_ins
|
|||
|
||||
class MATH500EvalConfig(BaseEnvConfig):
|
||||
"""Configuration for MATH-500 evaluation environment."""
|
||||
|
||||
|
||||
# Dataset configuration
|
||||
dataset_name: str = Field(
|
||||
default="HuggingFaceH4/MATH-500",
|
||||
description="HuggingFace dataset name"
|
||||
default="HuggingFaceH4/MATH-500", description="HuggingFace dataset name"
|
||||
)
|
||||
subset: str = Field(
|
||||
default="default",
|
||||
description="Dataset subset"
|
||||
)
|
||||
eval_split: str = Field(
|
||||
default="test",
|
||||
description="Split to evaluate on"
|
||||
)
|
||||
shuffle_seed: int = Field(
|
||||
default=42,
|
||||
description="Random seed for shuffling"
|
||||
)
|
||||
|
||||
subset: str = Field(default="default", description="Dataset subset")
|
||||
eval_split: str = Field(default="test", description="Split to evaluate on")
|
||||
shuffle_seed: int = Field(default=42, description="Random seed for shuffling")
|
||||
|
||||
# Generation parameters
|
||||
eval_temperature: float = Field(
|
||||
default=0.6,
|
||||
description="Temperature for evaluation generation"
|
||||
default=0.6, description="Temperature for evaluation generation"
|
||||
)
|
||||
eval_max_tokens: int = Field(
|
||||
default=0,
|
||||
description="Max tokens for evaluation (0 = use model default)"
|
||||
default=0, description="Max tokens for evaluation (0 = use model default)"
|
||||
)
|
||||
|
||||
|
||||
# System prompt configuration
|
||||
custom_system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional custom system prompt"
|
||||
default=None, description="Optional custom system prompt"
|
||||
)
|
||||
|
||||
|
||||
# Thinking mode configuration
|
||||
thinking_mode: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use thinking mode with <think></think> tags"
|
||||
description="Whether to use thinking mode with <think></think> tags",
|
||||
)
|
||||
custom_thinking_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional custom thinking prompt"
|
||||
default=None, description="Optional custom thinking prompt"
|
||||
)
|
||||
|
||||
|
||||
# Math verification configuration
|
||||
include_hope_suffix: bool = Field(
|
||||
default=True,
|
||||
description="Whether to include 'I hope it is correct' in answer instruction"
|
||||
description="Whether to include 'I hope it is correct' in answer instruction",
|
||||
)
|
||||
use_original_prompt: bool = Field(
|
||||
default=False,
|
||||
description="Use lighteval's original prompt format (ANSWER: format)"
|
||||
description="Use lighteval's original prompt format (ANSWER: format)",
|
||||
)
|
||||
max_math_workers: int = Field(
|
||||
default=64,
|
||||
description="Maximum workers for math verification ProcessPoolExecutor"
|
||||
description="Maximum workers for math verification ProcessPoolExecutor",
|
||||
)
|
||||
|
||||
|
||||
# Retry and debug configuration
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Maximum retries for failed API calls"
|
||||
default=3, description="Maximum retries for failed API calls"
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0,
|
||||
description="Delay between retries in seconds"
|
||||
default=1.0, description="Delay between retries in seconds"
|
||||
)
|
||||
min_response_length: int = Field(
|
||||
default=1,
|
||||
description="Minimum response length to consider valid"
|
||||
default=1, description="Minimum response length to consider valid"
|
||||
)
|
||||
full_debug: bool = Field(
|
||||
default=False,
|
||||
description="Enable full debug output"
|
||||
)
|
||||
|
||||
full_debug: bool = Field(default=False, description="Enable full debug output")
|
||||
|
||||
# Override defaults
|
||||
group_size: int = 1
|
||||
max_num_workers: int = 1024
|
||||
|
|
@ -159,7 +138,7 @@ class MATH500EvalConfig(BaseEnvConfig):
|
|||
class MATH500EvalEnv(BaseEnv):
|
||||
"""
|
||||
MATH-500 Evaluation Environment.
|
||||
|
||||
|
||||
Evaluates challenging math problem solving using the MATH-500 dataset.
|
||||
Uses math_verify for robust answer verification with string fallback.
|
||||
"""
|
||||
|
|
@ -186,51 +165,53 @@ class MATH500EvalEnv(BaseEnv):
|
|||
async def setup(self) -> None:
|
||||
"""Initialize the environment and load the dataset."""
|
||||
await super().setup()
|
||||
|
||||
|
||||
# Initialize math executor
|
||||
self._math_executor = get_math_executor(self.config.max_math_workers)
|
||||
|
||||
|
||||
if not self._dataset_loaded:
|
||||
await self._load_dataset()
|
||||
|
||||
|
||||
print(f"\nMATH-500 Evaluation Setup (Generative Mode):")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Evaluation split: {self.config.eval_split}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
if self.config.thinking_mode:
|
||||
thinking_prompt = get_default_thinking_prompt(self.config.custom_thinking_prompt)
|
||||
thinking_prompt = get_default_thinking_prompt(
|
||||
self.config.custom_thinking_prompt
|
||||
)
|
||||
print(f" Thinking prompt: {thinking_prompt[:80]}...")
|
||||
print(f" Loaded {len(self.eval_items)} evaluation items")
|
||||
|
||||
async def _load_dataset(self) -> None:
|
||||
"""Load and process the MATH-500 dataset."""
|
||||
print(f"Loading MATH-500 dataset: {self.config.dataset_name}...")
|
||||
|
||||
|
||||
try:
|
||||
dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
self.config.subset if self.config.subset != "default" else None,
|
||||
trust_remote_code=True
|
||||
trust_remote_code=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading dataset: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if self.config.eval_split not in dataset:
|
||||
available_splits = list(dataset.keys())
|
||||
raise ValueError(
|
||||
f"Split '{self.config.eval_split}' not found. Available: {available_splits}"
|
||||
)
|
||||
|
||||
|
||||
split_data = dataset[self.config.eval_split]
|
||||
|
||||
|
||||
# Process items
|
||||
self.eval_items = []
|
||||
for idx, item in enumerate(split_data):
|
||||
problem = item.get("problem", "")
|
||||
answer = item.get("answer", "") # MATH-500 has 'answer' field
|
||||
solution = item.get("solution", "") # May also have solution
|
||||
|
||||
|
||||
# Extract final answer
|
||||
if answer:
|
||||
final_answer = answer.strip()
|
||||
|
|
@ -240,24 +221,26 @@ class MATH500EvalEnv(BaseEnv):
|
|||
final_answer = boxed[-1] if boxed else solution.strip()
|
||||
else:
|
||||
final_answer = ""
|
||||
|
||||
|
||||
subject = item.get("subject", "unknown")
|
||||
level = item.get("level", "")
|
||||
unique_id = item.get("unique_id", str(idx))
|
||||
|
||||
self.eval_items.append({
|
||||
"id": unique_id,
|
||||
"problem": problem,
|
||||
"answer": final_answer,
|
||||
"solution": solution,
|
||||
"subject": subject,
|
||||
"level": level,
|
||||
})
|
||||
|
||||
|
||||
self.eval_items.append(
|
||||
{
|
||||
"id": unique_id,
|
||||
"problem": problem,
|
||||
"answer": final_answer,
|
||||
"solution": solution,
|
||||
"subject": subject,
|
||||
"level": level,
|
||||
}
|
||||
)
|
||||
|
||||
# Shuffle with seed
|
||||
random.seed(self.config.shuffle_seed)
|
||||
random.shuffle(self.eval_items)
|
||||
|
||||
|
||||
self._dataset_loaded = True
|
||||
print(f"Loaded {len(self.eval_items)} evaluation items")
|
||||
|
||||
|
|
@ -270,17 +253,19 @@ class MATH500EvalEnv(BaseEnv):
|
|||
include_hope=self.config.include_hope_suffix
|
||||
)
|
||||
return MATH500_BOXED_PROMPT_TEMPLATE.format(
|
||||
answer_instruction=answer_instruction,
|
||||
problem=item["problem"]
|
||||
answer_instruction=answer_instruction, problem=item["problem"]
|
||||
)
|
||||
|
||||
def _create_system_content(self) -> str:
|
||||
"""Create system message content based on thinking mode."""
|
||||
return create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt
|
||||
) or ""
|
||||
return (
|
||||
create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt,
|
||||
)
|
||||
or ""
|
||||
)
|
||||
|
||||
async def rollout_and_score_eval(
|
||||
self,
|
||||
|
|
@ -290,12 +275,12 @@ class MATH500EvalEnv(BaseEnv):
|
|||
"""Run evaluation on a single item and return the result."""
|
||||
prompt = self._format_prompt(item)
|
||||
system_content = self._create_system_content()
|
||||
|
||||
|
||||
messages = []
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
|
||||
# Build API call parameters
|
||||
kwargs = {
|
||||
"model": server.model_name,
|
||||
|
|
@ -304,35 +289,38 @@ class MATH500EvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
response_text = ""
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
response = await self.server.chat_completion(**kwargs)
|
||||
response_text = response.choices[0].message.content or ""
|
||||
|
||||
|
||||
if len(response_text) >= self.config.min_response_length:
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f" API error (attempt {attempt + 1}): {e}")
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
continue
|
||||
|
||||
|
||||
if not response_text:
|
||||
return None
|
||||
|
||||
|
||||
# Validate thinking format
|
||||
is_valid_format, content_for_extraction = validate_thinking_format(
|
||||
response_text,
|
||||
self.config.thinking_mode
|
||||
response_text, self.config.thinking_mode
|
||||
)
|
||||
|
||||
|
||||
# Extract thinking content if present
|
||||
thinking_content = extract_thinking_content(response_text) if self.config.thinking_mode else None
|
||||
|
||||
thinking_content = (
|
||||
extract_thinking_content(response_text)
|
||||
if self.config.thinking_mode
|
||||
else None
|
||||
)
|
||||
|
||||
# Score using math_verify with string fallback
|
||||
gold_answer = item["answer"]
|
||||
is_correct, method, has_multiple_boxed = await score_math_answer_async(
|
||||
|
|
@ -341,27 +329,29 @@ class MATH500EvalEnv(BaseEnv):
|
|||
after_think=self.config.thinking_mode,
|
||||
wrap_gold_boxed=True,
|
||||
executor=self._math_executor,
|
||||
debug=self.config.full_debug
|
||||
debug=self.config.full_debug,
|
||||
)
|
||||
|
||||
|
||||
# Extract the boxed answer for logging
|
||||
if self.config.thinking_mode:
|
||||
match = THINK_CONTENT_AFTER_PATTERN.search(response_text)
|
||||
score_content = match.group(1) if match else response_text
|
||||
else:
|
||||
score_content = response_text
|
||||
|
||||
|
||||
boxed_answers = extract_boxed_answers(score_content)
|
||||
extracted_answer = boxed_answers[0] if boxed_answers else None
|
||||
|
||||
|
||||
if self.config.full_debug:
|
||||
print(f"\n--- Item: {item['id']} ---")
|
||||
print(f"Subject: {item.get('subject', 'N/A')}, Level: {item.get('level', 'N/A')}")
|
||||
print(
|
||||
f"Subject: {item.get('subject', 'N/A')}, Level: {item.get('level', 'N/A')}"
|
||||
)
|
||||
print(f"Problem: {item['problem'][:100]}...")
|
||||
print(f"Gold answer: {gold_answer}")
|
||||
print(f"Extracted: {extracted_answer}")
|
||||
print(f"Correct: {is_correct} (method: {method})")
|
||||
|
||||
|
||||
return {
|
||||
"item_id": item["id"],
|
||||
"subject": item.get("subject", "unknown"),
|
||||
|
|
@ -386,31 +376,28 @@ class MATH500EvalEnv(BaseEnv):
|
|||
print(f" Total questions: {len(self.eval_items)}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
# Create evaluation tasks
|
||||
async def eval_task(item):
|
||||
return await self.rollout_and_score_eval(item, self.server_configs[0])
|
||||
|
||||
|
||||
tasks = [eval_task(item) for item in self.eval_items]
|
||||
|
||||
|
||||
# Run with progress bar
|
||||
results = await tqdm_asyncio.gather(
|
||||
*tasks,
|
||||
desc="Evaluating MATH-500"
|
||||
)
|
||||
|
||||
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating MATH-500")
|
||||
|
||||
# Filter out failed results
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return {"error": "No valid results", "accuracy": 0.0}
|
||||
|
||||
|
||||
# Calculate overall metrics
|
||||
total = len(valid_results)
|
||||
correct = sum(1 for r in valid_results if r["is_correct"])
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
|
||||
# Calculate per-subject metrics
|
||||
subject_metrics = {}
|
||||
for r in valid_results:
|
||||
|
|
@ -420,12 +407,14 @@ class MATH500EvalEnv(BaseEnv):
|
|||
subject_metrics[subject]["total"] += 1
|
||||
if r["is_correct"]:
|
||||
subject_metrics[subject]["correct"] += 1
|
||||
|
||||
|
||||
for subject in subject_metrics:
|
||||
s_total = subject_metrics[subject]["total"]
|
||||
s_correct = subject_metrics[subject]["correct"]
|
||||
subject_metrics[subject]["accuracy"] = s_correct / s_total if s_total > 0 else 0.0
|
||||
|
||||
subject_metrics[subject]["accuracy"] = (
|
||||
s_correct / s_total if s_total > 0 else 0.0
|
||||
)
|
||||
|
||||
# Calculate per-level metrics
|
||||
level_metrics = {}
|
||||
for r in valid_results:
|
||||
|
|
@ -435,23 +424,29 @@ class MATH500EvalEnv(BaseEnv):
|
|||
level_metrics[level]["total"] += 1
|
||||
if r["is_correct"]:
|
||||
level_metrics[level]["correct"] += 1
|
||||
|
||||
|
||||
for level in level_metrics:
|
||||
l_total = level_metrics[level]["total"]
|
||||
l_correct = level_metrics[level]["correct"]
|
||||
level_metrics[level]["accuracy"] = l_correct / l_total if l_total > 0 else 0.0
|
||||
|
||||
level_metrics[level]["accuracy"] = (
|
||||
l_correct / l_total if l_total > 0 else 0.0
|
||||
)
|
||||
|
||||
# Count verification methods and other stats
|
||||
method_counts = {}
|
||||
for r in valid_results:
|
||||
method = r.get("verification_method", "unknown")
|
||||
method_counts[method] = method_counts.get(method, 0) + 1
|
||||
|
||||
multiple_boxed = sum(1 for r in valid_results if r.get("has_multiple_boxed", False))
|
||||
|
||||
multiple_boxed = sum(
|
||||
1 for r in valid_results if r.get("has_multiple_boxed", False)
|
||||
)
|
||||
format_valid = sum(1 for r in valid_results if r.get("format_valid", True))
|
||||
has_thinking = sum(1 for r in valid_results if r.get("has_thinking", False))
|
||||
has_boxed = sum(1 for r in valid_results if r.get("extracted_answer") is not None)
|
||||
|
||||
has_boxed = sum(
|
||||
1 for r in valid_results if r.get("extracted_answer") is not None
|
||||
)
|
||||
|
||||
metrics = {
|
||||
"accuracy": accuracy,
|
||||
"total_evaluated": total,
|
||||
|
|
@ -465,7 +460,7 @@ class MATH500EvalEnv(BaseEnv):
|
|||
"level_metrics": level_metrics,
|
||||
"verification_methods": method_counts,
|
||||
}
|
||||
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("MATH-500 Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
|
|
@ -476,18 +471,24 @@ class MATH500EvalEnv(BaseEnv):
|
|||
print(f" Thinking Utilization: {has_thinking / total:.2%}")
|
||||
if subject_metrics and len(subject_metrics) > 1:
|
||||
print(f"\n Per-Subject Breakdown:")
|
||||
for subject, data in sorted(subject_metrics.items(), key=lambda x: -x[1]["accuracy"]):
|
||||
print(f" {subject}: {data['accuracy']:.2%} ({data['correct']}/{data['total']})")
|
||||
for subject, data in sorted(
|
||||
subject_metrics.items(), key=lambda x: -x[1]["accuracy"]
|
||||
):
|
||||
print(
|
||||
f" {subject}: {data['accuracy']:.2%} ({data['correct']}/{data['total']})"
|
||||
)
|
||||
if level_metrics and len(level_metrics) > 1:
|
||||
print(f"\n Per-Level Breakdown:")
|
||||
for level, data in sorted(level_metrics.items()):
|
||||
print(f" Level {level}: {data['accuracy']:.2%} ({data['correct']}/{data['total']})")
|
||||
print(
|
||||
f" Level {level}: {data['accuracy']:.2%} ({data['correct']}/{data['total']})"
|
||||
)
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
# Save results
|
||||
if self.config.data_dir_to_save_evals:
|
||||
self._save_results(metrics, valid_results)
|
||||
|
||||
|
||||
return metrics
|
||||
|
||||
def _save_results(self, metrics: Dict, results: List[Dict]) -> None:
|
||||
|
|
@ -498,20 +499,22 @@ class MATH500EvalEnv(BaseEnv):
|
|||
"""Log metrics to Weights & Biases."""
|
||||
if not self.config.use_wandb:
|
||||
return
|
||||
|
||||
|
||||
log_dict = {
|
||||
"math500/accuracy": metrics.get("accuracy", 0),
|
||||
"math500/total_evaluated": metrics.get("total_evaluated", 0),
|
||||
"math500/has_boxed_rate": metrics.get("has_boxed_rate", 0),
|
||||
"math500/format_compliance_rate": metrics.get("format_compliance_rate", 0),
|
||||
"math500/thinking_utilization_rate": metrics.get("thinking_utilization_rate", 0),
|
||||
"math500/thinking_utilization_rate": metrics.get(
|
||||
"thinking_utilization_rate", 0
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Log per-subject accuracies
|
||||
for subject, data in metrics.get("subject_metrics", {}).items():
|
||||
safe_name = subject.replace(" ", "_")[:30]
|
||||
log_dict[f"math500/accuracy_{safe_name}"] = data.get("accuracy", 0)
|
||||
|
||||
|
||||
wandb.log(log_dict, step=step)
|
||||
|
||||
# Required abstract method implementations
|
||||
|
|
@ -530,4 +533,3 @@ class MATH500EvalEnv(BaseEnv):
|
|||
|
||||
if __name__ == "__main__":
|
||||
MATH500EvalEnv.cli()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue