mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
1314 lines
42 KiB
Python
1314 lines
42 KiB
Python
"""
|
|
Shared helper functions for evaluation environments.
|
|
|
|
This module contains common utilities used across multiple eval environments,
|
|
making it easier to maintain consistent behavior and update logic in one place.
|
|
|
|
Includes:
|
|
- MCQA answer extraction (letter-based)
|
|
- Numbered choice extraction
|
|
- Freeform answer extraction
|
|
- Thinking mode validation
|
|
- Math answer verification (using math_verify library)
|
|
- System prompt creation
|
|
- Results saving utilities
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from string import ascii_uppercase
|
|
from typing import Dict, List, Optional, Set, Tuple
|
|
|
|
# Try to import math_verify libraries (optional dependency for math evals)
|
|
try:
|
|
from latex2sympy2_extended import NormalizationConfig
|
|
from math_verify import LatexExtractionConfig, parse, verify
|
|
from math_verify.errors import TimeoutException
|
|
|
|
MATH_VERIFY_AVAILABLE = True
|
|
except ImportError:
|
|
MATH_VERIFY_AVAILABLE = False
|
|
NormalizationConfig = None
|
|
LatexExtractionConfig = None
|
|
parse = None
|
|
verify = None
|
|
TimeoutException = Exception
|
|
|
|
|
|
# Pre-compiled regex for answer tag extraction
|
|
ANSWER_TAG_PATTERN = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
|
|
|
|
# Pre-compiled regex for thinking mode
|
|
THINK_OPEN_PATTERN = re.compile(r"<think>", re.IGNORECASE)
|
|
THINK_CLOSE_PATTERN = re.compile(r"</think>", re.IGNORECASE)
|
|
THINK_CONTENT_AFTER_PATTERN = re.compile(r"</think>\s*(.*)", re.DOTALL | re.IGNORECASE)
|
|
THINK_CONTENT_INSIDE_PATTERN = re.compile(
|
|
r"<think>(.*?)</think>", re.DOTALL | re.IGNORECASE
|
|
)
|
|
|
|
|
|
# Common prefixes that models use before stating their answer
|
|
# These will be stripped to help isolate the actual answer
|
|
ANSWER_PREFIXES = [
|
|
# "Final Answer" variants
|
|
r"(?:the\s+)?final\s+answer\s+is\s*:?\s*",
|
|
r"(?:my\s+)?final\s+answer\s*:?\s*",
|
|
# "Answer" variants
|
|
r"(?:the\s+)?answer\s+is\s*:?\s*",
|
|
r"(?:the\s+)?correct\s+answer\s+is\s*:?\s*",
|
|
r"(?:my\s+)?answer\s*:?\s*",
|
|
r"answer\s*:\s*",
|
|
# "Choice/Option" variants
|
|
r"(?:the\s+)?(?:correct\s+)?(?:choice|option)\s+is\s*:?\s*",
|
|
r"(?:i\s+)?(?:choose|select|pick)\s*:?\s*",
|
|
# "It is/It's" variants
|
|
r"(?:it\s+is|it's)\s*:?\s*",
|
|
r"(?:that\s+would\s+be|that's)\s*:?\s*",
|
|
# "I think/believe" variants
|
|
r"(?:i\s+)?(?:think|believe|would\s+say)\s+(?:it\s+is|it's|the\s+answer\s+is)\s*:?\s*",
|
|
]
|
|
|
|
# Compile the prefix patterns
|
|
ANSWER_PREFIX_PATTERN = re.compile(
|
|
r"^(?:" + "|".join(ANSWER_PREFIXES) + r")", re.IGNORECASE
|
|
)
|
|
|
|
|
|
def extract_letter_from_answer_tag(
|
|
response: str,
|
|
valid_letters: Set[str],
|
|
debug: bool = False,
|
|
choices: Optional[List[str]] = None,
|
|
) -> Tuple[Optional[str], str]:
|
|
"""
|
|
Extract a single letter answer from <answer></answer> tags.
|
|
|
|
Uses multiple strategies:
|
|
1. Exact choice text matching (if choices provided)
|
|
2. Word boundary matching for single letter
|
|
|
|
Handles many common variations:
|
|
- <answer>A</answer> -> A (single letter)
|
|
- <answer>A.</answer> -> A (letter with punctuation)
|
|
- <answer>(A)</answer> -> A (letter in parentheses)
|
|
- <answer>Final Answer: B</answer> -> B (common prefix stripped)
|
|
- <answer>The answer is C</answer> -> C (common prefix stripped)
|
|
- <answer>Tom Holland</answer> -> A (if "Tom Holland" is choice A)
|
|
- <answer>A) Tom Holland</answer> -> A (choice text with letter prefix)
|
|
- <answer>A - Tom Holland</answer> -> A (choice text with letter prefix)
|
|
|
|
Rejects ambiguous cases like:
|
|
- <answer>A is better than B</answer> -> None (multiple valid letters)
|
|
- <answer>Between A and C</answer> -> None (multiple valid letters)
|
|
|
|
Args:
|
|
response: The model's response (content after </think> in thinking mode)
|
|
valid_letters: Set of valid answer letters (e.g., {'A', 'B', 'C', 'D'})
|
|
debug: Whether to print debug information
|
|
choices: Optional list of choice texts in order (A, B, C, D...)
|
|
|
|
Returns:
|
|
Tuple of (extracted_letter or None, extraction_method)
|
|
"""
|
|
# Find the first <answer></answer> tag
|
|
answer_tag_match = ANSWER_TAG_PATTERN.search(response)
|
|
if not answer_tag_match:
|
|
return None, "no_answer_tag"
|
|
|
|
answer_content = answer_tag_match.group(1).strip()
|
|
if not answer_content:
|
|
return None, "empty_answer_tag"
|
|
|
|
# Try stripping common prefixes to isolate the answer
|
|
content_to_check = answer_content
|
|
prefix_match = ANSWER_PREFIX_PATTERN.match(content_to_check)
|
|
if prefix_match:
|
|
content_to_check = content_to_check[prefix_match.end() :].strip() # noqa: E203
|
|
if debug:
|
|
print(f" Stripped prefix, remaining content: '{content_to_check}'")
|
|
|
|
# STRICT CHECK: If after stripping, we have just a letter (with optional punctuation)
|
|
# This is the ideal case: <answer>B</answer> or <answer>Final Answer: B</answer>
|
|
cleaned = content_to_check.strip(".,)(:;!? \t\n")
|
|
if cleaned.upper() in valid_letters:
|
|
if debug:
|
|
print(
|
|
f" Extracted '{cleaned.upper()}' using method 'answer_tag' (strict letter match)"
|
|
)
|
|
return cleaned.upper(), "answer_tag"
|
|
|
|
# CHOICE TEXT MATCHING: Check if answer matches a choice text
|
|
if choices:
|
|
letter_from_choice = _match_choice_text(
|
|
answer_content, choices, valid_letters, debug
|
|
)
|
|
if letter_from_choice:
|
|
return letter_from_choice, "answer_tag_choice_match"
|
|
|
|
# WORD BOUNDARY CHECK: Find all valid letters as standalone words
|
|
letters_pattern = "|".join(sorted(valid_letters))
|
|
word_bounded_pattern = re.compile(rf"\b({letters_pattern})\b", re.IGNORECASE)
|
|
|
|
# Find ALL valid letters in the (prefix-stripped) content
|
|
found_letters = word_bounded_pattern.findall(content_to_check.upper())
|
|
|
|
# Only accept if EXACTLY ONE valid letter is found
|
|
if len(found_letters) == 1:
|
|
letter = found_letters[0].upper()
|
|
if debug:
|
|
print(
|
|
f" Extracted '{letter}' using method 'answer_tag' (single word-bounded letter)"
|
|
)
|
|
return letter, "answer_tag"
|
|
elif len(found_letters) > 1:
|
|
# If multiple found after prefix strip, try the original content
|
|
found_in_original = word_bounded_pattern.findall(answer_content.upper())
|
|
if len(found_in_original) == 1:
|
|
letter = found_in_original[0].upper()
|
|
if debug:
|
|
print(
|
|
f" Extracted '{letter}' using method 'answer_tag' (from original content)"
|
|
)
|
|
return letter, "answer_tag"
|
|
if debug:
|
|
print(
|
|
f" Multiple letters found in answer tag: {found_letters} - rejecting"
|
|
)
|
|
return None, "answer_tag_ambiguous"
|
|
else:
|
|
# No letters found after prefix strip, try original content
|
|
found_in_original = word_bounded_pattern.findall(answer_content.upper())
|
|
if len(found_in_original) == 1:
|
|
letter = found_in_original[0].upper()
|
|
if debug:
|
|
print(
|
|
f" Extracted '{letter}' using method 'answer_tag' (from original content)"
|
|
)
|
|
return letter, "answer_tag"
|
|
if debug:
|
|
print(
|
|
f" No valid letters found in answer tag content: '{answer_content}'"
|
|
)
|
|
return None, "answer_tag_no_letter"
|
|
|
|
|
|
def _match_choice_text(
|
|
answer_content: str,
|
|
choices: List[str],
|
|
valid_letters: Set[str],
|
|
debug: bool = False,
|
|
) -> Optional[str]:
|
|
"""
|
|
Match answer content against choice texts.
|
|
|
|
Handles formats like:
|
|
- "Tom Holland" (exact choice text)
|
|
- "A) Tom Holland" (letter prefix + choice text)
|
|
- "A - Tom Holland" (letter prefix + choice text)
|
|
- "A. Tom Holland" (letter prefix + choice text)
|
|
|
|
Args:
|
|
answer_content: The content inside <answer> tags
|
|
choices: List of choice texts in order (index 0 = A, 1 = B, etc.)
|
|
valid_letters: Set of valid letters
|
|
debug: Whether to print debug info
|
|
|
|
Returns:
|
|
The letter if a match is found, None otherwise
|
|
"""
|
|
answer_lower = answer_content.lower().strip()
|
|
answer_normalized = re.sub(r"\s+", " ", answer_lower) # Normalize whitespace
|
|
|
|
for i, choice_text in enumerate(choices):
|
|
if i >= len(valid_letters):
|
|
break
|
|
letter = ascii_uppercase[i]
|
|
if letter not in valid_letters:
|
|
continue
|
|
|
|
choice_lower = choice_text.lower().strip()
|
|
choice_normalized = re.sub(r"\s+", " ", choice_lower)
|
|
|
|
# Check exact match with choice text
|
|
if answer_normalized == choice_normalized:
|
|
if debug:
|
|
print(
|
|
f" Extracted '{letter}' via exact choice text match: '{choice_text}'"
|
|
)
|
|
return letter
|
|
|
|
# Check if answer contains choice text (for longer answers)
|
|
if choice_normalized and choice_normalized in answer_normalized:
|
|
# Make sure it's a substantial match (not just a single word)
|
|
if len(choice_normalized) >= 3:
|
|
if debug:
|
|
print(
|
|
f" Extracted '{letter}' via choice text containment: '{choice_text}'"
|
|
)
|
|
return letter
|
|
|
|
# Check for "A) choice text", "A - choice text", "A. choice text" patterns
|
|
prefixed_patterns = [
|
|
rf"^{letter}\s*[\)\-\.:\]]\s*", # A), A-, A., A:, A]
|
|
rf"^\({letter}\)\s*", # (A)
|
|
]
|
|
for pattern in prefixed_patterns:
|
|
stripped = re.sub(pattern, "", answer_content, flags=re.IGNORECASE).strip()
|
|
stripped_normalized = re.sub(r"\s+", " ", stripped.lower())
|
|
if stripped_normalized == choice_normalized:
|
|
if debug:
|
|
print(
|
|
f" Extracted '{letter}' via prefixed choice text match: '{choice_text}'"
|
|
)
|
|
return letter
|
|
|
|
return None
|
|
|
|
|
|
def extract_number_from_answer_tag(
|
|
response: str, num_choices: int, debug: bool = False
|
|
) -> Tuple[Optional[int], str]:
|
|
"""
|
|
Extract a single number answer from <answer></answer> tags.
|
|
|
|
Uses word boundary matching to find valid numbers. Only returns a match
|
|
if EXACTLY ONE valid number (1 to num_choices) is found.
|
|
|
|
Handles variations like:
|
|
- <answer>2</answer> -> 2
|
|
- <answer>2.</answer> -> 2
|
|
- <answer>Choice 2</answer> -> 2
|
|
- <answer>The answer is 3</answer> -> 3
|
|
- <answer>Option 1</answer> -> 1
|
|
|
|
Args:
|
|
response: The model's response (content after </think> in thinking mode)
|
|
num_choices: Number of valid choices (e.g., 5 means valid range is 1-5)
|
|
debug: Whether to print debug information
|
|
|
|
Returns:
|
|
Tuple of (extracted_number or None, extraction_method)
|
|
"""
|
|
# Find the first <answer></answer> tag
|
|
answer_tag_match = ANSWER_TAG_PATTERN.search(response)
|
|
if not answer_tag_match:
|
|
return None, "no_answer_tag"
|
|
|
|
answer_content = answer_tag_match.group(1).strip()
|
|
if not answer_content:
|
|
return None, "empty_answer_tag"
|
|
|
|
# Try stripping common prefixes
|
|
content_to_check = answer_content
|
|
prefix_match = ANSWER_PREFIX_PATTERN.match(content_to_check)
|
|
if prefix_match:
|
|
content_to_check = content_to_check[prefix_match.end() :].strip() # noqa: E203
|
|
if debug:
|
|
print(f" Stripped prefix, remaining content: '{content_to_check}'")
|
|
|
|
# STRICT CHECK: If after stripping, we have just a number
|
|
cleaned = content_to_check.strip(".,)(:;!? \t\n")
|
|
try:
|
|
num = int(cleaned)
|
|
if 1 <= num <= num_choices:
|
|
if debug:
|
|
print(
|
|
f" Extracted '{num}' using method 'answer_tag' (strict match after prefix strip)"
|
|
)
|
|
return num, "answer_tag"
|
|
except ValueError:
|
|
pass
|
|
|
|
# WORD BOUNDARY CHECK: Find ALL word-bounded numbers
|
|
word_bounded_numbers = re.findall(r"\b(\d+)\b", content_to_check)
|
|
|
|
# Filter to valid range
|
|
valid_numbers = []
|
|
for num_str in word_bounded_numbers:
|
|
try:
|
|
num = int(num_str)
|
|
if 1 <= num <= num_choices:
|
|
valid_numbers.append(num)
|
|
except ValueError:
|
|
continue
|
|
|
|
# Only accept if EXACTLY ONE valid number is found
|
|
if len(valid_numbers) == 1:
|
|
number = valid_numbers[0]
|
|
if debug:
|
|
print(
|
|
f" Extracted '{number}' using method 'answer_tag' (single word-bounded number)"
|
|
)
|
|
return number, "answer_tag"
|
|
elif len(valid_numbers) > 1:
|
|
# Try original content
|
|
original_numbers = re.findall(r"\b(\d+)\b", answer_content)
|
|
valid_in_original = [
|
|
int(n) for n in original_numbers if 1 <= int(n) <= num_choices
|
|
]
|
|
if len(valid_in_original) == 1:
|
|
if debug:
|
|
print(
|
|
f" Extracted '{valid_in_original[0]}' using method 'answer_tag' (from original)"
|
|
)
|
|
return valid_in_original[0], "answer_tag"
|
|
if debug:
|
|
print(
|
|
f" Multiple valid numbers found in answer tag: {valid_numbers} - rejecting"
|
|
)
|
|
return None, "answer_tag_ambiguous"
|
|
else:
|
|
# Try original content
|
|
original_numbers = re.findall(r"\b(\d+)\b", answer_content)
|
|
valid_in_original = [
|
|
int(n) for n in original_numbers if 1 <= int(n) <= num_choices
|
|
]
|
|
if len(valid_in_original) == 1:
|
|
if debug:
|
|
print(
|
|
f" Extracted '{valid_in_original[0]}' using method 'answer_tag' (from original)"
|
|
)
|
|
return valid_in_original[0], "answer_tag"
|
|
if debug:
|
|
print(
|
|
f" No valid numbers found in answer tag content: '{answer_content}'"
|
|
)
|
|
return None, "answer_tag_no_number"
|
|
|
|
|
|
def extract_freeform_from_answer_tag(
|
|
response: str, debug: bool = False
|
|
) -> Tuple[Optional[str], str]:
|
|
"""
|
|
Extract freeform text answer from <answer></answer> tags.
|
|
|
|
Simply returns the stripped content inside the tags.
|
|
Used for open-ended questions like DROP and SimpleQA.
|
|
|
|
Args:
|
|
response: The model's response (content after </think> in thinking mode)
|
|
debug: Whether to print debug information
|
|
|
|
Returns:
|
|
Tuple of (extracted_text or None, extraction_method)
|
|
"""
|
|
# Find the first <answer></answer> tag
|
|
answer_tag_match = ANSWER_TAG_PATTERN.search(response)
|
|
if not answer_tag_match:
|
|
return None, "no_answer_tag"
|
|
|
|
answer_content = answer_tag_match.group(1).strip()
|
|
if not answer_content:
|
|
return None, "empty_answer_tag"
|
|
|
|
if debug:
|
|
preview = (
|
|
answer_content[:50] + "..." if len(answer_content) > 50 else answer_content
|
|
)
|
|
print(f" Extracted '{preview}' using method 'answer_tag'")
|
|
|
|
return answer_content, "answer_tag"
|
|
|
|
|
|
def validate_thinking_format(
|
|
response: str, thinking_mode: bool = True
|
|
) -> Tuple[bool, str]:
|
|
"""
|
|
Validate thinking format and extract content after </think> tags.
|
|
|
|
In thinking mode, we expect exactly one pair of <think></think> tags.
|
|
Returns the content after </think> for answer extraction.
|
|
|
|
Args:
|
|
response: The model's full response
|
|
thinking_mode: Whether thinking mode is enabled
|
|
|
|
Returns:
|
|
Tuple of (is_valid, content_for_extraction)
|
|
"""
|
|
if not thinking_mode:
|
|
return True, response
|
|
|
|
# Check for exactly one pair of think tags
|
|
think_open_count = len(THINK_OPEN_PATTERN.findall(response))
|
|
think_close_count = len(THINK_CLOSE_PATTERN.findall(response))
|
|
|
|
if think_open_count != 1 or think_close_count != 1:
|
|
return False, response
|
|
|
|
# Extract content after </think> tags for answer extraction
|
|
match = THINK_CONTENT_AFTER_PATTERN.search(response)
|
|
if match:
|
|
return True, match.group(1).strip()
|
|
else:
|
|
return False, response
|
|
|
|
|
|
def extract_thinking_content(response: str) -> Optional[str]:
|
|
"""
|
|
Extract the content inside <think></think> tags.
|
|
|
|
Args:
|
|
response: The model's full response
|
|
|
|
Returns:
|
|
Content inside think tags, or None if not found
|
|
"""
|
|
match = THINK_CONTENT_INSIDE_PATTERN.search(response)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return None
|
|
|
|
|
|
def get_default_thinking_prompt(custom_prompt: Optional[str] = None) -> str:
|
|
"""
|
|
Get the thinking system prompt.
|
|
|
|
Args:
|
|
custom_prompt: Optional custom thinking prompt to use instead of default
|
|
|
|
Returns:
|
|
The thinking prompt string
|
|
"""
|
|
if custom_prompt:
|
|
return custom_prompt
|
|
|
|
return (
|
|
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
|
|
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
|
|
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
|
|
"</think> tags, and then provide your solution or response to the problem."
|
|
)
|
|
|
|
|
|
# Fallback regex patterns for MCQA when answer tags don't work
|
|
def build_mcqa_fallback_patterns(num_choices: int = 4):
|
|
"""
|
|
Build fallback regex patterns for extracting MCQA answers.
|
|
|
|
These are used when <answer> tags are not present or ambiguous.
|
|
Patterns are ordered by priority (lower number = higher priority).
|
|
|
|
Args:
|
|
num_choices: Number of valid choices (determines valid letters)
|
|
|
|
Returns:
|
|
List of (priority, pattern, method_name) tuples
|
|
"""
|
|
letters = ascii_uppercase[:num_choices]
|
|
letter_pattern = rf"([{letters}]|\([{letters}]\))"
|
|
|
|
patterns = [
|
|
# Priority 0: "final answer is: X" with "I hope"
|
|
(
|
|
0,
|
|
re.compile(
|
|
rf"(?i:final\s+answer\s+is)\s*:?\s*{letter_pattern}\.?\s*I\s*hope",
|
|
re.IGNORECASE,
|
|
),
|
|
"final_answer_hope",
|
|
),
|
|
# Priority 50: "final answer ... is X"
|
|
(
|
|
50,
|
|
re.compile(
|
|
rf"(?i:final\s+answer).{{0,100}}?\s+is\s*:?\s*{letter_pattern}",
|
|
re.IGNORECASE | re.DOTALL,
|
|
),
|
|
"final_answer_is",
|
|
),
|
|
# Priority 75: "the answer is X"
|
|
(
|
|
75,
|
|
re.compile(
|
|
rf"(?i:the\s+answer\s+is)\s*:?\s*{letter_pattern}", re.IGNORECASE
|
|
),
|
|
"the_answer_is",
|
|
),
|
|
# Priority 100: "answer: X"
|
|
(
|
|
100,
|
|
re.compile(
|
|
rf"(?i:answer)\s*:\s*.{{0,50}}?{letter_pattern}",
|
|
re.IGNORECASE | re.DOTALL,
|
|
),
|
|
"answer_colon",
|
|
),
|
|
# Priority 150: "answer X"
|
|
(
|
|
150,
|
|
re.compile(rf"(?i:answer)\s+{letter_pattern}", re.IGNORECASE),
|
|
"answer_space",
|
|
),
|
|
# Priority 200: Response starts with letter
|
|
(
|
|
200,
|
|
re.compile(rf"^\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE),
|
|
"start",
|
|
),
|
|
# Priority 210: Letter at start of any line
|
|
(
|
|
210,
|
|
re.compile(rf"\n\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE),
|
|
"line_start",
|
|
),
|
|
# Priority 250: Standalone letter with word boundaries
|
|
(250, re.compile(rf"\b{letter_pattern}\b", re.IGNORECASE), "standalone"),
|
|
]
|
|
|
|
return patterns
|
|
|
|
|
|
def extract_mcqa_answer_with_fallback(
|
|
response: str,
|
|
num_choices: int = 4,
|
|
fallback_patterns: list = None,
|
|
debug: bool = False,
|
|
) -> Tuple[Optional[str], str]:
|
|
"""
|
|
Extract MCQA answer using answer tags first, then fallback patterns.
|
|
|
|
Args:
|
|
response: The model's response (content after </think> in thinking mode)
|
|
num_choices: Number of valid choices
|
|
fallback_patterns: Pre-built fallback patterns (optional, will be built if not provided)
|
|
debug: Whether to print debug information
|
|
|
|
Returns:
|
|
Tuple of (extracted_letter or None, extraction_method)
|
|
"""
|
|
if not response:
|
|
return None, "empty_response"
|
|
|
|
valid_letters = set(ascii_uppercase[:num_choices])
|
|
|
|
# PRIMARY: Try <answer></answer> tags first
|
|
letter, method = extract_letter_from_answer_tag(response, valid_letters, debug)
|
|
if letter:
|
|
return letter, method
|
|
|
|
# FALLBACK: Use regex patterns
|
|
if fallback_patterns is None:
|
|
fallback_patterns = build_mcqa_fallback_patterns(num_choices)
|
|
|
|
for priority, pattern, method_name in fallback_patterns:
|
|
matches = pattern.findall(response)
|
|
if matches:
|
|
# Get the last match for answer patterns (final answer is most reliable)
|
|
match = (
|
|
matches[-1]
|
|
if method_name
|
|
in ["final_answer_is", "the_answer_is", "answer_colon", "answer_space"]
|
|
else matches[0]
|
|
)
|
|
|
|
if isinstance(match, tuple):
|
|
match = match[0]
|
|
letter = match.strip("()").upper()
|
|
|
|
if letter in valid_letters:
|
|
if debug:
|
|
print(
|
|
f" Extracted '{letter}' using fallback method '{method_name}' (priority {priority})"
|
|
)
|
|
return letter, f"fallback_{method_name}"
|
|
|
|
# Last resort: find any valid letter (take the last one)
|
|
for letter in reversed(list(valid_letters)):
|
|
if letter in response.upper():
|
|
if debug:
|
|
print(f" Extracted '{letter}' using fallback 'last_valid_letter'")
|
|
return letter, "fallback_last_valid_letter"
|
|
|
|
return None, "no_match"
|
|
|
|
|
|
# =============================================================================
|
|
# MATH ANSWER VERIFICATION HELPERS
|
|
# =============================================================================
|
|
# These functions use the math_verify library for robust mathematical answer
|
|
# verification. They support \boxed{} extraction, symbolic comparison, and
|
|
# string normalization fallback.
|
|
|
|
# Regex for extracting \boxed{} content
|
|
BOXED_PATTERN = re.compile(r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}")
|
|
|
|
# Global ProcessPoolExecutor for math verification (avoids timeouts)
|
|
_math_executor: Optional[ProcessPoolExecutor] = None
|
|
|
|
|
|
def get_math_executor(max_workers: int = 64) -> ProcessPoolExecutor:
|
|
"""
|
|
Get or create the global ProcessPoolExecutor for math verification.
|
|
|
|
Using a process pool protects against hangs from sympy/latex parsing.
|
|
|
|
Args:
|
|
max_workers: Maximum number of worker processes
|
|
|
|
Returns:
|
|
ProcessPoolExecutor instance
|
|
"""
|
|
global _math_executor
|
|
if _math_executor is None:
|
|
_math_executor = ProcessPoolExecutor(max_workers=max_workers)
|
|
return _math_executor
|
|
|
|
|
|
def extract_boxed_answers(text: str) -> List[str]:
|
|
"""
|
|
Extract all \\boxed{} answers from text.
|
|
|
|
Args:
|
|
text: The text to search for boxed answers
|
|
|
|
Returns:
|
|
List of extracted boxed contents
|
|
"""
|
|
return BOXED_PATTERN.findall(text)
|
|
|
|
|
|
def extract_first_boxed_answer(
|
|
response: str, after_think: bool = True, debug: bool = False
|
|
) -> Tuple[Optional[str], str, bool]:
|
|
"""
|
|
Extract the first \\boxed{} answer from a response.
|
|
|
|
Follows the rule: only accept if there's exactly ONE boxed answer
|
|
after the </think> tag (if thinking mode). Multiple boxed answers = failure.
|
|
|
|
Args:
|
|
response: The model's full response
|
|
after_think: Whether to only look after </think> tags
|
|
debug: Whether to print debug information
|
|
|
|
Returns:
|
|
Tuple of (extracted_answer or None, extraction_method, has_multiple_boxed)
|
|
"""
|
|
# Get content to search
|
|
if after_think:
|
|
# Extract content after </think>
|
|
match = THINK_CONTENT_AFTER_PATTERN.search(response)
|
|
if match:
|
|
search_content = match.group(1)
|
|
else:
|
|
# No think tags, use full response
|
|
search_content = response
|
|
else:
|
|
search_content = response
|
|
|
|
# Find all boxed answers
|
|
boxed_answers = extract_boxed_answers(search_content)
|
|
|
|
if len(boxed_answers) == 0:
|
|
if debug:
|
|
print(" No \\boxed{} found in response")
|
|
return None, "no_boxed", False
|
|
|
|
if len(boxed_answers) > 1:
|
|
if debug:
|
|
print(f" Multiple \\boxed{{}} found ({len(boxed_answers)}) - rejecting")
|
|
return None, "multiple_boxed", True
|
|
|
|
# Exactly one boxed answer
|
|
answer = boxed_answers[0].strip()
|
|
if debug:
|
|
preview = answer[:50] + "..." if len(answer) > 50 else answer
|
|
print(f" Extracted '{preview}' from \\boxed{{}}")
|
|
|
|
return answer, "boxed", False
|
|
|
|
|
|
def math_normalize_string(text: str) -> str:
|
|
"""
|
|
Normalize a math answer string for comparison.
|
|
|
|
This is a fallback when symbolic verification fails.
|
|
Based on lighteval's math_normalizer.
|
|
|
|
Args:
|
|
text: The text to normalize
|
|
|
|
Returns:
|
|
Normalized string
|
|
"""
|
|
if not text:
|
|
return ""
|
|
|
|
# Remove outer whitespace
|
|
text = text.strip()
|
|
|
|
# Remove \boxed{} wrapper if present
|
|
boxed_match = BOXED_PATTERN.search(text)
|
|
if boxed_match:
|
|
text = boxed_match.group(1)
|
|
|
|
# Normalize whitespace
|
|
text = " ".join(text.split())
|
|
|
|
# Remove common LaTeX commands that don't affect value
|
|
text = re.sub(r"\\[,;:!]", "", text) # \, \; etc.
|
|
text = re.sub(r"\\quad|\\qquad", " ", text)
|
|
text = re.sub(r"\\text\{([^}]*)\}", r"\1", text) # \text{...} -> ...
|
|
text = re.sub(r"\\mathrm\{([^}]*)\}", r"\1", text)
|
|
text = re.sub(r"\\mathbf\{([^}]*)\}", r"\1", text)
|
|
|
|
# Normalize math operators
|
|
text = text.replace("\\times", "*")
|
|
text = text.replace("\\cdot", "*")
|
|
text = text.replace("\\div", "/")
|
|
text = text.replace("\\pm", "+-")
|
|
|
|
# Remove $ signs
|
|
text = text.replace("$", "")
|
|
|
|
# Remove trailing punctuation
|
|
text = text.rstrip(".,;:")
|
|
|
|
# Lowercase for comparison
|
|
text = text.lower()
|
|
|
|
return text.strip()
|
|
|
|
|
|
def compare_math_strings(pred: str, gold: str) -> bool:
|
|
"""
|
|
Compare two math answer strings after normalization.
|
|
|
|
This is a fallback when symbolic verification fails.
|
|
|
|
Args:
|
|
pred: Predicted answer
|
|
gold: Gold answer
|
|
|
|
Returns:
|
|
True if answers match, False otherwise
|
|
"""
|
|
pred_norm = math_normalize_string(pred)
|
|
gold_norm = math_normalize_string(gold)
|
|
|
|
if not pred_norm:
|
|
return False
|
|
|
|
# Exact match
|
|
if pred_norm == gold_norm:
|
|
return True
|
|
|
|
# Try numeric comparison
|
|
try:
|
|
# Handle commas in numbers
|
|
pred_clean = pred_norm.replace(",", "").replace(" ", "")
|
|
gold_clean = gold_norm.replace(",", "").replace(" ", "")
|
|
|
|
pred_num = float(pred_clean)
|
|
gold_num = float(gold_clean)
|
|
|
|
# Exact numeric match
|
|
if pred_num == gold_num:
|
|
return True
|
|
|
|
# Small relative error (for floating point)
|
|
if gold_num != 0 and abs(pred_num - gold_num) / abs(gold_num) < 1e-6:
|
|
return True
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
# Try integer comparison (for AIME-style 0-999)
|
|
try:
|
|
pred_int = int(float(pred_norm.replace(",", "")))
|
|
gold_int = int(float(gold_norm.replace(",", "")))
|
|
if pred_int == gold_int:
|
|
return True
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
def _score_math_answer_worker(
|
|
gold: str, response: str, wrap_gold_boxed: bool = True
|
|
) -> Tuple[Optional[bool], str]:
|
|
"""
|
|
Worker function for scoring math answers (runs in separate process).
|
|
|
|
This function is designed to run in a ProcessPoolExecutor to protect
|
|
against hangs from sympy/latex parsing.
|
|
|
|
Args:
|
|
gold: The gold answer
|
|
response: The model's response (content to extract answer from)
|
|
wrap_gold_boxed: Whether to wrap gold in \\boxed{} if not present
|
|
|
|
Returns:
|
|
Tuple of (is_correct or None, method_used)
|
|
"""
|
|
if not MATH_VERIFY_AVAILABLE:
|
|
# Fallback to string comparison
|
|
boxed = extract_boxed_answers(response)
|
|
if boxed:
|
|
return compare_math_strings(boxed[0], gold), "string_fallback_no_lib"
|
|
return None, "no_math_verify"
|
|
|
|
try:
|
|
# Prepare gold answer
|
|
if wrap_gold_boxed and "\\boxed" not in gold:
|
|
gold_text = f"\\boxed{{{gold}}}"
|
|
else:
|
|
gold_text = gold
|
|
|
|
# Parse gold
|
|
gold_parsed = parse(
|
|
gold_text,
|
|
extraction_mode="first_match",
|
|
extraction_config=[LatexExtractionConfig()],
|
|
)
|
|
|
|
if len(gold_parsed) == 0:
|
|
# Gold couldn't be parsed, try string comparison
|
|
boxed = extract_boxed_answers(response)
|
|
if boxed:
|
|
return (
|
|
compare_math_strings(boxed[0], gold),
|
|
"string_fallback_gold_parse",
|
|
)
|
|
return None, "gold_parse_failed"
|
|
|
|
# Parse response
|
|
response_parsed = parse(
|
|
response,
|
|
extraction_config=[
|
|
LatexExtractionConfig(
|
|
normalization_config=NormalizationConfig(
|
|
nits=False,
|
|
malformed_operators=False,
|
|
basic_latex=True,
|
|
equations=True,
|
|
boxed="all",
|
|
units=True,
|
|
),
|
|
boxed_match_priority=0,
|
|
try_extract_without_anchor=False,
|
|
)
|
|
],
|
|
extraction_mode="first_match",
|
|
)
|
|
|
|
if len(response_parsed) == 0:
|
|
# Response couldn't be parsed, try string comparison
|
|
boxed = extract_boxed_answers(response)
|
|
if boxed:
|
|
return (
|
|
compare_math_strings(boxed[0], gold),
|
|
"string_fallback_response_parse",
|
|
)
|
|
return None, "response_parse_failed"
|
|
|
|
# Verify match
|
|
is_correct = verify(response_parsed, gold_parsed)
|
|
return is_correct, "math_verify"
|
|
|
|
except TimeoutException:
|
|
# Timeout during parsing/verification, try string comparison
|
|
boxed = extract_boxed_answers(response)
|
|
if boxed:
|
|
return compare_math_strings(boxed[0], gold), "string_fallback_timeout"
|
|
return None, "timeout"
|
|
except Exception as e:
|
|
# Any other error, try string comparison
|
|
boxed = extract_boxed_answers(response)
|
|
if boxed:
|
|
return compare_math_strings(boxed[0], gold), "string_fallback_error"
|
|
return None, f"error_{type(e).__name__}"
|
|
|
|
|
|
def score_math_answer(
|
|
gold: str,
|
|
response: str,
|
|
after_think: bool = True,
|
|
wrap_gold_boxed: bool = True,
|
|
executor: Optional[ProcessPoolExecutor] = None,
|
|
debug: bool = False,
|
|
) -> Tuple[Optional[bool], str, bool]:
|
|
"""
|
|
Score a math answer using math_verify with process isolation.
|
|
|
|
This is the main function for scoring math answers. It:
|
|
1. Extracts content after </think> if thinking mode
|
|
2. Checks for multiple \\boxed{} (fails if multiple)
|
|
3. Uses math_verify for symbolic comparison
|
|
4. Falls back to string normalization if that fails
|
|
|
|
Args:
|
|
gold: The gold answer
|
|
response: The model's full response
|
|
after_think: Whether to extract content after </think>
|
|
wrap_gold_boxed: Whether to wrap gold in \\boxed{} if not present
|
|
executor: Optional ProcessPoolExecutor to use
|
|
debug: Whether to print debug information
|
|
|
|
Returns:
|
|
Tuple of (is_correct or None, method_used, has_multiple_boxed)
|
|
"""
|
|
# Get content to score
|
|
if after_think:
|
|
match = THINK_CONTENT_AFTER_PATTERN.search(response)
|
|
if match:
|
|
score_content = match.group(1)
|
|
else:
|
|
score_content = response
|
|
else:
|
|
score_content = response
|
|
|
|
# Check for multiple boxed answers
|
|
boxed_answers = extract_boxed_answers(score_content)
|
|
if len(boxed_answers) > 1:
|
|
if debug:
|
|
print(f" Multiple \\boxed{{}} found ({len(boxed_answers)}) - rejecting")
|
|
return None, "multiple_boxed", True
|
|
|
|
if len(boxed_answers) == 0:
|
|
if debug:
|
|
print(" No \\boxed{} found")
|
|
return None, "no_boxed", False
|
|
|
|
# Use executor if provided, otherwise run directly
|
|
if executor is not None:
|
|
import asyncio
|
|
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
future = loop.run_in_executor(
|
|
executor,
|
|
_score_math_answer_worker,
|
|
gold,
|
|
score_content,
|
|
wrap_gold_boxed,
|
|
)
|
|
# Note: This needs to be awaited in async context
|
|
# For sync usage, call _score_math_answer_worker directly
|
|
is_correct, method = future.result(timeout=30)
|
|
except Exception as e:
|
|
if debug:
|
|
print(f" Executor error: {e}")
|
|
# Fallback to string comparison
|
|
if boxed_answers:
|
|
is_correct = compare_math_strings(boxed_answers[0], gold)
|
|
method = "string_fallback_executor_error"
|
|
else:
|
|
return None, f"executor_error_{type(e).__name__}", False
|
|
else:
|
|
is_correct, method = _score_math_answer_worker(
|
|
gold, score_content, wrap_gold_boxed
|
|
)
|
|
|
|
if debug:
|
|
print(f" Score: {is_correct} (method: {method})")
|
|
|
|
return is_correct, method, False
|
|
|
|
|
|
async def score_math_answer_async(
|
|
gold: str,
|
|
response: str,
|
|
after_think: bool = True,
|
|
wrap_gold_boxed: bool = True,
|
|
executor: Optional[ProcessPoolExecutor] = None,
|
|
debug: bool = False,
|
|
) -> Tuple[Optional[bool], str, bool]:
|
|
"""
|
|
Async version of score_math_answer for use in async evaluation loops.
|
|
|
|
Uses ProcessPoolExecutor to run math verification in separate process,
|
|
protecting against hangs.
|
|
|
|
Args:
|
|
gold: The gold answer
|
|
response: The model's full response
|
|
after_think: Whether to extract content after </think>
|
|
wrap_gold_boxed: Whether to wrap gold in \\boxed{} if not present
|
|
executor: Optional ProcessPoolExecutor to use
|
|
debug: Whether to print debug information
|
|
|
|
Returns:
|
|
Tuple of (is_correct or None, method_used, has_multiple_boxed)
|
|
"""
|
|
import asyncio
|
|
|
|
# Get content to score
|
|
if after_think:
|
|
match = THINK_CONTENT_AFTER_PATTERN.search(response)
|
|
if match:
|
|
score_content = match.group(1)
|
|
else:
|
|
score_content = response
|
|
else:
|
|
score_content = response
|
|
|
|
# Check for multiple boxed answers
|
|
boxed_answers = extract_boxed_answers(score_content)
|
|
if len(boxed_answers) > 1:
|
|
if debug:
|
|
print(f" Multiple \\boxed{{}} found ({len(boxed_answers)}) - rejecting")
|
|
return None, "multiple_boxed", True
|
|
|
|
if len(boxed_answers) == 0:
|
|
if debug:
|
|
print(" No \\boxed{} found")
|
|
return None, "no_boxed", False
|
|
|
|
# Get executor
|
|
if executor is None:
|
|
executor = get_math_executor()
|
|
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
is_correct, method = await loop.run_in_executor(
|
|
executor, _score_math_answer_worker, gold, score_content, wrap_gold_boxed
|
|
)
|
|
except Exception as e:
|
|
if debug:
|
|
print(f" Executor error: {e}")
|
|
# Fallback to string comparison
|
|
if boxed_answers:
|
|
is_correct = compare_math_strings(boxed_answers[0], gold)
|
|
method = "string_fallback_executor_error"
|
|
else:
|
|
return None, f"executor_error_{type(e).__name__}", False
|
|
|
|
if debug:
|
|
print(f" Score: {is_correct} (method: {method})")
|
|
|
|
return is_correct, method, False
|
|
|
|
|
|
def format_math_answer_instruction(include_hope: bool = True) -> str:
|
|
"""
|
|
Get the standard instruction for math answer format.
|
|
|
|
Based on lighteval's AIME prompt which works well with math_verify.
|
|
|
|
Args:
|
|
include_hope: Whether to include "I hope it is correct" suffix
|
|
|
|
Returns:
|
|
Instruction string
|
|
"""
|
|
if include_hope:
|
|
return (
|
|
"The last line of your response should be of the following format: "
|
|
"'Therefore, the final answer is: $\\boxed{ANSWER}$. I hope it is correct' "
|
|
"(without quotes) where ANSWER is just the final number or expression that solves the problem."
|
|
)
|
|
else:
|
|
return (
|
|
"Put your final answer in \\boxed{} format. "
|
|
"For example: \\boxed{42} or \\boxed{\\frac{1}{2}}"
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# SYSTEM PROMPT AND CONFIGURATION HELPERS
|
|
# =============================================================================
|
|
# These functions handle common system prompt creation patterns used across
|
|
# evaluation environments with thinking mode support.
|
|
|
|
|
|
def create_system_content(
|
|
thinking_mode: bool,
|
|
custom_thinking_prompt: Optional[str] = None,
|
|
custom_system_prompt: Optional[str] = None,
|
|
) -> Optional[str]:
|
|
"""
|
|
Create system message content based on thinking mode configuration.
|
|
|
|
This is the standard pattern used across all eval environments:
|
|
- In thinking mode: thinking_prompt + optional system_prompt
|
|
- In non-thinking mode: just the system_prompt (or None)
|
|
|
|
Args:
|
|
thinking_mode: Whether thinking mode is enabled
|
|
custom_thinking_prompt: Optional custom thinking prompt (uses default if None)
|
|
custom_system_prompt: Optional additional system prompt
|
|
|
|
Returns:
|
|
System content string, or None if no content needed
|
|
"""
|
|
if thinking_mode:
|
|
thinking_prompt = get_default_thinking_prompt(custom_thinking_prompt)
|
|
if custom_system_prompt:
|
|
return f"{thinking_prompt}\n\n{custom_system_prompt}"
|
|
return thinking_prompt
|
|
return custom_system_prompt
|
|
|
|
|
|
# =============================================================================
|
|
# RESULTS SAVING UTILITIES
|
|
# =============================================================================
|
|
# Common patterns for saving evaluation results to disk.
|
|
|
|
|
|
def save_eval_results(
|
|
save_dir: str,
|
|
metrics: Dict,
|
|
results: List[Dict],
|
|
metrics_filename: str = "metrics.json",
|
|
results_filename: str = "results.jsonl",
|
|
print_confirmation: bool = True,
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Save evaluation results to disk in standard format.
|
|
|
|
Creates two files:
|
|
- metrics.json: Summary metrics dict
|
|
- results.jsonl: Per-item results (one JSON object per line)
|
|
|
|
Args:
|
|
save_dir: Directory to save results to (created if doesn't exist)
|
|
metrics: Dictionary of evaluation metrics
|
|
results: List of per-item result dictionaries
|
|
metrics_filename: Name for metrics file
|
|
results_filename: Name for results file
|
|
print_confirmation: Whether to print confirmation message
|
|
|
|
Returns:
|
|
Tuple of (metrics_path, results_path)
|
|
"""
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# Save metrics
|
|
metrics_path = os.path.join(save_dir, metrics_filename)
|
|
with open(metrics_path, "w") as f:
|
|
json.dump(metrics, f, indent=2, default=str)
|
|
|
|
# Save detailed results
|
|
results_path = os.path.join(save_dir, results_filename)
|
|
with open(results_path, "w") as f:
|
|
for r in results:
|
|
f.write(json.dumps(r, default=str) + "\n")
|
|
|
|
if print_confirmation:
|
|
print(f"Results saved to {save_dir}")
|
|
|
|
return metrics_path, results_path
|
|
|
|
|
|
def load_eval_results(
|
|
save_dir: str,
|
|
metrics_filename: str = "metrics.json",
|
|
results_filename: str = "results.jsonl",
|
|
) -> Tuple[Dict, List[Dict]]:
|
|
"""
|
|
Load evaluation results from disk.
|
|
|
|
Args:
|
|
save_dir: Directory containing results
|
|
metrics_filename: Name of metrics file
|
|
results_filename: Name of results file
|
|
|
|
Returns:
|
|
Tuple of (metrics dict, list of result dicts)
|
|
"""
|
|
# Load metrics
|
|
metrics_path = os.path.join(save_dir, metrics_filename)
|
|
with open(metrics_path, "r") as f:
|
|
metrics = json.load(f)
|
|
|
|
# Load detailed results
|
|
results_path = os.path.join(save_dir, results_filename)
|
|
results = []
|
|
with open(results_path, "r") as f:
|
|
for line in f:
|
|
if line.strip():
|
|
results.append(json.loads(line))
|
|
|
|
return metrics, results
|
|
|
|
|
|
# =============================================================================
|
|
# COMMON EVALUATION UTILITIES
|
|
# =============================================================================
|
|
# Helper functions used across multiple eval environments.
|
|
|
|
|
|
def calculate_accuracy(
|
|
results: List[Dict],
|
|
score_key: str = "is_correct",
|
|
filter_fn: Optional[callable] = None,
|
|
) -> float:
|
|
"""
|
|
Calculate accuracy from a list of result dictionaries.
|
|
|
|
Args:
|
|
results: List of result dictionaries
|
|
score_key: Key to look up score/correctness (should be bool or 0/1)
|
|
filter_fn: Optional function to filter results before calculation
|
|
|
|
Returns:
|
|
Accuracy as float between 0 and 1
|
|
"""
|
|
if filter_fn:
|
|
results = [r for r in results if filter_fn(r)]
|
|
|
|
if not results:
|
|
return 0.0
|
|
|
|
correct = sum(1 for r in results if r.get(score_key, False))
|
|
return correct / len(results)
|
|
|
|
|
|
def group_results_by_key(results: List[Dict], key: str) -> Dict[str, List[Dict]]:
|
|
"""
|
|
Group results by a specific key value.
|
|
|
|
Useful for computing per-category or per-subset metrics.
|
|
|
|
Args:
|
|
results: List of result dictionaries
|
|
key: Key to group by
|
|
|
|
Returns:
|
|
Dictionary mapping key values to lists of results
|
|
"""
|
|
grouped = {}
|
|
for r in results:
|
|
value = r.get(key, "unknown")
|
|
if value not in grouped:
|
|
grouped[value] = []
|
|
grouped[value].append(r)
|
|
return grouped
|
|
|
|
|
|
def format_percentage(value: float, decimals: int = 2) -> str:
|
|
"""
|
|
Format a float as a percentage string.
|
|
|
|
Args:
|
|
value: Float value (0-1 scale)
|
|
decimals: Number of decimal places
|
|
|
|
Returns:
|
|
Formatted percentage string (e.g., "75.50%")
|
|
"""
|
|
return f"{value * 100:.{decimals}f}%"
|
|
|
|
|
|
def print_eval_summary(title: str, metrics: Dict[str, float], width: int = 60) -> None:
|
|
"""
|
|
Print a formatted evaluation summary.
|
|
|
|
Args:
|
|
title: Summary title
|
|
metrics: Dictionary of metric name -> value
|
|
width: Width of separator lines
|
|
"""
|
|
print(f"\n{'='*width}")
|
|
print(title)
|
|
print(f"{'='*width}")
|
|
for name, value in metrics.items():
|
|
if isinstance(value, float) and 0 <= value <= 1:
|
|
print(f" {name}: {format_percentage(value)}")
|
|
else:
|
|
print(f" {name}: {value}")
|
|
print(f"{'='*width}\n")
|