""" Expression Edit Distance (EED) Score Module for PHYBench Evaluation. This module implements the EED Score metric for evaluating mathematical expressions, as described in the PHYBench paper (https://arxiv.org/abs/2504.16074). The EED Score measures similarity between model-generated and reference expressions by computing tree edit distance over their SymPy expression trees. It provides: - Continuous scoring (0-100) that captures partial correctness - 204% improved sample efficiency over binary scoring - Robust handling of equivalent mathematical forms Key components: - LaTeX preprocessing and normalization - SymPy expression tree construction - Extended Zhang-Shasha tree edit distance algorithm - Configurable scoring with subtree discount Dependencies: - sympy: Symbolic mathematics - latex2sympy2_extended: LaTeX to SymPy conversion - numpy: Numerical operations Based on the official PHYBench implementation: https://github.com/phybench-official/phybench/tree/main/EED """ import re from typing import List, Optional, Tuple from numpy import ones, zeros # Try to import required dependencies try: from latex2sympy2_extended import latex2sympy from sympy import ( Add, Float, Function, Integer, Mul, Pow, Rational, Symbol, expand, posify, simplify, ) from sympy.core.numbers import Exp1, Infinity, NegativeInfinity, Pi EED_AVAILABLE = True except ImportError: EED_AVAILABLE = False latex2sympy = None # ============================================================================= # CONFIGURATION CONSTANTS # ============================================================================= # Cost configuration for tree edit operations # These can be modified if different node types should have different weights INSERT_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1} DELETE_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1} UPDATE_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1} # Cost of updating between different types (e.g., number -> symbol) CHANGE_TYPE_COST = 1 # Subtree discount configuration # Minimum size to trigger cluster discount for subtree operations BAR_SIZE = 5 # Discount slope for subtree operations (0.6 means 40% discount) DISCOUNT_SLOPE = 0.6 # Timeout limits (in seconds) SIMPLIFY_TIME_LIMIT = 30 EQUALS_TIME_LIMIT = 10 # ============================================================================= # TREE NODE CLASS # ============================================================================= class TreeNode: """ A node in the expression tree representation. Attributes: label: Node label (e.g., "number_2", "symbol_x", "operator_Add") children: List of child TreeNode objects subtree_size: Cached size of the subtree rooted at this node """ def __init__(self, label: str, children: Optional[List["TreeNode"]] = None): self.label = label self.children = children if children is not None else [] self.subtree_size = 0 def get_children(self) -> List["TreeNode"]: """Return the list of child nodes.""" return self.children def __str__(self) -> str: return self.label # ============================================================================= # TREE EDIT DISTANCE FUNCTIONS # ============================================================================= def calc_tree_size(node: TreeNode) -> int: """ Calculate the size of a subtree based on total insertion cost. The size equals the sum of insertion costs of all nodes in the subtree. Results are cached in node.subtree_size for efficiency. Args: node: Root node of the subtree Returns: Total size of the subtree """ # Get insertion cost for this node type node_type = node.label.split("_")[0] total = INSERT_COST.get(node_type, 1) # Return cached value if available (for non-leaf nodes) if node.children and node.subtree_size != 0: return node.subtree_size # Recursively calculate size of children for child in node.children: total += calc_tree_size(child) # Cache the result node.subtree_size = total return total def update_func(x: TreeNode, y: TreeNode) -> float: """ Calculate the cost of updating node x to node y. Args: x: Source node y: Target node Returns: Update cost (0 if identical, type-specific cost if same type, else CHANGE_TYPE_COST) """ if x.label == y.label: return 0 x_type = x.label.split("_")[0] y_type = y.label.split("_")[0] if x_type == y_type: return UPDATE_COST.get(x_type, 1) return CHANGE_TYPE_COST def remove_func(x: TreeNode) -> float: """Calculate the cost of removing a single node.""" node_type = x.label.split("_")[0] return DELETE_COST.get(node_type, 1) def remove_tree_func(x: TreeNode) -> float: """ Calculate the cost of removing an entire subtree. Applies discount for large subtrees (cluster discount). Args: x: Root of subtree to remove Returns: Removal cost with potential discount """ if not x.children: return remove_func(x) size = calc_tree_size(x) # Apply discount for large subtrees return min(size, DISCOUNT_SLOPE * (size - BAR_SIZE) + BAR_SIZE) def insert_func(x: TreeNode) -> float: """Calculate the cost of inserting a single node.""" node_type = x.label.split("_")[0] return INSERT_COST.get(node_type, 1) def insert_tree_func(x: TreeNode) -> float: """Calculate the cost of inserting an entire subtree (same as removal).""" return remove_tree_func(x) # ============================================================================= # ANNOTATED TREE FOR ZHANG-SHASHA ALGORITHM # ============================================================================= class AnnotatedTree: """ Annotated tree structure for the Zhang-Shasha algorithm. Computes post-order enumeration, left-most descendants, and keyroots needed for efficient tree edit distance calculation. """ def __init__(self, root: TreeNode, get_children): self.get_children = get_children self.root = root self.nodes = [] # Post-order enumeration of nodes self.ids = [] # Matching list of IDs self.lmds = [] # Left-most descendants self.keyroots = None # Build the annotated structure import collections stack = [(root, collections.deque())] pstack = [] j = 0 while stack: n, anc = stack.pop() nid = j for c in self.get_children(n): a = collections.deque(anc) a.appendleft(nid) stack.append((c, a)) pstack.append(((n, nid), anc)) j += 1 lmds = {} keyroots = {} i = 0 while pstack: (n, nid), anc = pstack.pop() self.nodes.append(n) self.ids.append(nid) if not self.get_children(n): lmd = i for a in anc: if a not in lmds: lmds[a] = i else: break else: lmd = lmds.get(nid, i) self.lmds.append(lmd) keyroots[lmd] = i i += 1 self.keyroots = sorted(keyroots.values()) def ext_distance( a_root: TreeNode, b_root: TreeNode, get_children, single_insert_cost, insert_cost, single_remove_cost, remove_cost, update_cost_func, ) -> float: """ Compute extended tree edit distance using modified Zhang-Shasha algorithm. This implementation extends the standard algorithm with subtree insertion and deletion operations for handling clustered changes. Args: a_root: Root of first tree b_root: Root of second tree get_children: Function to get children of a node single_insert_cost: Cost function for single node insertion insert_cost: Cost function for subtree insertion single_remove_cost: Cost function for single node removal remove_cost: Cost function for subtree removal update_cost_func: Cost function for updating a node Returns: Tree edit distance between the two trees """ a_tree = AnnotatedTree(a_root, get_children) b_tree = AnnotatedTree(b_root, get_children) size_a = len(a_tree.nodes) size_b = len(b_tree.nodes) treedists = zeros((size_a, size_b), float) fd = 1000 * ones((size_a + 1, size_b + 1), float) def treedist(x: int, y: int): al = a_tree.lmds bl = b_tree.lmds an = a_tree.nodes bn = b_tree.nodes fd[al[x]][bl[y]] = 0 for i in range(al[x], x + 1): node = an[i] fd[i + 1][bl[y]] = fd[al[i]][bl[y]] + remove_cost(node) for j in range(bl[y], y + 1): node = bn[j] fd[al[x]][j + 1] = fd[al[x]][bl[j]] + insert_cost(node) for i in range(al[x], x + 1): for j in range(bl[y], y + 1): node1 = an[i] node2 = bn[j] costs = [ fd[i][j + 1] + single_remove_cost(node1), fd[i + 1][j] + single_insert_cost(node2), fd[al[i]][j + 1] + remove_cost(node1), fd[i + 1][bl[j]] + insert_cost(node2), ] m = min(costs) if al[x] == al[i] and bl[y] == bl[j]: treedists[i][j] = min(m, fd[i][j] + update_cost_func(node1, node2)) fd[i + 1][j + 1] = treedists[i][j] else: fd[i + 1][j + 1] = min(m, fd[al[i]][bl[j]] + treedists[i][j]) for x in a_tree.keyroots: for y in b_tree.keyroots: treedist(x, y) return treedists[-1][-1] # ============================================================================= # LATEX PREPROCESSING # ============================================================================= def brackets_balanced(s: str) -> bool: """Check if brackets in a string are balanced.""" stack = [] bracket_pairs = {")": "(", "]": "[", "}": "{"} for char in s: if char in bracket_pairs.values(): stack.append(char) elif char in bracket_pairs: if not stack or stack[-1] != bracket_pairs[char]: return False stack.pop() return len(stack) == 0 def find_first_unescaped_brace(s: str) -> int: """Find the position of the first unescaped opening brace.""" escaped = False for i, c in enumerate(s): if c == "\\" and not escaped: escaped = True continue if c == "{" and not escaped: return i escaped = False return -1 def extract_bracket_content(s: str, bracket_position: int) -> Tuple[Optional[str], int]: """Extract content inside braces starting at given position.""" brace_start = bracket_position + 1 brace_depth = 0 content = [] escaped = False for i in range(brace_start, len(s)): char = s[i] if escaped: content.append(char) escaped = False continue if char == "\\": escaped = True content.append(char) continue if char == "{": brace_depth += 1 content.append(char) elif char == "}": if brace_depth == 0: return "".join(content), i brace_depth -= 1 content.append(char) else: content.append(char) return None, -1 def remove_command(s: str, command: str, keep_inside: bool = False) -> str: """ Remove all occurrences of a LaTeX command from a string. Args: s: Input string command: Command to remove (e.g., "\\textbf") keep_inside: If True, preserve content inside braces Returns: String with command removed """ pos = s.find(command) if pos < 0: return s end_index = pos + len(command) level = 0 if end_index < len(s) and s[end_index] == "{": while end_index < len(s): if s[end_index] == "{": level += 1 elif s[end_index] == "}": level -= 1 if level == 0: break end_index += 1 if keep_inside: s1 = s[:pos] + s[pos + len(command) + 1 : end_index] + s[end_index + 1 :] else: s1 = s[:pos] + s[end_index + 1 :] else: s1 = s[:pos] + s[end_index:] if command not in s1: return s1 return remove_command(s1, command, keep_inside) def extract_last_equal_content(s: str) -> str: """Extract content after the last equality/comparison operator.""" comparison_operators = ("=", "\\approx", "\\ge", "\\le", "\\geq", "\\leq", "<", ">") content = s for sign in comparison_operators: if sign in s: rfind_index = s.rfind(sign) if rfind_index != -1: content = s[rfind_index + len(sign) :] return content.strip() def convert_latex_fractions(latex_str: str) -> str: """Convert non-standard fractions like \\frac\\alpha2 to \\frac{\\alpha}{2}.""" pattern = ( r"\\frac((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))" r"((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))" ) def replacer(match): numerator, denominator = match.group(1), match.group(2) wrap_num = ( f"{{{numerator}}}" if not (numerator.startswith("{") and numerator.endswith("}")) else numerator ) wrap_den = ( f"{{{denominator}}}" if not (denominator.startswith("{") and denominator.endswith("}")) else denominator ) return rf"\frac{wrap_num}{wrap_den}" return re.sub(pattern, replacer, latex_str) def convert_vec_syntax(text: str) -> str: """Convert \\vec x to \\vec{x}.""" pattern = r"\\vec(\s*)(\\?[a-zA-Zα-ωΑ-Ω]+)" replacement = r"\\vec{\2}" return re.sub(pattern, replacement, text) def first_preprocess(s: str, extract_box: bool = True) -> str: """ First stage of LaTeX preprocessing. Extracts boxed content, removes outer braces, and extracts content after equals. Args: s: Input LaTeX string extract_box: Whether to extract content from \\boxed{} Returns: Preprocessed string """ s = s.replace("\\{", "(") s = s.replace("\\}", ")") if not brackets_balanced(s): return s if extract_box: boxed_content = remove_command(s, "\\boxed", keep_inside=True) else: boxed_content = s # Remove overall braces def remove_overall_brace(text: str) -> Tuple[str, bool]: pos = find_first_unescaped_brace(text) if pos == -1: return text, False content, final = extract_bracket_content(text, pos) if content and (final == len(text) - 1 or "}" not in text[final + 1 :]): # Check if there's a command before the brace if pos > 0 and text[pos - 1] not in (" ", "\t", "\n"): return text, False return content, True return text, False # Remove outer braces iteratively for _ in range(10): boxed_content, changed = remove_overall_brace(boxed_content) if not changed: break # Handle \\quad separator if "\\quad" in boxed_content: boxed_content = boxed_content.split("\\quad")[0] # Extract content after last equals sign last_equal_content = extract_last_equal_content(boxed_content) # Remove outer braces again for _ in range(10): last_equal_content, changed = remove_overall_brace(last_equal_content) if not changed: break return last_equal_content def second_preprocess(s: str) -> str: """ Second stage of LaTeX preprocessing. Removes/modifies LaTeX commands and normalizes expressions. Args: s: Input string from first preprocessing stage Returns: Normalized LaTeX string ready for conversion """ # Commands to completely remove (including their content) kill_commands = ["\\begin", "\\end"] # Commands to remove but keep their content remove_commands = [ "\\text", "\\mathbf", "\\mathrm", "\\pmb", "\\hat", "\\overline", "\\boldsymbol", ] # Content to remove entirely remove_content = [ "\\,", "$", ",", "`", "latex", "\\left", "\\right", "\\Bigr", "\\Bigl", "\n", "\\]", "\\[", "\\Big", "\\bigl", "\\bigr", "\\biggl", "\\biggr", "\\displaystyle", "\\infty", ] # Content replacements replace_content = [ ("\\operatorname{asin}", "\\asin"), ("\\operatorname{sech}", "\\sech"), ("\\operatorname{acos}", "\\acos"), ("\\operatorname{sinh}", "\\sinh"), ("\\dfrac", "\\frac"), ("\\tfrac", "\\frac"), ("\\Exp", "\\exp"), ("\\times", "\\bar{times}"), ("\\partial", "\\bar{partial}"), ("\\perp", "\\bar{perp}"), ("\\epsilon", "\\varepsilon"), ("\\varOmega", "\\Omega"), ("I", "\\bar{I}"), ("_e", "_{e}"), ("e_", "\\bar{e}_"), ("E_", "\\bar{E}_"), ("\\pm", "+"), ("\\mp", "-"), ("{+}", "{p}"), ("{-}", "{m}"), ("_+", "_p"), ("_-", "_m"), ] # Apply transformations for command in kill_commands: s = remove_command(s, command, keep_inside=False) for command in remove_commands: s = remove_command(s, command, keep_inside=True) for content in remove_content: s = s.replace(content, "") for old, new in replace_content: s = s.replace(old, new) # Additional transformations s = convert_latex_fractions(s) s = convert_vec_syntax(s) # Remove trailing period if s and s[-1] == ".": s = s[:-1] return s class LaTeXNormalizationConfig: """Configuration for latex2sympy normalization.""" basic_latex: bool = True units: bool = False malformed_operators: bool = True nits: bool = True boxed = "all" equations: bool = False class LaTeXConversionConfig: """Configuration for latex2sympy conversion.""" interpret_as_mixed_fractions: bool = False interpret_simple_eq_as_assignment: bool = False interpret_contains_as_eq: bool = True lowercase_symbols: bool = False def master_convert(s: str): """ Convert a LaTeX string to a SymPy expression. This is the main conversion function that applies preprocessing and uses latex2sympy for the actual conversion. Args: s: LaTeX string to convert Returns: SymPy expression Raises: Various exceptions if conversion fails """ if not EED_AVAILABLE: raise ImportError( "latex2sympy2_extended and sympy are required for EED scoring" ) preprocessed_stage1 = first_preprocess(s) preprocessed_stage2 = second_preprocess(preprocessed_stage1) sym = latex2sympy( preprocessed_stage2, normalization_config=LaTeXNormalizationConfig(), conversion_config=LaTeXConversionConfig(), ) return sym # ============================================================================= # SYMPY TO TREE CONVERSION # ============================================================================= def sympy_to_tree(expr) -> TreeNode: """ Convert a SymPy expression to a tree structure. Args: expr: SymPy expression Returns: TreeNode representing the expression Raises: ValueError: If expression contains unsupported types """ # Numbers and constants if isinstance( expr, (Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity) ): return TreeNode(label=f"number_{expr}", children=[]) # Symbols if isinstance(expr, Symbol): return TreeNode(label=f"symbol_{expr}", children=[]) # Binary operators (Add, Mul, Pow) if isinstance(expr, (Add, Mul, Pow)): op_name = type(expr).__name__ children = [sympy_to_tree(arg) for arg in expr.args] return TreeNode(label=f"operator_{op_name}", children=children) # Functions if isinstance(expr, Function): func_name = expr.func.__name__ children = [sympy_to_tree(arg) for arg in expr.args] return TreeNode(label=f"function_{func_name}", children=children) raise ValueError(f"Unsupported SymPy type: {type(expr).__name__}") # ============================================================================= # SCORING FUNCTION # ============================================================================= def score_calc(tree_dist: float, tree_size: int) -> float: """ Calculate EED score from tree distance and size. The scoring function: - 100 if distance is 0 (exact match) - 60 - 100*r if 0 < r < 0.6 (partial credit) - 0 if r >= 0.6 (too different) where r = distance / tree_size Args: tree_dist: Tree edit distance tree_size: Size of the ground truth tree Returns: Score between 0 and 100 """ if tree_dist == 0: return 100.0 return max(0, 100 * DISCOUNT_SLOPE - 100 * tree_dist / tree_size) def time_simplify(expr, timeout: int = SIMPLIFY_TIME_LIMIT): """ Simplify expression with timeout protection. Args: expr: SymPy expression to simplify timeout: Timeout in seconds Returns: Simplified expression, or original if timeout/error """ try: # Note: For production use, consider using multiprocessing for true timeout return simplify(expr) except Exception: return expr def time_equal(expr1, expr2, timeout: int = EQUALS_TIME_LIMIT) -> bool: """ Check expression equality with timeout protection. Args: expr1: First expression expr2: Second expression timeout: Timeout in seconds Returns: True if expressions are equal, False otherwise """ try: return expr1.equals(expr2) except Exception: return False # ============================================================================= # MAIN EED FUNCTION # ============================================================================= def compute_eed_score( answer_latex: str, test_latex: str, debug_mode: bool = False, ) -> Tuple[float, float, int, float]: """ Compute the EED (Expression Edit Distance) Score between two LaTeX expressions. This function evaluates the similarity between a ground truth answer and a model-generated answer by: 1. Converting LaTeX to SymPy expressions 2. Simplifying and checking for equivalence 3. Building expression trees 4. Computing tree edit distance 5. Converting distance to a 0-100 score Args: answer_latex: Ground truth answer in LaTeX format test_latex: Model-generated answer in LaTeX format debug_mode: If True, raise exceptions instead of returning defaults Returns: Tuple of (score, relative_distance, tree_size, distance): - score: EED score from 0-100 (100 = perfect match) - relative_distance: distance / tree_size (-1 if error) - tree_size: Size of ground truth tree (-1 if error) - distance: Raw tree edit distance (-1 if error) """ if not EED_AVAILABLE: if debug_mode: raise ImportError("EED scoring requires latex2sympy2_extended and sympy") return 0, -1, -1, -1 # Handle empty or invalid input if not test_latex: return 0, -1, -1, -1 # Skip unsupported expressions (integrals, sums) if "\\int" in test_latex or "\\int" in answer_latex: return 0, -1, -1, -1 if "\\sum" in test_latex or "\\sum" in answer_latex: return 0, -1, -1, -1 # Quick check for exact string match if answer_latex == test_latex: return 100, 0.0, -1, 0 # Skip if test is much longer than answer (likely wrong) if len(test_latex) > 3 * len(answer_latex): return 0, -1, -1, -1 # Convert LaTeX to SymPy try: answer_exp = master_convert(answer_latex) test_exp = master_convert(test_latex) except Exception as e: if debug_mode: raise ValueError(f"Failed to convert LaTeX: {e}") return 0, -1, -1, -1 # Simplify and check equivalence try: # Assume all symbols are positive for simplification answer_exp, rep1 = posify(answer_exp) answer_exp = time_simplify(answer_exp) test_exp, rep2 = posify(test_exp) test_exp = time_simplify(test_exp) # Restore original symbols answer_exp = answer_exp.subs(rep1) test_exp = test_exp.subs(rep2) # Check for equivalence zero_exp = time_simplify(expand(answer_exp - test_exp)) if answer_exp == test_exp or zero_exp == 0: return 100, 0.0, 0, 0 if time_equal(answer_exp, test_exp): return 100, 0.0, 0, 0 except Exception as e: if debug_mode: raise ValueError(f"Failed during simplification: {e}") return 0, -1, -1, -1 # Build expression trees try: tree_answer = sympy_to_tree(answer_exp) tree_test = sympy_to_tree(test_exp) except Exception as e: if debug_mode: raise ValueError(f"Failed to build expression tree: {e}") return 0, -1, -1, -1 # Compute tree edit distance try: distance = ext_distance( tree_test, tree_answer, get_children=lambda x: x.get_children(), single_insert_cost=insert_func, insert_cost=insert_tree_func, single_remove_cost=remove_func, remove_cost=remove_tree_func, update_cost_func=update_func, ) except Exception as e: if debug_mode: raise ValueError(f"Failed to calculate distance: {e}") tree_size = calc_tree_size(tree_answer) return 0, -1, tree_size, -1 # Calculate final score tree_size = calc_tree_size(tree_answer) rel_distance = distance / tree_size score = score_calc(distance, tree_size) return score, rel_distance, tree_size, distance def extract_boxed_content(latex_str: str) -> Optional[str]: """ Extract content from \\boxed{} in a LaTeX string. Args: latex_str: LaTeX string potentially containing \\boxed{} Returns: Content inside \\boxed{}, or None if not found """ # Pattern to match \boxed{...} with nested braces pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}" match = re.search(pattern, latex_str) if match: return match.group(1) return None def extract_all_boxed(latex_str: str) -> List[str]: """ Extract all \\boxed{} contents from a LaTeX string. Handles arbitrarily nested braces by counting brace depth. Args: latex_str: LaTeX string Returns: List of contents from all \\boxed{} occurrences """ results = [] i = 0 boxed_pattern = "\\boxed{" while i < len(latex_str): # Find next \boxed{ pos = latex_str.find(boxed_pattern, i) if pos == -1: break # Start after \boxed{ start = pos + len(boxed_pattern) depth = 1 j = start # Count braces to find matching closing brace while j < len(latex_str) and depth > 0: if latex_str[j] == "{": depth += 1 elif latex_str[j] == "}": depth -= 1 j += 1 if depth == 0: # Extract content between braces content = latex_str[start : j - 1].strip() results.append(content) i = j return results