add phybench eval

This commit is contained in:
teknium 2025-12-28 01:44:20 +00:00
parent d04f8c0ae7
commit 830a129655
2 changed files with 1681 additions and 0 deletions

View file

@ -0,0 +1,970 @@
"""
Expression Edit Distance (EED) Score Module for PHYBench Evaluation.
This module implements the EED Score metric for evaluating mathematical expressions,
as described in the PHYBench paper (https://arxiv.org/abs/2504.16074).
The EED Score measures similarity between model-generated and reference expressions
by computing tree edit distance over their SymPy expression trees. It provides:
- Continuous scoring (0-100) that captures partial correctness
- 204% improved sample efficiency over binary scoring
- Robust handling of equivalent mathematical forms
Key components:
- LaTeX preprocessing and normalization
- SymPy expression tree construction
- Extended Zhang-Shasha tree edit distance algorithm
- Configurable scoring with subtree discount
Dependencies:
- sympy: Symbolic mathematics
- latex2sympy2_extended: LaTeX to SymPy conversion
- numpy: Numerical operations
Based on the official PHYBench implementation:
https://github.com/phybench-official/phybench/tree/main/EED
"""
import re
from typing import List, Optional, Tuple
from numpy import ones, zeros
# Try to import required dependencies
try:
from latex2sympy2_extended import latex2sympy
from sympy import (
Add,
Float,
Function,
Integer,
Mul,
Pow,
Rational,
Symbol,
expand,
posify,
simplify,
)
from sympy.core.numbers import Exp1, Infinity, NegativeInfinity, Pi
EED_AVAILABLE = True
except ImportError:
EED_AVAILABLE = False
latex2sympy = None
# =============================================================================
# CONFIGURATION CONSTANTS
# =============================================================================
# Cost configuration for tree edit operations
# These can be modified if different node types should have different weights
INSERT_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1}
DELETE_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1}
UPDATE_COST = {"number": 1, "symbol": 1, "operator": 1, "function": 1}
# Cost of updating between different types (e.g., number -> symbol)
CHANGE_TYPE_COST = 1
# Subtree discount configuration
# Minimum size to trigger cluster discount for subtree operations
BAR_SIZE = 5
# Discount slope for subtree operations (0.6 means 40% discount)
DISCOUNT_SLOPE = 0.6
# Timeout limits (in seconds)
SIMPLIFY_TIME_LIMIT = 30
EQUALS_TIME_LIMIT = 10
# =============================================================================
# TREE NODE CLASS
# =============================================================================
class TreeNode:
"""
A node in the expression tree representation.
Attributes:
label: Node label (e.g., "number_2", "symbol_x", "operator_Add")
children: List of child TreeNode objects
subtree_size: Cached size of the subtree rooted at this node
"""
def __init__(self, label: str, children: Optional[List["TreeNode"]] = None):
self.label = label
self.children = children if children is not None else []
self.subtree_size = 0
def get_children(self) -> List["TreeNode"]:
"""Return the list of child nodes."""
return self.children
def __str__(self) -> str:
return self.label
# =============================================================================
# TREE EDIT DISTANCE FUNCTIONS
# =============================================================================
def calc_tree_size(node: TreeNode) -> int:
"""
Calculate the size of a subtree based on total insertion cost.
The size equals the sum of insertion costs of all nodes in the subtree.
Results are cached in node.subtree_size for efficiency.
Args:
node: Root node of the subtree
Returns:
Total size of the subtree
"""
# Get insertion cost for this node type
node_type = node.label.split("_")[0]
total = INSERT_COST.get(node_type, 1)
# Return cached value if available (for non-leaf nodes)
if node.children and node.subtree_size != 0:
return node.subtree_size
# Recursively calculate size of children
for child in node.children:
total += calc_tree_size(child)
# Cache the result
node.subtree_size = total
return total
def update_func(x: TreeNode, y: TreeNode) -> float:
"""
Calculate the cost of updating node x to node y.
Args:
x: Source node
y: Target node
Returns:
Update cost (0 if identical, type-specific cost if same type, else CHANGE_TYPE_COST)
"""
if x.label == y.label:
return 0
x_type = x.label.split("_")[0]
y_type = y.label.split("_")[0]
if x_type == y_type:
return UPDATE_COST.get(x_type, 1)
return CHANGE_TYPE_COST
def remove_func(x: TreeNode) -> float:
"""Calculate the cost of removing a single node."""
node_type = x.label.split("_")[0]
return DELETE_COST.get(node_type, 1)
def remove_tree_func(x: TreeNode) -> float:
"""
Calculate the cost of removing an entire subtree.
Applies discount for large subtrees (cluster discount).
Args:
x: Root of subtree to remove
Returns:
Removal cost with potential discount
"""
if not x.children:
return remove_func(x)
size = calc_tree_size(x)
# Apply discount for large subtrees
return min(size, DISCOUNT_SLOPE * (size - BAR_SIZE) + BAR_SIZE)
def insert_func(x: TreeNode) -> float:
"""Calculate the cost of inserting a single node."""
node_type = x.label.split("_")[0]
return INSERT_COST.get(node_type, 1)
def insert_tree_func(x: TreeNode) -> float:
"""Calculate the cost of inserting an entire subtree (same as removal)."""
return remove_tree_func(x)
# =============================================================================
# ANNOTATED TREE FOR ZHANG-SHASHA ALGORITHM
# =============================================================================
class AnnotatedTree:
"""
Annotated tree structure for the Zhang-Shasha algorithm.
Computes post-order enumeration, left-most descendants, and keyroots
needed for efficient tree edit distance calculation.
"""
def __init__(self, root: TreeNode, get_children):
self.get_children = get_children
self.root = root
self.nodes = [] # Post-order enumeration of nodes
self.ids = [] # Matching list of IDs
self.lmds = [] # Left-most descendants
self.keyroots = None
# Build the annotated structure
import collections
stack = [(root, collections.deque())]
pstack = []
j = 0
while stack:
n, anc = stack.pop()
nid = j
for c in self.get_children(n):
a = collections.deque(anc)
a.appendleft(nid)
stack.append((c, a))
pstack.append(((n, nid), anc))
j += 1
lmds = {}
keyroots = {}
i = 0
while pstack:
(n, nid), anc = pstack.pop()
self.nodes.append(n)
self.ids.append(nid)
if not self.get_children(n):
lmd = i
for a in anc:
if a not in lmds:
lmds[a] = i
else:
break
else:
lmd = lmds.get(nid, i)
self.lmds.append(lmd)
keyroots[lmd] = i
i += 1
self.keyroots = sorted(keyroots.values())
def ext_distance(
a_root: TreeNode,
b_root: TreeNode,
get_children,
single_insert_cost,
insert_cost,
single_remove_cost,
remove_cost,
update_cost_func,
) -> float:
"""
Compute extended tree edit distance using modified Zhang-Shasha algorithm.
This implementation extends the standard algorithm with subtree insertion
and deletion operations for handling clustered changes.
Args:
a_root: Root of first tree
b_root: Root of second tree
get_children: Function to get children of a node
single_insert_cost: Cost function for single node insertion
insert_cost: Cost function for subtree insertion
single_remove_cost: Cost function for single node removal
remove_cost: Cost function for subtree removal
update_cost_func: Cost function for updating a node
Returns:
Tree edit distance between the two trees
"""
a_tree = AnnotatedTree(a_root, get_children)
b_tree = AnnotatedTree(b_root, get_children)
size_a = len(a_tree.nodes)
size_b = len(b_tree.nodes)
treedists = zeros((size_a, size_b), float)
fd = 1000 * ones((size_a + 1, size_b + 1), float)
def treedist(x: int, y: int):
al = a_tree.lmds
bl = b_tree.lmds
an = a_tree.nodes
bn = b_tree.nodes
fd[al[x]][bl[y]] = 0
for i in range(al[x], x + 1):
node = an[i]
fd[i + 1][bl[y]] = fd[al[i]][bl[y]] + remove_cost(node)
for j in range(bl[y], y + 1):
node = bn[j]
fd[al[x]][j + 1] = fd[al[x]][bl[j]] + insert_cost(node)
for i in range(al[x], x + 1):
for j in range(bl[y], y + 1):
node1 = an[i]
node2 = bn[j]
costs = [
fd[i][j + 1] + single_remove_cost(node1),
fd[i + 1][j] + single_insert_cost(node2),
fd[al[i]][j + 1] + remove_cost(node1),
fd[i + 1][bl[j]] + insert_cost(node2),
]
m = min(costs)
if al[x] == al[i] and bl[y] == bl[j]:
treedists[i][j] = min(m, fd[i][j] + update_cost_func(node1, node2))
fd[i + 1][j + 1] = treedists[i][j]
else:
fd[i + 1][j + 1] = min(m, fd[al[i]][bl[j]] + treedists[i][j])
for x in a_tree.keyroots:
for y in b_tree.keyroots:
treedist(x, y)
return treedists[-1][-1]
# =============================================================================
# LATEX PREPROCESSING
# =============================================================================
def brackets_balanced(s: str) -> bool:
"""Check if brackets in a string are balanced."""
stack = []
bracket_pairs = {")": "(", "]": "[", "}": "{"}
for char in s:
if char in bracket_pairs.values():
stack.append(char)
elif char in bracket_pairs:
if not stack or stack[-1] != bracket_pairs[char]:
return False
stack.pop()
return len(stack) == 0
def find_first_unescaped_brace(s: str) -> int:
"""Find the position of the first unescaped opening brace."""
escaped = False
for i, c in enumerate(s):
if c == "\\" and not escaped:
escaped = True
continue
if c == "{" and not escaped:
return i
escaped = False
return -1
def extract_bracket_content(s: str, bracket_position: int) -> Tuple[Optional[str], int]:
"""Extract content inside braces starting at given position."""
brace_start = bracket_position + 1
brace_depth = 0
content = []
escaped = False
for i in range(brace_start, len(s)):
char = s[i]
if escaped:
content.append(char)
escaped = False
continue
if char == "\\":
escaped = True
content.append(char)
continue
if char == "{":
brace_depth += 1
content.append(char)
elif char == "}":
if brace_depth == 0:
return "".join(content), i
brace_depth -= 1
content.append(char)
else:
content.append(char)
return None, -1
def remove_command(s: str, command: str, keep_inside: bool = False) -> str:
"""
Remove all occurrences of a LaTeX command from a string.
Args:
s: Input string
command: Command to remove (e.g., "\\textbf")
keep_inside: If True, preserve content inside braces
Returns:
String with command removed
"""
pos = s.find(command)
if pos < 0:
return s
end_index = pos + len(command)
level = 0
if end_index < len(s) and s[end_index] == "{":
while end_index < len(s):
if s[end_index] == "{":
level += 1
elif s[end_index] == "}":
level -= 1
if level == 0:
break
end_index += 1
if keep_inside:
s1 = s[:pos] + s[pos + len(command) + 1 : end_index] + s[end_index + 1 :]
else:
s1 = s[:pos] + s[end_index + 1 :]
else:
s1 = s[:pos] + s[end_index:]
if command not in s1:
return s1
return remove_command(s1, command, keep_inside)
def extract_last_equal_content(s: str) -> str:
"""Extract content after the last equality/comparison operator."""
comparison_operators = ("=", "\\approx", "\\ge", "\\le", "\\geq", "\\leq", "<", ">")
content = s
for sign in comparison_operators:
if sign in s:
rfind_index = s.rfind(sign)
if rfind_index != -1:
content = s[rfind_index + len(sign) :]
return content.strip()
def convert_latex_fractions(latex_str: str) -> str:
"""Convert non-standard fractions like \\frac\\alpha2 to \\frac{\\alpha}{2}."""
pattern = (
r"\\frac((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))"
r"((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))"
)
def replacer(match):
numerator, denominator = match.group(1), match.group(2)
wrap_num = (
f"{{{numerator}}}"
if not (numerator.startswith("{") and numerator.endswith("}"))
else numerator
)
wrap_den = (
f"{{{denominator}}}"
if not (denominator.startswith("{") and denominator.endswith("}"))
else denominator
)
return rf"\frac{wrap_num}{wrap_den}"
return re.sub(pattern, replacer, latex_str)
def convert_vec_syntax(text: str) -> str:
"""Convert \\vec x to \\vec{x}."""
pattern = r"\\vec(\s*)(\\?[a-zA-Zα-ωΑ-Ω]+)"
replacement = r"\\vec{\2}"
return re.sub(pattern, replacement, text)
def first_preprocess(s: str, extract_box: bool = True) -> str:
"""
First stage of LaTeX preprocessing.
Extracts boxed content, removes outer braces, and extracts content after equals.
Args:
s: Input LaTeX string
extract_box: Whether to extract content from \\boxed{}
Returns:
Preprocessed string
"""
s = s.replace("\\{", "(")
s = s.replace("\\}", ")")
if not brackets_balanced(s):
return s
if extract_box:
boxed_content = remove_command(s, "\\boxed", keep_inside=True)
else:
boxed_content = s
# Remove overall braces
def remove_overall_brace(text: str) -> Tuple[str, bool]:
pos = find_first_unescaped_brace(text)
if pos == -1:
return text, False
content, final = extract_bracket_content(text, pos)
if content and (final == len(text) - 1 or "}" not in text[final + 1 :]):
# Check if there's a command before the brace
if pos > 0 and text[pos - 1] not in (" ", "\t", "\n"):
return text, False
return content, True
return text, False
# Remove outer braces iteratively
for _ in range(10):
boxed_content, changed = remove_overall_brace(boxed_content)
if not changed:
break
# Handle \\quad separator
if "\\quad" in boxed_content:
boxed_content = boxed_content.split("\\quad")[0]
# Extract content after last equals sign
last_equal_content = extract_last_equal_content(boxed_content)
# Remove outer braces again
for _ in range(10):
last_equal_content, changed = remove_overall_brace(last_equal_content)
if not changed:
break
return last_equal_content
def second_preprocess(s: str) -> str:
"""
Second stage of LaTeX preprocessing.
Removes/modifies LaTeX commands and normalizes expressions.
Args:
s: Input string from first preprocessing stage
Returns:
Normalized LaTeX string ready for conversion
"""
# Commands to completely remove (including their content)
kill_commands = ["\\begin", "\\end"]
# Commands to remove but keep their content
remove_commands = [
"\\text",
"\\mathbf",
"\\mathrm",
"\\pmb",
"\\hat",
"\\overline",
"\\boldsymbol",
]
# Content to remove entirely
remove_content = [
"\\,",
"$",
",",
"`",
"latex",
"\\left",
"\\right",
"\\Bigr",
"\\Bigl",
"\n",
"\\]",
"\\[",
"\\Big",
"\\bigl",
"\\bigr",
"\\biggl",
"\\biggr",
"\\displaystyle",
"\\infty",
]
# Content replacements
replace_content = [
("\\operatorname{asin}", "\\asin"),
("\\operatorname{sech}", "\\sech"),
("\\operatorname{acos}", "\\acos"),
("\\operatorname{sinh}", "\\sinh"),
("\\dfrac", "\\frac"),
("\\tfrac", "\\frac"),
("\\Exp", "\\exp"),
("\\times", "\\bar{times}"),
("\\partial", "\\bar{partial}"),
("\\perp", "\\bar{perp}"),
("\\epsilon", "\\varepsilon"),
("\\varOmega", "\\Omega"),
("I", "\\bar{I}"),
("_e", "_{e}"),
("e_", "\\bar{e}_"),
("E_", "\\bar{E}_"),
("\\pm", "+"),
("\\mp", "-"),
("{+}", "{p}"),
("{-}", "{m}"),
("_+", "_p"),
("_-", "_m"),
]
# Apply transformations
for command in kill_commands:
s = remove_command(s, command, keep_inside=False)
for command in remove_commands:
s = remove_command(s, command, keep_inside=True)
for content in remove_content:
s = s.replace(content, "")
for old, new in replace_content:
s = s.replace(old, new)
# Additional transformations
s = convert_latex_fractions(s)
s = convert_vec_syntax(s)
# Remove trailing period
if s and s[-1] == ".":
s = s[:-1]
return s
class LaTeXNormalizationConfig:
"""Configuration for latex2sympy normalization."""
basic_latex: bool = True
units: bool = False
malformed_operators: bool = True
nits: bool = True
boxed = "all"
equations: bool = False
class LaTeXConversionConfig:
"""Configuration for latex2sympy conversion."""
interpret_as_mixed_fractions: bool = False
interpret_simple_eq_as_assignment: bool = False
interpret_contains_as_eq: bool = True
lowercase_symbols: bool = False
def master_convert(s: str):
"""
Convert a LaTeX string to a SymPy expression.
This is the main conversion function that applies preprocessing
and uses latex2sympy for the actual conversion.
Args:
s: LaTeX string to convert
Returns:
SymPy expression
Raises:
Various exceptions if conversion fails
"""
if not EED_AVAILABLE:
raise ImportError("latex2sympy2_extended and sympy are required for EED scoring")
preprocessed_stage1 = first_preprocess(s)
preprocessed_stage2 = second_preprocess(preprocessed_stage1)
sym = latex2sympy(
preprocessed_stage2,
normalization_config=LaTeXNormalizationConfig(),
conversion_config=LaTeXConversionConfig(),
)
return sym
# =============================================================================
# SYMPY TO TREE CONVERSION
# =============================================================================
def sympy_to_tree(expr) -> TreeNode:
"""
Convert a SymPy expression to a tree structure.
Args:
expr: SymPy expression
Returns:
TreeNode representing the expression
Raises:
ValueError: If expression contains unsupported types
"""
# Numbers and constants
if isinstance(expr, (Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity)):
return TreeNode(label=f"number_{expr}", children=[])
# Symbols
if isinstance(expr, Symbol):
return TreeNode(label=f"symbol_{expr}", children=[])
# Binary operators (Add, Mul, Pow)
if isinstance(expr, (Add, Mul, Pow)):
op_name = type(expr).__name__
children = [sympy_to_tree(arg) for arg in expr.args]
return TreeNode(label=f"operator_{op_name}", children=children)
# Functions
if isinstance(expr, Function):
func_name = expr.func.__name__
children = [sympy_to_tree(arg) for arg in expr.args]
return TreeNode(label=f"function_{func_name}", children=children)
raise ValueError(f"Unsupported SymPy type: {type(expr).__name__}")
# =============================================================================
# SCORING FUNCTION
# =============================================================================
def score_calc(tree_dist: float, tree_size: int) -> float:
"""
Calculate EED score from tree distance and size.
The scoring function:
- 100 if distance is 0 (exact match)
- 60 - 100*r if 0 < r < 0.6 (partial credit)
- 0 if r >= 0.6 (too different)
where r = distance / tree_size
Args:
tree_dist: Tree edit distance
tree_size: Size of the ground truth tree
Returns:
Score between 0 and 100
"""
if tree_dist == 0:
return 100.0
return max(0, 100 * DISCOUNT_SLOPE - 100 * tree_dist / tree_size)
def time_simplify(expr, timeout: int = SIMPLIFY_TIME_LIMIT):
"""
Simplify expression with timeout protection.
Args:
expr: SymPy expression to simplify
timeout: Timeout in seconds
Returns:
Simplified expression, or original if timeout/error
"""
try:
# Note: For production use, consider using multiprocessing for true timeout
return simplify(expr)
except Exception:
return expr
def time_equal(expr1, expr2, timeout: int = EQUALS_TIME_LIMIT) -> bool:
"""
Check expression equality with timeout protection.
Args:
expr1: First expression
expr2: Second expression
timeout: Timeout in seconds
Returns:
True if expressions are equal, False otherwise
"""
try:
return expr1.equals(expr2)
except Exception:
return False
# =============================================================================
# MAIN EED FUNCTION
# =============================================================================
def compute_eed_score(
answer_latex: str,
test_latex: str,
debug_mode: bool = False,
) -> Tuple[float, float, int, float]:
"""
Compute the EED (Expression Edit Distance) Score between two LaTeX expressions.
This function evaluates the similarity between a ground truth answer and
a model-generated answer by:
1. Converting LaTeX to SymPy expressions
2. Simplifying and checking for equivalence
3. Building expression trees
4. Computing tree edit distance
5. Converting distance to a 0-100 score
Args:
answer_latex: Ground truth answer in LaTeX format
test_latex: Model-generated answer in LaTeX format
debug_mode: If True, raise exceptions instead of returning defaults
Returns:
Tuple of (score, relative_distance, tree_size, distance):
- score: EED score from 0-100 (100 = perfect match)
- relative_distance: distance / tree_size (-1 if error)
- tree_size: Size of ground truth tree (-1 if error)
- distance: Raw tree edit distance (-1 if error)
"""
if not EED_AVAILABLE:
if debug_mode:
raise ImportError("EED scoring requires latex2sympy2_extended and sympy")
return 0, -1, -1, -1
# Handle empty or invalid input
if not test_latex:
return 0, -1, -1, -1
# Skip unsupported expressions (integrals, sums)
if "\\int" in test_latex or "\\int" in answer_latex:
return 0, -1, -1, -1
if "\\sum" in test_latex or "\\sum" in answer_latex:
return 0, -1, -1, -1
# Quick check for exact string match
if answer_latex == test_latex:
return 100, 0.0, -1, 0
# Skip if test is much longer than answer (likely wrong)
if len(test_latex) > 3 * len(answer_latex):
return 0, -1, -1, -1
# Convert LaTeX to SymPy
try:
answer_exp = master_convert(answer_latex)
test_exp = master_convert(test_latex)
except Exception as e:
if debug_mode:
raise ValueError(f"Failed to convert LaTeX: {e}")
return 0, -1, -1, -1
# Simplify and check equivalence
try:
# Assume all symbols are positive for simplification
answer_exp, rep1 = posify(answer_exp)
answer_exp = time_simplify(answer_exp)
test_exp, rep2 = posify(test_exp)
test_exp = time_simplify(test_exp)
# Restore original symbols
answer_exp = answer_exp.subs(rep1)
test_exp = test_exp.subs(rep2)
# Check for equivalence
zero_exp = time_simplify(expand(answer_exp - test_exp))
if answer_exp == test_exp or zero_exp == 0:
return 100, 0.0, 0, 0
if time_equal(answer_exp, test_exp):
return 100, 0.0, 0, 0
except Exception as e:
if debug_mode:
raise ValueError(f"Failed during simplification: {e}")
return 0, -1, -1, -1
# Build expression trees
try:
tree_answer = sympy_to_tree(answer_exp)
tree_test = sympy_to_tree(test_exp)
except Exception as e:
if debug_mode:
raise ValueError(f"Failed to build expression tree: {e}")
return 0, -1, -1, -1
# Compute tree edit distance
try:
distance = ext_distance(
tree_test,
tree_answer,
get_children=lambda x: x.get_children(),
single_insert_cost=insert_func,
insert_cost=insert_tree_func,
single_remove_cost=remove_func,
remove_cost=remove_tree_func,
update_cost_func=update_func,
)
except Exception as e:
if debug_mode:
raise ValueError(f"Failed to calculate distance: {e}")
tree_size = calc_tree_size(tree_answer)
return 0, -1, tree_size, -1
# Calculate final score
tree_size = calc_tree_size(tree_answer)
rel_distance = distance / tree_size
score = score_calc(distance, tree_size)
return score, rel_distance, tree_size, distance
def extract_boxed_content(latex_str: str) -> Optional[str]:
"""
Extract content from \\boxed{} in a LaTeX string.
Args:
latex_str: LaTeX string potentially containing \\boxed{}
Returns:
Content inside \\boxed{}, or None if not found
"""
# Pattern to match \boxed{...} with nested braces
pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
match = re.search(pattern, latex_str)
if match:
return match.group(1)
return None
def extract_all_boxed(latex_str: str) -> List[str]:
"""
Extract all \\boxed{} contents from a LaTeX string.
Args:
latex_str: LaTeX string
Returns:
List of contents from all \\boxed{} occurrences
"""
pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
return re.findall(pattern, latex_str)

View file

@ -0,0 +1,711 @@
"""
PHYBench Evaluation Environment for Atropos.
This environment evaluates models on PHYBench - a benchmark for evaluating
physical perception and reasoning capabilities in Large Language Models.
Dataset: Eureka-Lab/PHYBench
Paper: https://arxiv.org/abs/2504.16074
Website: https://www.phybench.cn/
PHYBench is a human-curated benchmark with 500 original physics problems spanning:
- Mechanics (MECHANICS)
- Electromagnetism (ELECTRICITY)
- Thermodynamics (THERMODYNAMICS)
- Optics (OPTICS)
- Modern Physics (MODERN)
- Advanced Physics (ADVANCED)
Key features:
- Original problems to prevent data contamination
- Symbolic expression answers in LaTeX format
- Two evaluation metrics:
1. Binary Accuracy: Exact match using SymPy equivalence
2. EED Score: Expression Edit Distance for partial credit (0-100)
The EED Score provides:
- 204% improved sample efficiency over binary scoring
- Continuous scoring that captures partial correctness
- Differentiation between minor coefficient errors and structural errors
Supports thinking mode with <think></think> tags for extended reasoning.
"""
import asyncio
import random
import re
from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from eed_score import EED_AVAILABLE, compute_eed_score, extract_all_boxed
from eval_helpers import (
THINK_CONTENT_AFTER_PATTERN,
create_system_content,
extract_thinking_content,
get_default_thinking_prompt,
save_eval_results,
validate_thinking_format,
)
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
)
# Physics domain tags in PHYBench
PHYBENCH_TAGS = [
"MECHANICS",
"ELECTRICITY",
"THERMODYNAMICS",
"OPTICS",
"MODERN",
"ADVANCED",
]
# Prompt template for PHYBench evaluation
PHYBENCH_PROMPT_TEMPLATE = """You are a physics expert. Please read the following question and provide a step-by-step solution.
Put your final answer, which must be a readable LaTeX formula, in a \\boxed{{}} environment.
Question: {problem}
Answer:"""
# Alternative prompt with more detailed instructions
PHYBENCH_DETAILED_PROMPT_TEMPLATE = """Solve the following physics problem. Show your reasoning step by step.
Your final answer should be a single symbolic expression (e.g., $\\sqrt{{\\frac{{2g}}{{3R}}}}$).
- Equivalent forms are accepted
- No numerical approximations
- No equation chains
Put your final answer in \\boxed{{}} format.
For example: \\boxed{{2mg + \\frac{{4mv_0^2}}{{l}}}}
Problem:
{problem}
Solution:"""
class PHYBenchEvalConfig(BaseEnvConfig):
"""Configuration for PHYBench evaluation environment."""
# Dataset configuration
dataset_name: str = Field(
default="Eureka-Lab/PHYBench",
description="HuggingFace dataset name",
)
eval_split: str = Field(
default="train",
description="Split to evaluate on (PHYBench only has train split)",
)
shuffle_seed: int = Field(
default=42,
description="Random seed for shuffling",
)
max_samples: Optional[int] = Field(
default=None,
description="Maximum number of samples to evaluate (None = all)",
)
tags_filter: Optional[List[str]] = Field(
default=None,
description="Filter to specific physics domains (e.g., ['MECHANICS', 'OPTICS'])",
)
# Generation parameters
eval_temperature: float = Field(
default=0.6,
description="Temperature for evaluation generation",
)
eval_max_tokens: int = Field(
default=0,
description="Max tokens for evaluation (0 = use model default)",
)
# System prompt configuration
custom_system_prompt: Optional[str] = Field(
default=None,
description="Optional custom system prompt",
)
# Thinking mode configuration
thinking_mode: bool = Field(
default=True,
description="Whether to use thinking mode with <think></think> tags",
)
custom_thinking_prompt: Optional[str] = Field(
default=None,
description="Optional custom thinking prompt",
)
# Prompt configuration
use_detailed_prompt: bool = Field(
default=False,
description="Use detailed prompt with more instructions",
)
# Scoring configuration
compute_eed_score: bool = Field(
default=True,
description="Whether to compute EED Score (requires latex2sympy2_extended)",
)
# Retry and debug configuration
max_retries: int = Field(
default=3,
description="Maximum retries for failed API calls",
)
retry_delay: float = Field(
default=1.0,
description="Delay between retries in seconds",
)
min_response_length: int = Field(
default=1,
description="Minimum response length to consider valid",
)
full_debug: bool = Field(
default=False,
description="Enable full debug output",
)
# Override defaults for eval-only environment
group_size: int = 1
max_num_workers: int = 1024
max_eval_workers: int = 256
max_num_workers_per_node: int = 128
use_wandb: bool = True
rollout_server_url: str = "http://localhost:8000"
total_steps: int = 1
wandb_name: str = "phybench_eval"
steps_per_eval: int = 1
class PHYBenchEvalEnv(BaseEnv):
"""
PHYBench Evaluation Environment.
Evaluates models on physics problems requiring symbolic expression answers.
Uses both binary accuracy and EED Score for comprehensive evaluation.
"""
name = "phybench_eval"
def __init__(
self,
config: PHYBenchEvalConfig,
server_configs: List[APIServerConfig],
slurm_job_id: Optional[str] = None,
testing: bool = False,
):
super().__init__(config, server_configs, slurm_job_id, testing)
self.config: PHYBenchEvalConfig = config
self.eval_items: List[Dict] = []
self._dataset_loaded = False
# Pre-compile regex patterns for answer extraction
self._boxed_pattern = re.compile(r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}")
# Check EED availability
if self.config.compute_eed_score and not EED_AVAILABLE:
print(
"Warning: EED Score requested but latex2sympy2_extended not available. "
"Install with: pip install latex2sympy2_extended sympy"
)
@classmethod
def config_cls(cls) -> type:
return PHYBenchEvalConfig
async def setup(self) -> None:
"""Initialize the environment and load the dataset."""
await super().setup()
if not self._dataset_loaded:
await self._load_dataset()
print("\nPHYBench Evaluation Setup:")
print(f" Dataset: {self.config.dataset_name}")
print(f" Evaluation split: {self.config.eval_split}")
print(f" Thinking mode: {self.config.thinking_mode}")
print(f" EED Score enabled: {self.config.compute_eed_score and EED_AVAILABLE}")
if self.config.thinking_mode:
thinking_prompt = get_default_thinking_prompt(self.config.custom_thinking_prompt)
print(f" Thinking prompt: {thinking_prompt[:80]}...")
if self.config.tags_filter:
print(f" Tags filter: {self.config.tags_filter}")
print(f" Loaded {len(self.eval_items)} evaluation items")
async def _load_dataset(self) -> None:
"""Load and process the PHYBench dataset."""
print(f"Loading PHYBench dataset: {self.config.dataset_name}...")
try:
dataset = load_dataset(
self.config.dataset_name,
trust_remote_code=True,
)
except Exception as e:
print(f"Error loading dataset: {e}")
raise
if self.config.eval_split not in dataset:
available_splits = list(dataset.keys())
raise ValueError(
f"Split '{self.config.eval_split}' not found. Available: {available_splits}"
)
split_data = dataset[self.config.eval_split]
# Process items
self.eval_items = []
tag_counts: Dict[str, int] = {}
for item in split_data:
problem_id = item.get("id", "")
tag = item.get("tag", "UNKNOWN")
content = item.get("content", "")
solution = item.get("solution", "")
answer = item.get("answer", "")
# Skip if no content or answer
if not content or not answer:
continue
# Apply tag filter if specified
if self.config.tags_filter and tag not in self.config.tags_filter:
continue
# Track tag distribution
tag_counts[tag] = tag_counts.get(tag, 0) + 1
self.eval_items.append({
"id": problem_id,
"tag": tag,
"content": content,
"solution": solution,
"answer": answer,
})
# Shuffle with seed for reproducibility
random.seed(self.config.shuffle_seed)
random.shuffle(self.eval_items)
# Apply max_samples limit if specified
if self.config.max_samples and len(self.eval_items) > self.config.max_samples:
self.eval_items = self.eval_items[: self.config.max_samples]
self._dataset_loaded = True
# Print tag distribution
print(f"Loaded {len(self.eval_items)} items")
print("Tag distribution:")
for tag, count in sorted(tag_counts.items()):
print(f" {tag}: {count}")
def _format_prompt(self, item: Dict) -> str:
"""Format the problem into a prompt."""
if self.config.use_detailed_prompt:
return PHYBENCH_DETAILED_PROMPT_TEMPLATE.format(problem=item["content"])
return PHYBENCH_PROMPT_TEMPLATE.format(problem=item["content"])
def _create_system_content(self) -> str:
"""Create system message content based on thinking mode."""
return (
create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt,
)
or ""
)
def _extract_answer(self, response: str, debug: bool = False) -> Tuple[Optional[str], str]:
"""
Extract the answer from the model's response.
Looks for \\boxed{} content. If multiple found, uses the last one.
Args:
response: Model's response text
debug: Whether to print debug info
Returns:
Tuple of (extracted_answer, extraction_method)
"""
if not response:
return None, "empty_response"
# Find all boxed answers
boxed_answers = extract_all_boxed(response)
if not boxed_answers:
if debug:
print(" No \\boxed{} found in response")
return None, "no_boxed"
if len(boxed_answers) > 1:
if debug:
print(f" Multiple \\boxed{{}} found ({len(boxed_answers)}), using last one")
return boxed_answers[-1], "boxed_last"
return boxed_answers[0], "boxed"
def _check_equivalence(
self,
predicted: str,
gold: str,
debug: bool = False,
) -> Tuple[bool, str]:
"""
Check if predicted answer is equivalent to gold answer.
Uses SymPy for symbolic equivalence checking.
Args:
predicted: Predicted answer in LaTeX
gold: Gold answer in LaTeX
debug: Whether to print debug info
Returns:
Tuple of (is_correct, method)
"""
if not predicted:
return False, "empty_prediction"
# Clean up the answers
pred_clean = predicted.strip()
gold_clean = gold.strip()
# Exact string match
if pred_clean == gold_clean:
return True, "exact_match"
# Try EED Score - if score is 100, they're equivalent
if self.config.compute_eed_score and EED_AVAILABLE:
try:
score, _, _, _ = compute_eed_score(gold_clean, pred_clean, debug_mode=False)
if score == 100:
return True, "sympy_equivalent"
except Exception:
pass
return False, "not_equivalent"
def _compute_scores(
self,
predicted: str,
gold: str,
debug: bool = False,
) -> Dict:
"""
Compute both accuracy and EED Score.
Args:
predicted: Predicted answer
gold: Gold answer
debug: Whether to print debug info
Returns:
Dictionary with scoring results
"""
result = {
"is_correct": False,
"match_method": "none",
"eed_score": 0.0,
"eed_rel_distance": -1,
"eed_tree_size": -1,
"eed_distance": -1,
}
if not predicted:
return result
# Check equivalence (for binary accuracy)
is_correct, match_method = self._check_equivalence(predicted, gold, debug)
result["is_correct"] = is_correct
result["match_method"] = match_method
# Compute EED Score if enabled
if self.config.compute_eed_score and EED_AVAILABLE:
try:
eed_score, rel_dist, tree_size, distance = compute_eed_score(
gold, predicted, debug_mode=debug
)
result["eed_score"] = eed_score
result["eed_rel_distance"] = rel_dist
result["eed_tree_size"] = tree_size
result["eed_distance"] = distance
# If EED score is 100, mark as correct
if eed_score == 100 and not is_correct:
result["is_correct"] = True
result["match_method"] = "eed_equivalent"
except Exception as e:
if debug:
print(f" EED Score error: {e}")
return result
async def rollout_and_score_eval(
self,
item: Dict,
server: APIServerConfig,
) -> Optional[Dict]:
"""Run evaluation on a single item and return the result."""
prompt = self._format_prompt(item)
system_content = self._create_system_content()
messages = []
if system_content:
messages.append({"role": "system", "content": system_content})
messages.append({"role": "user", "content": prompt})
# Build API call parameters
kwargs = {
"model": server.model_name,
"messages": messages,
"temperature": self.config.eval_temperature,
}
if self.config.eval_max_tokens > 0:
kwargs["max_tokens"] = self.config.eval_max_tokens
response_text = ""
for attempt in range(self.config.max_retries):
try:
response = await self.server.chat_completion(**kwargs)
response_text = response.choices[0].message.content or ""
if len(response_text) >= self.config.min_response_length:
break
except Exception as e:
if self.config.full_debug:
print(f" API error (attempt {attempt + 1}): {e}")
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay)
continue
if not response_text:
return None
# Validate thinking format and extract content after </think>
is_valid_format, content_for_extraction = validate_thinking_format(
response_text, self.config.thinking_mode
)
# Extract thinking content if present
thinking_content = (
extract_thinking_content(response_text) if self.config.thinking_mode else None
)
# Get content for answer extraction
if self.config.thinking_mode:
match = THINK_CONTENT_AFTER_PATTERN.search(response_text)
if match:
answer_content = match.group(1)
else:
answer_content = response_text
else:
answer_content = response_text
# Extract answer
extracted_answer, extraction_method = self._extract_answer(
answer_content, debug=self.config.full_debug
)
# Compute scores
gold_answer = item["answer"]
scores = self._compute_scores(extracted_answer, gold_answer, debug=self.config.full_debug)
if self.config.full_debug:
status = "" if scores["is_correct"] else ""
eed = scores["eed_score"]
print(f" [{status}] {item['tag']}: EED={eed:.1f}, gold={gold_answer[:50]}...")
return {
"item_id": item["id"],
"tag": item["tag"],
"content": item["content"][:200],
"gold_answer": gold_answer,
"extracted_answer": extracted_answer,
"extraction_method": extraction_method,
"is_correct": scores["is_correct"],
"match_method": scores["match_method"],
"eed_score": scores["eed_score"],
"eed_rel_distance": scores["eed_rel_distance"],
"eed_tree_size": scores["eed_tree_size"],
"eed_distance": scores["eed_distance"],
"format_valid": is_valid_format,
"response": response_text,
"response_length": len(response_text),
"thinking_content": thinking_content,
"has_thinking": thinking_content is not None,
}
async def evaluate(self, *args, **kwargs) -> Dict:
"""Run the full PHYBench evaluation."""
print(f"\n{'='*60}")
print("Starting PHYBench Evaluation")
print(f"{'='*60}")
print(f" Total questions: {len(self.eval_items)}")
print(f" Thinking mode: {self.config.thinking_mode}")
print(f" EED Score: {self.config.compute_eed_score and EED_AVAILABLE}")
print(f"{'='*60}\n")
# Create evaluation tasks
async def eval_task(item):
return await self.rollout_and_score_eval(item, self.server_configs[0])
tasks = [eval_task(item) for item in self.eval_items]
# Run with progress bar
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating PHYBench")
# Filter out failed results
valid_results = [r for r in results if r is not None]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return {"error": "No valid results", "accuracy": 0.0}
# Calculate metrics
total = len(valid_results)
correct = sum(1 for r in valid_results if r["is_correct"])
accuracy = correct / total if total > 0 else 0.0
# Calculate average EED Score
eed_scores = [r["eed_score"] for r in valid_results if r["eed_score"] >= 0]
avg_eed_score = sum(eed_scores) / len(eed_scores) if eed_scores else 0.0
# Calculate per-tag metrics
tag_metrics: Dict[str, Dict] = {}
for r in valid_results:
tag = r.get("tag", "UNKNOWN")
if tag not in tag_metrics:
tag_metrics[tag] = {"total": 0, "correct": 0, "eed_scores": []}
tag_metrics[tag]["total"] += 1
if r["is_correct"]:
tag_metrics[tag]["correct"] += 1
if r["eed_score"] >= 0:
tag_metrics[tag]["eed_scores"].append(r["eed_score"])
for tag in tag_metrics:
t_total = tag_metrics[tag]["total"]
t_correct = tag_metrics[tag]["correct"]
t_eed_scores = tag_metrics[tag]["eed_scores"]
tag_metrics[tag]["accuracy"] = t_correct / t_total if t_total > 0 else 0.0
tag_metrics[tag]["avg_eed_score"] = (
sum(t_eed_scores) / len(t_eed_scores) if t_eed_scores else 0.0
)
# Calculate extraction method statistics
extraction_methods: Dict[str, int] = {}
for r in valid_results:
method = r.get("extraction_method", "unknown")
extraction_methods[method] = extraction_methods.get(method, 0) + 1
# Format compliance and thinking utilization
format_valid = sum(1 for r in valid_results if r.get("format_valid", True))
has_thinking = sum(1 for r in valid_results if r.get("has_thinking", False))
has_boxed = sum(1 for r in valid_results if r.get("extracted_answer") is not None)
# Average response length
response_lengths = [r.get("response_length", 0) for r in valid_results]
avg_response_length = sum(response_lengths) / len(response_lengths) if response_lengths else 0
metrics = {
"accuracy": accuracy,
"avg_eed_score": avg_eed_score,
"total_evaluated": total,
"total_correct": correct,
"has_boxed_rate": has_boxed / total if total > 0 else 0.0,
"format_compliance_rate": format_valid / total if total > 0 else 0.0,
"thinking_utilization_rate": has_thinking / total if total > 0 else 0.0,
"avg_response_length": avg_response_length,
"tag_metrics": tag_metrics,
"extraction_methods": extraction_methods,
}
# Print summary
print(f"\n{'='*60}")
print("PHYBench Evaluation Results")
print(f"{'='*60}")
print(f" Overall Accuracy: {accuracy:.2%} ({correct}/{total})")
print(f" Average EED Score: {avg_eed_score:.1f}/100")
print(f" Has \\boxed{{}} Rate: {has_boxed / total:.2%}")
print(f" Avg Response Length: {avg_response_length:.0f} chars")
if self.config.thinking_mode:
print(f" Format Compliance: {format_valid / total:.2%}")
print(f" Thinking Utilization: {has_thinking / total:.2%}")
print("\n Per-Tag Breakdown:")
for tag in sorted(tag_metrics.keys()):
data = tag_metrics[tag]
acc = data["accuracy"]
eed = data["avg_eed_score"]
cnt = data["total"]
print(f" {tag}: Acc={acc:.2%}, EED={eed:.1f} ({cnt} items)")
print("\n Extraction Methods:")
for method, count in sorted(extraction_methods.items(), key=lambda x: -x[1]):
print(f" {method}: {count}")
print(f"{'='*60}\n")
# Save results
if self.config.data_dir_to_save_evals:
self._save_results(metrics, valid_results)
return metrics
def _save_results(self, metrics: Dict, results: List[Dict]) -> None:
"""Save evaluation results to disk."""
save_eval_results(self.config.data_dir_to_save_evals, metrics, results)
async def wandb_log(self, metrics: Dict, step: int = 0) -> None:
"""Log metrics to Weights & Biases."""
if not self.config.use_wandb:
return
log_dict = {
"phybench/accuracy": metrics.get("accuracy", 0),
"phybench/avg_eed_score": metrics.get("avg_eed_score", 0),
"phybench/total_evaluated": metrics.get("total_evaluated", 0),
"phybench/has_boxed_rate": metrics.get("has_boxed_rate", 0),
"phybench/format_compliance_rate": metrics.get("format_compliance_rate", 0),
"phybench/thinking_utilization_rate": metrics.get("thinking_utilization_rate", 0),
"phybench/avg_response_length": metrics.get("avg_response_length", 0),
}
# Log per-tag metrics
for tag, data in metrics.get("tag_metrics", {}).items():
safe_tag = tag.lower()
log_dict[f"phybench/accuracy_{safe_tag}"] = data.get("accuracy", 0)
log_dict[f"phybench/eed_score_{safe_tag}"] = data.get("avg_eed_score", 0)
wandb.log(log_dict, step=step)
# Required abstract method implementations
async def get_next_item(self) -> Optional[Dict]:
"""Not used in evaluation mode."""
return None
async def collect_trajectories(self, item) -> List:
"""Not used in evaluation mode."""
return []
async def score(self, rollout_group_data) -> Optional[List]:
"""Not used in evaluation mode."""
return None
if __name__ == "__main__":
PHYBenchEvalEnv.cli()