mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add tool call parsing based on vllm impl and an openai server endpoint
This commit is contained in:
parent
887a94374c
commit
add42a2afb
11 changed files with 3370 additions and 34 deletions
363
atroposlib/tests/test_managed_server_proxy_integration.py
Normal file
363
atroposlib/tests/test_managed_server_proxy_integration.py
Normal file
|
|
@ -0,0 +1,363 @@
|
|||
"""Integration tests for ManagedServer OpenAI proxy against real vLLM backend.
|
||||
|
||||
Spins up example_trainer/vllm_api_server.py with Qwen3-4B as a subprocess.
|
||||
Requires GPU — skipped by default. Run with:
|
||||
|
||||
pytest --run-gpu atroposlib/tests/test_managed_server_proxy_integration.py -v -s
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.server_handling.managed_server_proxy import create_app
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VLLM_PORT = 8123
|
||||
VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507"
|
||||
PROXY_MODEL = VLLM_MODEL
|
||||
VLLM_BASE_URL = f"http://localhost:{VLLM_PORT}/v1"
|
||||
REPO_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||
VLLM_SCRIPT = os.path.join(REPO_ROOT, "example_trainer", "vllm_api_server.py")
|
||||
VENV_PYTHON = sys.executable # use the current interpreter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vllm_backend():
|
||||
"""Start vLLM api server as a subprocess. Module-scoped so it's shared."""
|
||||
cmd = [
|
||||
VENV_PYTHON,
|
||||
VLLM_SCRIPT,
|
||||
"--model",
|
||||
VLLM_MODEL,
|
||||
"--port",
|
||||
str(VLLM_PORT),
|
||||
"--max-model-len",
|
||||
"32000",
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
]
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=REPO_ROOT,
|
||||
)
|
||||
|
||||
# Wait for health
|
||||
deadline = time.time() + 180 # 3 min for model loading
|
||||
healthy = False
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(f"http://localhost:{VLLM_PORT}/health", timeout=2)
|
||||
if resp.status_code == 200:
|
||||
healthy = True
|
||||
break
|
||||
except (requests.ConnectionError, requests.Timeout):
|
||||
pass
|
||||
|
||||
if proc.poll() is not None:
|
||||
stdout = proc.stdout.read().decode() if proc.stdout else ""
|
||||
pytest.fail(
|
||||
f"vLLM server exited early (code={proc.returncode}):\n{stdout[-3000:]}"
|
||||
)
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
if not healthy:
|
||||
proc.kill()
|
||||
stdout = proc.stdout.read().decode() if proc.stdout else ""
|
||||
pytest.fail(f"vLLM server didn't become healthy within 180s:\n{stdout[-3000:]}")
|
||||
|
||||
yield proc
|
||||
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer_instance():
|
||||
return AutoTokenizer.from_pretrained(VLLM_MODEL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def proxy_client(vllm_backend, tokenizer_instance):
|
||||
"""Create a test client for the proxy backed by the real vLLM server."""
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
config = APIServerConfig(
|
||||
model_name=VLLM_MODEL,
|
||||
base_url=VLLM_BASE_URL,
|
||||
api_key="",
|
||||
server_type="vllm",
|
||||
health_check=False,
|
||||
)
|
||||
server_manager = ServerManager(configs=[config])
|
||||
|
||||
app = create_app(
|
||||
server_manager=server_manager,
|
||||
tokenizer=tokenizer_instance,
|
||||
model_name=VLLM_MODEL,
|
||||
)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealChatCompletion:
|
||||
def test_basic_completion(self, proxy_client):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say hello in one word."}],
|
||||
"max_tokens": 30,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["choices"]) == 1
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
assert content is not None
|
||||
assert len(content) > 0
|
||||
assert data["model"] == VLLM_MODEL
|
||||
|
||||
def test_n_completions(self, proxy_client):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Pick a random number"}],
|
||||
"max_tokens": 20,
|
||||
"temperature": 1.0,
|
||||
"n": 4,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["choices"]) == 4
|
||||
|
||||
# Check nodes
|
||||
resp = proxy_client.get(f"/{uuid}/nodes")
|
||||
assert len(resp.json()["nodes"]) == 4
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealLogprobs:
|
||||
def test_logprobs_are_valid(self, proxy_client):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 20,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
resp = proxy_client.get(f"/{uuid}/nodes")
|
||||
nodes = resp.json()["nodes"]
|
||||
assert len(nodes) == 1
|
||||
|
||||
node = nodes[0]
|
||||
# Find where completion starts (logprobs transition from 1.0 to negative)
|
||||
prompt_end = 0
|
||||
for i, lp in enumerate(node["logprobs"]):
|
||||
if lp != 1.0:
|
||||
prompt_end = i
|
||||
break
|
||||
|
||||
# Prompt logprobs should be 1.0
|
||||
assert all(lp == 1.0 for lp in node["logprobs"][:prompt_end])
|
||||
# Completion logprobs should be negative
|
||||
completion_lps = [lp for lp in node["logprobs"][prompt_end:] if lp != 1.0]
|
||||
assert len(completion_lps) > 0
|
||||
assert all(lp < 0 for lp in completion_lps)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealTokenAlignment:
|
||||
def test_tokens_decode_to_full_text(self, proxy_client, tokenizer_instance):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say exactly: test123"}],
|
||||
"max_tokens": 30,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
resp = proxy_client.get(f"/{uuid}/nodes")
|
||||
node = resp.json()["nodes"][0]
|
||||
|
||||
# Lengths must match
|
||||
assert len(node["tokens"]) == len(node["masked_tokens"])
|
||||
assert len(node["tokens"]) == len(node["logprobs"])
|
||||
|
||||
# Decode tokens and check they match full_text
|
||||
decoded = tokenizer_instance.decode(node["tokens"])
|
||||
# The decoded text should be close to (or contain) the full_text
|
||||
# Exact match may differ due to special token handling, but content should match
|
||||
assert len(decoded) > 0
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealRender:
|
||||
def test_render_matches_tokenizer(self, proxy_client, tokenizer_instance):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
]
|
||||
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions/render",
|
||||
json={"messages": messages},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
# Compare with direct tokenizer rendering
|
||||
expected_text = tokenizer_instance.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
assert data["prompt_text"] == expected_text
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealSequenceExtension:
|
||||
def test_multi_turn_extends(self, proxy_client):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
# Turn 1
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say hello"}],
|
||||
"max_tokens": 20,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
turn1_content = resp.json()["choices"][0]["message"]["content"]
|
||||
|
||||
# Turn 2 — extends turn 1
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{"role": "user", "content": "Say hello"},
|
||||
{"role": "assistant", "content": turn1_content},
|
||||
{"role": "user", "content": "Now say goodbye"},
|
||||
],
|
||||
"max_tokens": 20,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Should have nodes (extension behavior depends on prefix matching)
|
||||
resp = proxy_client.get(f"/{uuid}/nodes")
|
||||
nodes = resp.json()["nodes"]
|
||||
assert len(nodes) >= 1
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealConcurrentSessions:
|
||||
def test_sessions_independent(self, proxy_client):
|
||||
"""Multiple sessions should not contaminate each other."""
|
||||
uuids = []
|
||||
for _ in range(3):
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuids.append(resp.json()["uuid"])
|
||||
|
||||
# Complete on each
|
||||
for i, uuid in enumerate(uuids):
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": f"Count to {i+1}"}],
|
||||
"max_tokens": 30,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Each should have exactly 1 node
|
||||
for uuid in uuids:
|
||||
resp = proxy_client.get(f"/{uuid}/nodes")
|
||||
assert len(resp.json()["nodes"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
class TestRealOpenAIClientCompat:
|
||||
def test_openai_client_works(self, proxy_client):
|
||||
"""Verify the standard openai Python client can talk to our proxy."""
|
||||
resp = proxy_client.post("/sessions/create", json={})
|
||||
uuid = resp.json()["uuid"]
|
||||
|
||||
# The TestClient doesn't expose a real port, so we test the
|
||||
# response format is compatible by checking structure manually
|
||||
resp = proxy_client.post(
|
||||
f"/{uuid}/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
data = resp.json()
|
||||
|
||||
# Verify all fields the openai client expects
|
||||
assert "id" in data
|
||||
assert "object" in data
|
||||
assert data["object"] == "chat.completion"
|
||||
assert "created" in data
|
||||
assert "model" in data
|
||||
assert "choices" in data
|
||||
assert isinstance(data["choices"], list)
|
||||
for choice in data["choices"]:
|
||||
assert "index" in choice
|
||||
assert "message" in choice
|
||||
assert "finish_reason" in choice
|
||||
assert "role" in choice["message"]
|
||||
assert "content" in choice["message"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue