add env using the tool api stuff

This commit is contained in:
dmahan93 2026-03-03 19:51:30 -06:00
parent c8eb63f33d
commit 12d61d197f
15 changed files with 2632 additions and 21 deletions

View file

@ -0,0 +1,19 @@
env:
tokenizer_name: "Qwen/Qwen3-1.7B"
rollout_server_url: "http://localhost:8002"
max_token_length: 4096
start_tok_length: 4096
group_size: 2
batch_size: 8
total_steps: 200
steps_per_eval: 25
use_wandb: false
wandb_name: "t1-tool-planning-env"
eval_limit_ratio: 0.1
max_num_workers_per_node: 8
openai:
model_name: "Qwen/Qwen3-1.7B"
base_url: "http://localhost:9001/v1"
api_key: "x"
server_type: "vllm"

View file

@ -0,0 +1,164 @@
#!/usr/bin/env python3
"""Inspect multi-step node output to verify extending works correctly."""
import asyncio
import logging
import os
import signal
import subprocess
import sys
import time
import requests
sys.path.insert(0, os.path.dirname(__file__))
logging.basicConfig(level=logging.WARNING)
REPO_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
MODEL = "Qwen/Qwen3-1.7B"
PORT = 8123
def start_vllm():
cmd = [
sys.executable,
"-m",
"example_trainer.vllm_api_server",
"--model",
MODEL,
"--port",
str(PORT),
"--gpu-memory-utilization",
"0.45",
"--enforce-eager",
]
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=REPO_ROOT
)
deadline = time.time() + 180
while time.time() < deadline:
try:
r = requests.get(f"http://localhost:{PORT}/health", timeout=2)
if r.status_code == 200:
print("vLLM ready")
return proc
except Exception:
pass
if proc.poll() is not None:
out = proc.stdout.read().decode()[-2000:]
print(f"vLLM died:\n{out}")
sys.exit(1)
time.sleep(3)
proc.kill()
print("vLLM timeout")
sys.exit(1)
async def main():
from t1_core import collect_multistep_trajectory
from t1_tools import T1_TOOLS
from transformers import AutoTokenizer
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.server_manager import ServerManager
config = APIServerConfig(
model_name=MODEL,
base_url=f"http://localhost:{PORT}/v1",
api_key="x",
server_type="vllm",
)
server = ServerManager(
configs=[config], slurm=False, testing=False, tool_parser="hermes"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
convo = [
{
"Role": "assistant",
"Filled_Template": "Hello! How can I help?",
"Filled_Plan": "",
},
{
"Role": "user",
"Filled_Template": "Find hotels in Austin, check-in May 10, check-out May 15, 2025.",
"Filled_Plan": 'search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])', # noqa: E501
},
{
"Role": "assistant",
"Filled_Template": "Found some. Filter?",
"Filled_Plan": "",
},
{
"Role": "user",
"Filled_Template": "Yes, free wifi.",
"Filled_Plan": "filter_hotels(prior_results=hotels, free_wifi_included=True)",
},
]
turn_results, nodes = await collect_multistep_trajectory(
server=server,
tokenizer=tokenizer,
conversation=convo,
tools=T1_TOOLS,
max_tokens=500,
temperature=0.0,
tool_choice="auto",
)
print(f"\nNodes: {len(nodes)}")
node = nodes[0]
unmasked_idx = [i for i, t in enumerate(node.masked_tokens) if t != -100]
masked_idx = [i for i, t in enumerate(node.masked_tokens) if t == -100]
first_u = unmasked_idx[0] if unmasked_idx else 0
print(
f"Total: {len(node.tokens)} | Masked: {len(masked_idx)} | Unmasked: {len(unmasked_idx)}"
)
# Check contiguity
gaps = []
for j in range(1, len(unmasked_idx)):
if unmasked_idx[j] != unmasked_idx[j - 1] + 1:
gaps.append((unmasked_idx[j - 1], unmasked_idx[j]))
print(f"Unmasked contiguous: {not gaps} Gaps: {gaps}")
# Decode
prompt_text = tokenizer.decode(node.tokens[:first_u], skip_special_tokens=False)
comp_tokens = [node.tokens[i] for i in unmasked_idx]
comp_text = tokenizer.decode(comp_tokens, skip_special_tokens=False)
print("\n--- PROMPT TAIL (last 150 chars) ---")
print(prompt_text[-150:])
print("\n--- COMPLETION (unmasked, first 400 chars) ---")
print(comp_text[:400])
print("\n--- COMPLETION (unmasked, last 200 chars) ---")
print(comp_text[-200:])
print(f"\nPrompt logprobs sample (should be 1.0): {node.logprobs[:3]}")
print(f"Completion logprobs sample: {[node.logprobs[i] for i in unmasked_idx[:5]]}")
for tr in turn_results:
tc = len(tr["tool_calls"]) if tr["tool_calls"] else 0
print(
f"\nTurn {tr['turn_idx']}: {tc} tool_calls, reward={tr['scores']['reward']:.2f}"
)
if tr["tool_calls"]:
for t in tr["tool_calls"]:
print(f" {t['function']['name']}({t['function']['arguments'][:80]})")
else:
print(f" text: {(tr['content'] or '')[:80]}")
if __name__ == "__main__":
proc = start_vllm()
try:
asyncio.run(main())
finally:
proc.send_signal(signal.SIGTERM)
try:
proc.wait(timeout=10)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait()

View file

@ -0,0 +1,297 @@
"""
Core T1 tool planning logic extracted for testability.
These functions do the actual work: generating tool-calling completions
via ManagedServer and scoring them. The env class just orchestrates.
Two modes:
- Single-turn: generate_tool_completions + score_completions
- Multi-step: collect_multistep_trajectory walks a full conversation,
feeding the model's actual responses back at each turn.
"""
import logging
from typing import Any, Dict, List, Optional, Tuple
from t1_prompts import SYSTEM_PROMPT
from t1_scoring import score_turn
from t1_tools import T1_TOOLS
from atroposlib.envs.base import ScoredDataGroup
from atroposlib.envs.server_handling.managed_server import SequenceNode
from atroposlib.envs.server_handling.server_manager import ServerManager
logger = logging.getLogger(__name__)
async def generate_tool_completions(
server: ServerManager,
tokenizer: Any,
messages: List[Dict[str, str]],
tools: List[dict] = None,
n: int = 4,
max_tokens: int = 512,
temperature: float = 1.0,
tool_choice: str = "auto",
split: str = "train",
tool_parser: str = "hermes",
) -> Tuple[Any, List[SequenceNode]]:
"""Generate tool-calling completions and return result + tracked nodes.
Args:
server: ServerManager with backends configured
tokenizer: Tokenizer for the model
messages: Chat messages to complete
tools: OpenAI function tool definitions (defaults to T1_TOOLS)
n: Number of completions to generate
max_tokens: Max tokens per completion
temperature: Sampling temperature
tool_choice: "auto", "none", or "required"
split: "train" or "eval" (for server load balancing)
tool_parser: vLLM tool parser name
Returns:
(ChatCompletion, list of SequenceNodes)
"""
if tools is None:
tools = T1_TOOLS
logger.info(
f"generate_tool_completions: n={n}, max_tokens={max_tokens}, "
f"temp={temperature}, tool_choice={tool_choice}, "
f"num_messages={len(messages)}"
)
async with server.managed_server(tokenizer=tokenizer) as managed:
logger.debug(
f" ManagedServer opened (tool_parser={managed._tool_parser_name})"
)
result = await managed.chat_completion(
messages=messages,
tools=tools,
tool_choice=tool_choice,
n=n,
max_tokens=max_tokens,
temperature=temperature,
split=split,
)
logger.debug(f" Got {len(result.choices)} choices")
for i, c in enumerate(result.choices):
tc_count = len(c.message.tool_calls) if c.message.tool_calls else 0
content_preview = (c.message.content or "")[:80]
logger.debug(
f" choice[{i}]: {tc_count} tool_calls, content={content_preview!r}"
)
state = managed.get_state()
nodes = state["nodes"]
logger.debug(f" Got {len(nodes)} tracked nodes")
return result, nodes
def score_completions(
result: Any,
nodes: List[SequenceNode],
gt_code: str,
min_unmasked_tokens: int = 5,
) -> Tuple[Optional[ScoredDataGroup], List[Dict[str, float]]]:
"""Score completions against ground truth and build a ScoredDataGroup.
Args:
result: ChatCompletion from generate_tool_completions
nodes: SequenceNodes from generate_tool_completions
gt_code: Ground truth Python code (Filled_Plan)
min_unmasked_tokens: Skip choices with fewer unmasked tokens
Returns:
(ScoredDataGroup or None, list of per-choice score dicts)
"""
logger.debug(
f"score_completions: {len(result.choices)} choices, "
f"{len(nodes)} nodes, gt_code={gt_code[:60]}..."
)
all_scores = []
scores = ScoredDataGroup()
scores["tokens"] = []
scores["masks"] = []
scores["scores"] = []
scores["inference_logprobs"] = []
for i, (choice, node) in enumerate(zip(result.choices, nodes)):
turn_scores = score_turn(
gt_code, choice.message.tool_calls, choice.message.content
)
all_scores.append(turn_scores)
logger.debug(f" choice[{i}] scores: {turn_scores}")
unmasked = len([t for t in node.masked_tokens if t != -100])
if unmasked < min_unmasked_tokens:
logger.debug(f" choice[{i}] skipped: only {unmasked} unmasked tokens")
continue
scores["tokens"].append(node.tokens)
scores["masks"].append(node.masked_tokens)
scores["inference_logprobs"].append(node.logprobs)
scores["scores"].append(turn_scores["reward"])
if not scores["tokens"]:
logger.debug(" -> None (no valid tokens)")
return None, all_scores
if all(s == scores["scores"][0] for s in scores["scores"]):
logger.debug(f" -> None (all scores identical: {scores['scores'][0]})")
return None, all_scores
logger.debug(f" -> valid group, scores={scores['scores']}")
return scores, all_scores
async def collect_multistep_trajectory(
server: ServerManager,
tokenizer: Any,
conversation: List[Dict[str, str]],
tools: List[dict] = None,
max_tokens: int = 512,
temperature: float = 0.7,
tool_choice: str = "auto",
tool_parser: str = "hermes",
) -> Tuple[List[Dict[str, Any]], List[SequenceNode]]:
"""Walk through a full conversation in ONE managed_server session.
Uses a single ManagedServer context across all turns so sequence tracking
works properly each turn extends the previous node, building up a full
multi-turn trajectory with aligned tokens and logprobs.
At each user turn:
1. Add the user message to the running conversation
2. Generate a model response (n=1) via the SAME managed server
3. Score against ground truth
4. Add the model's ACTUAL response (not GT) to conversation history
5. Continue to next turn regardless of quality
Nodes are collected ONCE at the end from managed.get_state().
Args:
server: ServerManager with backends configured
tokenizer: Tokenizer for the model
conversation: List of turn dicts with Role, Filled_Template, Filled_Plan
tools: Tool definitions (defaults to T1_TOOLS)
max_tokens: Max tokens per completion
temperature: Sampling temperature
tool_choice: "auto", "none", or "required"
tool_parser: vLLM tool parser name
Returns:
(turn_results, nodes) where:
turn_results: list of per-turn dicts with scores, tool_calls, content
nodes: list of SequenceNodes from the managed server (one per turn,
each extending the previous full trajectory with tokens/logprobs)
"""
if tools is None:
tools = T1_TOOLS
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
turn_results = []
logger.info(
f"collect_multistep_trajectory: {len(conversation)} turns, temp={temperature}"
)
async with server.managed_server(tokenizer=tokenizer) as managed:
for i, turn in enumerate(conversation):
role = turn["Role"].strip().lower()
if role == "assistant":
if not turn_results:
# First assistant turn (greeting) — use GT since model hasn't spoken yet
messages.append(
{"role": "assistant", "content": turn["Filled_Template"]}
)
logger.debug(
f" turn[{i}] assistant (GT greeting): {turn['Filled_Template'][:60]}"
)
# Otherwise skip — we already added the model's response after the previous user turn
continue
if role != "user":
continue
# User turn — add to conversation and generate model response
messages.append({"role": "user", "content": turn["Filled_Template"]})
gt_code = turn.get("Filled_Plan", "")
logger.info(f" turn[{i}] user: {turn['Filled_Template'][:60]}...")
logger.debug(f" turn[{i}] gt_code: {gt_code[:80]}")
# Generate within the SAME managed server session
result = await managed.chat_completion(
messages=messages,
tools=tools,
tool_choice=tool_choice,
n=1,
max_tokens=max_tokens,
temperature=temperature,
)
choice = result.choices[0]
# Score this turn
turn_scores = score_turn(
gt_code, choice.message.tool_calls, choice.message.content
)
tc_count = (
len(choice.message.tool_calls) if choice.message.tool_calls else 0
)
logger.info(
f" turn[{i}] result: {tc_count} tool_calls, "
f"reward={turn_scores['reward']:.2f}, tc_f1={turn_scores['tool_call_f1']:.2f}"
)
turn_results.append(
{
"turn_idx": i,
"user_message": turn["Filled_Template"],
"gt_code": gt_code,
"content": choice.message.content,
"tool_calls": choice.message.tool_calls,
"scores": turn_scores,
"messages_so_far": [m.copy() for m in messages],
}
)
# Add model's ACTUAL response to conversation for next turn
assistant_msg = {"role": "assistant"}
if choice.message.tool_calls:
assistant_msg["tool_calls"] = choice.message.tool_calls
assistant_msg["content"] = choice.message.content or ""
else:
assistant_msg["content"] = choice.message.content or ""
messages.append(assistant_msg)
logger.debug(
f" turn[{i}] added assistant msg to conversation (total: {len(messages)})"
)
# Get nodes ONCE at the end — the managed server tracked extending sequences
state = managed.get_state()
nodes = state["nodes"]
logger.info(
f" trajectory complete: {len(turn_results)} turns, {len(nodes)} nodes"
)
# Summary
if turn_results:
avg_reward = sum(r["scores"]["reward"] for r in turn_results) / len(
turn_results
)
avg_tc_f1 = sum(r["scores"]["tool_call_f1"] for r in turn_results) / len(
turn_results
)
logger.info(f" avg_reward={avg_reward:.3f}, avg_tc_f1={avg_tc_f1:.3f}")
return turn_results, nodes

View file

@ -0,0 +1,168 @@
"""
T1 dataset loader downloads and parses the capitalone/T1 HuggingFace dataset.
Conversations are returned in the format expected by t1_core:
[{"Role": "assistant"|"user", "Filled_Template": str, "Filled_Plan": str}, ...]
"""
import logging
import random
from typing import Dict, List, Optional, Tuple
import pandas as pd
from huggingface_hub import hf_hub_download, list_repo_tree
logger = logging.getLogger(__name__)
REPO_ID = "capitalone/T1"
# Single-domain only for now (simpler tool defs, shorter conversations)
SINGLE_DOMAINS = ["hotel", "flight", "restaurant", "attraction"]
MULTI_DOMAINS = [
"flighthotel",
"hotelrestaurant",
"hotelattraction",
"flighthotelrestaurant",
"flighthotelattraction",
]
ALL_DOMAINS = SINGLE_DOMAINS + MULTI_DOMAINS
def _parse_role(filled_template: str) -> Tuple[str, str]:
"""Extract role and content from 'role: content' format."""
if filled_template.startswith("assistant:"):
return "assistant", filled_template[len("assistant:") :].strip()
elif filled_template.startswith("user:"):
return "user", filled_template[len("user:") :].strip()
else:
# Fallback — try to guess
return "assistant", filled_template.strip()
def _csv_to_conversations(path: str) -> Dict[int, List[dict]]:
"""Parse a T1 CSV file into a dict of conversation_id → turns."""
df = pd.read_csv(path)
conversations = {}
for conv_id, group in df.groupby("ID"):
turns = []
for _, row in group.iterrows():
template = row["Filled_Template"]
plan = row["Filled_Plan"]
# Skip NaN/empty templates
if pd.isna(template) or str(template).strip() in ("", "nan"):
continue
template = str(template)
role, content = _parse_role(template)
# Skip empty content after role extraction
if not content.strip():
continue
# NaN plans become empty string
if pd.isna(plan) or str(plan).strip() == "nan":
plan = ""
else:
plan = str(plan)
turns.append(
{
"Role": role,
"Filled_Template": content,
"Filled_Plan": plan,
}
)
# Only keep conversations with at least one user turn
if turns and any(t["Role"] == "user" for t in turns):
conversations[int(conv_id)] = turns
return conversations
def load_t1_dataset(
domains: Optional[List[str]] = None,
split: str = "train",
max_files_per_domain: Optional[int] = None,
cache_dir: Optional[str] = None,
) -> List[List[dict]]:
"""Load T1 conversations from HuggingFace.
Args:
domains: List of domains to load (default: single-domain only)
split: "train", "test", or "validation"
max_files_per_domain: Limit files per domain (each has 25, ~15 convos each)
cache_dir: HF cache directory
Returns:
List of conversations, each a list of turn dicts
"""
if domains is None:
domains = SINGLE_DOMAINS
all_conversations = []
# List all CSV files for the requested domains/split
repo_files = list(list_repo_tree(REPO_ID, repo_type="dataset", recursive=True))
csv_files = [
f.path for f in repo_files if hasattr(f, "size") and f.path.endswith(".csv")
]
for domain in domains:
prefix = f"{domain}/{split}/"
domain_files = sorted([f for f in csv_files if f.startswith(prefix)])
if max_files_per_domain:
domain_files = domain_files[:max_files_per_domain]
logger.info(f"Loading {len(domain_files)} files from {domain}/{split}")
for file_path in domain_files:
kwargs = {}
if cache_dir:
kwargs["cache_dir"] = cache_dir
local_path = hf_hub_download(
REPO_ID, file_path, repo_type="dataset", **kwargs
)
convos = _csv_to_conversations(local_path)
all_conversations.extend(convos.values())
logger.info(f"Loaded {len(all_conversations)} conversations total")
return all_conversations
def load_t1_split(
domains: Optional[List[str]] = None,
max_files_per_domain: Optional[int] = None,
eval_ratio: float = 0.1,
seed: int = 42,
) -> Tuple[List[List[dict]], List[List[dict]]]:
"""Load T1 train conversations and split into train/eval.
Args:
domains: Domains to load
max_files_per_domain: Limit files per domain
eval_ratio: Fraction of conversations for eval
seed: Random seed for split
Returns:
(train_conversations, eval_conversations)
"""
conversations = load_t1_dataset(
domains=domains,
split="train",
max_files_per_domain=max_files_per_domain,
)
rng = random.Random(seed)
rng.shuffle(conversations)
n_eval = max(1, int(len(conversations) * eval_ratio))
eval_convos = conversations[:n_eval]
train_convos = conversations[n_eval:]
logger.info(f"Split: {len(train_convos)} train, {len(eval_convos)} eval")
return train_convos, eval_convos

View file

@ -0,0 +1,237 @@
"""
T1 Tool-Integrated Reasoning environment for Atropos.
Multi-step tool calling with full trajectory tracking:
- Loads the capitalone/T1 dataset from HuggingFace
- Walks through complete conversations, feeding model's actual responses back
- One ManagedServer session per trajectory one extending node with aligned tokens
- GRPO over group_size independent trajectories per conversation
"""
import logging
from typing import Dict, List, Optional, Tuple
from t1_core import collect_multistep_trajectory # noqa: E402
from t1_data import SINGLE_DOMAINS, load_t1_split # noqa: E402
from t1_tools import T1_TOOLS # noqa: E402
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s"
)
class T1ToolPlanningEnv(BaseEnv):
"""T1 Tool-Integrated Reasoning environment — multi-step trajectories.
Each trajectory walks a full conversation:
- Model generates responses at each user turn
- Model's actual output is fed back (not GT) for the next turn
- One ManagedServer session one extending node per trajectory
- GRPO compares group_size independent trajectories on the same conversation
"""
name = "t1_tool_planning"
def __init__(self, config, server_configs, slurm=True, testing=False):
super().__init__(config, server_configs, slurm, testing)
# BaseEnv doesn't pass tool_parser to ServerManager — set it here
# so ManagedServer creates a ToolCallTranslator for hermes-style tool calls
self.server.tool_parser = "hermes"
self.reward_buffer = []
self.tc_f1_buffer = []
self.tp_f1_buffer = []
self.eval_metrics = []
self.iter = 0
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
model_name = "Qwen/Qwen3-1.7B"
env_config = BaseEnvConfig(
tokenizer_name=model_name,
group_size=4,
use_wandb=True,
rollout_server_url="http://localhost:8002",
total_steps=200,
batch_size=16,
steps_per_eval=25,
max_token_length=4096,
start_tok_length=4096,
wandb_name="t1-tool-planning",
eval_limit_ratio=0.1,
max_num_workers_per_node=8,
)
server_config = APIServerConfig(
model_name=model_name,
base_url="http://localhost:9001/v1",
api_key="x",
server_type="vllm",
)
# MUST return as a list — single APIServerConfig (not in list) causes
# ServerManager to ignore base_url and auto-generate ports 9004-9007
return env_config, [server_config]
async def setup(self):
logger.info("=== T1ToolPlanningEnv.setup() starting ===")
# Load real T1 dataset from HuggingFace
# Start with single-domain, 2 files per domain (~30 convos each = ~120 total)
# Increase max_files_per_domain for more data
self.train_conversations, self.eval_conversations = load_t1_split(
domains=SINGLE_DOMAINS,
eval_ratio=0.1,
)
logger.info(
f"Setup complete: {len(self.train_conversations)} train conversations, "
f"{len(self.eval_conversations)} eval conversations"
)
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
if self.reward_buffer:
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / len(
self.reward_buffer
)
if self.tc_f1_buffer:
wandb_metrics["train/tool_call_f1"] = sum(self.tc_f1_buffer) / len(
self.tc_f1_buffer
)
if self.tp_f1_buffer:
wandb_metrics["train/tool_param_f1"] = sum(self.tp_f1_buffer) / len(
self.tp_f1_buffer
)
self.reward_buffer = []
self.tc_f1_buffer = []
self.tp_f1_buffer = []
for k, v in self.eval_metrics:
wandb_metrics[k] = v
self.eval_metrics = []
await super().wandb_log(wandb_metrics)
async def evaluate(self, *args, **kwargs):
logger.info(
f"=== evaluate() starting ({len(self.eval_conversations)} conversations) ==="
)
all_rewards = []
all_tc_f1 = []
all_tp_f1 = []
for convo in self.eval_conversations:
turn_results, nodes = await collect_multistep_trajectory(
server=self.server,
tokenizer=self.tokenizer,
conversation=convo,
tools=T1_TOOLS,
max_tokens=512,
temperature=0.0,
tool_choice="auto",
)
for tr in turn_results:
all_rewards.append(tr["scores"]["reward"])
all_tc_f1.append(tr["scores"]["tool_call_f1"])
all_tp_f1.append(tr["scores"]["tool_param_f1"])
if all_rewards:
self.eval_metrics.append(
("eval/avg_reward", sum(all_rewards) / len(all_rewards))
)
self.eval_metrics.append(
("eval/tool_call_f1", sum(all_tc_f1) / len(all_tc_f1))
)
self.eval_metrics.append(
("eval/tool_param_f1", sum(all_tp_f1) / len(all_tp_f1))
)
async def collect_trajectories(self, item) -> Tuple[ScoredDataGroup, list]:
convo = item
user_turns = sum(1 for t in convo if t["Role"].strip().lower() == "user")
logger.info(
f"collect_trajectories: {len(convo)} turns ({user_turns} user), group_size={self.config.group_size}"
)
scored = ScoredDataGroup()
scored["tokens"] = []
scored["masks"] = []
scored["scores"] = []
scored["inference_logprobs"] = []
# Run group_size independent trajectories on the same conversation
for g in range(self.config.group_size):
turn_results, nodes = await collect_multistep_trajectory(
server=self.server,
tokenizer=self.tokenizer,
conversation=convo,
tools=T1_TOOLS,
max_tokens=512,
temperature=1.0,
tool_choice="auto",
)
if not nodes or not turn_results:
logger.debug(f" trajectory[{g}]: no nodes/results, skipping")
continue
# One node per trajectory (extending across all turns)
node = nodes[0]
unmasked = len([t for t in node.masked_tokens if t != -100])
if unmasked < 5:
logger.debug(
f" trajectory[{g}]: only {unmasked} unmasked tokens, skipping"
)
continue
# Trajectory reward = average across all turns
avg_reward = sum(tr["scores"]["reward"] for tr in turn_results) / len(
turn_results
)
avg_tc_f1 = sum(tr["scores"]["tool_call_f1"] for tr in turn_results) / len(
turn_results
)
avg_tp_f1 = sum(tr["scores"]["tool_param_f1"] for tr in turn_results) / len(
turn_results
)
scored["tokens"].append(node.tokens)
scored["masks"].append(node.masked_tokens)
scored["inference_logprobs"].append(node.logprobs)
scored["scores"].append(avg_reward)
self.reward_buffer.append(avg_reward)
self.tc_f1_buffer.append(avg_tc_f1)
self.tp_f1_buffer.append(avg_tp_f1)
logger.info(
f" trajectory[{g}]: {len(turn_results)} turns, "
f"{len(node.tokens)} tokens, reward={avg_reward:.3f}"
)
if not scored["tokens"]:
logger.info(" -> None (no valid trajectories)")
return None, []
if all(s == scored["scores"][0] for s in scored["scores"]):
logger.info(f" -> None (all scores identical: {scored['scores'][0]:.3f})")
return None, []
logger.info(
f" -> valid group: {len(scored['tokens'])} trajectories, scores={[f'{s:.3f}' for s in scored['scores']]}"
)
return scored, []
async def get_next_item(self):
convo = self.train_conversations[self.iter % len(self.train_conversations)]
self.iter += 1
logger.debug(f"get_next_item: iter={self.iter}")
return convo
if __name__ == "__main__":
T1ToolPlanningEnv.cli()

View file

@ -0,0 +1,20 @@
"""
System prompt and few-shot examples for T1 tool planning.
"""
SYSTEM_PROMPT = """\
You are an expert travel planning assistant. \
You help users search for flights, hotels, restaurants, and attractions.
You have access to tools for searching, filtering, caching results, and seeking information from the user.
Important rules:
- Only call tools when the user provides enough mandatory information
- If mandatory parameters are missing, use seek_information to ask the user
- Use save_to_cache after searching to store results for later use
- Use get_results_from_cache when you need previously found results
- Use filter_* tools to narrow down cached results instead of re-searching
- If no new action is needed, respond with text only (no tool calls)
- Preserve entity values exactly as the user states them (don't modify case or format)
- When the user mentions dates, pass them as-is to the tools
"""

View file

@ -0,0 +1,300 @@
"""
Scoring for T1 tool planning environment.
Parses ground truth Python code with AST to extract tool calls,
then compares against structured tool_calls from the model response.
"""
import ast
import json
import logging
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# Tools that T1 considers "real" (not just print or cache ops)
SEARCH_FILTER_TOOLS = {
"search_hotels",
"filter_hotels",
"search_flights",
"filter_flights",
"search_restaurants",
"filter_restaurants",
"search_attractions",
"filter_attractions",
"search_nearest",
"sort_results",
"adjust_date",
"seek_information",
}
CACHE_TOOLS = {"save_to_cache", "get_results_from_cache"}
ALL_TOOLS = SEARCH_FILTER_TOOLS | CACHE_TOOLS
def parse_ground_truth_code(code: str) -> List[Dict[str, Any]]:
"""Parse ground truth Python code into a list of tool calls.
Args:
code: Python code string from Filled_Plan column.
Returns:
List of {"name": str, "arguments": dict} dicts.
"""
if not code or not isinstance(code, str):
return []
code = code.strip()
if not code or code == 'print("No planning needed")':
return []
try:
tree = ast.parse(code)
except SyntaxError:
logger.warning("Failed to parse ground truth code: %s", code[:100])
return []
calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
# Get function name
if isinstance(node.func, ast.Name):
func_name = node.func.id
elif isinstance(node.func, ast.Attribute):
func_name = node.func.attr
else:
continue
if func_name not in ALL_TOOLS:
continue
# Extract keyword arguments
args = {}
for kw in node.keywords:
if kw.arg is None:
continue
try:
args[kw.arg] = ast.literal_eval(kw.value)
except (ValueError, TypeError):
# For variable references (like prior_results=hotels),
# store as string
if isinstance(kw.value, ast.Name):
args[kw.arg] = kw.value.id
else:
args[kw.arg] = ast.dump(kw.value)
calls.append({"name": func_name, "arguments": args})
return calls
def parse_model_tool_calls(tool_calls: Optional[List[dict]]) -> List[Dict[str, Any]]:
"""Normalize model's structured tool_calls into comparable format.
Args:
tool_calls: List of tool call dicts from ChatCompletion response.
Returns:
List of {"name": str, "arguments": dict} dicts.
"""
if not tool_calls:
return []
result = []
for tc in tool_calls:
func = tc.get("function", {})
name = func.get("name", "")
args_str = func.get("arguments", "{}")
try:
args = json.loads(args_str) if isinstance(args_str, str) else args_str
except (json.JSONDecodeError, TypeError):
args = {}
result.append({"name": name, "arguments": args})
return result
def tool_call_f1(
ground_truth: List[Dict[str, Any]],
generated: List[Dict[str, Any]],
) -> Tuple[float, float, float]:
"""Compute precision, recall, F1 on tool names.
Compares the multiset of tool names called.
"""
gt_names = Counter(c["name"] for c in ground_truth)
gen_names = Counter(c["name"] for c in generated)
# Remove "print" if other tools exist
if len(gt_names) > 1:
gt_names.pop("print", None)
if len(gen_names) > 1:
gen_names.pop("print", None)
if not gt_names and not gen_names:
return 1.0, 1.0, 1.0 # both empty = correct
tp = sum((gt_names & gen_names).values())
precision = tp / sum(gen_names.values()) if gen_names else 0.0
recall = tp / sum(gt_names.values()) if gt_names else 0.0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
return precision, recall, f1
def tool_param_f1(
ground_truth: List[Dict[str, Any]],
generated: List[Dict[str, Any]],
) -> Tuple[float, float, float]:
"""Compute precision, recall, F1 on tool parameters.
For each matching tool name pair, compares the argument keys and values.
"""
if not ground_truth and not generated:
return 1.0, 1.0, 1.0
# Match tool calls by name (greedy matching)
gt_remaining = list(ground_truth)
matched_pairs = []
for gen_call in generated:
for i, gt_call in enumerate(gt_remaining):
if gt_call["name"] == gen_call["name"]:
matched_pairs.append((gt_call, gen_call))
gt_remaining.pop(i)
break
if not matched_pairs:
return 0.0, 0.0, 0.0
total_tp = 0
total_gt_params = 0
total_gen_params = 0
for gt_call, gen_call in matched_pairs:
gt_args = gt_call.get("arguments", {})
gen_args = gen_call.get("arguments", {})
total_gt_params += len(gt_args)
total_gen_params += len(gen_args)
for key in gt_args:
if key in gen_args:
# Loose comparison — normalize types
gt_val = gt_args[key]
gen_val = gen_args[key]
if _values_match(gt_val, gen_val):
total_tp += 1
precision = total_tp / total_gen_params if total_gen_params > 0 else 0.0
recall = total_tp / total_gt_params if total_gt_params > 0 else 0.0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
return precision, recall, f1
def _values_match(gt_val: Any, gen_val: Any) -> bool:
"""Loose comparison of argument values."""
if gt_val == gen_val:
return True
# String comparison (case-insensitive)
if isinstance(gt_val, str) and isinstance(gen_val, str):
return gt_val.lower().strip() == gen_val.lower().strip()
# List comparison (order-insensitive for some)
if isinstance(gt_val, list) and isinstance(gen_val, list):
if len(gt_val) == len(gen_val):
return sorted(str(v) for v in gt_val) == sorted(str(v) for v in gen_val)
# Number comparison
try:
if float(gt_val) == float(gen_val):
return True
except (ValueError, TypeError):
pass
return False
def score_turn(
ground_truth_code: str,
model_tool_calls: Optional[List[dict]],
model_content: Optional[str] = None,
) -> Dict[str, float]:
"""Score a single turn.
Args:
ground_truth_code: Python code from Filled_Plan column.
model_tool_calls: Structured tool_calls from model response.
model_content: Text content from model response (for seek_information).
Returns:
Dict with scores: tool_call_f1, tool_param_f1, correct_no_op, reward
"""
gt_calls = parse_ground_truth_code(ground_truth_code)
gen_calls = parse_model_tool_calls(model_tool_calls)
# Handle "no planning needed" case
gt_is_noop = len(gt_calls) == 0
gen_is_noop = len(gen_calls) == 0
if gt_is_noop and gen_is_noop:
return {
"tool_call_f1": 1.0,
"tool_param_f1": 1.0,
"correct_no_op": 1.0,
"reward": 1.0,
}
if gt_is_noop and not gen_is_noop:
return {
"tool_call_f1": 0.0,
"tool_param_f1": 0.0,
"correct_no_op": 0.0,
"reward": 0.0,
}
# Check if this is a seek_information turn
gt_is_seek = any(c["name"] == "seek_information" for c in gt_calls)
tc_p, tc_r, tc_f1 = tool_call_f1(gt_calls, gen_calls)
tp_p, tp_r, tp_f1 = tool_param_f1(gt_calls, gen_calls)
# Composite reward — graduated so GRPO gets signal even with weak models
#
# GT expects tools but model produced none → 0.0 (worst)
# GT expects tools, model called tools but wrong ones → 0.1 (format credit)
# GT expects tools, model called some right ones → 0.1 + 0.5*tc_f1 + 0.3*tp_f1
# Perfect match → 0.1 + 0.5 + 0.3 + 0.1 = 1.0
if not gt_is_noop and gen_is_noop:
# GT expects tool calls but model produced none
reward = 0.0
else:
# Model attempted tool calls
reward = 0.1 # format credit: produced valid tool call structure
reward += 0.5 * tc_f1
reward += 0.3 * tp_f1
# Bonus for getting all tools right
if tc_f1 == 1.0:
reward += 0.1
return {
"tool_call_precision": tc_p,
"tool_call_recall": tc_r,
"tool_call_f1": tc_f1,
"tool_param_precision": tp_p,
"tool_param_recall": tp_r,
"tool_param_f1": tp_f1,
"correct_no_op": 0.0,
"is_seek_info": float(gt_is_seek),
"reward": reward,
}

View file

@ -0,0 +1,288 @@
"""
OpenAI function definitions for T1's 14 travel planning tools.
These are passed to managed_server.chat_completion(tools=T1_TOOLS)
so the model uses proper tool calling instead of raw code generation.
"""
T1_TOOLS = [
{
"type": "function",
"function": {
"name": "search_hotels",
"description": "Search for hotels in a city. Requires city, checkin_date, checkout_date.",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"},
"checkin_date": {"type": "array", "items": {"type": "string"}},
"checkout_date": {"type": "array", "items": {"type": "string"}},
"num_rooms": {"type": "integer"},
"num_people": {"type": "integer"},
"neighborhood": {"type": "array", "items": {"type": "string"}},
"hotel_name": {"type": "array", "items": {"type": "string"}},
"budget": {"type": "integer"},
"rating": {"type": "array", "items": {"type": "number"}},
"stars": {"type": "array", "items": {"type": "integer"}},
"free_wifi_included": {"type": "boolean"},
"breakfast_included": {"type": "boolean"},
"gym_present": {"type": "boolean"},
"pool_present": {"type": "boolean"},
"is_pet_friendly": {"type": "boolean"},
"has_spa_services": {"type": "boolean"},
"smoking_allowed": {"type": "boolean"},
"is_wheelchair_accessible": {"type": "boolean"},
"has_free_parking": {"type": "boolean"},
"airport_shuttle_present": {"type": "boolean"},
},
"required": ["city", "checkin_date", "checkout_date"],
},
},
},
{
"type": "function",
"function": {
"name": "filter_hotels",
"description": "Filter previously searched hotel results by additional criteria.",
"parameters": {
"type": "object",
"properties": {
"prior_results": {
"type": "string",
"description": "Variable name of prior results",
},
"neighborhood": {"type": "array", "items": {"type": "string"}},
"budget": {"type": "integer"},
"rating": {"type": "array", "items": {"type": "number"}},
"stars": {"type": "array", "items": {"type": "integer"}},
"free_wifi_included": {"type": "boolean"},
"breakfast_included": {"type": "boolean"},
"gym_present": {"type": "boolean"},
"pool_present": {"type": "boolean"},
"is_pet_friendly": {"type": "boolean"},
},
"required": ["prior_results"],
},
},
},
{
"type": "function",
"function": {
"name": "search_flights",
"description": "Search for flights. Requires departure_date and origin/destination.",
"parameters": {
"type": "object",
"properties": {
"start_airport_city": {"type": "string"},
"end_airport_city": {"type": "string"},
"departure_date": {"type": "array", "items": {"type": "string"}},
"arrival_date": {"type": "array", "items": {"type": "string"}},
"airline": {"type": "array", "items": {"type": "string"}},
"budget": {"type": "integer"},
"flight_class": {"type": "array", "items": {"type": "string"}},
"num_layovers": {"type": "array", "items": {"type": "integer"}},
},
"required": ["departure_date"],
},
},
},
{
"type": "function",
"function": {
"name": "filter_flights",
"description": "Filter previously searched flight results.",
"parameters": {
"type": "object",
"properties": {
"prior_results": {"type": "string"},
"airline": {"type": "array", "items": {"type": "string"}},
"budget": {"type": "integer"},
"flight_class": {"type": "array", "items": {"type": "string"}},
},
"required": ["prior_results"],
},
},
},
{
"type": "function",
"function": {
"name": "search_restaurants",
"description": "Search for restaurants in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"},
"cuisine": {"type": "array", "items": {"type": "string"}},
"rating": {"type": "array", "items": {"type": "number"}},
"neighborhood": {"type": "array", "items": {"type": "string"}},
"price_range": {"type": "array", "items": {"type": "string"}},
},
"required": ["city"],
},
},
},
{
"type": "function",
"function": {
"name": "filter_restaurants",
"description": "Filter previously searched restaurant results.",
"parameters": {
"type": "object",
"properties": {
"prior_results": {"type": "string"},
"cuisine": {"type": "array", "items": {"type": "string"}},
"rating": {"type": "array", "items": {"type": "number"}},
"neighborhood": {"type": "array", "items": {"type": "string"}},
},
"required": ["prior_results"],
},
},
},
{
"type": "function",
"function": {
"name": "search_attractions",
"description": "Search for attractions in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"},
"type": {"type": "array", "items": {"type": "string"}},
"neighborhood": {"type": "array", "items": {"type": "string"}},
},
"required": ["city"],
},
},
},
{
"type": "function",
"function": {
"name": "filter_attractions",
"description": "Filter previously searched attraction results.",
"parameters": {
"type": "object",
"properties": {
"prior_results": {"type": "string"},
"type": {"type": "array", "items": {"type": "string"}},
"neighborhood": {"type": "array", "items": {"type": "string"}},
},
"required": ["prior_results"],
},
},
},
{
"type": "function",
"function": {
"name": "save_to_cache",
"description": "Save results to cache with a unique key for later retrieval.",
"parameters": {
"type": "object",
"properties": {
"key": {"type": "string", "description": "Unique cache key"},
"value": {
"type": "string",
"description": "Variable name of results to cache",
},
},
"required": ["key", "value"],
},
},
},
{
"type": "function",
"function": {
"name": "get_results_from_cache",
"description": "Retrieve previously cached results by key.",
"parameters": {
"type": "object",
"properties": {
"key": {"type": "string", "description": "Cache key to retrieve"},
},
"required": ["key"],
},
},
},
{
"type": "function",
"function": {
"name": "sort_results",
"description": "Sort results by a specific field.",
"parameters": {
"type": "object",
"properties": {
"results": {
"type": "string",
"description": "Variable name of results to sort",
},
"sort_by": {
"type": "string",
"description": "Field to sort by (e.g. price, rating)",
},
"ascending": {"type": "boolean"},
},
"required": ["results", "sort_by"],
},
},
},
{
"type": "function",
"function": {
"name": "seek_information",
"description": "Ask the user for missing mandatory information before calling a tool.",
"parameters": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "Question to ask the user",
},
},
"required": ["question"],
},
},
},
{
"type": "function",
"function": {
"name": "adjust_date",
"description": "Adjust a date by a number of days.",
"parameters": {
"type": "object",
"properties": {
"date": {"type": "string"},
"days": {"type": "integer"},
},
"required": ["date", "days"],
},
},
},
{
"type": "function",
"function": {
"name": "search_nearest",
"description": "Find nearest locations between two sets of results (e.g. hotels near restaurants).",
"parameters": {
"type": "object",
"properties": {
"hotels": {
"type": "string",
"description": "Variable name of hotel results",
},
"restaurants": {
"type": "string",
"description": "Variable name of restaurant results",
},
"attractions": {
"type": "string",
"description": "Variable name of attraction results",
},
"groupBy": {
"type": "string",
"description": "Group results by this entity type",
},
},
"required": [],
},
},
},
]

View file

@ -0,0 +1,294 @@
#!/usr/bin/env python3
"""
Live test for T1 tool planning runs against an already-running vLLM server.
No pytest fixtures, no subprocess spawning. Just creates a ServerManager
pointed at localhost:9001, calls generate_tool_completions, and prints results.
Usage:
# With vLLM already running on port 9001:
python environments/t1_tool_planning/test_t1_live.py
# Custom port:
python environments/t1_tool_planning/test_t1_live.py --port 8123
# Custom model:
python environments/t1_tool_planning/test_t1_live.py --model Qwen/Qwen3-4B
"""
import argparse
import asyncio
import json
import logging
import os
import sys
# Ensure t1 modules are importable
sys.path.insert(0, os.path.dirname(__file__))
from t1_core import generate_tool_completions, score_completions # noqa: E402
from t1_prompts import SYSTEM_PROMPT # noqa: E402
from t1_scoring import score_turn # noqa: E402
from t1_tools import T1_TOOLS # noqa: E402
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
)
logger = logging.getLogger("test_t1_live")
def make_server_manager(model_name: str, base_url: str):
"""Create a ServerManager pointed at an existing vLLM server."""
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.server_manager import ServerManager
config = APIServerConfig(
model_name=model_name,
base_url=base_url,
api_key="x",
server_type="vllm",
)
server = ServerManager(
configs=[config],
slurm=False,
testing=False,
tool_parser="hermes",
)
return server
def make_tokenizer(model_name: str):
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(model_name)
SAMPLE_CONVERSATIONS = {
1: [
{
"Role": "assistant",
"Filled_Template": "Hello! I'm your travel assistant. How can I help you today?",
"Filled_Plan": "",
},
{
"Role": "user",
"Filled_Template": "I'm looking for hotels in Austin with check-in on May 10, 2025 and check-out on May 15, 2025.", # noqa: E501
"Filled_Plan": 'hotels = search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])\nsave_to_cache(key="hotels", value=hotels)', # noqa: E501
},
],
2: [
{
"Role": "assistant",
"Filled_Template": "Welcome! What can I help you plan?",
"Filled_Plan": "",
},
{
"Role": "user",
"Filled_Template": "I need a hotel in New York but I'm not sure about dates yet.",
"Filled_Plan": 'seek_information("We need to ask for the check-in and check-out dates")',
},
],
3: [
{
"Role": "assistant",
"Filled_Template": "Hi there! Looking for travel help?",
"Filled_Plan": "",
},
{
"Role": "user",
"Filled_Template": "No that's perfect, thanks!",
"Filled_Plan": 'print("No planning needed")',
},
],
}
def build_messages(conversation: list, turn_index: int) -> list:
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for i, turn in enumerate(conversation):
if i > turn_index:
break
role = turn["Role"].strip().lower()
messages.append({"role": role, "content": turn["Filled_Template"]})
return messages
async def test_single_completion(server, tokenizer):
"""Test 1: Single completion with tool calling."""
print("\n" + "=" * 60)
print("TEST 1: Single tool-calling completion")
print("=" * 60)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": "Find me hotels in Austin, checking in May 10 and out May 15, 2025.",
},
]
result, nodes = await generate_tool_completions(
server=server,
tokenizer=tokenizer,
messages=messages,
tools=T1_TOOLS,
n=1,
max_tokens=500,
temperature=0.0,
tool_choice="auto",
)
choice = result.choices[0]
print(f"\nContent: {choice.message.content}")
print(f"Tool calls: {choice.message.tool_calls}")
print(f"Finish reason: {choice.finish_reason}")
print(f"Nodes tracked: {len(nodes)}")
if nodes:
node = nodes[0]
print(f"Token count: {len(node.tokens)}")
unmasked = len([t for t in node.masked_tokens if t != -100])
print(f"Unmasked tokens: {unmasked}")
print(f"Logprobs sample: {node.logprobs[-5:]}")
# Score against ground truth
gt_code = 'hotels = search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])\nsave_to_cache(key="hotels", value=hotels)' # noqa: E501
scores = score_turn(gt_code, choice.message.tool_calls, choice.message.content)
print(f"\nScores: {json.dumps(scores, indent=2)}")
return True
async def test_group_completions(server, tokenizer):
"""Test 2: Multiple completions (group_size=4) for GRPO."""
print("\n" + "=" * 60)
print("TEST 2: Group completions (n=4) for GRPO")
print("=" * 60)
convo = SAMPLE_CONVERSATIONS[1]
messages = build_messages(convo, turn_index=1)
gt_code = convo[1]["Filled_Plan"]
result, nodes = await generate_tool_completions(
server=server,
tokenizer=tokenizer,
messages=messages,
tools=T1_TOOLS,
n=4,
max_tokens=500,
temperature=1.0,
tool_choice="auto",
)
print(f"\nGot {len(result.choices)} choices, {len(nodes)} nodes")
for i, choice in enumerate(result.choices):
tc_count = len(choice.message.tool_calls) if choice.message.tool_calls else 0
content = (choice.message.content or "")[:60]
print(f" choice[{i}]: {tc_count} tool_calls, content={content!r}")
# Score and build ScoredDataGroup
scored, all_scores = score_completions(result, nodes, gt_code)
print("\nPer-choice scores:")
for i, s in enumerate(all_scores):
print(
f" [{i}] reward={s['reward']:.2f} tc_f1={s['tool_call_f1']:.2f} tp_f1={s['tool_param_f1']:.2f}"
)
if scored:
print(f"\nScoredDataGroup valid: {len(scored['tokens'])} items")
print(f" scores: {scored['scores']}")
else:
print("\nScoredDataGroup: None (discarded)")
return True
async def test_noop_turn(server, tokenizer):
"""Test 3: No-op turn (model should NOT call tools)."""
print("\n" + "=" * 60)
print("TEST 3: No-op turn")
print("=" * 60)
convo = SAMPLE_CONVERSATIONS[3]
messages = build_messages(convo, turn_index=1)
gt_code = convo[1]["Filled_Plan"]
result, nodes = await generate_tool_completions(
server=server,
tokenizer=tokenizer,
messages=messages,
tools=T1_TOOLS,
n=1,
max_tokens=300,
temperature=0.0,
tool_choice="auto",
)
choice = result.choices[0]
print(f"\nContent: {(choice.message.content or '')[:100]}")
print(f"Tool calls: {choice.message.tool_calls}")
scores = score_turn(gt_code, choice.message.tool_calls, choice.message.content)
print(f"Scores: {json.dumps(scores, indent=2)}")
return True
async def main():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=9001)
parser.add_argument("--model", type=str, default="Qwen/Qwen3-1.7B")
args = parser.parse_args()
base_url = f"http://localhost:{args.port}/v1"
print(f"Connecting to vLLM at {base_url} (model={args.model})")
# Check health first
import requests
try:
resp = requests.get(f"http://localhost:{args.port}/health", timeout=5)
print(f"vLLM health: {resp.status_code}")
if resp.status_code != 200:
print("ERROR: vLLM not healthy!")
return
except Exception as e:
print(f"ERROR: Can't reach vLLM: {e}")
print(
"Make sure vLLM is running: bash environments/t1_tool_planning/run_vllm.sh"
)
return
server = make_server_manager(args.model, base_url)
tokenizer = make_tokenizer(args.model)
print(f"ServerManager created with {len(server.servers)} server(s)")
print(f"Server type: {type(server.servers[0]).__name__}")
print(f"Tool parser: {server.tool_parser}")
passed = 0
failed = 0
for test_fn in [test_single_completion, test_group_completions, test_noop_turn]:
try:
ok = await test_fn(server, tokenizer)
if ok:
passed += 1
print("\n ✓ PASSED")
except Exception as e:
failed += 1
print(f"\n ✗ FAILED: {e}")
import traceback
traceback.print_exc()
print(f"\n{'=' * 60}")
print(f"RESULTS: {passed} passed, {failed} failed")
print(f"{'=' * 60}")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,743 @@
"""
Standalone test for T1 tool planning environment.
Spins up vllm_api_server, creates a ServerManager with tool_parser="hermes",
and runs through T1 conversations end-to-end using the extracted t1_core functions.
Tests the full tool calling infrastructure:
ServerManager ManagedServer ToolCallTranslator vLLM hermes parser
structured tool_calls scoring against T1 ground truth
Usage:
# With GPU (spins up vLLM):
pytest --run-gpu environments/t1_tool_planning/test_t1_standalone.py -v -s
# Scoring logic only (no GPU):
pytest environments/t1_tool_planning/test_t1_standalone.py -v -k "not gpu"
"""
import json
import os
import signal
import subprocess
import sys
import time
import pytest
import requests
# -- T1 env imports --
sys.path.insert(0, os.path.dirname(__file__))
from t1_core import ( # noqa: E402
collect_multistep_trajectory,
generate_tool_completions,
score_completions,
)
from t1_prompts import SYSTEM_PROMPT # noqa: E402
from t1_scoring import ( # noqa: E402
parse_ground_truth_code,
parse_model_tool_calls,
score_turn,
tool_call_f1,
tool_param_f1,
)
from t1_tools import T1_TOOLS # noqa: E402
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
VLLM_PORT = 8123
VLLM_MODEL = "Qwen/Qwen3-1.7B"
VLLM_BASE_URL = f"http://localhost:{VLLM_PORT}/v1"
REPO_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
VLLM_SCRIPT = os.path.join(REPO_ROOT, "example_trainer", "vllm_api_server.py")
# Sample T1 data — small hotel conversations for testing
SAMPLE_T1_CONVERSATIONS = {
1: [
{
"Role": "assistant",
"Filled_Template": "Hello! I'm your travel assistant. How can I help you today?",
"Filled_Plan": "",
},
{
"Role": "user",
"Filled_Template": "I'm looking for hotels in Austin with check-in on May 10, 2025 and check-out on May 15, 2025.", # noqa: E501
"Filled_Plan": 'hotels = search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])\nsave_to_cache(key="hotels", value=hotels)', # noqa: E501
},
{
"Role": "assistant",
"Filled_Template": "I found several hotels in Austin for those dates. Would you like to filter by any specific amenities?", # noqa: E501
"Filled_Plan": "",
},
{
"Role": "user",
"Filled_Template": "Yes, I need one with free wifi and a gym.",
"Filled_Plan": 'hotels = get_results_from_cache(key="hotels")\nfiltered_hotels = filter_hotels(prior_results=hotels, free_wifi_included=True, gym_present=True)\nsave_to_cache(key="filtered_hotels", value=filtered_hotels)', # noqa: E501
},
{
"Role": "assistant",
"Filled_Template": "Here are hotels with free wifi and gym. Anything else?",
"Filled_Plan": "",
},
{
"Role": "user",
"Filled_Template": "No that's perfect, thanks!",
"Filled_Plan": 'print("No planning needed")',
},
],
2: [
{
"Role": "assistant",
"Filled_Template": "Welcome! What can I help you plan?",
"Filled_Plan": "",
},
{
"Role": "user",
"Filled_Template": "I need a hotel in New York but I'm not sure about dates yet.",
"Filled_Plan": 'seek_information("We need to ask for the check-in and check-out dates")',
},
{
"Role": "assistant",
"Filled_Template": "Sure! When would you like to check in and check out?",
"Filled_Plan": "",
},
{
"Role": "user",
"Filled_Template": "Check in June 1 and check out June 5, 2025.",
"Filled_Plan": 'hotels = search_hotels(city="New York", checkin_date=["June 1, 2025"], checkout_date=["June 5, 2025"])\nsave_to_cache(key="hotels", value=hotels)', # noqa: E501
},
],
}
def load_sample_data() -> dict:
"""Load the sample T1 conversations."""
return SAMPLE_T1_CONVERSATIONS
def build_messages_for_turn(conversation: list, turn_index: int) -> list:
"""Build chat messages up to (and including) the given user turn.
Uses ground-truth assistant responses for prior turns.
"""
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for i, turn in enumerate(conversation):
if i > turn_index:
break
role = turn["Role"].strip().lower()
content = turn["Filled_Template"]
if role == "assistant":
messages.append({"role": "assistant", "content": content})
elif role == "user":
messages.append({"role": "user", "content": content})
return messages
# ---------------------------------------------------------------------------
# Scoring-only tests (no GPU needed)
# ---------------------------------------------------------------------------
class TestT1Scoring:
"""Test the scoring logic with known inputs."""
def test_parse_ground_truth_search(self):
code = 'hotels = search_hotels(city="Austin", checkin_date=["May 10"], checkout_date=["May 15"])\nsave_to_cache(key="hotels", value=hotels)' # noqa: E501
calls = parse_ground_truth_code(code)
assert len(calls) == 2
assert calls[0]["name"] == "search_hotels"
assert calls[0]["arguments"]["city"] == "Austin"
assert calls[1]["name"] == "save_to_cache"
def test_parse_ground_truth_filter(self):
code = 'hotels = get_results_from_cache(key="hotels")\nfiltered = filter_hotels(prior_results=hotels, free_wifi_included=True)' # noqa: E501
calls = parse_ground_truth_code(code)
assert len(calls) == 2
assert calls[0]["name"] == "get_results_from_cache"
assert calls[1]["name"] == "filter_hotels"
assert calls[1]["arguments"]["free_wifi_included"] is True
def test_parse_ground_truth_noop(self):
code = 'print("No planning needed")'
calls = parse_ground_truth_code(code)
assert len(calls) == 0
def test_parse_ground_truth_seek(self):
code = 'seek_information("We need check-in dates")'
calls = parse_ground_truth_code(code)
assert len(calls) == 1
assert calls[0]["name"] == "seek_information"
def test_parse_empty(self):
assert parse_ground_truth_code("") == []
assert parse_ground_truth_code(None) == []
def test_tool_call_f1_perfect(self):
gt = [
{"name": "search_hotels", "arguments": {}},
{"name": "save_to_cache", "arguments": {}},
]
gen = [
{"name": "search_hotels", "arguments": {}},
{"name": "save_to_cache", "arguments": {}},
]
p, r, f1 = tool_call_f1(gt, gen)
assert f1 == 1.0
def test_tool_call_f1_partial(self):
gt = [
{"name": "search_hotels", "arguments": {}},
{"name": "save_to_cache", "arguments": {}},
]
gen = [{"name": "search_hotels", "arguments": {}}]
p, r, f1 = tool_call_f1(gt, gen)
assert p == 1.0
assert r == 0.5
assert 0 < f1 < 1
def test_tool_call_f1_wrong(self):
gt = [{"name": "search_hotels", "arguments": {}}]
gen = [{"name": "search_flights", "arguments": {}}]
p, r, f1 = tool_call_f1(gt, gen)
assert f1 == 0.0
def test_tool_param_f1_matching(self):
gt = [
{
"name": "search_hotels",
"arguments": {"city": "Austin", "checkin_date": ["May 10"]},
}
]
gen = [
{
"name": "search_hotels",
"arguments": {"city": "Austin", "checkin_date": ["May 10"]},
}
]
p, r, f1 = tool_param_f1(gt, gen)
assert f1 == 1.0
def test_tool_param_f1_partial(self):
gt = [
{
"name": "search_hotels",
"arguments": {"city": "Austin", "checkin_date": ["May 10"]},
}
]
gen = [{"name": "search_hotels", "arguments": {"city": "Austin"}}]
p, r, f1 = tool_param_f1(gt, gen)
assert p == 1.0 # what we generated is correct
assert r == 0.5 # but we missed one param
def test_score_turn_noop_correct(self):
scores = score_turn('print("No planning needed")', None)
assert scores["reward"] == 1.0
def test_score_turn_noop_wrong(self):
# GT says no-op but model called tools
fake_calls = [
{"function": {"name": "search_hotels", "arguments": '{"city": "X"}'}}
]
scores = score_turn('print("No planning needed")', fake_calls)
assert scores["reward"] == 0.0
def test_score_turn_tools_expected_none_produced(self):
# GT expects tools but model produced none → 0.0
gt = 'hotels = search_hotels(city="Austin", checkin_date=["May 10"], checkout_date=["May 15"])'
scores = score_turn(gt, None)
assert scores["reward"] == 0.0
def test_score_turn_wrong_tool_gets_format_credit(self):
# GT expects search_hotels, model called search_flights → 0.1 (format credit only)
gt = 'hotels = search_hotels(city="Austin", checkin_date=["May 10"], checkout_date=["May 15"])'
wrong_calls = [
{
"function": {
"name": "search_flights",
"arguments": '{"start_airport_city": "X"}',
}
}
]
scores = score_turn(gt, wrong_calls)
assert scores["reward"] == 0.1 # format credit, no f1 match
assert scores["tool_call_f1"] == 0.0
def test_score_turn_right_tool_higher_than_wrong(self):
# Right tool should score higher than wrong tool
gt = 'hotels = search_hotels(city="Austin", checkin_date=["May 10"], checkout_date=["May 15"])'
right_calls = [
{"function": {"name": "search_hotels", "arguments": '{"city": "Austin"}'}}
]
wrong_calls = [
{
"function": {
"name": "search_flights",
"arguments": '{"start_airport_city": "X"}',
}
}
]
right_scores = score_turn(gt, right_calls)
wrong_scores = score_turn(gt, wrong_calls)
assert right_scores["reward"] > wrong_scores["reward"]
def test_parse_model_tool_calls(self):
calls = [
{"function": {"name": "search_hotels", "arguments": '{"city": "Austin"}'}},
{
"function": {
"name": "save_to_cache",
"arguments": '{"key": "hotels", "value": "hotels"}',
}
},
]
parsed = parse_model_tool_calls(calls)
assert len(parsed) == 2
assert parsed[0]["name"] == "search_hotels"
assert parsed[0]["arguments"]["city"] == "Austin"
def test_sample_data_loads(self):
convos = load_sample_data()
assert len(convos) == 2
assert len(convos[1]) == 6 # 3 assistant + 3 user turns
assert len(convos[2]) == 4
def test_build_messages(self):
convos = load_sample_data()
# Turn index 1 = first user turn (index 0 is assistant)
msgs = build_messages_for_turn(convos[1], turn_index=1)
assert msgs[0]["role"] == "system"
assert msgs[1]["role"] == "assistant"
assert msgs[2]["role"] == "user"
assert "Austin" in msgs[2]["content"]
def test_t1_tools_valid(self):
"""Verify all tool definitions have required fields."""
for tool in T1_TOOLS:
assert tool["type"] == "function"
assert "name" in tool["function"]
assert "parameters" in tool["function"]
# ---------------------------------------------------------------------------
# GPU integration test — full pipeline with vLLM
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def vllm_backend():
"""Start vLLM api server as a subprocess."""
cmd = [
sys.executable,
VLLM_SCRIPT,
"--model",
VLLM_MODEL,
"--port",
str(VLLM_PORT),
"--gpu-memory-utilization",
"0.45",
"--enforce-eager",
]
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=REPO_ROOT,
)
deadline = time.time() + 180
healthy = False
while time.time() < deadline:
try:
resp = requests.get(f"http://localhost:{VLLM_PORT}/health", timeout=2)
if resp.status_code == 200:
healthy = True
break
except (requests.ConnectionError, requests.Timeout):
pass
if proc.poll() is not None:
stdout = proc.stdout.read().decode() if proc.stdout else ""
pytest.fail(f"vLLM exited early:\n{stdout[-3000:]}")
time.sleep(3)
if not healthy:
proc.kill()
stdout = proc.stdout.read().decode() if proc.stdout else ""
pytest.fail(f"vLLM didn't start within 180s:\n{stdout[-3000:]}")
yield proc
proc.send_signal(signal.SIGTERM)
try:
proc.wait(timeout=15)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait()
@pytest.fixture(scope="module")
def server_and_tokenizer(vllm_backend):
"""Create a ServerManager + tokenizer pointed at the vLLM backend."""
from transformers import AutoTokenizer
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.server_manager import ServerManager
config = APIServerConfig(
model_name=VLLM_MODEL,
base_url=VLLM_BASE_URL,
api_key="x",
server_type="vllm",
)
server = ServerManager(
configs=[config],
slurm=False,
testing=False,
tool_parser="hermes",
)
tokenizer = AutoTokenizer.from_pretrained(VLLM_MODEL)
return server, tokenizer
@pytest.mark.gpu
class TestT1FullPipeline:
"""End-to-end test: vLLM → ServerManager → ManagedServer → tool calls → scoring."""
async def test_single_turn_tool_call(self, server_and_tokenizer):
"""Model should call search_hotels when given enough info."""
server, tokenizer = server_and_tokenizer
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": "Find me hotels in Austin, checking in May 10 and out May 15, 2025.",
},
]
result, nodes = await generate_tool_completions(
server=server,
tokenizer=tokenizer,
messages=messages,
tools=T1_TOOLS,
n=1,
max_tokens=500,
temperature=0.0,
tool_choice="auto",
)
choice = result.choices[0]
print(f"\nContent: {choice.message.content}")
print(f"Tool calls: {choice.message.tool_calls}")
print(f"Finish reason: {choice.finish_reason}")
# Model should have called at least search_hotels
if choice.message.tool_calls:
names = [tc["function"]["name"] for tc in choice.message.tool_calls]
print(f"Tool names called: {names}")
assert "search_hotels" in names or "seek_information" in names
# Score against ground truth
gt_code = 'hotels = search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])\nsave_to_cache(key="hotels", value=hotels)' # noqa: E501
scores = score_turn(gt_code, choice.message.tool_calls)
print(f"Scores: {json.dumps(scores, indent=2)}")
assert scores["reward"] > 0
# Verify nodes are tracked
assert len(nodes) == 1
assert len(nodes[0].tokens) > 0
assert len(nodes[0].logprobs) == len(nodes[0].tokens)
async def test_group_completions(self, server_and_tokenizer):
"""Generate n=4 completions (GRPO-style) and score them."""
server, tokenizer = server_and_tokenizer
convo = SAMPLE_T1_CONVERSATIONS[1]
messages = build_messages_for_turn(convo, turn_index=1)
gt_code = convo[1]["Filled_Plan"]
result, nodes = await generate_tool_completions(
server=server,
tokenizer=tokenizer,
messages=messages,
tools=T1_TOOLS,
n=4,
max_tokens=500,
temperature=1.0,
tool_choice="auto",
)
assert len(result.choices) == 4
assert len(nodes) == 4
print(f"\nGot {len(result.choices)} choices:")
for i, choice in enumerate(result.choices):
tc_count = (
len(choice.message.tool_calls) if choice.message.tool_calls else 0
)
content = (choice.message.content or "")[:60]
print(f" choice[{i}]: {tc_count} tool_calls, content={content!r}")
# Score and build ScoredDataGroup
scored, all_scores = score_completions(result, nodes, gt_code)
print("\nPer-choice scores:")
for i, s in enumerate(all_scores):
print(
f" [{i}] reward={s['reward']:.2f} tc_f1={s['tool_call_f1']:.2f} tp_f1={s['tool_param_f1']:.2f}"
)
# At least some choices should have tokens
assert len(all_scores) == 4
if scored:
print(
f"\nScoredDataGroup: {len(scored['tokens'])} valid items, scores={scored['scores']}"
)
# Verify structure
assert len(scored["tokens"]) == len(scored["masks"])
assert len(scored["tokens"]) == len(scored["scores"])
assert len(scored["tokens"]) == len(scored["inference_logprobs"])
async def test_seek_information(self, server_and_tokenizer):
"""Model should ask for missing info when dates aren't provided."""
server, tokenizer = server_and_tokenizer
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": "I need a hotel in New York but I'm not sure about dates.",
},
]
result, nodes = await generate_tool_completions(
server=server,
tokenizer=tokenizer,
messages=messages,
tools=T1_TOOLS,
n=1,
max_tokens=300,
temperature=0.0,
tool_choice="auto",
)
choice = result.choices[0]
print(f"\nContent: {choice.message.content}")
print(f"Tool calls: {choice.message.tool_calls}")
gt_code = (
'seek_information("We need to ask for the check-in and check-out dates")'
)
scores = score_turn(gt_code, choice.message.tool_calls)
print(f"Scores: {json.dumps(scores, indent=2)}")
assert len(nodes) == 1
async def test_noop_turn(self, server_and_tokenizer):
"""No-op turn — model should just respond without tools."""
server, tokenizer = server_and_tokenizer
convo = SAMPLE_T1_CONVERSATIONS[1]
# Turn 5 = "No that's perfect, thanks!" → print("No planning needed")
messages = build_messages_for_turn(convo, turn_index=5)
gt_code = convo[5]["Filled_Plan"]
result, nodes = await generate_tool_completions(
server=server,
tokenizer=tokenizer,
messages=messages,
tools=T1_TOOLS,
n=1,
max_tokens=300,
temperature=0.0,
tool_choice="auto",
)
choice = result.choices[0]
print(f"\nContent: {(choice.message.content or '')[:100]}")
print(f"Tool calls: {choice.message.tool_calls}")
scores = score_turn(gt_code, choice.message.tool_calls, choice.message.content)
print(f"Scores: {json.dumps(scores, indent=2)}")
async def test_conversation_walkthrough(self, server_and_tokenizer):
"""Walk through a full T1 conversation using GT context, scoring each user turn."""
server, tokenizer = server_and_tokenizer
convos = load_sample_data()
convo = convos[1] # 3-turn hotel conversation
all_scores = []
user_turn_indices = [
i for i, t in enumerate(convo) if t["Role"].strip().lower() == "user"
]
for turn_idx in user_turn_indices:
turn = convo[turn_idx]
messages = build_messages_for_turn(convo, turn_idx)
result, nodes = await generate_tool_completions(
server=server,
tokenizer=tokenizer,
messages=messages,
tools=T1_TOOLS,
n=1,
max_tokens=500,
temperature=0.0,
tool_choice="auto",
)
choice = result.choices[0]
gt_code = turn["Filled_Plan"]
scores = score_turn(
gt_code, choice.message.tool_calls, choice.message.content
)
print(f"\n--- Turn {turn_idx} ---")
print(f"User: {turn['Filled_Template'][:80]}...")
print(f"GT code: {gt_code[:80]}...")
if choice.message.tool_calls:
names = [tc["function"]["name"] for tc in choice.message.tool_calls]
print(f"Model called: {names}")
else:
print(f"Model text: {(choice.message.content or '')[:80]}...")
print(
f"Scores: tc_f1={scores['tool_call_f1']:.2f} tp_f1={scores['tool_param_f1']:.2f} reward={scores['reward']:.2f}" # noqa: E501
)
all_scores.append(scores)
# Aggregate
avg_tc_f1 = sum(s["tool_call_f1"] for s in all_scores) / len(all_scores)
avg_tp_f1 = sum(s["tool_param_f1"] for s in all_scores) / len(all_scores)
avg_reward = sum(s["reward"] for s in all_scores) / len(all_scores)
print(f"\n=== AGGREGATE ({len(all_scores)} turns) ===")
print(f"Tool Call F1: {avg_tc_f1:.3f}")
print(f"Tool Param F1: {avg_tp_f1:.3f}")
print(f"Avg Reward: {avg_reward:.3f}")
async def test_multistep_trajectory(self, server_and_tokenizer):
"""Multi-step: walk full conversation feeding model's OWN responses back.
This is the real end-to-end test. At each turn:
1. Model generates a response (possibly with tool_calls)
2. That response is fed back as conversation history
3. ToolCallTranslator reconstructs raw text from tool_calls
4. Chat template re-tokenizes with reconstructed text
5. Next turn uses the model's actual history, not GT
Tests the full bidirectional tool call pipeline.
"""
server, tokenizer = server_and_tokenizer
convos = load_sample_data()
convo = convos[1] # 3-turn hotel conversation (6 entries: 3 asst + 3 user)
turn_results, nodes = await collect_multistep_trajectory(
server=server,
tokenizer=tokenizer,
conversation=convo,
tools=T1_TOOLS,
max_tokens=500,
temperature=0.0,
tool_choice="auto",
)
assert len(turn_results) > 0, "Should have at least one turn result"
assert len(nodes) > 0, "Should have tracked nodes"
print(
f"\n=== MULTI-STEP TRAJECTORY ({len(turn_results)} turns, {len(nodes)} nodes) ==="
)
for tr in turn_results:
print(f"\n--- Turn {tr['turn_idx']} ---")
print(f"User: {tr['user_message'][:80]}...")
print(f"GT: {tr['gt_code'][:80]}...")
if tr["tool_calls"]:
names = [tc["function"]["name"] for tc in tr["tool_calls"]]
print(f"Model called: {names}")
else:
print(f"Model text: {(tr['content'] or '')[:80]}...")
s = tr["scores"]
print(
f"Scores: tc_f1={s['tool_call_f1']:.2f} tp_f1={s['tool_param_f1']:.2f} reward={s['reward']:.2f}"
)
print(f"Messages in context: {len(tr['messages_so_far'])}")
# Verify nodes — each turn should extend the previous
print("\nNodes from managed server:")
for i, node in enumerate(nodes):
unmasked = len([t for t in node.masked_tokens if t != -100])
print(f" node[{i}]: {len(node.tokens)} tokens, {unmasked} unmasked")
# Verify conversation grew correctly
user_turns = len(turn_results)
last_msg_count = len(turn_results[-1]["messages_so_far"])
expected_min = (
2 + user_turns
) # system + greeting + at least 1 msg per user turn
print(
f"\nFinal conversation: {last_msg_count} messages (expected >= {expected_min})"
)
assert last_msg_count >= expected_min
# Aggregate
avg_reward = sum(r["scores"]["reward"] for r in turn_results) / len(
turn_results
)
avg_tc_f1 = sum(r["scores"]["tool_call_f1"] for r in turn_results) / len(
turn_results
)
print(f"\nAvg Reward: {avg_reward:.3f}")
print(f"Avg TC F1: {avg_tc_f1:.3f}")
async def test_multistep_with_tool_history(self, server_and_tokenizer):
"""Verify tool_calls in history are properly reconstructed for subsequent turns.
The critical path: turn N produces tool_calls turn N+1's prompt must
contain those tool_calls in raw text form (e.g. <tool_call> tags) so the
tokenizer produces correct tokens.
"""
server, tokenizer = server_and_tokenizer
convos = load_sample_data()
convo = convos[2] # 2-turn: seek_information → search_hotels
turn_results, nodes = await collect_multistep_trajectory(
server=server,
tokenizer=tokenizer,
conversation=convo,
tools=T1_TOOLS,
max_tokens=500,
temperature=0.0,
tool_choice="auto",
)
print(
f"\n=== TOOL HISTORY TEST ({len(turn_results)} turns, {len(nodes)} nodes) ==="
)
for tr in turn_results:
tc_count = len(tr["tool_calls"]) if tr["tool_calls"] else 0
print(
f"Turn {tr['turn_idx']}: {tc_count} tool_calls, reward={tr['scores']['reward']:.2f}"
)
# Nodes should show extending — later nodes have more tokens
for i, node in enumerate(nodes):
print(f" node[{i}]: {len(node.tokens)} total tokens")
# If turn 0 produced tool_calls, verify turn 1's messages contain them
if len(turn_results) >= 2 and turn_results[0]["tool_calls"]:
turn1_messages = turn_results[1]["messages_so_far"]
assistant_msgs = [m for m in turn1_messages if m.get("role") == "assistant"]
has_tool_history = any(m.get("tool_calls") for m in assistant_msgs)
print(f"\nTurn 1 has tool_call history in context: {has_tool_history}")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])