atropos/environments/t1_tool_planning/t1_scoring.py
2026-03-03 19:51:30 -06:00

300 lines
8.8 KiB
Python

"""
Scoring for T1 tool planning environment.
Parses ground truth Python code with AST to extract tool calls,
then compares against structured tool_calls from the model response.
"""
import ast
import json
import logging
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# Tools that T1 considers "real" (not just print or cache ops)
SEARCH_FILTER_TOOLS = {
"search_hotels",
"filter_hotels",
"search_flights",
"filter_flights",
"search_restaurants",
"filter_restaurants",
"search_attractions",
"filter_attractions",
"search_nearest",
"sort_results",
"adjust_date",
"seek_information",
}
CACHE_TOOLS = {"save_to_cache", "get_results_from_cache"}
ALL_TOOLS = SEARCH_FILTER_TOOLS | CACHE_TOOLS
def parse_ground_truth_code(code: str) -> List[Dict[str, Any]]:
"""Parse ground truth Python code into a list of tool calls.
Args:
code: Python code string from Filled_Plan column.
Returns:
List of {"name": str, "arguments": dict} dicts.
"""
if not code or not isinstance(code, str):
return []
code = code.strip()
if not code or code == 'print("No planning needed")':
return []
try:
tree = ast.parse(code)
except SyntaxError:
logger.warning("Failed to parse ground truth code: %s", code[:100])
return []
calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
# Get function name
if isinstance(node.func, ast.Name):
func_name = node.func.id
elif isinstance(node.func, ast.Attribute):
func_name = node.func.attr
else:
continue
if func_name not in ALL_TOOLS:
continue
# Extract keyword arguments
args = {}
for kw in node.keywords:
if kw.arg is None:
continue
try:
args[kw.arg] = ast.literal_eval(kw.value)
except (ValueError, TypeError):
# For variable references (like prior_results=hotels),
# store as string
if isinstance(kw.value, ast.Name):
args[kw.arg] = kw.value.id
else:
args[kw.arg] = ast.dump(kw.value)
calls.append({"name": func_name, "arguments": args})
return calls
def parse_model_tool_calls(tool_calls: Optional[List[dict]]) -> List[Dict[str, Any]]:
"""Normalize model's structured tool_calls into comparable format.
Args:
tool_calls: List of tool call dicts from ChatCompletion response.
Returns:
List of {"name": str, "arguments": dict} dicts.
"""
if not tool_calls:
return []
result = []
for tc in tool_calls:
func = tc.get("function", {})
name = func.get("name", "")
args_str = func.get("arguments", "{}")
try:
args = json.loads(args_str) if isinstance(args_str, str) else args_str
except (json.JSONDecodeError, TypeError):
args = {}
result.append({"name": name, "arguments": args})
return result
def tool_call_f1(
ground_truth: List[Dict[str, Any]],
generated: List[Dict[str, Any]],
) -> Tuple[float, float, float]:
"""Compute precision, recall, F1 on tool names.
Compares the multiset of tool names called.
"""
gt_names = Counter(c["name"] for c in ground_truth)
gen_names = Counter(c["name"] for c in generated)
# Remove "print" if other tools exist
if len(gt_names) > 1:
gt_names.pop("print", None)
if len(gen_names) > 1:
gen_names.pop("print", None)
if not gt_names and not gen_names:
return 1.0, 1.0, 1.0 # both empty = correct
tp = sum((gt_names & gen_names).values())
precision = tp / sum(gen_names.values()) if gen_names else 0.0
recall = tp / sum(gt_names.values()) if gt_names else 0.0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
return precision, recall, f1
def tool_param_f1(
ground_truth: List[Dict[str, Any]],
generated: List[Dict[str, Any]],
) -> Tuple[float, float, float]:
"""Compute precision, recall, F1 on tool parameters.
For each matching tool name pair, compares the argument keys and values.
"""
if not ground_truth and not generated:
return 1.0, 1.0, 1.0
# Match tool calls by name (greedy matching)
gt_remaining = list(ground_truth)
matched_pairs = []
for gen_call in generated:
for i, gt_call in enumerate(gt_remaining):
if gt_call["name"] == gen_call["name"]:
matched_pairs.append((gt_call, gen_call))
gt_remaining.pop(i)
break
if not matched_pairs:
return 0.0, 0.0, 0.0
total_tp = 0
total_gt_params = 0
total_gen_params = 0
for gt_call, gen_call in matched_pairs:
gt_args = gt_call.get("arguments", {})
gen_args = gen_call.get("arguments", {})
total_gt_params += len(gt_args)
total_gen_params += len(gen_args)
for key in gt_args:
if key in gen_args:
# Loose comparison — normalize types
gt_val = gt_args[key]
gen_val = gen_args[key]
if _values_match(gt_val, gen_val):
total_tp += 1
precision = total_tp / total_gen_params if total_gen_params > 0 else 0.0
recall = total_tp / total_gt_params if total_gt_params > 0 else 0.0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
return precision, recall, f1
def _values_match(gt_val: Any, gen_val: Any) -> bool:
"""Loose comparison of argument values."""
if gt_val == gen_val:
return True
# String comparison (case-insensitive)
if isinstance(gt_val, str) and isinstance(gen_val, str):
return gt_val.lower().strip() == gen_val.lower().strip()
# List comparison (order-insensitive for some)
if isinstance(gt_val, list) and isinstance(gen_val, list):
if len(gt_val) == len(gen_val):
return sorted(str(v) for v in gt_val) == sorted(str(v) for v in gen_val)
# Number comparison
try:
if float(gt_val) == float(gen_val):
return True
except (ValueError, TypeError):
pass
return False
def score_turn(
ground_truth_code: str,
model_tool_calls: Optional[List[dict]],
model_content: Optional[str] = None,
) -> Dict[str, float]:
"""Score a single turn.
Args:
ground_truth_code: Python code from Filled_Plan column.
model_tool_calls: Structured tool_calls from model response.
model_content: Text content from model response (for seek_information).
Returns:
Dict with scores: tool_call_f1, tool_param_f1, correct_no_op, reward
"""
gt_calls = parse_ground_truth_code(ground_truth_code)
gen_calls = parse_model_tool_calls(model_tool_calls)
# Handle "no planning needed" case
gt_is_noop = len(gt_calls) == 0
gen_is_noop = len(gen_calls) == 0
if gt_is_noop and gen_is_noop:
return {
"tool_call_f1": 1.0,
"tool_param_f1": 1.0,
"correct_no_op": 1.0,
"reward": 1.0,
}
if gt_is_noop and not gen_is_noop:
return {
"tool_call_f1": 0.0,
"tool_param_f1": 0.0,
"correct_no_op": 0.0,
"reward": 0.0,
}
# Check if this is a seek_information turn
gt_is_seek = any(c["name"] == "seek_information" for c in gt_calls)
tc_p, tc_r, tc_f1 = tool_call_f1(gt_calls, gen_calls)
tp_p, tp_r, tp_f1 = tool_param_f1(gt_calls, gen_calls)
# Composite reward — graduated so GRPO gets signal even with weak models
#
# GT expects tools but model produced none → 0.0 (worst)
# GT expects tools, model called tools but wrong ones → 0.1 (format credit)
# GT expects tools, model called some right ones → 0.1 + 0.5*tc_f1 + 0.3*tp_f1
# Perfect match → 0.1 + 0.5 + 0.3 + 0.1 = 1.0
if not gt_is_noop and gen_is_noop:
# GT expects tool calls but model produced none
reward = 0.0
else:
# Model attempted tool calls
reward = 0.1 # format credit: produced valid tool call structure
reward += 0.5 * tc_f1
reward += 0.3 * tp_f1
# Bonus for getting all tools right
if tc_f1 == 1.0:
reward += 0.1
return {
"tool_call_precision": tc_p,
"tool_call_recall": tc_r,
"tool_call_f1": tc_f1,
"tool_param_precision": tp_p,
"tool_param_recall": tp_r,
"tool_param_f1": tp_f1,
"correct_no_op": 0.0,
"is_seek_info": float(gt_is_seek),
"reward": reward,
}