mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add env using the tool api stuff
This commit is contained in:
parent
c8eb63f33d
commit
12d61d197f
15 changed files with 2632 additions and 21 deletions
19
environments/t1_tool_planning/configs/default.yaml
Normal file
19
environments/t1_tool_planning/configs/default.yaml
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
env:
|
||||
tokenizer_name: "Qwen/Qwen3-1.7B"
|
||||
rollout_server_url: "http://localhost:8002"
|
||||
max_token_length: 4096
|
||||
start_tok_length: 4096
|
||||
group_size: 2
|
||||
batch_size: 8
|
||||
total_steps: 200
|
||||
steps_per_eval: 25
|
||||
use_wandb: false
|
||||
wandb_name: "t1-tool-planning-env"
|
||||
eval_limit_ratio: 0.1
|
||||
max_num_workers_per_node: 8
|
||||
|
||||
openai:
|
||||
model_name: "Qwen/Qwen3-1.7B"
|
||||
base_url: "http://localhost:9001/v1"
|
||||
api_key: "x"
|
||||
server_type: "vllm"
|
||||
164
environments/t1_tool_planning/inspect_nodes.py
Normal file
164
environments/t1_tool_planning/inspect_nodes.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Inspect multi-step node output to verify extending works correctly."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
REPO_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||
MODEL = "Qwen/Qwen3-1.7B"
|
||||
PORT = 8123
|
||||
|
||||
|
||||
def start_vllm():
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"example_trainer.vllm_api_server",
|
||||
"--model",
|
||||
MODEL,
|
||||
"--port",
|
||||
str(PORT),
|
||||
"--gpu-memory-utilization",
|
||||
"0.45",
|
||||
"--enforce-eager",
|
||||
]
|
||||
proc = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=REPO_ROOT
|
||||
)
|
||||
deadline = time.time() + 180
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
r = requests.get(f"http://localhost:{PORT}/health", timeout=2)
|
||||
if r.status_code == 200:
|
||||
print("vLLM ready")
|
||||
return proc
|
||||
except Exception:
|
||||
pass
|
||||
if proc.poll() is not None:
|
||||
out = proc.stdout.read().decode()[-2000:]
|
||||
print(f"vLLM died:\n{out}")
|
||||
sys.exit(1)
|
||||
time.sleep(3)
|
||||
proc.kill()
|
||||
print("vLLM timeout")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
async def main():
|
||||
from t1_core import collect_multistep_trajectory
|
||||
from t1_tools import T1_TOOLS
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
config = APIServerConfig(
|
||||
model_name=MODEL,
|
||||
base_url=f"http://localhost:{PORT}/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
server = ServerManager(
|
||||
configs=[config], slurm=False, testing=False, tool_parser="hermes"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
||||
|
||||
convo = [
|
||||
{
|
||||
"Role": "assistant",
|
||||
"Filled_Template": "Hello! How can I help?",
|
||||
"Filled_Plan": "",
|
||||
},
|
||||
{
|
||||
"Role": "user",
|
||||
"Filled_Template": "Find hotels in Austin, check-in May 10, check-out May 15, 2025.",
|
||||
"Filled_Plan": 'search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])', # noqa: E501
|
||||
},
|
||||
{
|
||||
"Role": "assistant",
|
||||
"Filled_Template": "Found some. Filter?",
|
||||
"Filled_Plan": "",
|
||||
},
|
||||
{
|
||||
"Role": "user",
|
||||
"Filled_Template": "Yes, free wifi.",
|
||||
"Filled_Plan": "filter_hotels(prior_results=hotels, free_wifi_included=True)",
|
||||
},
|
||||
]
|
||||
|
||||
turn_results, nodes = await collect_multistep_trajectory(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
conversation=convo,
|
||||
tools=T1_TOOLS,
|
||||
max_tokens=500,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
print(f"\nNodes: {len(nodes)}")
|
||||
node = nodes[0]
|
||||
|
||||
unmasked_idx = [i for i, t in enumerate(node.masked_tokens) if t != -100]
|
||||
masked_idx = [i for i, t in enumerate(node.masked_tokens) if t == -100]
|
||||
first_u = unmasked_idx[0] if unmasked_idx else 0
|
||||
|
||||
print(
|
||||
f"Total: {len(node.tokens)} | Masked: {len(masked_idx)} | Unmasked: {len(unmasked_idx)}"
|
||||
)
|
||||
|
||||
# Check contiguity
|
||||
gaps = []
|
||||
for j in range(1, len(unmasked_idx)):
|
||||
if unmasked_idx[j] != unmasked_idx[j - 1] + 1:
|
||||
gaps.append((unmasked_idx[j - 1], unmasked_idx[j]))
|
||||
print(f"Unmasked contiguous: {not gaps} Gaps: {gaps}")
|
||||
|
||||
# Decode
|
||||
prompt_text = tokenizer.decode(node.tokens[:first_u], skip_special_tokens=False)
|
||||
comp_tokens = [node.tokens[i] for i in unmasked_idx]
|
||||
comp_text = tokenizer.decode(comp_tokens, skip_special_tokens=False)
|
||||
|
||||
print("\n--- PROMPT TAIL (last 150 chars) ---")
|
||||
print(prompt_text[-150:])
|
||||
print("\n--- COMPLETION (unmasked, first 400 chars) ---")
|
||||
print(comp_text[:400])
|
||||
print("\n--- COMPLETION (unmasked, last 200 chars) ---")
|
||||
print(comp_text[-200:])
|
||||
|
||||
print(f"\nPrompt logprobs sample (should be 1.0): {node.logprobs[:3]}")
|
||||
print(f"Completion logprobs sample: {[node.logprobs[i] for i in unmasked_idx[:5]]}")
|
||||
|
||||
for tr in turn_results:
|
||||
tc = len(tr["tool_calls"]) if tr["tool_calls"] else 0
|
||||
print(
|
||||
f"\nTurn {tr['turn_idx']}: {tc} tool_calls, reward={tr['scores']['reward']:.2f}"
|
||||
)
|
||||
if tr["tool_calls"]:
|
||||
for t in tr["tool_calls"]:
|
||||
print(f" {t['function']['name']}({t['function']['arguments'][:80]})")
|
||||
else:
|
||||
print(f" text: {(tr['content'] or '')[:80]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
proc = start_vllm()
|
||||
try:
|
||||
asyncio.run(main())
|
||||
finally:
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
297
environments/t1_tool_planning/t1_core.py
Normal file
297
environments/t1_tool_planning/t1_core.py
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
"""
|
||||
Core T1 tool planning logic — extracted for testability.
|
||||
|
||||
These functions do the actual work: generating tool-calling completions
|
||||
via ManagedServer and scoring them. The env class just orchestrates.
|
||||
|
||||
Two modes:
|
||||
- Single-turn: generate_tool_completions + score_completions
|
||||
- Multi-step: collect_multistep_trajectory walks a full conversation,
|
||||
feeding the model's actual responses back at each turn.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from t1_prompts import SYSTEM_PROMPT
|
||||
from t1_scoring import score_turn
|
||||
from t1_tools import T1_TOOLS
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.managed_server import SequenceNode
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def generate_tool_completions(
|
||||
server: ServerManager,
|
||||
tokenizer: Any,
|
||||
messages: List[Dict[str, str]],
|
||||
tools: List[dict] = None,
|
||||
n: int = 4,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 1.0,
|
||||
tool_choice: str = "auto",
|
||||
split: str = "train",
|
||||
tool_parser: str = "hermes",
|
||||
) -> Tuple[Any, List[SequenceNode]]:
|
||||
"""Generate tool-calling completions and return result + tracked nodes.
|
||||
|
||||
Args:
|
||||
server: ServerManager with backends configured
|
||||
tokenizer: Tokenizer for the model
|
||||
messages: Chat messages to complete
|
||||
tools: OpenAI function tool definitions (defaults to T1_TOOLS)
|
||||
n: Number of completions to generate
|
||||
max_tokens: Max tokens per completion
|
||||
temperature: Sampling temperature
|
||||
tool_choice: "auto", "none", or "required"
|
||||
split: "train" or "eval" (for server load balancing)
|
||||
tool_parser: vLLM tool parser name
|
||||
|
||||
Returns:
|
||||
(ChatCompletion, list of SequenceNodes)
|
||||
"""
|
||||
if tools is None:
|
||||
tools = T1_TOOLS
|
||||
|
||||
logger.info(
|
||||
f"generate_tool_completions: n={n}, max_tokens={max_tokens}, "
|
||||
f"temp={temperature}, tool_choice={tool_choice}, "
|
||||
f"num_messages={len(messages)}"
|
||||
)
|
||||
|
||||
async with server.managed_server(tokenizer=tokenizer) as managed:
|
||||
logger.debug(
|
||||
f" ManagedServer opened (tool_parser={managed._tool_parser_name})"
|
||||
)
|
||||
|
||||
result = await managed.chat_completion(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
n=n,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
split=split,
|
||||
)
|
||||
|
||||
logger.debug(f" Got {len(result.choices)} choices")
|
||||
for i, c in enumerate(result.choices):
|
||||
tc_count = len(c.message.tool_calls) if c.message.tool_calls else 0
|
||||
content_preview = (c.message.content or "")[:80]
|
||||
logger.debug(
|
||||
f" choice[{i}]: {tc_count} tool_calls, content={content_preview!r}"
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
logger.debug(f" Got {len(nodes)} tracked nodes")
|
||||
|
||||
return result, nodes
|
||||
|
||||
|
||||
def score_completions(
|
||||
result: Any,
|
||||
nodes: List[SequenceNode],
|
||||
gt_code: str,
|
||||
min_unmasked_tokens: int = 5,
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Dict[str, float]]]:
|
||||
"""Score completions against ground truth and build a ScoredDataGroup.
|
||||
|
||||
Args:
|
||||
result: ChatCompletion from generate_tool_completions
|
||||
nodes: SequenceNodes from generate_tool_completions
|
||||
gt_code: Ground truth Python code (Filled_Plan)
|
||||
min_unmasked_tokens: Skip choices with fewer unmasked tokens
|
||||
|
||||
Returns:
|
||||
(ScoredDataGroup or None, list of per-choice score dicts)
|
||||
"""
|
||||
logger.debug(
|
||||
f"score_completions: {len(result.choices)} choices, "
|
||||
f"{len(nodes)} nodes, gt_code={gt_code[:60]}..."
|
||||
)
|
||||
|
||||
all_scores = []
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
scores["inference_logprobs"] = []
|
||||
|
||||
for i, (choice, node) in enumerate(zip(result.choices, nodes)):
|
||||
turn_scores = score_turn(
|
||||
gt_code, choice.message.tool_calls, choice.message.content
|
||||
)
|
||||
all_scores.append(turn_scores)
|
||||
logger.debug(f" choice[{i}] scores: {turn_scores}")
|
||||
|
||||
unmasked = len([t for t in node.masked_tokens if t != -100])
|
||||
if unmasked < min_unmasked_tokens:
|
||||
logger.debug(f" choice[{i}] skipped: only {unmasked} unmasked tokens")
|
||||
continue
|
||||
|
||||
scores["tokens"].append(node.tokens)
|
||||
scores["masks"].append(node.masked_tokens)
|
||||
scores["inference_logprobs"].append(node.logprobs)
|
||||
scores["scores"].append(turn_scores["reward"])
|
||||
|
||||
if not scores["tokens"]:
|
||||
logger.debug(" -> None (no valid tokens)")
|
||||
return None, all_scores
|
||||
|
||||
if all(s == scores["scores"][0] for s in scores["scores"]):
|
||||
logger.debug(f" -> None (all scores identical: {scores['scores'][0]})")
|
||||
return None, all_scores
|
||||
|
||||
logger.debug(f" -> valid group, scores={scores['scores']}")
|
||||
return scores, all_scores
|
||||
|
||||
|
||||
async def collect_multistep_trajectory(
|
||||
server: ServerManager,
|
||||
tokenizer: Any,
|
||||
conversation: List[Dict[str, str]],
|
||||
tools: List[dict] = None,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
tool_choice: str = "auto",
|
||||
tool_parser: str = "hermes",
|
||||
) -> Tuple[List[Dict[str, Any]], List[SequenceNode]]:
|
||||
"""Walk through a full conversation in ONE managed_server session.
|
||||
|
||||
Uses a single ManagedServer context across all turns so sequence tracking
|
||||
works properly — each turn extends the previous node, building up a full
|
||||
multi-turn trajectory with aligned tokens and logprobs.
|
||||
|
||||
At each user turn:
|
||||
1. Add the user message to the running conversation
|
||||
2. Generate a model response (n=1) via the SAME managed server
|
||||
3. Score against ground truth
|
||||
4. Add the model's ACTUAL response (not GT) to conversation history
|
||||
5. Continue to next turn regardless of quality
|
||||
|
||||
Nodes are collected ONCE at the end from managed.get_state().
|
||||
|
||||
Args:
|
||||
server: ServerManager with backends configured
|
||||
tokenizer: Tokenizer for the model
|
||||
conversation: List of turn dicts with Role, Filled_Template, Filled_Plan
|
||||
tools: Tool definitions (defaults to T1_TOOLS)
|
||||
max_tokens: Max tokens per completion
|
||||
temperature: Sampling temperature
|
||||
tool_choice: "auto", "none", or "required"
|
||||
tool_parser: vLLM tool parser name
|
||||
|
||||
Returns:
|
||||
(turn_results, nodes) where:
|
||||
turn_results: list of per-turn dicts with scores, tool_calls, content
|
||||
nodes: list of SequenceNodes from the managed server (one per turn,
|
||||
each extending the previous — full trajectory with tokens/logprobs)
|
||||
"""
|
||||
if tools is None:
|
||||
tools = T1_TOOLS
|
||||
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
turn_results = []
|
||||
|
||||
logger.info(
|
||||
f"collect_multistep_trajectory: {len(conversation)} turns, temp={temperature}"
|
||||
)
|
||||
|
||||
async with server.managed_server(tokenizer=tokenizer) as managed:
|
||||
for i, turn in enumerate(conversation):
|
||||
role = turn["Role"].strip().lower()
|
||||
|
||||
if role == "assistant":
|
||||
if not turn_results:
|
||||
# First assistant turn (greeting) — use GT since model hasn't spoken yet
|
||||
messages.append(
|
||||
{"role": "assistant", "content": turn["Filled_Template"]}
|
||||
)
|
||||
logger.debug(
|
||||
f" turn[{i}] assistant (GT greeting): {turn['Filled_Template'][:60]}"
|
||||
)
|
||||
# Otherwise skip — we already added the model's response after the previous user turn
|
||||
continue
|
||||
|
||||
if role != "user":
|
||||
continue
|
||||
|
||||
# User turn — add to conversation and generate model response
|
||||
messages.append({"role": "user", "content": turn["Filled_Template"]})
|
||||
gt_code = turn.get("Filled_Plan", "")
|
||||
|
||||
logger.info(f" turn[{i}] user: {turn['Filled_Template'][:60]}...")
|
||||
logger.debug(f" turn[{i}] gt_code: {gt_code[:80]}")
|
||||
|
||||
# Generate within the SAME managed server session
|
||||
result = await managed.chat_completion(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
n=1,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
choice = result.choices[0]
|
||||
|
||||
# Score this turn
|
||||
turn_scores = score_turn(
|
||||
gt_code, choice.message.tool_calls, choice.message.content
|
||||
)
|
||||
|
||||
tc_count = (
|
||||
len(choice.message.tool_calls) if choice.message.tool_calls else 0
|
||||
)
|
||||
logger.info(
|
||||
f" turn[{i}] result: {tc_count} tool_calls, "
|
||||
f"reward={turn_scores['reward']:.2f}, tc_f1={turn_scores['tool_call_f1']:.2f}"
|
||||
)
|
||||
|
||||
turn_results.append(
|
||||
{
|
||||
"turn_idx": i,
|
||||
"user_message": turn["Filled_Template"],
|
||||
"gt_code": gt_code,
|
||||
"content": choice.message.content,
|
||||
"tool_calls": choice.message.tool_calls,
|
||||
"scores": turn_scores,
|
||||
"messages_so_far": [m.copy() for m in messages],
|
||||
}
|
||||
)
|
||||
|
||||
# Add model's ACTUAL response to conversation for next turn
|
||||
assistant_msg = {"role": "assistant"}
|
||||
if choice.message.tool_calls:
|
||||
assistant_msg["tool_calls"] = choice.message.tool_calls
|
||||
assistant_msg["content"] = choice.message.content or ""
|
||||
else:
|
||||
assistant_msg["content"] = choice.message.content or ""
|
||||
messages.append(assistant_msg)
|
||||
|
||||
logger.debug(
|
||||
f" turn[{i}] added assistant msg to conversation (total: {len(messages)})"
|
||||
)
|
||||
|
||||
# Get nodes ONCE at the end — the managed server tracked extending sequences
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
logger.info(
|
||||
f" trajectory complete: {len(turn_results)} turns, {len(nodes)} nodes"
|
||||
)
|
||||
|
||||
# Summary
|
||||
if turn_results:
|
||||
avg_reward = sum(r["scores"]["reward"] for r in turn_results) / len(
|
||||
turn_results
|
||||
)
|
||||
avg_tc_f1 = sum(r["scores"]["tool_call_f1"] for r in turn_results) / len(
|
||||
turn_results
|
||||
)
|
||||
logger.info(f" avg_reward={avg_reward:.3f}, avg_tc_f1={avg_tc_f1:.3f}")
|
||||
|
||||
return turn_results, nodes
|
||||
168
environments/t1_tool_planning/t1_data.py
Normal file
168
environments/t1_tool_planning/t1_data.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""
|
||||
T1 dataset loader — downloads and parses the capitalone/T1 HuggingFace dataset.
|
||||
|
||||
Conversations are returned in the format expected by t1_core:
|
||||
[{"Role": "assistant"|"user", "Filled_Template": str, "Filled_Plan": str}, ...]
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from huggingface_hub import hf_hub_download, list_repo_tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REPO_ID = "capitalone/T1"
|
||||
|
||||
# Single-domain only for now (simpler tool defs, shorter conversations)
|
||||
SINGLE_DOMAINS = ["hotel", "flight", "restaurant", "attraction"]
|
||||
MULTI_DOMAINS = [
|
||||
"flighthotel",
|
||||
"hotelrestaurant",
|
||||
"hotelattraction",
|
||||
"flighthotelrestaurant",
|
||||
"flighthotelattraction",
|
||||
]
|
||||
ALL_DOMAINS = SINGLE_DOMAINS + MULTI_DOMAINS
|
||||
|
||||
|
||||
def _parse_role(filled_template: str) -> Tuple[str, str]:
|
||||
"""Extract role and content from 'role: content' format."""
|
||||
if filled_template.startswith("assistant:"):
|
||||
return "assistant", filled_template[len("assistant:") :].strip()
|
||||
elif filled_template.startswith("user:"):
|
||||
return "user", filled_template[len("user:") :].strip()
|
||||
else:
|
||||
# Fallback — try to guess
|
||||
return "assistant", filled_template.strip()
|
||||
|
||||
|
||||
def _csv_to_conversations(path: str) -> Dict[int, List[dict]]:
|
||||
"""Parse a T1 CSV file into a dict of conversation_id → turns."""
|
||||
df = pd.read_csv(path)
|
||||
conversations = {}
|
||||
|
||||
for conv_id, group in df.groupby("ID"):
|
||||
turns = []
|
||||
for _, row in group.iterrows():
|
||||
template = row["Filled_Template"]
|
||||
plan = row["Filled_Plan"]
|
||||
|
||||
# Skip NaN/empty templates
|
||||
if pd.isna(template) or str(template).strip() in ("", "nan"):
|
||||
continue
|
||||
|
||||
template = str(template)
|
||||
role, content = _parse_role(template)
|
||||
|
||||
# Skip empty content after role extraction
|
||||
if not content.strip():
|
||||
continue
|
||||
|
||||
# NaN plans become empty string
|
||||
if pd.isna(plan) or str(plan).strip() == "nan":
|
||||
plan = ""
|
||||
else:
|
||||
plan = str(plan)
|
||||
|
||||
turns.append(
|
||||
{
|
||||
"Role": role,
|
||||
"Filled_Template": content,
|
||||
"Filled_Plan": plan,
|
||||
}
|
||||
)
|
||||
|
||||
# Only keep conversations with at least one user turn
|
||||
if turns and any(t["Role"] == "user" for t in turns):
|
||||
conversations[int(conv_id)] = turns
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def load_t1_dataset(
|
||||
domains: Optional[List[str]] = None,
|
||||
split: str = "train",
|
||||
max_files_per_domain: Optional[int] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> List[List[dict]]:
|
||||
"""Load T1 conversations from HuggingFace.
|
||||
|
||||
Args:
|
||||
domains: List of domains to load (default: single-domain only)
|
||||
split: "train", "test", or "validation"
|
||||
max_files_per_domain: Limit files per domain (each has 25, ~15 convos each)
|
||||
cache_dir: HF cache directory
|
||||
|
||||
Returns:
|
||||
List of conversations, each a list of turn dicts
|
||||
"""
|
||||
if domains is None:
|
||||
domains = SINGLE_DOMAINS
|
||||
|
||||
all_conversations = []
|
||||
|
||||
# List all CSV files for the requested domains/split
|
||||
repo_files = list(list_repo_tree(REPO_ID, repo_type="dataset", recursive=True))
|
||||
csv_files = [
|
||||
f.path for f in repo_files if hasattr(f, "size") and f.path.endswith(".csv")
|
||||
]
|
||||
|
||||
for domain in domains:
|
||||
prefix = f"{domain}/{split}/"
|
||||
domain_files = sorted([f for f in csv_files if f.startswith(prefix)])
|
||||
|
||||
if max_files_per_domain:
|
||||
domain_files = domain_files[:max_files_per_domain]
|
||||
|
||||
logger.info(f"Loading {len(domain_files)} files from {domain}/{split}")
|
||||
|
||||
for file_path in domain_files:
|
||||
kwargs = {}
|
||||
if cache_dir:
|
||||
kwargs["cache_dir"] = cache_dir
|
||||
|
||||
local_path = hf_hub_download(
|
||||
REPO_ID, file_path, repo_type="dataset", **kwargs
|
||||
)
|
||||
convos = _csv_to_conversations(local_path)
|
||||
all_conversations.extend(convos.values())
|
||||
|
||||
logger.info(f"Loaded {len(all_conversations)} conversations total")
|
||||
return all_conversations
|
||||
|
||||
|
||||
def load_t1_split(
|
||||
domains: Optional[List[str]] = None,
|
||||
max_files_per_domain: Optional[int] = None,
|
||||
eval_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
) -> Tuple[List[List[dict]], List[List[dict]]]:
|
||||
"""Load T1 train conversations and split into train/eval.
|
||||
|
||||
Args:
|
||||
domains: Domains to load
|
||||
max_files_per_domain: Limit files per domain
|
||||
eval_ratio: Fraction of conversations for eval
|
||||
seed: Random seed for split
|
||||
|
||||
Returns:
|
||||
(train_conversations, eval_conversations)
|
||||
"""
|
||||
conversations = load_t1_dataset(
|
||||
domains=domains,
|
||||
split="train",
|
||||
max_files_per_domain=max_files_per_domain,
|
||||
)
|
||||
|
||||
rng = random.Random(seed)
|
||||
rng.shuffle(conversations)
|
||||
|
||||
n_eval = max(1, int(len(conversations) * eval_ratio))
|
||||
eval_convos = conversations[:n_eval]
|
||||
train_convos = conversations[n_eval:]
|
||||
|
||||
logger.info(f"Split: {len(train_convos)} train, {len(eval_convos)} eval")
|
||||
return train_convos, eval_convos
|
||||
237
environments/t1_tool_planning/t1_env.py
Normal file
237
environments/t1_tool_planning/t1_env.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
"""
|
||||
T1 Tool-Integrated Reasoning environment for Atropos.
|
||||
|
||||
Multi-step tool calling with full trajectory tracking:
|
||||
- Loads the capitalone/T1 dataset from HuggingFace
|
||||
- Walks through complete conversations, feeding model's actual responses back
|
||||
- One ManagedServer session per trajectory → one extending node with aligned tokens
|
||||
- GRPO over group_size independent trajectories per conversation
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from t1_core import collect_multistep_trajectory # noqa: E402
|
||||
from t1_data import SINGLE_DOMAINS, load_t1_split # noqa: E402
|
||||
from t1_tools import T1_TOOLS # noqa: E402
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class T1ToolPlanningEnv(BaseEnv):
|
||||
"""T1 Tool-Integrated Reasoning environment — multi-step trajectories.
|
||||
|
||||
Each trajectory walks a full conversation:
|
||||
- Model generates responses at each user turn
|
||||
- Model's actual output is fed back (not GT) for the next turn
|
||||
- One ManagedServer session → one extending node per trajectory
|
||||
- GRPO compares group_size independent trajectories on the same conversation
|
||||
"""
|
||||
|
||||
name = "t1_tool_planning"
|
||||
|
||||
def __init__(self, config, server_configs, slurm=True, testing=False):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
# BaseEnv doesn't pass tool_parser to ServerManager — set it here
|
||||
# so ManagedServer creates a ToolCallTranslator for hermes-style tool calls
|
||||
self.server.tool_parser = "hermes"
|
||||
self.reward_buffer = []
|
||||
self.tc_f1_buffer = []
|
||||
self.tp_f1_buffer = []
|
||||
self.eval_metrics = []
|
||||
self.iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
model_name = "Qwen/Qwen3-1.7B"
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name=model_name,
|
||||
group_size=4,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8002",
|
||||
total_steps=200,
|
||||
batch_size=16,
|
||||
steps_per_eval=25,
|
||||
max_token_length=4096,
|
||||
start_tok_length=4096,
|
||||
wandb_name="t1-tool-planning",
|
||||
eval_limit_ratio=0.1,
|
||||
max_num_workers_per_node=8,
|
||||
)
|
||||
server_config = APIServerConfig(
|
||||
model_name=model_name,
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
# MUST return as a list — single APIServerConfig (not in list) causes
|
||||
# ServerManager to ignore base_url and auto-generate ports 9004-9007
|
||||
return env_config, [server_config]
|
||||
|
||||
async def setup(self):
|
||||
logger.info("=== T1ToolPlanningEnv.setup() starting ===")
|
||||
|
||||
# Load real T1 dataset from HuggingFace
|
||||
# Start with single-domain, 2 files per domain (~30 convos each = ~120 total)
|
||||
# Increase max_files_per_domain for more data
|
||||
self.train_conversations, self.eval_conversations = load_t1_split(
|
||||
domains=SINGLE_DOMAINS,
|
||||
eval_ratio=0.1,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Setup complete: {len(self.train_conversations)} train conversations, "
|
||||
f"{len(self.eval_conversations)} eval conversations"
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
if self.reward_buffer:
|
||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / len(
|
||||
self.reward_buffer
|
||||
)
|
||||
if self.tc_f1_buffer:
|
||||
wandb_metrics["train/tool_call_f1"] = sum(self.tc_f1_buffer) / len(
|
||||
self.tc_f1_buffer
|
||||
)
|
||||
if self.tp_f1_buffer:
|
||||
wandb_metrics["train/tool_param_f1"] = sum(self.tp_f1_buffer) / len(
|
||||
self.tp_f1_buffer
|
||||
)
|
||||
self.reward_buffer = []
|
||||
self.tc_f1_buffer = []
|
||||
self.tp_f1_buffer = []
|
||||
for k, v in self.eval_metrics:
|
||||
wandb_metrics[k] = v
|
||||
self.eval_metrics = []
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
logger.info(
|
||||
f"=== evaluate() starting ({len(self.eval_conversations)} conversations) ==="
|
||||
)
|
||||
all_rewards = []
|
||||
all_tc_f1 = []
|
||||
all_tp_f1 = []
|
||||
|
||||
for convo in self.eval_conversations:
|
||||
turn_results, nodes = await collect_multistep_trajectory(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
conversation=convo,
|
||||
tools=T1_TOOLS,
|
||||
max_tokens=512,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
for tr in turn_results:
|
||||
all_rewards.append(tr["scores"]["reward"])
|
||||
all_tc_f1.append(tr["scores"]["tool_call_f1"])
|
||||
all_tp_f1.append(tr["scores"]["tool_param_f1"])
|
||||
|
||||
if all_rewards:
|
||||
self.eval_metrics.append(
|
||||
("eval/avg_reward", sum(all_rewards) / len(all_rewards))
|
||||
)
|
||||
self.eval_metrics.append(
|
||||
("eval/tool_call_f1", sum(all_tc_f1) / len(all_tc_f1))
|
||||
)
|
||||
self.eval_metrics.append(
|
||||
("eval/tool_param_f1", sum(all_tp_f1) / len(all_tp_f1))
|
||||
)
|
||||
|
||||
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
|
||||
convo = item
|
||||
user_turns = sum(1 for t in convo if t["Role"].strip().lower() == "user")
|
||||
logger.info(
|
||||
f"collect_trajectories: {len(convo)} turns ({user_turns} user), group_size={self.config.group_size}"
|
||||
)
|
||||
|
||||
scored = ScoredDataGroup()
|
||||
scored["tokens"] = []
|
||||
scored["masks"] = []
|
||||
scored["scores"] = []
|
||||
scored["inference_logprobs"] = []
|
||||
|
||||
# Run group_size independent trajectories on the same conversation
|
||||
for g in range(self.config.group_size):
|
||||
turn_results, nodes = await collect_multistep_trajectory(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
conversation=convo,
|
||||
tools=T1_TOOLS,
|
||||
max_tokens=512,
|
||||
temperature=1.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
if not nodes or not turn_results:
|
||||
logger.debug(f" trajectory[{g}]: no nodes/results, skipping")
|
||||
continue
|
||||
|
||||
# One node per trajectory (extending across all turns)
|
||||
node = nodes[0]
|
||||
unmasked = len([t for t in node.masked_tokens if t != -100])
|
||||
if unmasked < 5:
|
||||
logger.debug(
|
||||
f" trajectory[{g}]: only {unmasked} unmasked tokens, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Trajectory reward = average across all turns
|
||||
avg_reward = sum(tr["scores"]["reward"] for tr in turn_results) / len(
|
||||
turn_results
|
||||
)
|
||||
avg_tc_f1 = sum(tr["scores"]["tool_call_f1"] for tr in turn_results) / len(
|
||||
turn_results
|
||||
)
|
||||
avg_tp_f1 = sum(tr["scores"]["tool_param_f1"] for tr in turn_results) / len(
|
||||
turn_results
|
||||
)
|
||||
|
||||
scored["tokens"].append(node.tokens)
|
||||
scored["masks"].append(node.masked_tokens)
|
||||
scored["inference_logprobs"].append(node.logprobs)
|
||||
scored["scores"].append(avg_reward)
|
||||
|
||||
self.reward_buffer.append(avg_reward)
|
||||
self.tc_f1_buffer.append(avg_tc_f1)
|
||||
self.tp_f1_buffer.append(avg_tp_f1)
|
||||
|
||||
logger.info(
|
||||
f" trajectory[{g}]: {len(turn_results)} turns, "
|
||||
f"{len(node.tokens)} tokens, reward={avg_reward:.3f}"
|
||||
)
|
||||
|
||||
if not scored["tokens"]:
|
||||
logger.info(" -> None (no valid trajectories)")
|
||||
return None, []
|
||||
if all(s == scored["scores"][0] for s in scored["scores"]):
|
||||
logger.info(f" -> None (all scores identical: {scored['scores'][0]:.3f})")
|
||||
return None, []
|
||||
|
||||
logger.info(
|
||||
f" -> valid group: {len(scored['tokens'])} trajectories, scores={[f'{s:.3f}' for s in scored['scores']]}"
|
||||
)
|
||||
return scored, []
|
||||
|
||||
async def get_next_item(self):
|
||||
convo = self.train_conversations[self.iter % len(self.train_conversations)]
|
||||
self.iter += 1
|
||||
logger.debug(f"get_next_item: iter={self.iter}")
|
||||
return convo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
T1ToolPlanningEnv.cli()
|
||||
20
environments/t1_tool_planning/t1_prompts.py
Normal file
20
environments/t1_tool_planning/t1_prompts.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""
|
||||
System prompt and few-shot examples for T1 tool planning.
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT = """\
|
||||
You are an expert travel planning assistant. \
|
||||
You help users search for flights, hotels, restaurants, and attractions.
|
||||
|
||||
You have access to tools for searching, filtering, caching results, and seeking information from the user.
|
||||
|
||||
Important rules:
|
||||
- Only call tools when the user provides enough mandatory information
|
||||
- If mandatory parameters are missing, use seek_information to ask the user
|
||||
- Use save_to_cache after searching to store results for later use
|
||||
- Use get_results_from_cache when you need previously found results
|
||||
- Use filter_* tools to narrow down cached results instead of re-searching
|
||||
- If no new action is needed, respond with text only (no tool calls)
|
||||
- Preserve entity values exactly as the user states them (don't modify case or format)
|
||||
- When the user mentions dates, pass them as-is to the tools
|
||||
"""
|
||||
300
environments/t1_tool_planning/t1_scoring.py
Normal file
300
environments/t1_tool_planning/t1_scoring.py
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
"""
|
||||
Scoring for T1 tool planning environment.
|
||||
|
||||
Parses ground truth Python code with AST to extract tool calls,
|
||||
then compares against structured tool_calls from the model response.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that T1 considers "real" (not just print or cache ops)
|
||||
SEARCH_FILTER_TOOLS = {
|
||||
"search_hotels",
|
||||
"filter_hotels",
|
||||
"search_flights",
|
||||
"filter_flights",
|
||||
"search_restaurants",
|
||||
"filter_restaurants",
|
||||
"search_attractions",
|
||||
"filter_attractions",
|
||||
"search_nearest",
|
||||
"sort_results",
|
||||
"adjust_date",
|
||||
"seek_information",
|
||||
}
|
||||
|
||||
CACHE_TOOLS = {"save_to_cache", "get_results_from_cache"}
|
||||
ALL_TOOLS = SEARCH_FILTER_TOOLS | CACHE_TOOLS
|
||||
|
||||
|
||||
def parse_ground_truth_code(code: str) -> List[Dict[str, Any]]:
|
||||
"""Parse ground truth Python code into a list of tool calls.
|
||||
|
||||
Args:
|
||||
code: Python code string from Filled_Plan column.
|
||||
|
||||
Returns:
|
||||
List of {"name": str, "arguments": dict} dicts.
|
||||
"""
|
||||
if not code or not isinstance(code, str):
|
||||
return []
|
||||
|
||||
code = code.strip()
|
||||
if not code or code == 'print("No planning needed")':
|
||||
return []
|
||||
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
logger.warning("Failed to parse ground truth code: %s", code[:100])
|
||||
return []
|
||||
|
||||
calls = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Call):
|
||||
# Get function name
|
||||
if isinstance(node.func, ast.Name):
|
||||
func_name = node.func.id
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
func_name = node.func.attr
|
||||
else:
|
||||
continue
|
||||
|
||||
if func_name not in ALL_TOOLS:
|
||||
continue
|
||||
|
||||
# Extract keyword arguments
|
||||
args = {}
|
||||
for kw in node.keywords:
|
||||
if kw.arg is None:
|
||||
continue
|
||||
try:
|
||||
args[kw.arg] = ast.literal_eval(kw.value)
|
||||
except (ValueError, TypeError):
|
||||
# For variable references (like prior_results=hotels),
|
||||
# store as string
|
||||
if isinstance(kw.value, ast.Name):
|
||||
args[kw.arg] = kw.value.id
|
||||
else:
|
||||
args[kw.arg] = ast.dump(kw.value)
|
||||
|
||||
calls.append({"name": func_name, "arguments": args})
|
||||
|
||||
return calls
|
||||
|
||||
|
||||
def parse_model_tool_calls(tool_calls: Optional[List[dict]]) -> List[Dict[str, Any]]:
|
||||
"""Normalize model's structured tool_calls into comparable format.
|
||||
|
||||
Args:
|
||||
tool_calls: List of tool call dicts from ChatCompletion response.
|
||||
|
||||
Returns:
|
||||
List of {"name": str, "arguments": dict} dicts.
|
||||
"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
name = func.get("name", "")
|
||||
args_str = func.get("arguments", "{}")
|
||||
try:
|
||||
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
result.append({"name": name, "arguments": args})
|
||||
return result
|
||||
|
||||
|
||||
def tool_call_f1(
|
||||
ground_truth: List[Dict[str, Any]],
|
||||
generated: List[Dict[str, Any]],
|
||||
) -> Tuple[float, float, float]:
|
||||
"""Compute precision, recall, F1 on tool names.
|
||||
|
||||
Compares the multiset of tool names called.
|
||||
"""
|
||||
gt_names = Counter(c["name"] for c in ground_truth)
|
||||
gen_names = Counter(c["name"] for c in generated)
|
||||
|
||||
# Remove "print" if other tools exist
|
||||
if len(gt_names) > 1:
|
||||
gt_names.pop("print", None)
|
||||
if len(gen_names) > 1:
|
||||
gen_names.pop("print", None)
|
||||
|
||||
if not gt_names and not gen_names:
|
||||
return 1.0, 1.0, 1.0 # both empty = correct
|
||||
|
||||
tp = sum((gt_names & gen_names).values())
|
||||
precision = tp / sum(gen_names.values()) if gen_names else 0.0
|
||||
recall = tp / sum(gt_names.values()) if gt_names else 0.0
|
||||
f1 = (
|
||||
2 * precision * recall / (precision + recall)
|
||||
if (precision + recall) > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return precision, recall, f1
|
||||
|
||||
|
||||
def tool_param_f1(
|
||||
ground_truth: List[Dict[str, Any]],
|
||||
generated: List[Dict[str, Any]],
|
||||
) -> Tuple[float, float, float]:
|
||||
"""Compute precision, recall, F1 on tool parameters.
|
||||
|
||||
For each matching tool name pair, compares the argument keys and values.
|
||||
"""
|
||||
if not ground_truth and not generated:
|
||||
return 1.0, 1.0, 1.0
|
||||
|
||||
# Match tool calls by name (greedy matching)
|
||||
gt_remaining = list(ground_truth)
|
||||
matched_pairs = []
|
||||
|
||||
for gen_call in generated:
|
||||
for i, gt_call in enumerate(gt_remaining):
|
||||
if gt_call["name"] == gen_call["name"]:
|
||||
matched_pairs.append((gt_call, gen_call))
|
||||
gt_remaining.pop(i)
|
||||
break
|
||||
|
||||
if not matched_pairs:
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
total_tp = 0
|
||||
total_gt_params = 0
|
||||
total_gen_params = 0
|
||||
|
||||
for gt_call, gen_call in matched_pairs:
|
||||
gt_args = gt_call.get("arguments", {})
|
||||
gen_args = gen_call.get("arguments", {})
|
||||
|
||||
total_gt_params += len(gt_args)
|
||||
total_gen_params += len(gen_args)
|
||||
|
||||
for key in gt_args:
|
||||
if key in gen_args:
|
||||
# Loose comparison — normalize types
|
||||
gt_val = gt_args[key]
|
||||
gen_val = gen_args[key]
|
||||
if _values_match(gt_val, gen_val):
|
||||
total_tp += 1
|
||||
|
||||
precision = total_tp / total_gen_params if total_gen_params > 0 else 0.0
|
||||
recall = total_tp / total_gt_params if total_gt_params > 0 else 0.0
|
||||
f1 = (
|
||||
2 * precision * recall / (precision + recall)
|
||||
if (precision + recall) > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return precision, recall, f1
|
||||
|
||||
|
||||
def _values_match(gt_val: Any, gen_val: Any) -> bool:
|
||||
"""Loose comparison of argument values."""
|
||||
if gt_val == gen_val:
|
||||
return True
|
||||
|
||||
# String comparison (case-insensitive)
|
||||
if isinstance(gt_val, str) and isinstance(gen_val, str):
|
||||
return gt_val.lower().strip() == gen_val.lower().strip()
|
||||
|
||||
# List comparison (order-insensitive for some)
|
||||
if isinstance(gt_val, list) and isinstance(gen_val, list):
|
||||
if len(gt_val) == len(gen_val):
|
||||
return sorted(str(v) for v in gt_val) == sorted(str(v) for v in gen_val)
|
||||
|
||||
# Number comparison
|
||||
try:
|
||||
if float(gt_val) == float(gen_val):
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def score_turn(
|
||||
ground_truth_code: str,
|
||||
model_tool_calls: Optional[List[dict]],
|
||||
model_content: Optional[str] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Score a single turn.
|
||||
|
||||
Args:
|
||||
ground_truth_code: Python code from Filled_Plan column.
|
||||
model_tool_calls: Structured tool_calls from model response.
|
||||
model_content: Text content from model response (for seek_information).
|
||||
|
||||
Returns:
|
||||
Dict with scores: tool_call_f1, tool_param_f1, correct_no_op, reward
|
||||
"""
|
||||
gt_calls = parse_ground_truth_code(ground_truth_code)
|
||||
gen_calls = parse_model_tool_calls(model_tool_calls)
|
||||
|
||||
# Handle "no planning needed" case
|
||||
gt_is_noop = len(gt_calls) == 0
|
||||
gen_is_noop = len(gen_calls) == 0
|
||||
|
||||
if gt_is_noop and gen_is_noop:
|
||||
return {
|
||||
"tool_call_f1": 1.0,
|
||||
"tool_param_f1": 1.0,
|
||||
"correct_no_op": 1.0,
|
||||
"reward": 1.0,
|
||||
}
|
||||
|
||||
if gt_is_noop and not gen_is_noop:
|
||||
return {
|
||||
"tool_call_f1": 0.0,
|
||||
"tool_param_f1": 0.0,
|
||||
"correct_no_op": 0.0,
|
||||
"reward": 0.0,
|
||||
}
|
||||
|
||||
# Check if this is a seek_information turn
|
||||
gt_is_seek = any(c["name"] == "seek_information" for c in gt_calls)
|
||||
|
||||
tc_p, tc_r, tc_f1 = tool_call_f1(gt_calls, gen_calls)
|
||||
tp_p, tp_r, tp_f1 = tool_param_f1(gt_calls, gen_calls)
|
||||
|
||||
# Composite reward — graduated so GRPO gets signal even with weak models
|
||||
#
|
||||
# GT expects tools but model produced none → 0.0 (worst)
|
||||
# GT expects tools, model called tools but wrong ones → 0.1 (format credit)
|
||||
# GT expects tools, model called some right ones → 0.1 + 0.5*tc_f1 + 0.3*tp_f1
|
||||
# Perfect match → 0.1 + 0.5 + 0.3 + 0.1 = 1.0
|
||||
if not gt_is_noop and gen_is_noop:
|
||||
# GT expects tool calls but model produced none
|
||||
reward = 0.0
|
||||
else:
|
||||
# Model attempted tool calls
|
||||
reward = 0.1 # format credit: produced valid tool call structure
|
||||
reward += 0.5 * tc_f1
|
||||
reward += 0.3 * tp_f1
|
||||
# Bonus for getting all tools right
|
||||
if tc_f1 == 1.0:
|
||||
reward += 0.1
|
||||
|
||||
return {
|
||||
"tool_call_precision": tc_p,
|
||||
"tool_call_recall": tc_r,
|
||||
"tool_call_f1": tc_f1,
|
||||
"tool_param_precision": tp_p,
|
||||
"tool_param_recall": tp_r,
|
||||
"tool_param_f1": tp_f1,
|
||||
"correct_no_op": 0.0,
|
||||
"is_seek_info": float(gt_is_seek),
|
||||
"reward": reward,
|
||||
}
|
||||
288
environments/t1_tool_planning/t1_tools.py
Normal file
288
environments/t1_tool_planning/t1_tools.py
Normal file
|
|
@ -0,0 +1,288 @@
|
|||
"""
|
||||
OpenAI function definitions for T1's 14 travel planning tools.
|
||||
|
||||
These are passed to managed_server.chat_completion(tools=T1_TOOLS)
|
||||
so the model uses proper tool calling instead of raw code generation.
|
||||
"""
|
||||
|
||||
T1_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_hotels",
|
||||
"description": "Search for hotels in a city. Requires city, checkin_date, checkout_date.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"checkin_date": {"type": "array", "items": {"type": "string"}},
|
||||
"checkout_date": {"type": "array", "items": {"type": "string"}},
|
||||
"num_rooms": {"type": "integer"},
|
||||
"num_people": {"type": "integer"},
|
||||
"neighborhood": {"type": "array", "items": {"type": "string"}},
|
||||
"hotel_name": {"type": "array", "items": {"type": "string"}},
|
||||
"budget": {"type": "integer"},
|
||||
"rating": {"type": "array", "items": {"type": "number"}},
|
||||
"stars": {"type": "array", "items": {"type": "integer"}},
|
||||
"free_wifi_included": {"type": "boolean"},
|
||||
"breakfast_included": {"type": "boolean"},
|
||||
"gym_present": {"type": "boolean"},
|
||||
"pool_present": {"type": "boolean"},
|
||||
"is_pet_friendly": {"type": "boolean"},
|
||||
"has_spa_services": {"type": "boolean"},
|
||||
"smoking_allowed": {"type": "boolean"},
|
||||
"is_wheelchair_accessible": {"type": "boolean"},
|
||||
"has_free_parking": {"type": "boolean"},
|
||||
"airport_shuttle_present": {"type": "boolean"},
|
||||
},
|
||||
"required": ["city", "checkin_date", "checkout_date"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "filter_hotels",
|
||||
"description": "Filter previously searched hotel results by additional criteria.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prior_results": {
|
||||
"type": "string",
|
||||
"description": "Variable name of prior results",
|
||||
},
|
||||
"neighborhood": {"type": "array", "items": {"type": "string"}},
|
||||
"budget": {"type": "integer"},
|
||||
"rating": {"type": "array", "items": {"type": "number"}},
|
||||
"stars": {"type": "array", "items": {"type": "integer"}},
|
||||
"free_wifi_included": {"type": "boolean"},
|
||||
"breakfast_included": {"type": "boolean"},
|
||||
"gym_present": {"type": "boolean"},
|
||||
"pool_present": {"type": "boolean"},
|
||||
"is_pet_friendly": {"type": "boolean"},
|
||||
},
|
||||
"required": ["prior_results"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_flights",
|
||||
"description": "Search for flights. Requires departure_date and origin/destination.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start_airport_city": {"type": "string"},
|
||||
"end_airport_city": {"type": "string"},
|
||||
"departure_date": {"type": "array", "items": {"type": "string"}},
|
||||
"arrival_date": {"type": "array", "items": {"type": "string"}},
|
||||
"airline": {"type": "array", "items": {"type": "string"}},
|
||||
"budget": {"type": "integer"},
|
||||
"flight_class": {"type": "array", "items": {"type": "string"}},
|
||||
"num_layovers": {"type": "array", "items": {"type": "integer"}},
|
||||
},
|
||||
"required": ["departure_date"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "filter_flights",
|
||||
"description": "Filter previously searched flight results.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prior_results": {"type": "string"},
|
||||
"airline": {"type": "array", "items": {"type": "string"}},
|
||||
"budget": {"type": "integer"},
|
||||
"flight_class": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["prior_results"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_restaurants",
|
||||
"description": "Search for restaurants in a city.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"cuisine": {"type": "array", "items": {"type": "string"}},
|
||||
"rating": {"type": "array", "items": {"type": "number"}},
|
||||
"neighborhood": {"type": "array", "items": {"type": "string"}},
|
||||
"price_range": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "filter_restaurants",
|
||||
"description": "Filter previously searched restaurant results.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prior_results": {"type": "string"},
|
||||
"cuisine": {"type": "array", "items": {"type": "string"}},
|
||||
"rating": {"type": "array", "items": {"type": "number"}},
|
||||
"neighborhood": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["prior_results"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_attractions",
|
||||
"description": "Search for attractions in a city.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"type": {"type": "array", "items": {"type": "string"}},
|
||||
"neighborhood": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "filter_attractions",
|
||||
"description": "Filter previously searched attraction results.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prior_results": {"type": "string"},
|
||||
"type": {"type": "array", "items": {"type": "string"}},
|
||||
"neighborhood": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["prior_results"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "save_to_cache",
|
||||
"description": "Save results to cache with a unique key for later retrieval.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {"type": "string", "description": "Unique cache key"},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Variable name of results to cache",
|
||||
},
|
||||
},
|
||||
"required": ["key", "value"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_results_from_cache",
|
||||
"description": "Retrieve previously cached results by key.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {"type": "string", "description": "Cache key to retrieve"},
|
||||
},
|
||||
"required": ["key"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sort_results",
|
||||
"description": "Sort results by a specific field.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"results": {
|
||||
"type": "string",
|
||||
"description": "Variable name of results to sort",
|
||||
},
|
||||
"sort_by": {
|
||||
"type": "string",
|
||||
"description": "Field to sort by (e.g. price, rating)",
|
||||
},
|
||||
"ascending": {"type": "boolean"},
|
||||
},
|
||||
"required": ["results", "sort_by"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seek_information",
|
||||
"description": "Ask the user for missing mandatory information before calling a tool.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Question to ask the user",
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "adjust_date",
|
||||
"description": "Adjust a date by a number of days.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date": {"type": "string"},
|
||||
"days": {"type": "integer"},
|
||||
},
|
||||
"required": ["date", "days"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_nearest",
|
||||
"description": "Find nearest locations between two sets of results (e.g. hotels near restaurants).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"hotels": {
|
||||
"type": "string",
|
||||
"description": "Variable name of hotel results",
|
||||
},
|
||||
"restaurants": {
|
||||
"type": "string",
|
||||
"description": "Variable name of restaurant results",
|
||||
},
|
||||
"attractions": {
|
||||
"type": "string",
|
||||
"description": "Variable name of attraction results",
|
||||
},
|
||||
"groupBy": {
|
||||
"type": "string",
|
||||
"description": "Group results by this entity type",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
294
environments/t1_tool_planning/test_t1_live.py
Normal file
294
environments/t1_tool_planning/test_t1_live.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Live test for T1 tool planning — runs against an already-running vLLM server.
|
||||
|
||||
No pytest fixtures, no subprocess spawning. Just creates a ServerManager
|
||||
pointed at localhost:9001, calls generate_tool_completions, and prints results.
|
||||
|
||||
Usage:
|
||||
# With vLLM already running on port 9001:
|
||||
python environments/t1_tool_planning/test_t1_live.py
|
||||
|
||||
# Custom port:
|
||||
python environments/t1_tool_planning/test_t1_live.py --port 8123
|
||||
|
||||
# Custom model:
|
||||
python environments/t1_tool_planning/test_t1_live.py --model Qwen/Qwen3-4B
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Ensure t1 modules are importable
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from t1_core import generate_tool_completions, score_completions # noqa: E402
|
||||
from t1_prompts import SYSTEM_PROMPT # noqa: E402
|
||||
from t1_scoring import score_turn # noqa: E402
|
||||
from t1_tools import T1_TOOLS # noqa: E402
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("test_t1_live")
|
||||
|
||||
|
||||
def make_server_manager(model_name: str, base_url: str):
|
||||
"""Create a ServerManager pointed at an existing vLLM server."""
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
config = APIServerConfig(
|
||||
model_name=model_name,
|
||||
base_url=base_url,
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
server = ServerManager(
|
||||
configs=[config],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
return server
|
||||
|
||||
|
||||
def make_tokenizer(model_name: str):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
return AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
|
||||
SAMPLE_CONVERSATIONS = {
|
||||
1: [
|
||||
{
|
||||
"Role": "assistant",
|
||||
"Filled_Template": "Hello! I'm your travel assistant. How can I help you today?",
|
||||
"Filled_Plan": "",
|
||||
},
|
||||
{
|
||||
"Role": "user",
|
||||
"Filled_Template": "I'm looking for hotels in Austin with check-in on May 10, 2025 and check-out on May 15, 2025.", # noqa: E501
|
||||
"Filled_Plan": 'hotels = search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])\nsave_to_cache(key="hotels", value=hotels)', # noqa: E501
|
||||
},
|
||||
],
|
||||
2: [
|
||||
{
|
||||
"Role": "assistant",
|
||||
"Filled_Template": "Welcome! What can I help you plan?",
|
||||
"Filled_Plan": "",
|
||||
},
|
||||
{
|
||||
"Role": "user",
|
||||
"Filled_Template": "I need a hotel in New York but I'm not sure about dates yet.",
|
||||
"Filled_Plan": 'seek_information("We need to ask for the check-in and check-out dates")',
|
||||
},
|
||||
],
|
||||
3: [
|
||||
{
|
||||
"Role": "assistant",
|
||||
"Filled_Template": "Hi there! Looking for travel help?",
|
||||
"Filled_Plan": "",
|
||||
},
|
||||
{
|
||||
"Role": "user",
|
||||
"Filled_Template": "No that's perfect, thanks!",
|
||||
"Filled_Plan": 'print("No planning needed")',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def build_messages(conversation: list, turn_index: int) -> list:
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
for i, turn in enumerate(conversation):
|
||||
if i > turn_index:
|
||||
break
|
||||
role = turn["Role"].strip().lower()
|
||||
messages.append({"role": role, "content": turn["Filled_Template"]})
|
||||
return messages
|
||||
|
||||
|
||||
async def test_single_completion(server, tokenizer):
|
||||
"""Test 1: Single completion with tool calling."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 1: Single tool-calling completion")
|
||||
print("=" * 60)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Find me hotels in Austin, checking in May 10 and out May 15, 2025.",
|
||||
},
|
||||
]
|
||||
|
||||
result, nodes = await generate_tool_completions(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
messages=messages,
|
||||
tools=T1_TOOLS,
|
||||
n=1,
|
||||
max_tokens=500,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
choice = result.choices[0]
|
||||
print(f"\nContent: {choice.message.content}")
|
||||
print(f"Tool calls: {choice.message.tool_calls}")
|
||||
print(f"Finish reason: {choice.finish_reason}")
|
||||
print(f"Nodes tracked: {len(nodes)}")
|
||||
|
||||
if nodes:
|
||||
node = nodes[0]
|
||||
print(f"Token count: {len(node.tokens)}")
|
||||
unmasked = len([t for t in node.masked_tokens if t != -100])
|
||||
print(f"Unmasked tokens: {unmasked}")
|
||||
print(f"Logprobs sample: {node.logprobs[-5:]}")
|
||||
|
||||
# Score against ground truth
|
||||
gt_code = 'hotels = search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])\nsave_to_cache(key="hotels", value=hotels)' # noqa: E501
|
||||
scores = score_turn(gt_code, choice.message.tool_calls, choice.message.content)
|
||||
print(f"\nScores: {json.dumps(scores, indent=2)}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_group_completions(server, tokenizer):
|
||||
"""Test 2: Multiple completions (group_size=4) for GRPO."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 2: Group completions (n=4) for GRPO")
|
||||
print("=" * 60)
|
||||
|
||||
convo = SAMPLE_CONVERSATIONS[1]
|
||||
messages = build_messages(convo, turn_index=1)
|
||||
gt_code = convo[1]["Filled_Plan"]
|
||||
|
||||
result, nodes = await generate_tool_completions(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
messages=messages,
|
||||
tools=T1_TOOLS,
|
||||
n=4,
|
||||
max_tokens=500,
|
||||
temperature=1.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
print(f"\nGot {len(result.choices)} choices, {len(nodes)} nodes")
|
||||
|
||||
for i, choice in enumerate(result.choices):
|
||||
tc_count = len(choice.message.tool_calls) if choice.message.tool_calls else 0
|
||||
content = (choice.message.content or "")[:60]
|
||||
print(f" choice[{i}]: {tc_count} tool_calls, content={content!r}")
|
||||
|
||||
# Score and build ScoredDataGroup
|
||||
scored, all_scores = score_completions(result, nodes, gt_code)
|
||||
|
||||
print("\nPer-choice scores:")
|
||||
for i, s in enumerate(all_scores):
|
||||
print(
|
||||
f" [{i}] reward={s['reward']:.2f} tc_f1={s['tool_call_f1']:.2f} tp_f1={s['tool_param_f1']:.2f}"
|
||||
)
|
||||
|
||||
if scored:
|
||||
print(f"\nScoredDataGroup valid: {len(scored['tokens'])} items")
|
||||
print(f" scores: {scored['scores']}")
|
||||
else:
|
||||
print("\nScoredDataGroup: None (discarded)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_noop_turn(server, tokenizer):
|
||||
"""Test 3: No-op turn (model should NOT call tools)."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 3: No-op turn")
|
||||
print("=" * 60)
|
||||
|
||||
convo = SAMPLE_CONVERSATIONS[3]
|
||||
messages = build_messages(convo, turn_index=1)
|
||||
gt_code = convo[1]["Filled_Plan"]
|
||||
|
||||
result, nodes = await generate_tool_completions(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
messages=messages,
|
||||
tools=T1_TOOLS,
|
||||
n=1,
|
||||
max_tokens=300,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
choice = result.choices[0]
|
||||
print(f"\nContent: {(choice.message.content or '')[:100]}")
|
||||
print(f"Tool calls: {choice.message.tool_calls}")
|
||||
|
||||
scores = score_turn(gt_code, choice.message.tool_calls, choice.message.content)
|
||||
print(f"Scores: {json.dumps(scores, indent=2)}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=9001)
|
||||
parser.add_argument("--model", type=str, default="Qwen/Qwen3-1.7B")
|
||||
args = parser.parse_args()
|
||||
|
||||
base_url = f"http://localhost:{args.port}/v1"
|
||||
print(f"Connecting to vLLM at {base_url} (model={args.model})")
|
||||
|
||||
# Check health first
|
||||
import requests
|
||||
|
||||
try:
|
||||
resp = requests.get(f"http://localhost:{args.port}/health", timeout=5)
|
||||
print(f"vLLM health: {resp.status_code}")
|
||||
if resp.status_code != 200:
|
||||
print("ERROR: vLLM not healthy!")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"ERROR: Can't reach vLLM: {e}")
|
||||
print(
|
||||
"Make sure vLLM is running: bash environments/t1_tool_planning/run_vllm.sh"
|
||||
)
|
||||
return
|
||||
|
||||
server = make_server_manager(args.model, base_url)
|
||||
tokenizer = make_tokenizer(args.model)
|
||||
|
||||
print(f"ServerManager created with {len(server.servers)} server(s)")
|
||||
print(f"Server type: {type(server.servers[0]).__name__}")
|
||||
print(f"Tool parser: {server.tool_parser}")
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_fn in [test_single_completion, test_group_completions, test_noop_turn]:
|
||||
try:
|
||||
ok = await test_fn(server, tokenizer)
|
||||
if ok:
|
||||
passed += 1
|
||||
print("\n ✓ PASSED")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
print(f"\n ✗ FAILED: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"RESULTS: {passed} passed, {failed} failed")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
743
environments/t1_tool_planning/test_t1_standalone.py
Normal file
743
environments/t1_tool_planning/test_t1_standalone.py
Normal file
|
|
@ -0,0 +1,743 @@
|
|||
"""
|
||||
Standalone test for T1 tool planning environment.
|
||||
|
||||
Spins up vllm_api_server, creates a ServerManager with tool_parser="hermes",
|
||||
and runs through T1 conversations end-to-end using the extracted t1_core functions.
|
||||
|
||||
Tests the full tool calling infrastructure:
|
||||
ServerManager → ManagedServer → ToolCallTranslator → vLLM hermes parser
|
||||
→ structured tool_calls → scoring against T1 ground truth
|
||||
|
||||
Usage:
|
||||
# With GPU (spins up vLLM):
|
||||
pytest --run-gpu environments/t1_tool_planning/test_t1_standalone.py -v -s
|
||||
|
||||
# Scoring logic only (no GPU):
|
||||
pytest environments/t1_tool_planning/test_t1_standalone.py -v -k "not gpu"
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
# -- T1 env imports --
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from t1_core import ( # noqa: E402
|
||||
collect_multistep_trajectory,
|
||||
generate_tool_completions,
|
||||
score_completions,
|
||||
)
|
||||
from t1_prompts import SYSTEM_PROMPT # noqa: E402
|
||||
from t1_scoring import ( # noqa: E402
|
||||
parse_ground_truth_code,
|
||||
parse_model_tool_calls,
|
||||
score_turn,
|
||||
tool_call_f1,
|
||||
tool_param_f1,
|
||||
)
|
||||
from t1_tools import T1_TOOLS # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VLLM_PORT = 8123
|
||||
VLLM_MODEL = "Qwen/Qwen3-1.7B"
|
||||
VLLM_BASE_URL = f"http://localhost:{VLLM_PORT}/v1"
|
||||
REPO_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||
VLLM_SCRIPT = os.path.join(REPO_ROOT, "example_trainer", "vllm_api_server.py")
|
||||
|
||||
# Sample T1 data — small hotel conversations for testing
|
||||
SAMPLE_T1_CONVERSATIONS = {
|
||||
1: [
|
||||
{
|
||||
"Role": "assistant",
|
||||
"Filled_Template": "Hello! I'm your travel assistant. How can I help you today?",
|
||||
"Filled_Plan": "",
|
||||
},
|
||||
{
|
||||
"Role": "user",
|
||||
"Filled_Template": "I'm looking for hotels in Austin with check-in on May 10, 2025 and check-out on May 15, 2025.", # noqa: E501
|
||||
"Filled_Plan": 'hotels = search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])\nsave_to_cache(key="hotels", value=hotels)', # noqa: E501
|
||||
},
|
||||
{
|
||||
"Role": "assistant",
|
||||
"Filled_Template": "I found several hotels in Austin for those dates. Would you like to filter by any specific amenities?", # noqa: E501
|
||||
"Filled_Plan": "",
|
||||
},
|
||||
{
|
||||
"Role": "user",
|
||||
"Filled_Template": "Yes, I need one with free wifi and a gym.",
|
||||
"Filled_Plan": 'hotels = get_results_from_cache(key="hotels")\nfiltered_hotels = filter_hotels(prior_results=hotels, free_wifi_included=True, gym_present=True)\nsave_to_cache(key="filtered_hotels", value=filtered_hotels)', # noqa: E501
|
||||
},
|
||||
{
|
||||
"Role": "assistant",
|
||||
"Filled_Template": "Here are hotels with free wifi and gym. Anything else?",
|
||||
"Filled_Plan": "",
|
||||
},
|
||||
{
|
||||
"Role": "user",
|
||||
"Filled_Template": "No that's perfect, thanks!",
|
||||
"Filled_Plan": 'print("No planning needed")',
|
||||
},
|
||||
],
|
||||
2: [
|
||||
{
|
||||
"Role": "assistant",
|
||||
"Filled_Template": "Welcome! What can I help you plan?",
|
||||
"Filled_Plan": "",
|
||||
},
|
||||
{
|
||||
"Role": "user",
|
||||
"Filled_Template": "I need a hotel in New York but I'm not sure about dates yet.",
|
||||
"Filled_Plan": 'seek_information("We need to ask for the check-in and check-out dates")',
|
||||
},
|
||||
{
|
||||
"Role": "assistant",
|
||||
"Filled_Template": "Sure! When would you like to check in and check out?",
|
||||
"Filled_Plan": "",
|
||||
},
|
||||
{
|
||||
"Role": "user",
|
||||
"Filled_Template": "Check in June 1 and check out June 5, 2025.",
|
||||
"Filled_Plan": 'hotels = search_hotels(city="New York", checkin_date=["June 1, 2025"], checkout_date=["June 5, 2025"])\nsave_to_cache(key="hotels", value=hotels)', # noqa: E501
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def load_sample_data() -> dict:
|
||||
"""Load the sample T1 conversations."""
|
||||
return SAMPLE_T1_CONVERSATIONS
|
||||
|
||||
|
||||
def build_messages_for_turn(conversation: list, turn_index: int) -> list:
|
||||
"""Build chat messages up to (and including) the given user turn.
|
||||
|
||||
Uses ground-truth assistant responses for prior turns.
|
||||
"""
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
for i, turn in enumerate(conversation):
|
||||
if i > turn_index:
|
||||
break
|
||||
|
||||
role = turn["Role"].strip().lower()
|
||||
content = turn["Filled_Template"]
|
||||
|
||||
if role == "assistant":
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
elif role == "user":
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scoring-only tests (no GPU needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestT1Scoring:
|
||||
"""Test the scoring logic with known inputs."""
|
||||
|
||||
def test_parse_ground_truth_search(self):
|
||||
code = 'hotels = search_hotels(city="Austin", checkin_date=["May 10"], checkout_date=["May 15"])\nsave_to_cache(key="hotels", value=hotels)' # noqa: E501
|
||||
calls = parse_ground_truth_code(code)
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["name"] == "search_hotels"
|
||||
assert calls[0]["arguments"]["city"] == "Austin"
|
||||
assert calls[1]["name"] == "save_to_cache"
|
||||
|
||||
def test_parse_ground_truth_filter(self):
|
||||
code = 'hotels = get_results_from_cache(key="hotels")\nfiltered = filter_hotels(prior_results=hotels, free_wifi_included=True)' # noqa: E501
|
||||
calls = parse_ground_truth_code(code)
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["name"] == "get_results_from_cache"
|
||||
assert calls[1]["name"] == "filter_hotels"
|
||||
assert calls[1]["arguments"]["free_wifi_included"] is True
|
||||
|
||||
def test_parse_ground_truth_noop(self):
|
||||
code = 'print("No planning needed")'
|
||||
calls = parse_ground_truth_code(code)
|
||||
assert len(calls) == 0
|
||||
|
||||
def test_parse_ground_truth_seek(self):
|
||||
code = 'seek_information("We need check-in dates")'
|
||||
calls = parse_ground_truth_code(code)
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["name"] == "seek_information"
|
||||
|
||||
def test_parse_empty(self):
|
||||
assert parse_ground_truth_code("") == []
|
||||
assert parse_ground_truth_code(None) == []
|
||||
|
||||
def test_tool_call_f1_perfect(self):
|
||||
gt = [
|
||||
{"name": "search_hotels", "arguments": {}},
|
||||
{"name": "save_to_cache", "arguments": {}},
|
||||
]
|
||||
gen = [
|
||||
{"name": "search_hotels", "arguments": {}},
|
||||
{"name": "save_to_cache", "arguments": {}},
|
||||
]
|
||||
p, r, f1 = tool_call_f1(gt, gen)
|
||||
assert f1 == 1.0
|
||||
|
||||
def test_tool_call_f1_partial(self):
|
||||
gt = [
|
||||
{"name": "search_hotels", "arguments": {}},
|
||||
{"name": "save_to_cache", "arguments": {}},
|
||||
]
|
||||
gen = [{"name": "search_hotels", "arguments": {}}]
|
||||
p, r, f1 = tool_call_f1(gt, gen)
|
||||
assert p == 1.0
|
||||
assert r == 0.5
|
||||
assert 0 < f1 < 1
|
||||
|
||||
def test_tool_call_f1_wrong(self):
|
||||
gt = [{"name": "search_hotels", "arguments": {}}]
|
||||
gen = [{"name": "search_flights", "arguments": {}}]
|
||||
p, r, f1 = tool_call_f1(gt, gen)
|
||||
assert f1 == 0.0
|
||||
|
||||
def test_tool_param_f1_matching(self):
|
||||
gt = [
|
||||
{
|
||||
"name": "search_hotels",
|
||||
"arguments": {"city": "Austin", "checkin_date": ["May 10"]},
|
||||
}
|
||||
]
|
||||
gen = [
|
||||
{
|
||||
"name": "search_hotels",
|
||||
"arguments": {"city": "Austin", "checkin_date": ["May 10"]},
|
||||
}
|
||||
]
|
||||
p, r, f1 = tool_param_f1(gt, gen)
|
||||
assert f1 == 1.0
|
||||
|
||||
def test_tool_param_f1_partial(self):
|
||||
gt = [
|
||||
{
|
||||
"name": "search_hotels",
|
||||
"arguments": {"city": "Austin", "checkin_date": ["May 10"]},
|
||||
}
|
||||
]
|
||||
gen = [{"name": "search_hotels", "arguments": {"city": "Austin"}}]
|
||||
p, r, f1 = tool_param_f1(gt, gen)
|
||||
assert p == 1.0 # what we generated is correct
|
||||
assert r == 0.5 # but we missed one param
|
||||
|
||||
def test_score_turn_noop_correct(self):
|
||||
scores = score_turn('print("No planning needed")', None)
|
||||
assert scores["reward"] == 1.0
|
||||
|
||||
def test_score_turn_noop_wrong(self):
|
||||
# GT says no-op but model called tools
|
||||
fake_calls = [
|
||||
{"function": {"name": "search_hotels", "arguments": '{"city": "X"}'}}
|
||||
]
|
||||
scores = score_turn('print("No planning needed")', fake_calls)
|
||||
assert scores["reward"] == 0.0
|
||||
|
||||
def test_score_turn_tools_expected_none_produced(self):
|
||||
# GT expects tools but model produced none → 0.0
|
||||
gt = 'hotels = search_hotels(city="Austin", checkin_date=["May 10"], checkout_date=["May 15"])'
|
||||
scores = score_turn(gt, None)
|
||||
assert scores["reward"] == 0.0
|
||||
|
||||
def test_score_turn_wrong_tool_gets_format_credit(self):
|
||||
# GT expects search_hotels, model called search_flights → 0.1 (format credit only)
|
||||
gt = 'hotels = search_hotels(city="Austin", checkin_date=["May 10"], checkout_date=["May 15"])'
|
||||
wrong_calls = [
|
||||
{
|
||||
"function": {
|
||||
"name": "search_flights",
|
||||
"arguments": '{"start_airport_city": "X"}',
|
||||
}
|
||||
}
|
||||
]
|
||||
scores = score_turn(gt, wrong_calls)
|
||||
assert scores["reward"] == 0.1 # format credit, no f1 match
|
||||
assert scores["tool_call_f1"] == 0.0
|
||||
|
||||
def test_score_turn_right_tool_higher_than_wrong(self):
|
||||
# Right tool should score higher than wrong tool
|
||||
gt = 'hotels = search_hotels(city="Austin", checkin_date=["May 10"], checkout_date=["May 15"])'
|
||||
right_calls = [
|
||||
{"function": {"name": "search_hotels", "arguments": '{"city": "Austin"}'}}
|
||||
]
|
||||
wrong_calls = [
|
||||
{
|
||||
"function": {
|
||||
"name": "search_flights",
|
||||
"arguments": '{"start_airport_city": "X"}',
|
||||
}
|
||||
}
|
||||
]
|
||||
right_scores = score_turn(gt, right_calls)
|
||||
wrong_scores = score_turn(gt, wrong_calls)
|
||||
assert right_scores["reward"] > wrong_scores["reward"]
|
||||
|
||||
def test_parse_model_tool_calls(self):
|
||||
calls = [
|
||||
{"function": {"name": "search_hotels", "arguments": '{"city": "Austin"}'}},
|
||||
{
|
||||
"function": {
|
||||
"name": "save_to_cache",
|
||||
"arguments": '{"key": "hotels", "value": "hotels"}',
|
||||
}
|
||||
},
|
||||
]
|
||||
parsed = parse_model_tool_calls(calls)
|
||||
assert len(parsed) == 2
|
||||
assert parsed[0]["name"] == "search_hotels"
|
||||
assert parsed[0]["arguments"]["city"] == "Austin"
|
||||
|
||||
def test_sample_data_loads(self):
|
||||
convos = load_sample_data()
|
||||
assert len(convos) == 2
|
||||
assert len(convos[1]) == 6 # 3 assistant + 3 user turns
|
||||
assert len(convos[2]) == 4
|
||||
|
||||
def test_build_messages(self):
|
||||
convos = load_sample_data()
|
||||
# Turn index 1 = first user turn (index 0 is assistant)
|
||||
msgs = build_messages_for_turn(convos[1], turn_index=1)
|
||||
assert msgs[0]["role"] == "system"
|
||||
assert msgs[1]["role"] == "assistant"
|
||||
assert msgs[2]["role"] == "user"
|
||||
assert "Austin" in msgs[2]["content"]
|
||||
|
||||
def test_t1_tools_valid(self):
|
||||
"""Verify all tool definitions have required fields."""
|
||||
for tool in T1_TOOLS:
|
||||
assert tool["type"] == "function"
|
||||
assert "name" in tool["function"]
|
||||
assert "parameters" in tool["function"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GPU integration test — full pipeline with vLLM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vllm_backend():
|
||||
"""Start vLLM api server as a subprocess."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
VLLM_SCRIPT,
|
||||
"--model",
|
||||
VLLM_MODEL,
|
||||
"--port",
|
||||
str(VLLM_PORT),
|
||||
"--gpu-memory-utilization",
|
||||
"0.45",
|
||||
"--enforce-eager",
|
||||
]
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=REPO_ROOT,
|
||||
)
|
||||
|
||||
deadline = time.time() + 180
|
||||
healthy = False
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(f"http://localhost:{VLLM_PORT}/health", timeout=2)
|
||||
if resp.status_code == 200:
|
||||
healthy = True
|
||||
break
|
||||
except (requests.ConnectionError, requests.Timeout):
|
||||
pass
|
||||
if proc.poll() is not None:
|
||||
stdout = proc.stdout.read().decode() if proc.stdout else ""
|
||||
pytest.fail(f"vLLM exited early:\n{stdout[-3000:]}")
|
||||
time.sleep(3)
|
||||
|
||||
if not healthy:
|
||||
proc.kill()
|
||||
stdout = proc.stdout.read().decode() if proc.stdout else ""
|
||||
pytest.fail(f"vLLM didn't start within 180s:\n{stdout[-3000:]}")
|
||||
|
||||
yield proc
|
||||
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_and_tokenizer(vllm_backend):
|
||||
"""Create a ServerManager + tokenizer pointed at the vLLM backend."""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
config = APIServerConfig(
|
||||
model_name=VLLM_MODEL,
|
||||
base_url=VLLM_BASE_URL,
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
server = ServerManager(
|
||||
configs=[config],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
tool_parser="hermes",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(VLLM_MODEL)
|
||||
return server, tokenizer
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestT1FullPipeline:
|
||||
"""End-to-end test: vLLM → ServerManager → ManagedServer → tool calls → scoring."""
|
||||
|
||||
async def test_single_turn_tool_call(self, server_and_tokenizer):
|
||||
"""Model should call search_hotels when given enough info."""
|
||||
server, tokenizer = server_and_tokenizer
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Find me hotels in Austin, checking in May 10 and out May 15, 2025.",
|
||||
},
|
||||
]
|
||||
|
||||
result, nodes = await generate_tool_completions(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
messages=messages,
|
||||
tools=T1_TOOLS,
|
||||
n=1,
|
||||
max_tokens=500,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
choice = result.choices[0]
|
||||
print(f"\nContent: {choice.message.content}")
|
||||
print(f"Tool calls: {choice.message.tool_calls}")
|
||||
print(f"Finish reason: {choice.finish_reason}")
|
||||
|
||||
# Model should have called at least search_hotels
|
||||
if choice.message.tool_calls:
|
||||
names = [tc["function"]["name"] for tc in choice.message.tool_calls]
|
||||
print(f"Tool names called: {names}")
|
||||
assert "search_hotels" in names or "seek_information" in names
|
||||
|
||||
# Score against ground truth
|
||||
gt_code = 'hotels = search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])\nsave_to_cache(key="hotels", value=hotels)' # noqa: E501
|
||||
scores = score_turn(gt_code, choice.message.tool_calls)
|
||||
print(f"Scores: {json.dumps(scores, indent=2)}")
|
||||
assert scores["reward"] > 0
|
||||
|
||||
# Verify nodes are tracked
|
||||
assert len(nodes) == 1
|
||||
assert len(nodes[0].tokens) > 0
|
||||
assert len(nodes[0].logprobs) == len(nodes[0].tokens)
|
||||
|
||||
async def test_group_completions(self, server_and_tokenizer):
|
||||
"""Generate n=4 completions (GRPO-style) and score them."""
|
||||
server, tokenizer = server_and_tokenizer
|
||||
|
||||
convo = SAMPLE_T1_CONVERSATIONS[1]
|
||||
messages = build_messages_for_turn(convo, turn_index=1)
|
||||
gt_code = convo[1]["Filled_Plan"]
|
||||
|
||||
result, nodes = await generate_tool_completions(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
messages=messages,
|
||||
tools=T1_TOOLS,
|
||||
n=4,
|
||||
max_tokens=500,
|
||||
temperature=1.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
assert len(result.choices) == 4
|
||||
assert len(nodes) == 4
|
||||
|
||||
print(f"\nGot {len(result.choices)} choices:")
|
||||
for i, choice in enumerate(result.choices):
|
||||
tc_count = (
|
||||
len(choice.message.tool_calls) if choice.message.tool_calls else 0
|
||||
)
|
||||
content = (choice.message.content or "")[:60]
|
||||
print(f" choice[{i}]: {tc_count} tool_calls, content={content!r}")
|
||||
|
||||
# Score and build ScoredDataGroup
|
||||
scored, all_scores = score_completions(result, nodes, gt_code)
|
||||
|
||||
print("\nPer-choice scores:")
|
||||
for i, s in enumerate(all_scores):
|
||||
print(
|
||||
f" [{i}] reward={s['reward']:.2f} tc_f1={s['tool_call_f1']:.2f} tp_f1={s['tool_param_f1']:.2f}"
|
||||
)
|
||||
|
||||
# At least some choices should have tokens
|
||||
assert len(all_scores) == 4
|
||||
|
||||
if scored:
|
||||
print(
|
||||
f"\nScoredDataGroup: {len(scored['tokens'])} valid items, scores={scored['scores']}"
|
||||
)
|
||||
# Verify structure
|
||||
assert len(scored["tokens"]) == len(scored["masks"])
|
||||
assert len(scored["tokens"]) == len(scored["scores"])
|
||||
assert len(scored["tokens"]) == len(scored["inference_logprobs"])
|
||||
|
||||
async def test_seek_information(self, server_and_tokenizer):
|
||||
"""Model should ask for missing info when dates aren't provided."""
|
||||
server, tokenizer = server_and_tokenizer
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I need a hotel in New York but I'm not sure about dates.",
|
||||
},
|
||||
]
|
||||
|
||||
result, nodes = await generate_tool_completions(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
messages=messages,
|
||||
tools=T1_TOOLS,
|
||||
n=1,
|
||||
max_tokens=300,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
choice = result.choices[0]
|
||||
print(f"\nContent: {choice.message.content}")
|
||||
print(f"Tool calls: {choice.message.tool_calls}")
|
||||
|
||||
gt_code = (
|
||||
'seek_information("We need to ask for the check-in and check-out dates")'
|
||||
)
|
||||
scores = score_turn(gt_code, choice.message.tool_calls)
|
||||
print(f"Scores: {json.dumps(scores, indent=2)}")
|
||||
|
||||
assert len(nodes) == 1
|
||||
|
||||
async def test_noop_turn(self, server_and_tokenizer):
|
||||
"""No-op turn — model should just respond without tools."""
|
||||
server, tokenizer = server_and_tokenizer
|
||||
|
||||
convo = SAMPLE_T1_CONVERSATIONS[1]
|
||||
# Turn 5 = "No that's perfect, thanks!" → print("No planning needed")
|
||||
messages = build_messages_for_turn(convo, turn_index=5)
|
||||
gt_code = convo[5]["Filled_Plan"]
|
||||
|
||||
result, nodes = await generate_tool_completions(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
messages=messages,
|
||||
tools=T1_TOOLS,
|
||||
n=1,
|
||||
max_tokens=300,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
choice = result.choices[0]
|
||||
print(f"\nContent: {(choice.message.content or '')[:100]}")
|
||||
print(f"Tool calls: {choice.message.tool_calls}")
|
||||
|
||||
scores = score_turn(gt_code, choice.message.tool_calls, choice.message.content)
|
||||
print(f"Scores: {json.dumps(scores, indent=2)}")
|
||||
|
||||
async def test_conversation_walkthrough(self, server_and_tokenizer):
|
||||
"""Walk through a full T1 conversation using GT context, scoring each user turn."""
|
||||
server, tokenizer = server_and_tokenizer
|
||||
convos = load_sample_data()
|
||||
convo = convos[1] # 3-turn hotel conversation
|
||||
|
||||
all_scores = []
|
||||
user_turn_indices = [
|
||||
i for i, t in enumerate(convo) if t["Role"].strip().lower() == "user"
|
||||
]
|
||||
|
||||
for turn_idx in user_turn_indices:
|
||||
turn = convo[turn_idx]
|
||||
messages = build_messages_for_turn(convo, turn_idx)
|
||||
|
||||
result, nodes = await generate_tool_completions(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
messages=messages,
|
||||
tools=T1_TOOLS,
|
||||
n=1,
|
||||
max_tokens=500,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
choice = result.choices[0]
|
||||
gt_code = turn["Filled_Plan"]
|
||||
scores = score_turn(
|
||||
gt_code, choice.message.tool_calls, choice.message.content
|
||||
)
|
||||
|
||||
print(f"\n--- Turn {turn_idx} ---")
|
||||
print(f"User: {turn['Filled_Template'][:80]}...")
|
||||
print(f"GT code: {gt_code[:80]}...")
|
||||
if choice.message.tool_calls:
|
||||
names = [tc["function"]["name"] for tc in choice.message.tool_calls]
|
||||
print(f"Model called: {names}")
|
||||
else:
|
||||
print(f"Model text: {(choice.message.content or '')[:80]}...")
|
||||
print(
|
||||
f"Scores: tc_f1={scores['tool_call_f1']:.2f} tp_f1={scores['tool_param_f1']:.2f} reward={scores['reward']:.2f}" # noqa: E501
|
||||
)
|
||||
|
||||
all_scores.append(scores)
|
||||
|
||||
# Aggregate
|
||||
avg_tc_f1 = sum(s["tool_call_f1"] for s in all_scores) / len(all_scores)
|
||||
avg_tp_f1 = sum(s["tool_param_f1"] for s in all_scores) / len(all_scores)
|
||||
avg_reward = sum(s["reward"] for s in all_scores) / len(all_scores)
|
||||
|
||||
print(f"\n=== AGGREGATE ({len(all_scores)} turns) ===")
|
||||
print(f"Tool Call F1: {avg_tc_f1:.3f}")
|
||||
print(f"Tool Param F1: {avg_tp_f1:.3f}")
|
||||
print(f"Avg Reward: {avg_reward:.3f}")
|
||||
|
||||
async def test_multistep_trajectory(self, server_and_tokenizer):
|
||||
"""Multi-step: walk full conversation feeding model's OWN responses back.
|
||||
|
||||
This is the real end-to-end test. At each turn:
|
||||
1. Model generates a response (possibly with tool_calls)
|
||||
2. That response is fed back as conversation history
|
||||
3. ToolCallTranslator reconstructs raw text from tool_calls
|
||||
4. Chat template re-tokenizes with reconstructed text
|
||||
5. Next turn uses the model's actual history, not GT
|
||||
|
||||
Tests the full bidirectional tool call pipeline.
|
||||
"""
|
||||
server, tokenizer = server_and_tokenizer
|
||||
convos = load_sample_data()
|
||||
convo = convos[1] # 3-turn hotel conversation (6 entries: 3 asst + 3 user)
|
||||
|
||||
turn_results, nodes = await collect_multistep_trajectory(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
conversation=convo,
|
||||
tools=T1_TOOLS,
|
||||
max_tokens=500,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
assert len(turn_results) > 0, "Should have at least one turn result"
|
||||
assert len(nodes) > 0, "Should have tracked nodes"
|
||||
|
||||
print(
|
||||
f"\n=== MULTI-STEP TRAJECTORY ({len(turn_results)} turns, {len(nodes)} nodes) ==="
|
||||
)
|
||||
for tr in turn_results:
|
||||
print(f"\n--- Turn {tr['turn_idx']} ---")
|
||||
print(f"User: {tr['user_message'][:80]}...")
|
||||
print(f"GT: {tr['gt_code'][:80]}...")
|
||||
if tr["tool_calls"]:
|
||||
names = [tc["function"]["name"] for tc in tr["tool_calls"]]
|
||||
print(f"Model called: {names}")
|
||||
else:
|
||||
print(f"Model text: {(tr['content'] or '')[:80]}...")
|
||||
s = tr["scores"]
|
||||
print(
|
||||
f"Scores: tc_f1={s['tool_call_f1']:.2f} tp_f1={s['tool_param_f1']:.2f} reward={s['reward']:.2f}"
|
||||
)
|
||||
print(f"Messages in context: {len(tr['messages_so_far'])}")
|
||||
|
||||
# Verify nodes — each turn should extend the previous
|
||||
print("\nNodes from managed server:")
|
||||
for i, node in enumerate(nodes):
|
||||
unmasked = len([t for t in node.masked_tokens if t != -100])
|
||||
print(f" node[{i}]: {len(node.tokens)} tokens, {unmasked} unmasked")
|
||||
|
||||
# Verify conversation grew correctly
|
||||
user_turns = len(turn_results)
|
||||
last_msg_count = len(turn_results[-1]["messages_so_far"])
|
||||
expected_min = (
|
||||
2 + user_turns
|
||||
) # system + greeting + at least 1 msg per user turn
|
||||
print(
|
||||
f"\nFinal conversation: {last_msg_count} messages (expected >= {expected_min})"
|
||||
)
|
||||
assert last_msg_count >= expected_min
|
||||
|
||||
# Aggregate
|
||||
avg_reward = sum(r["scores"]["reward"] for r in turn_results) / len(
|
||||
turn_results
|
||||
)
|
||||
avg_tc_f1 = sum(r["scores"]["tool_call_f1"] for r in turn_results) / len(
|
||||
turn_results
|
||||
)
|
||||
print(f"\nAvg Reward: {avg_reward:.3f}")
|
||||
print(f"Avg TC F1: {avg_tc_f1:.3f}")
|
||||
|
||||
async def test_multistep_with_tool_history(self, server_and_tokenizer):
|
||||
"""Verify tool_calls in history are properly reconstructed for subsequent turns.
|
||||
|
||||
The critical path: turn N produces tool_calls → turn N+1's prompt must
|
||||
contain those tool_calls in raw text form (e.g. <tool_call> tags) so the
|
||||
tokenizer produces correct tokens.
|
||||
"""
|
||||
server, tokenizer = server_and_tokenizer
|
||||
convos = load_sample_data()
|
||||
convo = convos[2] # 2-turn: seek_information → search_hotels
|
||||
|
||||
turn_results, nodes = await collect_multistep_trajectory(
|
||||
server=server,
|
||||
tokenizer=tokenizer,
|
||||
conversation=convo,
|
||||
tools=T1_TOOLS,
|
||||
max_tokens=500,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n=== TOOL HISTORY TEST ({len(turn_results)} turns, {len(nodes)} nodes) ==="
|
||||
)
|
||||
for tr in turn_results:
|
||||
tc_count = len(tr["tool_calls"]) if tr["tool_calls"] else 0
|
||||
print(
|
||||
f"Turn {tr['turn_idx']}: {tc_count} tool_calls, reward={tr['scores']['reward']:.2f}"
|
||||
)
|
||||
|
||||
# Nodes should show extending — later nodes have more tokens
|
||||
for i, node in enumerate(nodes):
|
||||
print(f" node[{i}]: {len(node.tokens)} total tokens")
|
||||
|
||||
# If turn 0 produced tool_calls, verify turn 1's messages contain them
|
||||
if len(turn_results) >= 2 and turn_results[0]["tool_calls"]:
|
||||
turn1_messages = turn_results[1]["messages_so_far"]
|
||||
assistant_msgs = [m for m in turn1_messages if m.get("role") == "assistant"]
|
||||
has_tool_history = any(m.get("tool_calls") for m in assistant_msgs)
|
||||
print(f"\nTurn 1 has tool_call history in context: {has_tool_history}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue