mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add phybench eval
This commit is contained in:
parent
d04f8c0ae7
commit
830a129655
2 changed files with 1681 additions and 0 deletions
970
environments/eval_environments/eed_score.py
Normal file
970
environments/eval_environments/eed_score.py
Normal file
|
|
@ -0,0 +1,970 @@
|
|||
"""
|
||||
Expression Edit Distance (EED) Score Module for PHYBench Evaluation.
|
||||
|
||||
This module implements the EED Score metric for evaluating mathematical expressions,
|
||||
as described in the PHYBench paper (https://arxiv.org/abs/2504.16074).
|
||||
|
||||
The EED Score measures similarity between model-generated and reference expressions
|
||||
by computing tree edit distance over their SymPy expression trees. It provides:
|
||||
- Continuous scoring (0-100) that captures partial correctness
|
||||
- 204% improved sample efficiency over binary scoring
|
||||
- Robust handling of equivalent mathematical forms
|
||||
|
||||
Key components:
|
||||
- LaTeX preprocessing and normalization
|
||||
- SymPy expression tree construction
|
||||
- Extended Zhang-Shasha tree edit distance algorithm
|
||||
- Configurable scoring with subtree discount
|
||||
|
||||
Dependencies:
|
||||
- sympy: Symbolic mathematics
|
||||
- latex2sympy2_extended: LaTeX to SymPy conversion
|
||||
- numpy: Numerical operations
|
||||
|
||||
Based on the official PHYBench implementation:
|
||||
https://github.com/phybench-official/phybench/tree/main/EED
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from numpy import ones, zeros
|
||||
|
||||
# Try to import required dependencies
|
||||
try:
|
||||
from latex2sympy2_extended import latex2sympy
|
||||
from sympy import (
|
||||
Add,
|
||||
Float,
|
||||
Function,
|
||||
Integer,
|
||||
Mul,
|
||||
Pow,
|
||||
Rational,
|
||||
Symbol,
|
||||
expand,
|
||||
posify,
|
||||
simplify,
|
||||
)
|
||||
from sympy.core.numbers import Exp1, Infinity, NegativeInfinity, Pi
|
||||
|
||||
EED_AVAILABLE = True
|
||||
except ImportError:
|
||||
EED_AVAILABLE = False
|
||||
latex2sympy = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CONFIGURATION CONSTANTS
|
||||
# =============================================================================
|
||||
|
||||
# Cost configuration for tree edit operations
|
||||
# These can be modified if different node types should have different weights
|
||||
INSERT_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1}
|
||||
DELETE_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1}
|
||||
UPDATE_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1}
|
||||
|
||||
# Cost of updating between different types (e.g., number -> symbol)
|
||||
CHANGE_TYPE_COST = 1
|
||||
|
||||
# Subtree discount configuration
|
||||
# Minimum size to trigger cluster discount for subtree operations
|
||||
BAR_SIZE = 5
|
||||
# Discount slope for subtree operations (0.6 means 40% discount)
|
||||
DISCOUNT_SLOPE = 0.6
|
||||
|
||||
# Timeout limits (in seconds)
|
||||
SIMPLIFY_TIME_LIMIT = 30
|
||||
EQUALS_TIME_LIMIT = 10
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TREE NODE CLASS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TreeNode:
|
||||
"""
|
||||
A node in the expression tree representation.
|
||||
|
||||
Attributes:
|
||||
label: Node label (e.g., "number_2", "symbol_x", "operator_Add")
|
||||
children: List of child TreeNode objects
|
||||
subtree_size: Cached size of the subtree rooted at this node
|
||||
"""
|
||||
|
||||
def __init__(self, label: str, children: Optional[List["TreeNode"]] = None):
|
||||
self.label = label
|
||||
self.children = children if children is not None else []
|
||||
self.subtree_size = 0
|
||||
|
||||
def get_children(self) -> List["TreeNode"]:
|
||||
"""Return the list of child nodes."""
|
||||
return self.children
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.label
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TREE EDIT DISTANCE FUNCTIONS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def calc_tree_size(node: TreeNode) -> int:
|
||||
"""
|
||||
Calculate the size of a subtree based on total insertion cost.
|
||||
|
||||
The size equals the sum of insertion costs of all nodes in the subtree.
|
||||
Results are cached in node.subtree_size for efficiency.
|
||||
|
||||
Args:
|
||||
node: Root node of the subtree
|
||||
|
||||
Returns:
|
||||
Total size of the subtree
|
||||
"""
|
||||
# Get insertion cost for this node type
|
||||
node_type = node.label.split("_")[0]
|
||||
total = INSERT_COST.get(node_type, 1)
|
||||
|
||||
# Return cached value if available (for non-leaf nodes)
|
||||
if node.children and node.subtree_size != 0:
|
||||
return node.subtree_size
|
||||
|
||||
# Recursively calculate size of children
|
||||
for child in node.children:
|
||||
total += calc_tree_size(child)
|
||||
|
||||
# Cache the result
|
||||
node.subtree_size = total
|
||||
return total
|
||||
|
||||
|
||||
def update_func(x: TreeNode, y: TreeNode) -> float:
|
||||
"""
|
||||
Calculate the cost of updating node x to node y.
|
||||
|
||||
Args:
|
||||
x: Source node
|
||||
y: Target node
|
||||
|
||||
Returns:
|
||||
Update cost (0 if identical, type-specific cost if same type, else CHANGE_TYPE_COST)
|
||||
"""
|
||||
if x.label == y.label:
|
||||
return 0
|
||||
|
||||
x_type = x.label.split("_")[0]
|
||||
y_type = y.label.split("_")[0]
|
||||
|
||||
if x_type == y_type:
|
||||
return UPDATE_COST.get(x_type, 1)
|
||||
|
||||
return CHANGE_TYPE_COST
|
||||
|
||||
|
||||
def remove_func(x: TreeNode) -> float:
|
||||
"""Calculate the cost of removing a single node."""
|
||||
node_type = x.label.split("_")[0]
|
||||
return DELETE_COST.get(node_type, 1)
|
||||
|
||||
|
||||
def remove_tree_func(x: TreeNode) -> float:
|
||||
"""
|
||||
Calculate the cost of removing an entire subtree.
|
||||
|
||||
Applies discount for large subtrees (cluster discount).
|
||||
|
||||
Args:
|
||||
x: Root of subtree to remove
|
||||
|
||||
Returns:
|
||||
Removal cost with potential discount
|
||||
"""
|
||||
if not x.children:
|
||||
return remove_func(x)
|
||||
|
||||
size = calc_tree_size(x)
|
||||
# Apply discount for large subtrees
|
||||
return min(size, DISCOUNT_SLOPE * (size - BAR_SIZE) + BAR_SIZE)
|
||||
|
||||
|
||||
def insert_func(x: TreeNode) -> float:
|
||||
"""Calculate the cost of inserting a single node."""
|
||||
node_type = x.label.split("_")[0]
|
||||
return INSERT_COST.get(node_type, 1)
|
||||
|
||||
|
||||
def insert_tree_func(x: TreeNode) -> float:
|
||||
"""Calculate the cost of inserting an entire subtree (same as removal)."""
|
||||
return remove_tree_func(x)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ANNOTATED TREE FOR ZHANG-SHASHA ALGORITHM
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AnnotatedTree:
|
||||
"""
|
||||
Annotated tree structure for the Zhang-Shasha algorithm.
|
||||
|
||||
Computes post-order enumeration, left-most descendants, and keyroots
|
||||
needed for efficient tree edit distance calculation.
|
||||
"""
|
||||
|
||||
def __init__(self, root: TreeNode, get_children):
|
||||
self.get_children = get_children
|
||||
self.root = root
|
||||
self.nodes = [] # Post-order enumeration of nodes
|
||||
self.ids = [] # Matching list of IDs
|
||||
self.lmds = [] # Left-most descendants
|
||||
self.keyroots = None
|
||||
|
||||
# Build the annotated structure
|
||||
import collections
|
||||
|
||||
stack = [(root, collections.deque())]
|
||||
pstack = []
|
||||
j = 0
|
||||
|
||||
while stack:
|
||||
n, anc = stack.pop()
|
||||
nid = j
|
||||
for c in self.get_children(n):
|
||||
a = collections.deque(anc)
|
||||
a.appendleft(nid)
|
||||
stack.append((c, a))
|
||||
pstack.append(((n, nid), anc))
|
||||
j += 1
|
||||
|
||||
lmds = {}
|
||||
keyroots = {}
|
||||
i = 0
|
||||
|
||||
while pstack:
|
||||
(n, nid), anc = pstack.pop()
|
||||
self.nodes.append(n)
|
||||
self.ids.append(nid)
|
||||
|
||||
if not self.get_children(n):
|
||||
lmd = i
|
||||
for a in anc:
|
||||
if a not in lmds:
|
||||
lmds[a] = i
|
||||
else:
|
||||
break
|
||||
else:
|
||||
lmd = lmds.get(nid, i)
|
||||
|
||||
self.lmds.append(lmd)
|
||||
keyroots[lmd] = i
|
||||
i += 1
|
||||
|
||||
self.keyroots = sorted(keyroots.values())
|
||||
|
||||
|
||||
def ext_distance(
|
||||
a_root: TreeNode,
|
||||
b_root: TreeNode,
|
||||
get_children,
|
||||
single_insert_cost,
|
||||
insert_cost,
|
||||
single_remove_cost,
|
||||
remove_cost,
|
||||
update_cost_func,
|
||||
) -> float:
|
||||
"""
|
||||
Compute extended tree edit distance using modified Zhang-Shasha algorithm.
|
||||
|
||||
This implementation extends the standard algorithm with subtree insertion
|
||||
and deletion operations for handling clustered changes.
|
||||
|
||||
Args:
|
||||
a_root: Root of first tree
|
||||
b_root: Root of second tree
|
||||
get_children: Function to get children of a node
|
||||
single_insert_cost: Cost function for single node insertion
|
||||
insert_cost: Cost function for subtree insertion
|
||||
single_remove_cost: Cost function for single node removal
|
||||
remove_cost: Cost function for subtree removal
|
||||
update_cost_func: Cost function for updating a node
|
||||
|
||||
Returns:
|
||||
Tree edit distance between the two trees
|
||||
"""
|
||||
a_tree = AnnotatedTree(a_root, get_children)
|
||||
b_tree = AnnotatedTree(b_root, get_children)
|
||||
|
||||
size_a = len(a_tree.nodes)
|
||||
size_b = len(b_tree.nodes)
|
||||
|
||||
treedists = zeros((size_a, size_b), float)
|
||||
fd = 1000 * ones((size_a + 1, size_b + 1), float)
|
||||
|
||||
def treedist(x: int, y: int):
|
||||
al = a_tree.lmds
|
||||
bl = b_tree.lmds
|
||||
an = a_tree.nodes
|
||||
bn = b_tree.nodes
|
||||
|
||||
fd[al[x]][bl[y]] = 0
|
||||
|
||||
for i in range(al[x], x + 1):
|
||||
node = an[i]
|
||||
fd[i + 1][bl[y]] = fd[al[i]][bl[y]] + remove_cost(node)
|
||||
|
||||
for j in range(bl[y], y + 1):
|
||||
node = bn[j]
|
||||
fd[al[x]][j + 1] = fd[al[x]][bl[j]] + insert_cost(node)
|
||||
|
||||
for i in range(al[x], x + 1):
|
||||
for j in range(bl[y], y + 1):
|
||||
node1 = an[i]
|
||||
node2 = bn[j]
|
||||
|
||||
costs = [
|
||||
fd[i][j + 1] + single_remove_cost(node1),
|
||||
fd[i + 1][j] + single_insert_cost(node2),
|
||||
fd[al[i]][j + 1] + remove_cost(node1),
|
||||
fd[i + 1][bl[j]] + insert_cost(node2),
|
||||
]
|
||||
m = min(costs)
|
||||
|
||||
if al[x] == al[i] and bl[y] == bl[j]:
|
||||
treedists[i][j] = min(m, fd[i][j] + update_cost_func(node1, node2))
|
||||
fd[i + 1][j + 1] = treedists[i][j]
|
||||
else:
|
||||
fd[i + 1][j + 1] = min(m, fd[al[i]][bl[j]] + treedists[i][j])
|
||||
|
||||
for x in a_tree.keyroots:
|
||||
for y in b_tree.keyroots:
|
||||
treedist(x, y)
|
||||
|
||||
return treedists[-1][-1]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LATEX PREPROCESSING
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def brackets_balanced(s: str) -> bool:
|
||||
"""Check if brackets in a string are balanced."""
|
||||
stack = []
|
||||
bracket_pairs = {")": "(", "]": "[", "}": "{"}
|
||||
|
||||
for char in s:
|
||||
if char in bracket_pairs.values():
|
||||
stack.append(char)
|
||||
elif char in bracket_pairs:
|
||||
if not stack or stack[-1] != bracket_pairs[char]:
|
||||
return False
|
||||
stack.pop()
|
||||
|
||||
return len(stack) == 0
|
||||
|
||||
|
||||
def find_first_unescaped_brace(s: str) -> int:
|
||||
"""Find the position of the first unescaped opening brace."""
|
||||
escaped = False
|
||||
for i, c in enumerate(s):
|
||||
if c == "\\" and not escaped:
|
||||
escaped = True
|
||||
continue
|
||||
if c == "{" and not escaped:
|
||||
return i
|
||||
escaped = False
|
||||
return -1
|
||||
|
||||
|
||||
def extract_bracket_content(s: str, bracket_position: int) -> Tuple[Optional[str], int]:
|
||||
"""Extract content inside braces starting at given position."""
|
||||
brace_start = bracket_position + 1
|
||||
brace_depth = 0
|
||||
content = []
|
||||
escaped = False
|
||||
|
||||
for i in range(brace_start, len(s)):
|
||||
char = s[i]
|
||||
if escaped:
|
||||
content.append(char)
|
||||
escaped = False
|
||||
continue
|
||||
if char == "\\":
|
||||
escaped = True
|
||||
content.append(char)
|
||||
continue
|
||||
if char == "{":
|
||||
brace_depth += 1
|
||||
content.append(char)
|
||||
elif char == "}":
|
||||
if brace_depth == 0:
|
||||
return "".join(content), i
|
||||
brace_depth -= 1
|
||||
content.append(char)
|
||||
else:
|
||||
content.append(char)
|
||||
|
||||
return None, -1
|
||||
|
||||
|
||||
def remove_command(s: str, command: str, keep_inside: bool = False) -> str:
|
||||
"""
|
||||
Remove all occurrences of a LaTeX command from a string.
|
||||
|
||||
Args:
|
||||
s: Input string
|
||||
command: Command to remove (e.g., "\\textbf")
|
||||
keep_inside: If True, preserve content inside braces
|
||||
|
||||
Returns:
|
||||
String with command removed
|
||||
"""
|
||||
pos = s.find(command)
|
||||
if pos < 0:
|
||||
return s
|
||||
|
||||
end_index = pos + len(command)
|
||||
level = 0
|
||||
|
||||
if end_index < len(s) and s[end_index] == "{":
|
||||
while end_index < len(s):
|
||||
if s[end_index] == "{":
|
||||
level += 1
|
||||
elif s[end_index] == "}":
|
||||
level -= 1
|
||||
if level == 0:
|
||||
break
|
||||
end_index += 1
|
||||
if keep_inside:
|
||||
s1 = s[:pos] + s[pos + len(command) + 1 : end_index] + s[end_index + 1 :]
|
||||
else:
|
||||
s1 = s[:pos] + s[end_index + 1 :]
|
||||
else:
|
||||
s1 = s[:pos] + s[end_index:]
|
||||
|
||||
if command not in s1:
|
||||
return s1
|
||||
return remove_command(s1, command, keep_inside)
|
||||
|
||||
|
||||
def extract_last_equal_content(s: str) -> str:
|
||||
"""Extract content after the last equality/comparison operator."""
|
||||
comparison_operators = ("=", "\\approx", "\\ge", "\\le", "\\geq", "\\leq", "<", ">")
|
||||
content = s
|
||||
|
||||
for sign in comparison_operators:
|
||||
if sign in s:
|
||||
rfind_index = s.rfind(sign)
|
||||
if rfind_index != -1:
|
||||
content = s[rfind_index + len(sign) :]
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def convert_latex_fractions(latex_str: str) -> str:
|
||||
"""Convert non-standard fractions like \\frac\\alpha2 to \\frac{\\alpha}{2}."""
|
||||
pattern = (
|
||||
r"\\frac((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))"
|
||||
r"((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))"
|
||||
)
|
||||
|
||||
def replacer(match):
|
||||
numerator, denominator = match.group(1), match.group(2)
|
||||
wrap_num = (
|
||||
f"{{{numerator}}}"
|
||||
if not (numerator.startswith("{") and numerator.endswith("}"))
|
||||
else numerator
|
||||
)
|
||||
wrap_den = (
|
||||
f"{{{denominator}}}"
|
||||
if not (denominator.startswith("{") and denominator.endswith("}"))
|
||||
else denominator
|
||||
)
|
||||
return rf"\frac{wrap_num}{wrap_den}"
|
||||
|
||||
return re.sub(pattern, replacer, latex_str)
|
||||
|
||||
|
||||
def convert_vec_syntax(text: str) -> str:
|
||||
"""Convert \\vec x to \\vec{x}."""
|
||||
pattern = r"\\vec(\s*)(\\?[a-zA-Zα-ωΑ-Ω]+)"
|
||||
replacement = r"\\vec{\2}"
|
||||
return re.sub(pattern, replacement, text)
|
||||
|
||||
|
||||
def first_preprocess(s: str, extract_box: bool = True) -> str:
|
||||
"""
|
||||
First stage of LaTeX preprocessing.
|
||||
|
||||
Extracts boxed content, removes outer braces, and extracts content after equals.
|
||||
|
||||
Args:
|
||||
s: Input LaTeX string
|
||||
extract_box: Whether to extract content from \\boxed{}
|
||||
|
||||
Returns:
|
||||
Preprocessed string
|
||||
"""
|
||||
s = s.replace("\\{", "(")
|
||||
s = s.replace("\\}", ")")
|
||||
|
||||
if not brackets_balanced(s):
|
||||
return s
|
||||
|
||||
if extract_box:
|
||||
boxed_content = remove_command(s, "\\boxed", keep_inside=True)
|
||||
else:
|
||||
boxed_content = s
|
||||
|
||||
# Remove overall braces
|
||||
def remove_overall_brace(text: str) -> Tuple[str, bool]:
|
||||
pos = find_first_unescaped_brace(text)
|
||||
if pos == -1:
|
||||
return text, False
|
||||
|
||||
content, final = extract_bracket_content(text, pos)
|
||||
if content and (final == len(text) - 1 or "}" not in text[final + 1 :]):
|
||||
# Check if there's a command before the brace
|
||||
if pos > 0 and text[pos - 1] not in (" ", "\t", "\n"):
|
||||
return text, False
|
||||
return content, True
|
||||
return text, False
|
||||
|
||||
# Remove outer braces iteratively
|
||||
for _ in range(10):
|
||||
boxed_content, changed = remove_overall_brace(boxed_content)
|
||||
if not changed:
|
||||
break
|
||||
|
||||
# Handle \\quad separator
|
||||
if "\\quad" in boxed_content:
|
||||
boxed_content = boxed_content.split("\\quad")[0]
|
||||
|
||||
# Extract content after last equals sign
|
||||
last_equal_content = extract_last_equal_content(boxed_content)
|
||||
|
||||
# Remove outer braces again
|
||||
for _ in range(10):
|
||||
last_equal_content, changed = remove_overall_brace(last_equal_content)
|
||||
if not changed:
|
||||
break
|
||||
|
||||
return last_equal_content
|
||||
|
||||
|
||||
def second_preprocess(s: str) -> str:
|
||||
"""
|
||||
Second stage of LaTeX preprocessing.
|
||||
|
||||
Removes/modifies LaTeX commands and normalizes expressions.
|
||||
|
||||
Args:
|
||||
s: Input string from first preprocessing stage
|
||||
|
||||
Returns:
|
||||
Normalized LaTeX string ready for conversion
|
||||
"""
|
||||
# Commands to completely remove (including their content)
|
||||
kill_commands = ["\\begin", "\\end"]
|
||||
|
||||
# Commands to remove but keep their content
|
||||
remove_commands = [
|
||||
"\\text",
|
||||
"\\mathbf",
|
||||
"\\mathrm",
|
||||
"\\pmb",
|
||||
"\\hat",
|
||||
"\\overline",
|
||||
"\\boldsymbol",
|
||||
]
|
||||
|
||||
# Content to remove entirely
|
||||
remove_content = [
|
||||
"\\,",
|
||||
"$",
|
||||
",",
|
||||
"`",
|
||||
"latex",
|
||||
"\\left",
|
||||
"\\right",
|
||||
"\\Bigr",
|
||||
"\\Bigl",
|
||||
"\n",
|
||||
"\\]",
|
||||
"\\[",
|
||||
"\\Big",
|
||||
"\\bigl",
|
||||
"\\bigr",
|
||||
"\\biggl",
|
||||
"\\biggr",
|
||||
"\\displaystyle",
|
||||
"\\infty",
|
||||
]
|
||||
|
||||
# Content replacements
|
||||
replace_content = [
|
||||
("\\operatorname{asin}", "\\asin"),
|
||||
("\\operatorname{sech}", "\\sech"),
|
||||
("\\operatorname{acos}", "\\acos"),
|
||||
("\\operatorname{sinh}", "\\sinh"),
|
||||
("\\dfrac", "\\frac"),
|
||||
("\\tfrac", "\\frac"),
|
||||
("\\Exp", "\\exp"),
|
||||
("\\times", "\\bar{times}"),
|
||||
("\\partial", "\\bar{partial}"),
|
||||
("\\perp", "\\bar{perp}"),
|
||||
("\\epsilon", "\\varepsilon"),
|
||||
("\\varOmega", "\\Omega"),
|
||||
("I", "\\bar{I}"),
|
||||
("_e", "_{e}"),
|
||||
("e_", "\\bar{e}_"),
|
||||
("E_", "\\bar{E}_"),
|
||||
("\\pm", "+"),
|
||||
("\\mp", "-"),
|
||||
("{+}", "{p}"),
|
||||
("{-}", "{m}"),
|
||||
("_+", "_p"),
|
||||
("_-", "_m"),
|
||||
]
|
||||
|
||||
# Apply transformations
|
||||
for command in kill_commands:
|
||||
s = remove_command(s, command, keep_inside=False)
|
||||
|
||||
for command in remove_commands:
|
||||
s = remove_command(s, command, keep_inside=True)
|
||||
|
||||
for content in remove_content:
|
||||
s = s.replace(content, "")
|
||||
|
||||
for old, new in replace_content:
|
||||
s = s.replace(old, new)
|
||||
|
||||
# Additional transformations
|
||||
s = convert_latex_fractions(s)
|
||||
s = convert_vec_syntax(s)
|
||||
|
||||
# Remove trailing period
|
||||
if s and s[-1] == ".":
|
||||
s = s[:-1]
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class LaTeXNormalizationConfig:
|
||||
"""Configuration for latex2sympy normalization."""
|
||||
|
||||
basic_latex: bool = True
|
||||
units: bool = False
|
||||
malformed_operators: bool = True
|
||||
nits: bool = True
|
||||
boxed = "all"
|
||||
equations: bool = False
|
||||
|
||||
|
||||
class LaTeXConversionConfig:
|
||||
"""Configuration for latex2sympy conversion."""
|
||||
|
||||
interpret_as_mixed_fractions: bool = False
|
||||
interpret_simple_eq_as_assignment: bool = False
|
||||
interpret_contains_as_eq: bool = True
|
||||
lowercase_symbols: bool = False
|
||||
|
||||
|
||||
def master_convert(s: str):
|
||||
"""
|
||||
Convert a LaTeX string to a SymPy expression.
|
||||
|
||||
This is the main conversion function that applies preprocessing
|
||||
and uses latex2sympy for the actual conversion.
|
||||
|
||||
Args:
|
||||
s: LaTeX string to convert
|
||||
|
||||
Returns:
|
||||
SymPy expression
|
||||
|
||||
Raises:
|
||||
Various exceptions if conversion fails
|
||||
"""
|
||||
if not EED_AVAILABLE:
|
||||
raise ImportError("latex2sympy2_extended and sympy are required for EED scoring")
|
||||
|
||||
preprocessed_stage1 = first_preprocess(s)
|
||||
preprocessed_stage2 = second_preprocess(preprocessed_stage1)
|
||||
|
||||
sym = latex2sympy(
|
||||
preprocessed_stage2,
|
||||
normalization_config=LaTeXNormalizationConfig(),
|
||||
conversion_config=LaTeXConversionConfig(),
|
||||
)
|
||||
return sym
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SYMPY TO TREE CONVERSION
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def sympy_to_tree(expr) -> TreeNode:
|
||||
"""
|
||||
Convert a SymPy expression to a tree structure.
|
||||
|
||||
Args:
|
||||
expr: SymPy expression
|
||||
|
||||
Returns:
|
||||
TreeNode representing the expression
|
||||
|
||||
Raises:
|
||||
ValueError: If expression contains unsupported types
|
||||
"""
|
||||
# Numbers and constants
|
||||
if isinstance(expr, (Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity)):
|
||||
return TreeNode(label=f"number_{expr}", children=[])
|
||||
|
||||
# Symbols
|
||||
if isinstance(expr, Symbol):
|
||||
return TreeNode(label=f"symbol_{expr}", children=[])
|
||||
|
||||
# Binary operators (Add, Mul, Pow)
|
||||
if isinstance(expr, (Add, Mul, Pow)):
|
||||
op_name = type(expr).__name__
|
||||
children = [sympy_to_tree(arg) for arg in expr.args]
|
||||
return TreeNode(label=f"operator_{op_name}", children=children)
|
||||
|
||||
# Functions
|
||||
if isinstance(expr, Function):
|
||||
func_name = expr.func.__name__
|
||||
children = [sympy_to_tree(arg) for arg in expr.args]
|
||||
return TreeNode(label=f"function_{func_name}", children=children)
|
||||
|
||||
raise ValueError(f"Unsupported SymPy type: {type(expr).__name__}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SCORING FUNCTION
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def score_calc(tree_dist: float, tree_size: int) -> float:
|
||||
"""
|
||||
Calculate EED score from tree distance and size.
|
||||
|
||||
The scoring function:
|
||||
- 100 if distance is 0 (exact match)
|
||||
- 60 - 100*r if 0 < r < 0.6 (partial credit)
|
||||
- 0 if r >= 0.6 (too different)
|
||||
|
||||
where r = distance / tree_size
|
||||
|
||||
Args:
|
||||
tree_dist: Tree edit distance
|
||||
tree_size: Size of the ground truth tree
|
||||
|
||||
Returns:
|
||||
Score between 0 and 100
|
||||
"""
|
||||
if tree_dist == 0:
|
||||
return 100.0
|
||||
|
||||
return max(0, 100 * DISCOUNT_SLOPE - 100 * tree_dist / tree_size)
|
||||
|
||||
|
||||
def time_simplify(expr, timeout: int = SIMPLIFY_TIME_LIMIT):
|
||||
"""
|
||||
Simplify expression with timeout protection.
|
||||
|
||||
Args:
|
||||
expr: SymPy expression to simplify
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
Simplified expression, or original if timeout/error
|
||||
"""
|
||||
try:
|
||||
# Note: For production use, consider using multiprocessing for true timeout
|
||||
return simplify(expr)
|
||||
except Exception:
|
||||
return expr
|
||||
|
||||
|
||||
def time_equal(expr1, expr2, timeout: int = EQUALS_TIME_LIMIT) -> bool:
|
||||
"""
|
||||
Check expression equality with timeout protection.
|
||||
|
||||
Args:
|
||||
expr1: First expression
|
||||
expr2: Second expression
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if expressions are equal, False otherwise
|
||||
"""
|
||||
try:
|
||||
return expr1.equals(expr2)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MAIN EED FUNCTION
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def compute_eed_score(
|
||||
answer_latex: str,
|
||||
test_latex: str,
|
||||
debug_mode: bool = False,
|
||||
) -> Tuple[float, float, int, float]:
|
||||
"""
|
||||
Compute the EED (Expression Edit Distance) Score between two LaTeX expressions.
|
||||
|
||||
This function evaluates the similarity between a ground truth answer and
|
||||
a model-generated answer by:
|
||||
1. Converting LaTeX to SymPy expressions
|
||||
2. Simplifying and checking for equivalence
|
||||
3. Building expression trees
|
||||
4. Computing tree edit distance
|
||||
5. Converting distance to a 0-100 score
|
||||
|
||||
Args:
|
||||
answer_latex: Ground truth answer in LaTeX format
|
||||
test_latex: Model-generated answer in LaTeX format
|
||||
debug_mode: If True, raise exceptions instead of returning defaults
|
||||
|
||||
Returns:
|
||||
Tuple of (score, relative_distance, tree_size, distance):
|
||||
- score: EED score from 0-100 (100 = perfect match)
|
||||
- relative_distance: distance / tree_size (-1 if error)
|
||||
- tree_size: Size of ground truth tree (-1 if error)
|
||||
- distance: Raw tree edit distance (-1 if error)
|
||||
"""
|
||||
if not EED_AVAILABLE:
|
||||
if debug_mode:
|
||||
raise ImportError("EED scoring requires latex2sympy2_extended and sympy")
|
||||
return 0, -1, -1, -1
|
||||
|
||||
# Handle empty or invalid input
|
||||
if not test_latex:
|
||||
return 0, -1, -1, -1
|
||||
|
||||
# Skip unsupported expressions (integrals, sums)
|
||||
if "\\int" in test_latex or "\\int" in answer_latex:
|
||||
return 0, -1, -1, -1
|
||||
if "\\sum" in test_latex or "\\sum" in answer_latex:
|
||||
return 0, -1, -1, -1
|
||||
|
||||
# Quick check for exact string match
|
||||
if answer_latex == test_latex:
|
||||
return 100, 0.0, -1, 0
|
||||
|
||||
# Skip if test is much longer than answer (likely wrong)
|
||||
if len(test_latex) > 3 * len(answer_latex):
|
||||
return 0, -1, -1, -1
|
||||
|
||||
# Convert LaTeX to SymPy
|
||||
try:
|
||||
answer_exp = master_convert(answer_latex)
|
||||
test_exp = master_convert(test_latex)
|
||||
except Exception as e:
|
||||
if debug_mode:
|
||||
raise ValueError(f"Failed to convert LaTeX: {e}")
|
||||
return 0, -1, -1, -1
|
||||
|
||||
# Simplify and check equivalence
|
||||
try:
|
||||
# Assume all symbols are positive for simplification
|
||||
answer_exp, rep1 = posify(answer_exp)
|
||||
answer_exp = time_simplify(answer_exp)
|
||||
|
||||
test_exp, rep2 = posify(test_exp)
|
||||
test_exp = time_simplify(test_exp)
|
||||
|
||||
# Restore original symbols
|
||||
answer_exp = answer_exp.subs(rep1)
|
||||
test_exp = test_exp.subs(rep2)
|
||||
|
||||
# Check for equivalence
|
||||
zero_exp = time_simplify(expand(answer_exp - test_exp))
|
||||
|
||||
if answer_exp == test_exp or zero_exp == 0:
|
||||
return 100, 0.0, 0, 0
|
||||
|
||||
if time_equal(answer_exp, test_exp):
|
||||
return 100, 0.0, 0, 0
|
||||
|
||||
except Exception as e:
|
||||
if debug_mode:
|
||||
raise ValueError(f"Failed during simplification: {e}")
|
||||
return 0, -1, -1, -1
|
||||
|
||||
# Build expression trees
|
||||
try:
|
||||
tree_answer = sympy_to_tree(answer_exp)
|
||||
tree_test = sympy_to_tree(test_exp)
|
||||
except Exception as e:
|
||||
if debug_mode:
|
||||
raise ValueError(f"Failed to build expression tree: {e}")
|
||||
return 0, -1, -1, -1
|
||||
|
||||
# Compute tree edit distance
|
||||
try:
|
||||
distance = ext_distance(
|
||||
tree_test,
|
||||
tree_answer,
|
||||
get_children=lambda x: x.get_children(),
|
||||
single_insert_cost=insert_func,
|
||||
insert_cost=insert_tree_func,
|
||||
single_remove_cost=remove_func,
|
||||
remove_cost=remove_tree_func,
|
||||
update_cost_func=update_func,
|
||||
)
|
||||
except Exception as e:
|
||||
if debug_mode:
|
||||
raise ValueError(f"Failed to calculate distance: {e}")
|
||||
tree_size = calc_tree_size(tree_answer)
|
||||
return 0, -1, tree_size, -1
|
||||
|
||||
# Calculate final score
|
||||
tree_size = calc_tree_size(tree_answer)
|
||||
rel_distance = distance / tree_size
|
||||
score = score_calc(distance, tree_size)
|
||||
|
||||
return score, rel_distance, tree_size, distance
|
||||
|
||||
|
||||
def extract_boxed_content(latex_str: str) -> Optional[str]:
|
||||
"""
|
||||
Extract content from \\boxed{} in a LaTeX string.
|
||||
|
||||
Args:
|
||||
latex_str: LaTeX string potentially containing \\boxed{}
|
||||
|
||||
Returns:
|
||||
Content inside \\boxed{}, or None if not found
|
||||
"""
|
||||
# Pattern to match \boxed{...} with nested braces
|
||||
pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
|
||||
match = re.search(pattern, latex_str)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
def extract_all_boxed(latex_str: str) -> List[str]:
|
||||
"""
|
||||
Extract all \\boxed{} contents from a LaTeX string.
|
||||
|
||||
Args:
|
||||
latex_str: LaTeX string
|
||||
|
||||
Returns:
|
||||
List of contents from all \\boxed{} occurrences
|
||||
"""
|
||||
pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
|
||||
return re.findall(pattern, latex_str)
|
||||
|
||||
711
environments/eval_environments/phybench_eval.py
Normal file
711
environments/eval_environments/phybench_eval.py
Normal file
|
|
@ -0,0 +1,711 @@
|
|||
"""
|
||||
PHYBench Evaluation Environment for Atropos.
|
||||
|
||||
This environment evaluates models on PHYBench - a benchmark for evaluating
|
||||
physical perception and reasoning capabilities in Large Language Models.
|
||||
|
||||
Dataset: Eureka-Lab/PHYBench
|
||||
Paper: https://arxiv.org/abs/2504.16074
|
||||
Website: https://www.phybench.cn/
|
||||
|
||||
PHYBench is a human-curated benchmark with 500 original physics problems spanning:
|
||||
- Mechanics (MECHANICS)
|
||||
- Electromagnetism (ELECTRICITY)
|
||||
- Thermodynamics (THERMODYNAMICS)
|
||||
- Optics (OPTICS)
|
||||
- Modern Physics (MODERN)
|
||||
- Advanced Physics (ADVANCED)
|
||||
|
||||
Key features:
|
||||
- Original problems to prevent data contamination
|
||||
- Symbolic expression answers in LaTeX format
|
||||
- Two evaluation metrics:
|
||||
1. Binary Accuracy: Exact match using SymPy equivalence
|
||||
2. EED Score: Expression Edit Distance for partial credit (0-100)
|
||||
|
||||
The EED Score provides:
|
||||
- 204% improved sample efficiency over binary scoring
|
||||
- Continuous scoring that captures partial correctness
|
||||
- Differentiation between minor coefficient errors and structural errors
|
||||
|
||||
Supports thinking mode with <think></think> tags for extended reasoning.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from eed_score import EED_AVAILABLE, compute_eed_score, extract_all_boxed
|
||||
from eval_helpers import (
|
||||
THINK_CONTENT_AFTER_PATTERN,
|
||||
create_system_content,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
save_eval_results,
|
||||
validate_thinking_format,
|
||||
)
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
)
|
||||
|
||||
|
||||
# Physics domain tags in PHYBench
|
||||
PHYBENCH_TAGS = [
|
||||
"MECHANICS",
|
||||
"ELECTRICITY",
|
||||
"THERMODYNAMICS",
|
||||
"OPTICS",
|
||||
"MODERN",
|
||||
"ADVANCED",
|
||||
]
|
||||
|
||||
# Prompt template for PHYBench evaluation
|
||||
PHYBENCH_PROMPT_TEMPLATE = """You are a physics expert. Please read the following question and provide a step-by-step solution.
|
||||
|
||||
Put your final answer, which must be a readable LaTeX formula, in a \\boxed{{}} environment.
|
||||
|
||||
Question: {problem}
|
||||
|
||||
Answer:"""
|
||||
|
||||
# Alternative prompt with more detailed instructions
|
||||
PHYBENCH_DETAILED_PROMPT_TEMPLATE = """Solve the following physics problem. Show your reasoning step by step.
|
||||
|
||||
Your final answer should be a single symbolic expression (e.g., $\\sqrt{{\\frac{{2g}}{{3R}}}}$).
|
||||
- Equivalent forms are accepted
|
||||
- No numerical approximations
|
||||
- No equation chains
|
||||
|
||||
Put your final answer in \\boxed{{}} format.
|
||||
For example: \\boxed{{2mg + \\frac{{4mv_0^2}}{{l}}}}
|
||||
|
||||
Problem:
|
||||
{problem}
|
||||
|
||||
Solution:"""
|
||||
|
||||
|
||||
class PHYBenchEvalConfig(BaseEnvConfig):
|
||||
"""Configuration for PHYBench evaluation environment."""
|
||||
|
||||
# Dataset configuration
|
||||
dataset_name: str = Field(
|
||||
default="Eureka-Lab/PHYBench",
|
||||
description="HuggingFace dataset name",
|
||||
)
|
||||
eval_split: str = Field(
|
||||
default="train",
|
||||
description="Split to evaluate on (PHYBench only has train split)",
|
||||
)
|
||||
shuffle_seed: int = Field(
|
||||
default=42,
|
||||
description="Random seed for shuffling",
|
||||
)
|
||||
max_samples: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Maximum number of samples to evaluate (None = all)",
|
||||
)
|
||||
tags_filter: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Filter to specific physics domains (e.g., ['MECHANICS', 'OPTICS'])",
|
||||
)
|
||||
|
||||
# Generation parameters
|
||||
eval_temperature: float = Field(
|
||||
default=0.6,
|
||||
description="Temperature for evaluation generation",
|
||||
)
|
||||
eval_max_tokens: int = Field(
|
||||
default=0,
|
||||
description="Max tokens for evaluation (0 = use model default)",
|
||||
)
|
||||
|
||||
# System prompt configuration
|
||||
custom_system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional custom system prompt",
|
||||
)
|
||||
|
||||
# Thinking mode configuration
|
||||
thinking_mode: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use thinking mode with <think></think> tags",
|
||||
)
|
||||
custom_thinking_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional custom thinking prompt",
|
||||
)
|
||||
|
||||
# Prompt configuration
|
||||
use_detailed_prompt: bool = Field(
|
||||
default=False,
|
||||
description="Use detailed prompt with more instructions",
|
||||
)
|
||||
|
||||
# Scoring configuration
|
||||
compute_eed_score: bool = Field(
|
||||
default=True,
|
||||
description="Whether to compute EED Score (requires latex2sympy2_extended)",
|
||||
)
|
||||
|
||||
# Retry and debug configuration
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Maximum retries for failed API calls",
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0,
|
||||
description="Delay between retries in seconds",
|
||||
)
|
||||
min_response_length: int = Field(
|
||||
default=1,
|
||||
description="Minimum response length to consider valid",
|
||||
)
|
||||
full_debug: bool = Field(
|
||||
default=False,
|
||||
description="Enable full debug output",
|
||||
)
|
||||
|
||||
# Override defaults for eval-only environment
|
||||
group_size: int = 1
|
||||
max_num_workers: int = 1024
|
||||
max_eval_workers: int = 256
|
||||
max_num_workers_per_node: int = 128
|
||||
use_wandb: bool = True
|
||||
rollout_server_url: str = "http://localhost:8000"
|
||||
total_steps: int = 1
|
||||
wandb_name: str = "phybench_eval"
|
||||
steps_per_eval: int = 1
|
||||
|
||||
|
||||
class PHYBenchEvalEnv(BaseEnv):
|
||||
"""
|
||||
PHYBench Evaluation Environment.
|
||||
|
||||
Evaluates models on physics problems requiring symbolic expression answers.
|
||||
Uses both binary accuracy and EED Score for comprehensive evaluation.
|
||||
"""
|
||||
|
||||
name = "phybench_eval"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PHYBenchEvalConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm_job_id: Optional[str] = None,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm_job_id, testing)
|
||||
self.config: PHYBenchEvalConfig = config
|
||||
self.eval_items: List[Dict] = []
|
||||
self._dataset_loaded = False
|
||||
|
||||
# Pre-compile regex patterns for answer extraction
|
||||
self._boxed_pattern = re.compile(r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}")
|
||||
|
||||
# Check EED availability
|
||||
if self.config.compute_eed_score and not EED_AVAILABLE:
|
||||
print(
|
||||
"Warning: EED Score requested but latex2sympy2_extended not available. "
|
||||
"Install with: pip install latex2sympy2_extended sympy"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def config_cls(cls) -> type:
|
||||
return PHYBenchEvalConfig
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Initialize the environment and load the dataset."""
|
||||
await super().setup()
|
||||
|
||||
if not self._dataset_loaded:
|
||||
await self._load_dataset()
|
||||
|
||||
print("\nPHYBench Evaluation Setup:")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Evaluation split: {self.config.eval_split}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
print(f" EED Score enabled: {self.config.compute_eed_score and EED_AVAILABLE}")
|
||||
if self.config.thinking_mode:
|
||||
thinking_prompt = get_default_thinking_prompt(self.config.custom_thinking_prompt)
|
||||
print(f" Thinking prompt: {thinking_prompt[:80]}...")
|
||||
if self.config.tags_filter:
|
||||
print(f" Tags filter: {self.config.tags_filter}")
|
||||
print(f" Loaded {len(self.eval_items)} evaluation items")
|
||||
|
||||
async def _load_dataset(self) -> None:
|
||||
"""Load and process the PHYBench dataset."""
|
||||
print(f"Loading PHYBench dataset: {self.config.dataset_name}...")
|
||||
|
||||
try:
|
||||
dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading dataset: {e}")
|
||||
raise
|
||||
|
||||
if self.config.eval_split not in dataset:
|
||||
available_splits = list(dataset.keys())
|
||||
raise ValueError(
|
||||
f"Split '{self.config.eval_split}' not found. Available: {available_splits}"
|
||||
)
|
||||
|
||||
split_data = dataset[self.config.eval_split]
|
||||
|
||||
# Process items
|
||||
self.eval_items = []
|
||||
tag_counts: Dict[str, int] = {}
|
||||
|
||||
for item in split_data:
|
||||
problem_id = item.get("id", "")
|
||||
tag = item.get("tag", "UNKNOWN")
|
||||
content = item.get("content", "")
|
||||
solution = item.get("solution", "")
|
||||
answer = item.get("answer", "")
|
||||
|
||||
# Skip if no content or answer
|
||||
if not content or not answer:
|
||||
continue
|
||||
|
||||
# Apply tag filter if specified
|
||||
if self.config.tags_filter and tag not in self.config.tags_filter:
|
||||
continue
|
||||
|
||||
# Track tag distribution
|
||||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||
|
||||
self.eval_items.append({
|
||||
"id": problem_id,
|
||||
"tag": tag,
|
||||
"content": content,
|
||||
"solution": solution,
|
||||
"answer": answer,
|
||||
})
|
||||
|
||||
# Shuffle with seed for reproducibility
|
||||
random.seed(self.config.shuffle_seed)
|
||||
random.shuffle(self.eval_items)
|
||||
|
||||
# Apply max_samples limit if specified
|
||||
if self.config.max_samples and len(self.eval_items) > self.config.max_samples:
|
||||
self.eval_items = self.eval_items[: self.config.max_samples]
|
||||
|
||||
self._dataset_loaded = True
|
||||
|
||||
# Print tag distribution
|
||||
print(f"Loaded {len(self.eval_items)} items")
|
||||
print("Tag distribution:")
|
||||
for tag, count in sorted(tag_counts.items()):
|
||||
print(f" {tag}: {count}")
|
||||
|
||||
def _format_prompt(self, item: Dict) -> str:
|
||||
"""Format the problem into a prompt."""
|
||||
if self.config.use_detailed_prompt:
|
||||
return PHYBENCH_DETAILED_PROMPT_TEMPLATE.format(problem=item["content"])
|
||||
return PHYBENCH_PROMPT_TEMPLATE.format(problem=item["content"])
|
||||
|
||||
def _create_system_content(self) -> str:
|
||||
"""Create system message content based on thinking mode."""
|
||||
return (
|
||||
create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt,
|
||||
)
|
||||
or ""
|
||||
)
|
||||
|
||||
def _extract_answer(self, response: str, debug: bool = False) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Extract the answer from the model's response.
|
||||
|
||||
Looks for \\boxed{} content. If multiple found, uses the last one.
|
||||
|
||||
Args:
|
||||
response: Model's response text
|
||||
debug: Whether to print debug info
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_answer, extraction_method)
|
||||
"""
|
||||
if not response:
|
||||
return None, "empty_response"
|
||||
|
||||
# Find all boxed answers
|
||||
boxed_answers = extract_all_boxed(response)
|
||||
|
||||
if not boxed_answers:
|
||||
if debug:
|
||||
print(" No \\boxed{} found in response")
|
||||
return None, "no_boxed"
|
||||
|
||||
if len(boxed_answers) > 1:
|
||||
if debug:
|
||||
print(f" Multiple \\boxed{{}} found ({len(boxed_answers)}), using last one")
|
||||
return boxed_answers[-1], "boxed_last"
|
||||
|
||||
return boxed_answers[0], "boxed"
|
||||
|
||||
def _check_equivalence(
|
||||
self,
|
||||
predicted: str,
|
||||
gold: str,
|
||||
debug: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if predicted answer is equivalent to gold answer.
|
||||
|
||||
Uses SymPy for symbolic equivalence checking.
|
||||
|
||||
Args:
|
||||
predicted: Predicted answer in LaTeX
|
||||
gold: Gold answer in LaTeX
|
||||
debug: Whether to print debug info
|
||||
|
||||
Returns:
|
||||
Tuple of (is_correct, method)
|
||||
"""
|
||||
if not predicted:
|
||||
return False, "empty_prediction"
|
||||
|
||||
# Clean up the answers
|
||||
pred_clean = predicted.strip()
|
||||
gold_clean = gold.strip()
|
||||
|
||||
# Exact string match
|
||||
if pred_clean == gold_clean:
|
||||
return True, "exact_match"
|
||||
|
||||
# Try EED Score - if score is 100, they're equivalent
|
||||
if self.config.compute_eed_score and EED_AVAILABLE:
|
||||
try:
|
||||
score, _, _, _ = compute_eed_score(gold_clean, pred_clean, debug_mode=False)
|
||||
if score == 100:
|
||||
return True, "sympy_equivalent"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False, "not_equivalent"
|
||||
|
||||
def _compute_scores(
|
||||
self,
|
||||
predicted: str,
|
||||
gold: str,
|
||||
debug: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
Compute both accuracy and EED Score.
|
||||
|
||||
Args:
|
||||
predicted: Predicted answer
|
||||
gold: Gold answer
|
||||
debug: Whether to print debug info
|
||||
|
||||
Returns:
|
||||
Dictionary with scoring results
|
||||
"""
|
||||
result = {
|
||||
"is_correct": False,
|
||||
"match_method": "none",
|
||||
"eed_score": 0.0,
|
||||
"eed_rel_distance": -1,
|
||||
"eed_tree_size": -1,
|
||||
"eed_distance": -1,
|
||||
}
|
||||
|
||||
if not predicted:
|
||||
return result
|
||||
|
||||
# Check equivalence (for binary accuracy)
|
||||
is_correct, match_method = self._check_equivalence(predicted, gold, debug)
|
||||
result["is_correct"] = is_correct
|
||||
result["match_method"] = match_method
|
||||
|
||||
# Compute EED Score if enabled
|
||||
if self.config.compute_eed_score and EED_AVAILABLE:
|
||||
try:
|
||||
eed_score, rel_dist, tree_size, distance = compute_eed_score(
|
||||
gold, predicted, debug_mode=debug
|
||||
)
|
||||
result["eed_score"] = eed_score
|
||||
result["eed_rel_distance"] = rel_dist
|
||||
result["eed_tree_size"] = tree_size
|
||||
result["eed_distance"] = distance
|
||||
|
||||
# If EED score is 100, mark as correct
|
||||
if eed_score == 100 and not is_correct:
|
||||
result["is_correct"] = True
|
||||
result["match_method"] = "eed_equivalent"
|
||||
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" EED Score error: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def rollout_and_score_eval(
|
||||
self,
|
||||
item: Dict,
|
||||
server: APIServerConfig,
|
||||
) -> Optional[Dict]:
|
||||
"""Run evaluation on a single item and return the result."""
|
||||
prompt = self._format_prompt(item)
|
||||
system_content = self._create_system_content()
|
||||
|
||||
messages = []
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Build API call parameters
|
||||
kwargs = {
|
||||
"model": server.model_name,
|
||||
"messages": messages,
|
||||
"temperature": self.config.eval_temperature,
|
||||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
response_text = ""
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
response = await self.server.chat_completion(**kwargs)
|
||||
response_text = response.choices[0].message.content or ""
|
||||
|
||||
if len(response_text) >= self.config.min_response_length:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f" API error (attempt {attempt + 1}): {e}")
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
continue
|
||||
|
||||
if not response_text:
|
||||
return None
|
||||
|
||||
# Validate thinking format and extract content after </think>
|
||||
is_valid_format, content_for_extraction = validate_thinking_format(
|
||||
response_text, self.config.thinking_mode
|
||||
)
|
||||
|
||||
# Extract thinking content if present
|
||||
thinking_content = (
|
||||
extract_thinking_content(response_text) if self.config.thinking_mode else None
|
||||
)
|
||||
|
||||
# Get content for answer extraction
|
||||
if self.config.thinking_mode:
|
||||
match = THINK_CONTENT_AFTER_PATTERN.search(response_text)
|
||||
if match:
|
||||
answer_content = match.group(1)
|
||||
else:
|
||||
answer_content = response_text
|
||||
else:
|
||||
answer_content = response_text
|
||||
|
||||
# Extract answer
|
||||
extracted_answer, extraction_method = self._extract_answer(
|
||||
answer_content, debug=self.config.full_debug
|
||||
)
|
||||
|
||||
# Compute scores
|
||||
gold_answer = item["answer"]
|
||||
scores = self._compute_scores(extracted_answer, gold_answer, debug=self.config.full_debug)
|
||||
|
||||
if self.config.full_debug:
|
||||
status = "✓" if scores["is_correct"] else "✗"
|
||||
eed = scores["eed_score"]
|
||||
print(f" [{status}] {item['tag']}: EED={eed:.1f}, gold={gold_answer[:50]}...")
|
||||
|
||||
return {
|
||||
"item_id": item["id"],
|
||||
"tag": item["tag"],
|
||||
"content": item["content"][:200],
|
||||
"gold_answer": gold_answer,
|
||||
"extracted_answer": extracted_answer,
|
||||
"extraction_method": extraction_method,
|
||||
"is_correct": scores["is_correct"],
|
||||
"match_method": scores["match_method"],
|
||||
"eed_score": scores["eed_score"],
|
||||
"eed_rel_distance": scores["eed_rel_distance"],
|
||||
"eed_tree_size": scores["eed_tree_size"],
|
||||
"eed_distance": scores["eed_distance"],
|
||||
"format_valid": is_valid_format,
|
||||
"response": response_text,
|
||||
"response_length": len(response_text),
|
||||
"thinking_content": thinking_content,
|
||||
"has_thinking": thinking_content is not None,
|
||||
}
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> Dict:
|
||||
"""Run the full PHYBench evaluation."""
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting PHYBench Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Total questions: {len(self.eval_items)}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
print(f" EED Score: {self.config.compute_eed_score and EED_AVAILABLE}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Create evaluation tasks
|
||||
async def eval_task(item):
|
||||
return await self.rollout_and_score_eval(item, self.server_configs[0])
|
||||
|
||||
tasks = [eval_task(item) for item in self.eval_items]
|
||||
|
||||
# Run with progress bar
|
||||
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating PHYBench")
|
||||
|
||||
# Filter out failed results
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return {"error": "No valid results", "accuracy": 0.0}
|
||||
|
||||
# Calculate metrics
|
||||
total = len(valid_results)
|
||||
correct = sum(1 for r in valid_results if r["is_correct"])
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
# Calculate average EED Score
|
||||
eed_scores = [r["eed_score"] for r in valid_results if r["eed_score"] >= 0]
|
||||
avg_eed_score = sum(eed_scores) / len(eed_scores) if eed_scores else 0.0
|
||||
|
||||
# Calculate per-tag metrics
|
||||
tag_metrics: Dict[str, Dict] = {}
|
||||
for r in valid_results:
|
||||
tag = r.get("tag", "UNKNOWN")
|
||||
if tag not in tag_metrics:
|
||||
tag_metrics[tag] = {"total": 0, "correct": 0, "eed_scores": []}
|
||||
tag_metrics[tag]["total"] += 1
|
||||
if r["is_correct"]:
|
||||
tag_metrics[tag]["correct"] += 1
|
||||
if r["eed_score"] >= 0:
|
||||
tag_metrics[tag]["eed_scores"].append(r["eed_score"])
|
||||
|
||||
for tag in tag_metrics:
|
||||
t_total = tag_metrics[tag]["total"]
|
||||
t_correct = tag_metrics[tag]["correct"]
|
||||
t_eed_scores = tag_metrics[tag]["eed_scores"]
|
||||
tag_metrics[tag]["accuracy"] = t_correct / t_total if t_total > 0 else 0.0
|
||||
tag_metrics[tag]["avg_eed_score"] = (
|
||||
sum(t_eed_scores) / len(t_eed_scores) if t_eed_scores else 0.0
|
||||
)
|
||||
|
||||
# Calculate extraction method statistics
|
||||
extraction_methods: Dict[str, int] = {}
|
||||
for r in valid_results:
|
||||
method = r.get("extraction_method", "unknown")
|
||||
extraction_methods[method] = extraction_methods.get(method, 0) + 1
|
||||
|
||||
# Format compliance and thinking utilization
|
||||
format_valid = sum(1 for r in valid_results if r.get("format_valid", True))
|
||||
has_thinking = sum(1 for r in valid_results if r.get("has_thinking", False))
|
||||
has_boxed = sum(1 for r in valid_results if r.get("extracted_answer") is not None)
|
||||
|
||||
# Average response length
|
||||
response_lengths = [r.get("response_length", 0) for r in valid_results]
|
||||
avg_response_length = sum(response_lengths) / len(response_lengths) if response_lengths else 0
|
||||
|
||||
metrics = {
|
||||
"accuracy": accuracy,
|
||||
"avg_eed_score": avg_eed_score,
|
||||
"total_evaluated": total,
|
||||
"total_correct": correct,
|
||||
"has_boxed_rate": has_boxed / total if total > 0 else 0.0,
|
||||
"format_compliance_rate": format_valid / total if total > 0 else 0.0,
|
||||
"thinking_utilization_rate": has_thinking / total if total > 0 else 0.0,
|
||||
"avg_response_length": avg_response_length,
|
||||
"tag_metrics": tag_metrics,
|
||||
"extraction_methods": extraction_methods,
|
||||
}
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
print("PHYBench Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(f" Overall Accuracy: {accuracy:.2%} ({correct}/{total})")
|
||||
print(f" Average EED Score: {avg_eed_score:.1f}/100")
|
||||
print(f" Has \\boxed{{}} Rate: {has_boxed / total:.2%}")
|
||||
print(f" Avg Response Length: {avg_response_length:.0f} chars")
|
||||
if self.config.thinking_mode:
|
||||
print(f" Format Compliance: {format_valid / total:.2%}")
|
||||
print(f" Thinking Utilization: {has_thinking / total:.2%}")
|
||||
|
||||
print("\n Per-Tag Breakdown:")
|
||||
for tag in sorted(tag_metrics.keys()):
|
||||
data = tag_metrics[tag]
|
||||
acc = data["accuracy"]
|
||||
eed = data["avg_eed_score"]
|
||||
cnt = data["total"]
|
||||
print(f" {tag}: Acc={acc:.2%}, EED={eed:.1f} ({cnt} items)")
|
||||
|
||||
print("\n Extraction Methods:")
|
||||
for method, count in sorted(extraction_methods.items(), key=lambda x: -x[1]):
|
||||
print(f" {method}: {count}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Save results
|
||||
if self.config.data_dir_to_save_evals:
|
||||
self._save_results(metrics, valid_results)
|
||||
|
||||
return metrics
|
||||
|
||||
def _save_results(self, metrics: Dict, results: List[Dict]) -> None:
|
||||
"""Save evaluation results to disk."""
|
||||
save_eval_results(self.config.data_dir_to_save_evals, metrics, results)
|
||||
|
||||
async def wandb_log(self, metrics: Dict, step: int = 0) -> None:
|
||||
"""Log metrics to Weights & Biases."""
|
||||
if not self.config.use_wandb:
|
||||
return
|
||||
|
||||
log_dict = {
|
||||
"phybench/accuracy": metrics.get("accuracy", 0),
|
||||
"phybench/avg_eed_score": metrics.get("avg_eed_score", 0),
|
||||
"phybench/total_evaluated": metrics.get("total_evaluated", 0),
|
||||
"phybench/has_boxed_rate": metrics.get("has_boxed_rate", 0),
|
||||
"phybench/format_compliance_rate": metrics.get("format_compliance_rate", 0),
|
||||
"phybench/thinking_utilization_rate": metrics.get("thinking_utilization_rate", 0),
|
||||
"phybench/avg_response_length": metrics.get("avg_response_length", 0),
|
||||
}
|
||||
|
||||
# Log per-tag metrics
|
||||
for tag, data in metrics.get("tag_metrics", {}).items():
|
||||
safe_tag = tag.lower()
|
||||
log_dict[f"phybench/accuracy_{safe_tag}"] = data.get("accuracy", 0)
|
||||
log_dict[f"phybench/eed_score_{safe_tag}"] = data.get("avg_eed_score", 0)
|
||||
|
||||
wandb.log(log_dict, step=step)
|
||||
|
||||
# Required abstract method implementations
|
||||
async def get_next_item(self) -> Optional[Dict]:
|
||||
"""Not used in evaluation mode."""
|
||||
return None
|
||||
|
||||
async def collect_trajectories(self, item) -> List:
|
||||
"""Not used in evaluation mode."""
|
||||
return []
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[List]:
|
||||
"""Not used in evaluation mode."""
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PHYBenchEvalEnv.cli()
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue