mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
294 lines
8.9 KiB
Python
294 lines
8.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Live test for T1 tool planning — runs against an already-running vLLM server.
|
|
|
|
No pytest fixtures, no subprocess spawning. Just creates a ServerManager
|
|
pointed at localhost:9001, calls generate_tool_completions, and prints results.
|
|
|
|
Usage:
|
|
# With vLLM already running on port 9001:
|
|
python environments/t1_tool_planning/test_t1_live.py
|
|
|
|
# Custom port:
|
|
python environments/t1_tool_planning/test_t1_live.py --port 8123
|
|
|
|
# Custom model:
|
|
python environments/t1_tool_planning/test_t1_live.py --model Qwen/Qwen3-4B
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
# Ensure t1 modules are importable
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
|
|
from t1_core import generate_tool_completions, score_completions # noqa: E402
|
|
from t1_prompts import SYSTEM_PROMPT # noqa: E402
|
|
from t1_scoring import score_turn # noqa: E402
|
|
from t1_tools import T1_TOOLS # noqa: E402
|
|
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
|
)
|
|
logger = logging.getLogger("test_t1_live")
|
|
|
|
|
|
def make_server_manager(model_name: str, base_url: str):
|
|
"""Create a ServerManager pointed at an existing vLLM server."""
|
|
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
|
from atroposlib.envs.server_handling.server_manager import ServerManager
|
|
|
|
config = APIServerConfig(
|
|
model_name=model_name,
|
|
base_url=base_url,
|
|
api_key="x",
|
|
server_type="vllm",
|
|
)
|
|
server = ServerManager(
|
|
configs=[config],
|
|
slurm=False,
|
|
testing=False,
|
|
tool_parser="hermes",
|
|
)
|
|
return server
|
|
|
|
|
|
def make_tokenizer(model_name: str):
|
|
from transformers import AutoTokenizer
|
|
|
|
return AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
SAMPLE_CONVERSATIONS = {
|
|
1: [
|
|
{
|
|
"Role": "assistant",
|
|
"Filled_Template": "Hello! I'm your travel assistant. How can I help you today?",
|
|
"Filled_Plan": "",
|
|
},
|
|
{
|
|
"Role": "user",
|
|
"Filled_Template": "I'm looking for hotels in Austin with check-in on May 10, 2025 and check-out on May 15, 2025.", # noqa: E501
|
|
"Filled_Plan": 'hotels = search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])\nsave_to_cache(key="hotels", value=hotels)', # noqa: E501
|
|
},
|
|
],
|
|
2: [
|
|
{
|
|
"Role": "assistant",
|
|
"Filled_Template": "Welcome! What can I help you plan?",
|
|
"Filled_Plan": "",
|
|
},
|
|
{
|
|
"Role": "user",
|
|
"Filled_Template": "I need a hotel in New York but I'm not sure about dates yet.",
|
|
"Filled_Plan": 'seek_information("We need to ask for the check-in and check-out dates")',
|
|
},
|
|
],
|
|
3: [
|
|
{
|
|
"Role": "assistant",
|
|
"Filled_Template": "Hi there! Looking for travel help?",
|
|
"Filled_Plan": "",
|
|
},
|
|
{
|
|
"Role": "user",
|
|
"Filled_Template": "No that's perfect, thanks!",
|
|
"Filled_Plan": 'print("No planning needed")',
|
|
},
|
|
],
|
|
}
|
|
|
|
|
|
def build_messages(conversation: list, turn_index: int) -> list:
|
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
|
for i, turn in enumerate(conversation):
|
|
if i > turn_index:
|
|
break
|
|
role = turn["Role"].strip().lower()
|
|
messages.append({"role": role, "content": turn["Filled_Template"]})
|
|
return messages
|
|
|
|
|
|
async def test_single_completion(server, tokenizer):
|
|
"""Test 1: Single completion with tool calling."""
|
|
print("\n" + "=" * 60)
|
|
print("TEST 1: Single tool-calling completion")
|
|
print("=" * 60)
|
|
|
|
messages = [
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{
|
|
"role": "user",
|
|
"content": "Find me hotels in Austin, checking in May 10 and out May 15, 2025.",
|
|
},
|
|
]
|
|
|
|
result, nodes = await generate_tool_completions(
|
|
server=server,
|
|
tokenizer=tokenizer,
|
|
messages=messages,
|
|
tools=T1_TOOLS,
|
|
n=1,
|
|
max_tokens=500,
|
|
temperature=0.0,
|
|
tool_choice="auto",
|
|
)
|
|
|
|
choice = result.choices[0]
|
|
print(f"\nContent: {choice.message.content}")
|
|
print(f"Tool calls: {choice.message.tool_calls}")
|
|
print(f"Finish reason: {choice.finish_reason}")
|
|
print(f"Nodes tracked: {len(nodes)}")
|
|
|
|
if nodes:
|
|
node = nodes[0]
|
|
print(f"Token count: {len(node.tokens)}")
|
|
unmasked = len([t for t in node.masked_tokens if t != -100])
|
|
print(f"Unmasked tokens: {unmasked}")
|
|
print(f"Logprobs sample: {node.logprobs[-5:]}")
|
|
|
|
# Score against ground truth
|
|
gt_code = 'hotels = search_hotels(city="Austin", checkin_date=["May 10, 2025"], checkout_date=["May 15, 2025"])\nsave_to_cache(key="hotels", value=hotels)' # noqa: E501
|
|
scores = score_turn(gt_code, choice.message.tool_calls, choice.message.content)
|
|
print(f"\nScores: {json.dumps(scores, indent=2)}")
|
|
|
|
return True
|
|
|
|
|
|
async def test_group_completions(server, tokenizer):
|
|
"""Test 2: Multiple completions (group_size=4) for GRPO."""
|
|
print("\n" + "=" * 60)
|
|
print("TEST 2: Group completions (n=4) for GRPO")
|
|
print("=" * 60)
|
|
|
|
convo = SAMPLE_CONVERSATIONS[1]
|
|
messages = build_messages(convo, turn_index=1)
|
|
gt_code = convo[1]["Filled_Plan"]
|
|
|
|
result, nodes = await generate_tool_completions(
|
|
server=server,
|
|
tokenizer=tokenizer,
|
|
messages=messages,
|
|
tools=T1_TOOLS,
|
|
n=4,
|
|
max_tokens=500,
|
|
temperature=1.0,
|
|
tool_choice="auto",
|
|
)
|
|
|
|
print(f"\nGot {len(result.choices)} choices, {len(nodes)} nodes")
|
|
|
|
for i, choice in enumerate(result.choices):
|
|
tc_count = len(choice.message.tool_calls) if choice.message.tool_calls else 0
|
|
content = (choice.message.content or "")[:60]
|
|
print(f" choice[{i}]: {tc_count} tool_calls, content={content!r}")
|
|
|
|
# Score and build ScoredDataGroup
|
|
scored, all_scores = score_completions(result, nodes, gt_code)
|
|
|
|
print("\nPer-choice scores:")
|
|
for i, s in enumerate(all_scores):
|
|
print(
|
|
f" [{i}] reward={s['reward']:.2f} tc_f1={s['tool_call_f1']:.2f} tp_f1={s['tool_param_f1']:.2f}"
|
|
)
|
|
|
|
if scored:
|
|
print(f"\nScoredDataGroup valid: {len(scored['tokens'])} items")
|
|
print(f" scores: {scored['scores']}")
|
|
else:
|
|
print("\nScoredDataGroup: None (discarded)")
|
|
|
|
return True
|
|
|
|
|
|
async def test_noop_turn(server, tokenizer):
|
|
"""Test 3: No-op turn (model should NOT call tools)."""
|
|
print("\n" + "=" * 60)
|
|
print("TEST 3: No-op turn")
|
|
print("=" * 60)
|
|
|
|
convo = SAMPLE_CONVERSATIONS[3]
|
|
messages = build_messages(convo, turn_index=1)
|
|
gt_code = convo[1]["Filled_Plan"]
|
|
|
|
result, nodes = await generate_tool_completions(
|
|
server=server,
|
|
tokenizer=tokenizer,
|
|
messages=messages,
|
|
tools=T1_TOOLS,
|
|
n=1,
|
|
max_tokens=300,
|
|
temperature=0.0,
|
|
tool_choice="auto",
|
|
)
|
|
|
|
choice = result.choices[0]
|
|
print(f"\nContent: {(choice.message.content or '')[:100]}")
|
|
print(f"Tool calls: {choice.message.tool_calls}")
|
|
|
|
scores = score_turn(gt_code, choice.message.tool_calls, choice.message.content)
|
|
print(f"Scores: {json.dumps(scores, indent=2)}")
|
|
|
|
return True
|
|
|
|
|
|
async def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--port", type=int, default=9001)
|
|
parser.add_argument("--model", type=str, default="Qwen/Qwen3-1.7B")
|
|
args = parser.parse_args()
|
|
|
|
base_url = f"http://localhost:{args.port}/v1"
|
|
print(f"Connecting to vLLM at {base_url} (model={args.model})")
|
|
|
|
# Check health first
|
|
import requests
|
|
|
|
try:
|
|
resp = requests.get(f"http://localhost:{args.port}/health", timeout=5)
|
|
print(f"vLLM health: {resp.status_code}")
|
|
if resp.status_code != 200:
|
|
print("ERROR: vLLM not healthy!")
|
|
return
|
|
except Exception as e:
|
|
print(f"ERROR: Can't reach vLLM: {e}")
|
|
print(
|
|
"Make sure vLLM is running: bash environments/t1_tool_planning/run_vllm.sh"
|
|
)
|
|
return
|
|
|
|
server = make_server_manager(args.model, base_url)
|
|
tokenizer = make_tokenizer(args.model)
|
|
|
|
print(f"ServerManager created with {len(server.servers)} server(s)")
|
|
print(f"Server type: {type(server.servers[0]).__name__}")
|
|
print(f"Tool parser: {server.tool_parser}")
|
|
|
|
passed = 0
|
|
failed = 0
|
|
|
|
for test_fn in [test_single_completion, test_group_completions, test_noop_turn]:
|
|
try:
|
|
ok = await test_fn(server, tokenizer)
|
|
if ok:
|
|
passed += 1
|
|
print("\n ✓ PASSED")
|
|
except Exception as e:
|
|
failed += 1
|
|
print(f"\n ✗ FAILED: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
print(f"\n{'=' * 60}")
|
|
print(f"RESULTS: {passed} passed, {failed} failed")
|
|
print(f"{'=' * 60}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|