add phybench eval

This commit is contained in:
teknium 2025-12-28 01:44:20 +00:00
parent d04f8c0ae7
commit 830a129655
2 changed files with 1681 additions and 0 deletions

View file

@ -0,0 +1,970 @@
"""
Expression Edit Distance (EED) Score Module for PHYBench Evaluation.
This module implements the EED Score metric for evaluating mathematical expressions,
as described in the PHYBench paper (https://arxiv.org/abs/2504.16074).
The EED Score measures similarity between model-generated and reference expressions
by computing tree edit distance over their SymPy expression trees. It provides:
- Continuous scoring (0-100) that captures partial correctness
- 204% improved sample efficiency over binary scoring
- Robust handling of equivalent mathematical forms
Key components:
- LaTeX preprocessing and normalization
- SymPy expression tree construction
- Extended Zhang-Shasha tree edit distance algorithm
- Configurable scoring with subtree discount
Dependencies:
- sympy: Symbolic mathematics
- latex2sympy2_extended: LaTeX to SymPy conversion
- numpy: Numerical operations
Based on the official PHYBench implementation:
https://github.com/phybench-official/phybench/tree/main/EED
"""
import re
from typing import List, Optional, Tuple
from numpy import ones, zeros
# Try to import required dependencies
try:
from latex2sympy2_extended import latex2sympy
from sympy import (
Add,
Float,
Function,
Integer,
Mul,
Pow,
Rational,
Symbol,
expand,
posify,
simplify,
)
from sympy.core.numbers import Exp1, Infinity, NegativeInfinity, Pi
EED_AVAILABLE = True
except ImportError:
EED_AVAILABLE = False
latex2sympy = None
# =============================================================================
# CONFIGURATION CONSTANTS
# =============================================================================
# Cost configuration for tree edit operations
# These can be modified if different node types should have different weights
INSERT_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1}
DELETE_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1}
UPDATE_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1}
# Cost of updating between different types (e.g., number -> symbol)
CHANGE_TYPE_COST = 1
# Subtree discount configuration
# Minimum size to trigger cluster discount for subtree operations
BAR_SIZE = 5
# Discount slope for subtree operations (0.6 means 40% discount)
DISCOUNT_SLOPE = 0.6
# Timeout limits (in seconds)
SIMPLIFY_TIME_LIMIT = 30
EQUALS_TIME_LIMIT = 10
# =============================================================================
# TREE NODE CLASS
# =============================================================================
class TreeNode:
"""
A node in the expression tree representation.
Attributes:
label: Node label (e.g., "number_2", "symbol_x", "operator_Add")
children: List of child TreeNode objects
subtree_size: Cached size of the subtree rooted at this node
"""
def __init__(self, label: str, children: Optional[List["TreeNode"]] = None):
self.label = label
self.children = children if children is not None else []
self.subtree_size = 0
def get_children(self) -> List["TreeNode"]:
"""Return the list of child nodes."""
return self.children
def __str__(self) -> str:
return self.label
# =============================================================================
# TREE EDIT DISTANCE FUNCTIONS
# =============================================================================
def calc_tree_size(node: TreeNode) -> int:
"""
Calculate the size of a subtree based on total insertion cost.
The size equals the sum of insertion costs of all nodes in the subtree.
Results are cached in node.subtree_size for efficiency.
Args:
node: Root node of the subtree
Returns:
Total size of the subtree
"""
# Get insertion cost for this node type
node_type = node.label.split("_")[0]
total = INSERT_COST.get(node_type, 1)
# Return cached value if available (for non-leaf nodes)
if node.children and node.subtree_size != 0:
return node.subtree_size
# Recursively calculate size of children
for child in node.children:
total += calc_tree_size(child)
# Cache the result
node.subtree_size = total
return total
def update_func(x: TreeNode, y: TreeNode) -> float:
"""
Calculate the cost of updating node x to node y.
Args:
x: Source node
y: Target node
Returns:
Update cost (0 if identical, type-specific cost if same type, else CHANGE_TYPE_COST)
"""
if x.label == y.label:
return 0
x_type = x.label.split("_")[0]
y_type = y.label.split("_")[0]
if x_type == y_type:
return UPDATE_COST.get(x_type, 1)
return CHANGE_TYPE_COST
def remove_func(x: TreeNode) -> float:
"""Calculate the cost of removing a single node."""
node_type = x.label.split("_")[0]
return DELETE_COST.get(node_type, 1)
def remove_tree_func(x: TreeNode) -> float:
"""
Calculate the cost of removing an entire subtree.
Applies discount for large subtrees (cluster discount).
Args:
x: Root of subtree to remove
Returns:
Removal cost with potential discount
"""
if not x.children:
return remove_func(x)
size = calc_tree_size(x)
# Apply discount for large subtrees
return min(size, DISCOUNT_SLOPE * (size - BAR_SIZE) + BAR_SIZE)
def insert_func(x: TreeNode) -> float:
"""Calculate the cost of inserting a single node."""
node_type = x.label.split("_")[0]
return INSERT_COST.get(node_type, 1)
def insert_tree_func(x: TreeNode) -> float:
"""Calculate the cost of inserting an entire subtree (same as removal)."""
return remove_tree_func(x)
# =============================================================================
# ANNOTATED TREE FOR ZHANG-SHASHA ALGORITHM
# =============================================================================
class AnnotatedTree:
"""
Annotated tree structure for the Zhang-Shasha algorithm.
Computes post-order enumeration, left-most descendants, and keyroots
needed for efficient tree edit distance calculation.
"""
def __init__(self, root: TreeNode, get_children):
self.get_children = get_children
self.root = root
self.nodes = [] # Post-order enumeration of nodes
self.ids = [] # Matching list of IDs
self.lmds = [] # Left-most descendants
self.keyroots = None
# Build the annotated structure
import collections
stack = [(root, collections.deque())]
pstack = []
j = 0
while stack:
n, anc = stack.pop()
nid = j
for c in self.get_children(n):
a = collections.deque(anc)
a.appendleft(nid)
stack.append((c, a))
pstack.append(((n, nid), anc))
j += 1
lmds = {}
keyroots = {}
i = 0
while pstack:
(n, nid), anc = pstack.pop()
self.nodes.append(n)
self.ids.append(nid)
if not self.get_children(n):
lmd = i
for a in anc:
if a not in lmds:
lmds[a] = i
else:
break
else:
lmd = lmds.get(nid, i)
self.lmds.append(lmd)
keyroots[lmd] = i
i += 1
self.keyroots = sorted(keyroots.values())
def ext_distance(
a_root: TreeNode,
b_root: TreeNode,
get_children,
single_insert_cost,
insert_cost,
single_remove_cost,
remove_cost,
update_cost_func,
) -> float:
"""
Compute extended tree edit distance using modified Zhang-Shasha algorithm.
This implementation extends the standard algorithm with subtree insertion
and deletion operations for handling clustered changes.
Args:
a_root: Root of first tree
b_root: Root of second tree
get_children: Function to get children of a node
single_insert_cost: Cost function for single node insertion
insert_cost: Cost function for subtree insertion
single_remove_cost: Cost function for single node removal
remove_cost: Cost function for subtree removal
update_cost_func: Cost function for updating a node
Returns:
Tree edit distance between the two trees
"""
a_tree = AnnotatedTree(a_root, get_children)
b_tree = AnnotatedTree(b_root, get_children)
size_a = len(a_tree.nodes)
size_b = len(b_tree.nodes)
treedists = zeros((size_a, size_b), float)
fd = 1000 * ones((size_a + 1, size_b + 1), float)
def treedist(x: int, y: int):
al = a_tree.lmds
bl = b_tree.lmds
an = a_tree.nodes
bn = b_tree.nodes
fd[al[x]][bl[y]] = 0
for i in range(al[x], x + 1):
node = an[i]
fd[i + 1][bl[y]] = fd[al[i]][bl[y]] + remove_cost(node)
for j in range(bl[y], y + 1):
node = bn[j]
fd[al[x]][j + 1] = fd[al[x]][bl[j]] + insert_cost(node)
for i in range(al[x], x + 1):
for j in range(bl[y], y + 1):
node1 = an[i]
node2 = bn[j]
costs = [
fd[i][j + 1] + single_remove_cost(node1),
fd[i + 1][j] + single_insert_cost(node2),
fd[al[i]][j + 1] + remove_cost(node1),
fd[i + 1][bl[j]] + insert_cost(node2),
]
m = min(costs)
if al[x] == al[i] and bl[y] == bl[j]:
treedists[i][j] = min(m, fd[i][j] + update_cost_func(node1, node2))
fd[i + 1][j + 1] = treedists[i][j]
else:
fd[i + 1][j + 1] = min(m, fd[al[i]][bl[j]] + treedists[i][j])
for x in a_tree.keyroots:
for y in b_tree.keyroots:
treedist(x, y)
return treedists[-1][-1]
# =============================================================================
# LATEX PREPROCESSING
# =============================================================================
def brackets_balanced(s: str) -> bool:
"""Check if brackets in a string are balanced."""
stack = []
bracket_pairs = {")": "(", "]": "[", "}": "{"}
for char in s:
if char in bracket_pairs.values():
stack.append(char)
elif char in bracket_pairs:
if not stack or stack[-1] != bracket_pairs[char]:
return False
stack.pop()
return len(stack) == 0
def find_first_unescaped_brace(s: str) -> int:
"""Find the position of the first unescaped opening brace."""
escaped = False
for i, c in enumerate(s):
if c == "\\" and not escaped:
escaped = True
continue
if c == "{" and not escaped:
return i
escaped = False
return -1
def extract_bracket_content(s: str, bracket_position: int) -> Tuple[Optional[str], int]:
"""Extract content inside braces starting at given position."""
brace_start = bracket_position + 1
brace_depth = 0
content = []
escaped = False
for i in range(brace_start, len(s)):
char = s[i]
if escaped:
content.append(char)
escaped = False
continue
if char == "\\":
escaped = True
content.append(char)
continue
if char == "{":
brace_depth += 1
content.append(char)
elif char == "}":
if brace_depth == 0:
return "".join(content), i
brace_depth -= 1
content.append(char)
else:
content.append(char)
return None, -1
def remove_command(s: str, command: str, keep_inside: bool = False) -> str:
"""
Remove all occurrences of a LaTeX command from a string.
Args:
s: Input string
command: Command to remove (e.g., "\\textbf")
keep_inside: If True, preserve content inside braces
Returns:
String with command removed
"""
pos = s.find(command)
if pos < 0:
return s
end_index = pos + len(command)
level = 0
if end_index < len(s) and s[end_index] == "{":
while end_index < len(s):
if s[end_index] == "{":
level += 1
elif s[end_index] == "}":
level -= 1
if level == 0:
break
end_index += 1
if keep_inside:
s1 = s[:pos] + s[pos + len(command) + 1 : end_index] + s[end_index + 1 :]
else:
s1 = s[:pos] + s[end_index + 1 :]
else:
s1 = s[:pos] + s[end_index:]
if command not in s1:
return s1
return remove_command(s1, command, keep_inside)
def extract_last_equal_content(s: str) -> str:
"""Extract content after the last equality/comparison operator."""
comparison_operators = ("=", "\\approx", "\\ge", "\\le", "\\geq", "\\leq", "<", ">")
content = s
for sign in comparison_operators:
if sign in s:
rfind_index = s.rfind(sign)
if rfind_index != -1:
content = s[rfind_index + len(sign) :]
return content.strip()
def convert_latex_fractions(latex_str: str) -> str:
"""Convert non-standard fractions like \\frac\\alpha2 to \\frac{\\alpha}{2}."""
pattern = (
r"\\frac((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))"
r"((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))"
)
def replacer(match):
numerator, denominator = match.group(1), match.group(2)
wrap_num = (
f"{{{numerator}}}"
if not (numerator.startswith("{") and numerator.endswith("}"))
else numerator
)
wrap_den = (
f"{{{denominator}}}"
if not (denominator.startswith("{") and denominator.endswith("}"))
else denominator
)
return rf"\frac{wrap_num}{wrap_den}"
return re.sub(pattern, replacer, latex_str)
def convert_vec_syntax(text: str) -> str:
"""Convert \\vec x to \\vec{x}."""
pattern = r"\\vec(\s*)(\\?[a-zA-Zα-ωΑ-Ω]+)"
replacement = r"\\vec{\2}"
return re.sub(pattern, replacement, text)
def first_preprocess(s: str, extract_box: bool = True) -> str:
"""
First stage of LaTeX preprocessing.
Extracts boxed content, removes outer braces, and extracts content after equals.
Args:
s: Input LaTeX string
extract_box: Whether to extract content from \\boxed{}
Returns:
Preprocessed string
"""
s = s.replace("\\{", "(")
s = s.replace("\\}", ")")
if not brackets_balanced(s):
return s
if extract_box:
boxed_content = remove_command(s, "\\boxed", keep_inside=True)
else:
boxed_content = s
# Remove overall braces
def remove_overall_brace(text: str) -> Tuple[str, bool]:
pos = find_first_unescaped_brace(text)
if pos == -1:
return text, False
content, final = extract_bracket_content(text, pos)
if content and (final == len(text) - 1 or "}" not in text[final + 1 :]):
# Check if there's a command before the brace
if pos > 0 and text[pos - 1] not in (" ", "\t", "\n"):
return text, False
return content, True
return text, False
# Remove outer braces iteratively
for _ in range(10):
boxed_content, changed = remove_overall_brace(boxed_content)
if not changed:
break
# Handle \\quad separator
if "\\quad" in boxed_content:
boxed_content = boxed_content.split("\\quad")[0]
# Extract content after last equals sign
last_equal_content = extract_last_equal_content(boxed_content)
# Remove outer braces again
for _ in range(10):
last_equal_content, changed = remove_overall_brace(last_equal_content)
if not changed:
break
return last_equal_content
def second_preprocess(s: str) -> str:
"""
Second stage of LaTeX preprocessing.
Removes/modifies LaTeX commands and normalizes expressions.
Args:
s: Input string from first preprocessing stage
Returns:
Normalized LaTeX string ready for conversion
"""
# Commands to completely remove (including their content)
kill_commands = ["\\begin", "\\end"]
# Commands to remove but keep their content
remove_commands = [
"\\text",
"\\mathbf",
"\\mathrm",
"\\pmb",
"\\hat",
"\\overline",
"\\boldsymbol",
]
# Content to remove entirely
remove_content = [
"\\,",
"$",
",",
"`",
"latex",
"\\left",
"\\right",
"\\Bigr",
"\\Bigl",
"\n",
"\\]",
"\\[",
"\\Big",
"\\bigl",
"\\bigr",
"\\biggl",
"\\biggr",
"\\displaystyle",
"\\infty",
]
# Content replacements
replace_content = [
("\\operatorname{asin}", "\\asin"),
("\\operatorname{sech}", "\\sech"),
("\\operatorname{acos}", "\\acos"),
("\\operatorname{sinh}", "\\sinh"),
("\\dfrac", "\\frac"),
("\\tfrac", "\\frac"),
("\\Exp", "\\exp"),
("\\times", "\\bar{times}"),
("\\partial", "\\bar{partial}"),
("\\perp", "\\bar{perp}"),
("\\epsilon", "\\varepsilon"),
("\\varOmega", "\\Omega"),
("I", "\\bar{I}"),
("_e", "_{e}"),
("e_", "\\bar{e}_"),
("E_", "\\bar{E}_"),
("\\pm", "+"),
("\\mp", "-"),
("{+}", "{p}"),
("{-}", "{m}"),
("_+", "_p"),
("_-", "_m"),
]
# Apply transformations
for command in kill_commands:
s = remove_command(s, command, keep_inside=False)
for command in remove_commands:
s = remove_command(s, command, keep_inside=True)
for content in remove_content:
s = s.replace(content, "")
for old, new in replace_content:
s = s.replace(old, new)
# Additional transformations
s = convert_latex_fractions(s)
s = convert_vec_syntax(s)
# Remove trailing period
if s and s[-1] == ".":
s = s[:-1]
return s
class LaTeXNormalizationConfig:
"""Configuration for latex2sympy normalization."""
basic_latex: bool = True
units: bool = False
malformed_operators: bool = True
nits: bool = True
boxed = "all"
equations: bool = False
class LaTeXConversionConfig:
"""Configuration for latex2sympy conversion."""
interpret_as_mixed_fractions: bool = False
interpret_simple_eq_as_assignment: bool = False
interpret_contains_as_eq: bool = True
lowercase_symbols: bool = False
def master_convert(s: str):
"""
Convert a LaTeX string to a SymPy expression.
This is the main conversion function that applies preprocessing
and uses latex2sympy for the actual conversion.
Args:
s: LaTeX string to convert
Returns:
SymPy expression
Raises:
Various exceptions if conversion fails
"""
if not EED_AVAILABLE:
raise ImportError("latex2sympy2_extended and sympy are required for EED scoring")
preprocessed_stage1 = first_preprocess(s)
preprocessed_stage2 = second_preprocess(preprocessed_stage1)
sym = latex2sympy(
preprocessed_stage2,
normalization_config=LaTeXNormalizationConfig(),
conversion_config=LaTeXConversionConfig(),
)
return sym
# =============================================================================
# SYMPY TO TREE CONVERSION
# =============================================================================
def sympy_to_tree(expr) -> TreeNode:
"""
Convert a SymPy expression to a tree structure.
Args:
expr: SymPy expression
Returns:
TreeNode representing the expression
Raises:
ValueError: If expression contains unsupported types
"""
# Numbers and constants
if isinstance(expr, (Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity)):
return TreeNode(label=f"number_{expr}", children=[])
# Symbols
if isinstance(expr, Symbol):
return TreeNode(label=f"symbol_{expr}", children=[])
# Binary operators (Add, Mul, Pow)
if isinstance(expr, (Add, Mul, Pow)):
op_name = type(expr).__name__
children = [sympy_to_tree(arg) for arg in expr.args]
return TreeNode(label=f"operator_{op_name}", children=children)
# Functions
if isinstance(expr, Function):
func_name = expr.func.__name__
children = [sympy_to_tree(arg) for arg in expr.args]
return TreeNode(label=f"function_{func_name}", children=children)
raise ValueError(f"Unsupported SymPy type: {type(expr).__name__}")
# =============================================================================
# SCORING FUNCTION
# =============================================================================
def score_calc(tree_dist: float, tree_size: int) -> float:
"""
Calculate EED score from tree distance and size.
The scoring function:
- 100 if distance is 0 (exact match)
- 60 - 100*r if 0 < r < 0.6 (partial credit)
- 0 if r >= 0.6 (too different)
where r = distance / tree_size
Args:
tree_dist: Tree edit distance
tree_size: Size of the ground truth tree
Returns:
Score between 0 and 100
"""
if tree_dist == 0:
return 100.0
return max(0, 100 * DISCOUNT_SLOPE - 100 * tree_dist / tree_size)
def time_simplify(expr, timeout: int = SIMPLIFY_TIME_LIMIT):
"""
Simplify expression with timeout protection.
Args:
expr: SymPy expression to simplify
timeout: Timeout in seconds
Returns:
Simplified expression, or original if timeout/error
"""
try:
# Note: For production use, consider using multiprocessing for true timeout
return simplify(expr)
except Exception:
return expr
def time_equal(expr1, expr2, timeout: int = EQUALS_TIME_LIMIT) -> bool:
"""
Check expression equality with timeout protection.
Args:
expr1: First expression
expr2: Second expression
timeout: Timeout in seconds
Returns:
True if expressions are equal, False otherwise
"""
try:
return expr1.equals(expr2)
except Exception:
return False
# =============================================================================
# MAIN EED FUNCTION
# =============================================================================
def compute_eed_score(
answer_latex: str,
test_latex: str,
debug_mode: bool = False,
) -> Tuple[float, float, int, float]:
"""
Compute the EED (Expression Edit Distance) Score between two LaTeX expressions.
This function evaluates the similarity between a ground truth answer and
a model-generated answer by:
1. Converting LaTeX to SymPy expressions
2. Simplifying and checking for equivalence
3. Building expression trees
4. Computing tree edit distance
5. Converting distance to a 0-100 score
Args:
answer_latex: Ground truth answer in LaTeX format
test_latex: Model-generated answer in LaTeX format
debug_mode: If True, raise exceptions instead of returning defaults
Returns:
Tuple of (score, relative_distance, tree_size, distance):
- score: EED score from 0-100 (100 = perfect match)
- relative_distance: distance / tree_size (-1 if error)
- tree_size: Size of ground truth tree (-1 if error)
- distance: Raw tree edit distance (-1 if error)
"""
if not EED_AVAILABLE:
if debug_mode:
raise ImportError("EED scoring requires latex2sympy2_extended and sympy")
return 0, -1, -1, -1
# Handle empty or invalid input
if not test_latex:
return 0, -1, -1, -1
# Skip unsupported expressions (integrals, sums)
if "\\int" in test_latex or "\\int" in answer_latex:
return 0, -1, -1, -1
if "\\sum" in test_latex or "\\sum" in answer_latex:
return 0, -1, -1, -1
# Quick check for exact string match
if answer_latex == test_latex:
return 100, 0.0, -1, 0
# Skip if test is much longer than answer (likely wrong)
if len(test_latex) > 3 * len(answer_latex):
return 0, -1, -1, -1
# Convert LaTeX to SymPy
try:
answer_exp = master_convert(answer_latex)
test_exp = master_convert(test_latex)
except Exception as e:
if debug_mode:
raise ValueError(f"Failed to convert LaTeX: {e}")
return 0, -1, -1, -1
# Simplify and check equivalence
try:
# Assume all symbols are positive for simplification
answer_exp, rep1 = posify(answer_exp)
answer_exp = time_simplify(answer_exp)
test_exp, rep2 = posify(test_exp)
test_exp = time_simplify(test_exp)
# Restore original symbols
answer_exp = answer_exp.subs(rep1)
test_exp = test_exp.subs(rep2)
# Check for equivalence
zero_exp = time_simplify(expand(answer_exp - test_exp))
if answer_exp == test_exp or zero_exp == 0:
return 100, 0.0, 0, 0
if time_equal(answer_exp, test_exp):
return 100, 0.0, 0, 0
except Exception as e:
if debug_mode:
raise ValueError(f"Failed during simplification: {e}")
return 0, -1, -1, -1
# Build expression trees
try:
tree_answer = sympy_to_tree(answer_exp)
tree_test = sympy_to_tree(test_exp)
except Exception as e:
if debug_mode:
raise ValueError(f"Failed to build expression tree: {e}")
return 0, -1, -1, -1
# Compute tree edit distance
try:
distance = ext_distance(
tree_test,
tree_answer,
get_children=lambda x: x.get_children(),
single_insert_cost=insert_func,
insert_cost=insert_tree_func,
single_remove_cost=remove_func,
remove_cost=remove_tree_func,
update_cost_func=update_func,
)
except Exception as e:
if debug_mode:
raise ValueError(f"Failed to calculate distance: {e}")
tree_size = calc_tree_size(tree_answer)
return 0, -1, tree_size, -1
# Calculate final score
tree_size = calc_tree_size(tree_answer)
rel_distance = distance / tree_size
score = score_calc(distance, tree_size)
return score, rel_distance, tree_size, distance
def extract_boxed_content(latex_str: str) -> Optional[str]:
"""
Extract content from \\boxed{} in a LaTeX string.
Args:
latex_str: LaTeX string potentially containing \\boxed{}
Returns:
Content inside \\boxed{}, or None if not found
"""
# Pattern to match \boxed{...} with nested braces
pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
match = re.search(pattern, latex_str)
if match:
return match.group(1)
return None
def extract_all_boxed(latex_str: str) -> List[str]:
"""
Extract all \\boxed{} contents from a LaTeX string.
Args:
latex_str: LaTeX string
Returns:
List of contents from all \\boxed{} occurrences
"""
pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
return re.findall(pattern, latex_str)