mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
363 lines
11 KiB
Python
363 lines
11 KiB
Python
"""Integration tests for ManagedServer OpenAI proxy against real vLLM backend.
|
|
|
|
Spins up example_trainer/vllm_api_server.py with Qwen3-4B as a subprocess.
|
|
Requires GPU — skipped by default. Run with:
|
|
|
|
pytest --run-gpu atroposlib/tests/test_managed_server_proxy_integration.py -v -s
|
|
|
|
"""
|
|
|
|
import os
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
|
|
import pytest
|
|
import requests
|
|
from transformers import AutoTokenizer
|
|
|
|
from atroposlib.envs.server_handling.managed_server_proxy import create_app
|
|
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
|
from atroposlib.envs.server_handling.server_manager import ServerManager
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
VLLM_PORT = 8123
|
|
VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507"
|
|
PROXY_MODEL = VLLM_MODEL
|
|
VLLM_BASE_URL = f"http://localhost:{VLLM_PORT}/v1"
|
|
REPO_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
|
|
VLLM_SCRIPT = os.path.join(REPO_ROOT, "example_trainer", "vllm_api_server.py")
|
|
VENV_PYTHON = sys.executable # use the current interpreter
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def vllm_backend():
|
|
"""Start vLLM api server as a subprocess. Module-scoped so it's shared."""
|
|
cmd = [
|
|
VENV_PYTHON,
|
|
VLLM_SCRIPT,
|
|
"--model",
|
|
VLLM_MODEL,
|
|
"--port",
|
|
str(VLLM_PORT),
|
|
"--max-model-len",
|
|
"32000",
|
|
"--max-num-seqs",
|
|
"32",
|
|
]
|
|
|
|
proc = subprocess.Popen(
|
|
cmd,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT,
|
|
cwd=REPO_ROOT,
|
|
)
|
|
|
|
# Wait for health
|
|
deadline = time.time() + 180 # 3 min for model loading
|
|
healthy = False
|
|
while time.time() < deadline:
|
|
try:
|
|
resp = requests.get(f"http://localhost:{VLLM_PORT}/health", timeout=2)
|
|
if resp.status_code == 200:
|
|
healthy = True
|
|
break
|
|
except (requests.ConnectionError, requests.Timeout):
|
|
pass
|
|
|
|
if proc.poll() is not None:
|
|
stdout = proc.stdout.read().decode() if proc.stdout else ""
|
|
pytest.fail(
|
|
f"vLLM server exited early (code={proc.returncode}):\n{stdout[-3000:]}"
|
|
)
|
|
|
|
time.sleep(3)
|
|
|
|
if not healthy:
|
|
proc.kill()
|
|
stdout = proc.stdout.read().decode() if proc.stdout else ""
|
|
pytest.fail(f"vLLM server didn't become healthy within 180s:\n{stdout[-3000:]}")
|
|
|
|
yield proc
|
|
|
|
proc.send_signal(signal.SIGTERM)
|
|
try:
|
|
proc.wait(timeout=15)
|
|
except subprocess.TimeoutExpired:
|
|
proc.kill()
|
|
proc.wait()
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def tokenizer_instance():
|
|
return AutoTokenizer.from_pretrained(VLLM_MODEL)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def proxy_client(vllm_backend, tokenizer_instance):
|
|
"""Create a test client for the proxy backed by the real vLLM server."""
|
|
from fastapi.testclient import TestClient
|
|
|
|
config = APIServerConfig(
|
|
model_name=VLLM_MODEL,
|
|
base_url=VLLM_BASE_URL,
|
|
api_key="",
|
|
server_type="vllm",
|
|
health_check=False,
|
|
)
|
|
server_manager = ServerManager(configs=[config])
|
|
|
|
app = create_app(
|
|
server_manager=server_manager,
|
|
tokenizer=tokenizer_instance,
|
|
model_name=VLLM_MODEL,
|
|
)
|
|
return TestClient(app)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.gpu
|
|
class TestRealChatCompletion:
|
|
def test_basic_completion(self, proxy_client):
|
|
resp = proxy_client.post("/sessions/create", json={})
|
|
uuid = resp.json()["uuid"]
|
|
|
|
resp = proxy_client.post(
|
|
f"/{uuid}/v1/chat/completions",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Say hello in one word."}],
|
|
"max_tokens": 30,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data["choices"]) == 1
|
|
content = data["choices"][0]["message"]["content"]
|
|
assert content is not None
|
|
assert len(content) > 0
|
|
assert data["model"] == VLLM_MODEL
|
|
|
|
def test_n_completions(self, proxy_client):
|
|
resp = proxy_client.post("/sessions/create", json={})
|
|
uuid = resp.json()["uuid"]
|
|
|
|
resp = proxy_client.post(
|
|
f"/{uuid}/v1/chat/completions",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Pick a random number"}],
|
|
"max_tokens": 20,
|
|
"temperature": 1.0,
|
|
"n": 4,
|
|
},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data["choices"]) == 4
|
|
|
|
# Check nodes
|
|
resp = proxy_client.get(f"/{uuid}/nodes")
|
|
assert len(resp.json()["nodes"]) == 4
|
|
|
|
|
|
@pytest.mark.gpu
|
|
class TestRealLogprobs:
|
|
def test_logprobs_are_valid(self, proxy_client):
|
|
resp = proxy_client.post("/sessions/create", json={})
|
|
uuid = resp.json()["uuid"]
|
|
|
|
proxy_client.post(
|
|
f"/{uuid}/v1/chat/completions",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"max_tokens": 20,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
|
|
resp = proxy_client.get(f"/{uuid}/nodes")
|
|
nodes = resp.json()["nodes"]
|
|
assert len(nodes) == 1
|
|
|
|
node = nodes[0]
|
|
# Find where completion starts (logprobs transition from 1.0 to negative)
|
|
prompt_end = 0
|
|
for i, lp in enumerate(node["logprobs"]):
|
|
if lp != 1.0:
|
|
prompt_end = i
|
|
break
|
|
|
|
# Prompt logprobs should be 1.0
|
|
assert all(lp == 1.0 for lp in node["logprobs"][:prompt_end])
|
|
# Completion logprobs should be negative
|
|
completion_lps = [lp for lp in node["logprobs"][prompt_end:] if lp != 1.0]
|
|
assert len(completion_lps) > 0
|
|
assert all(lp < 0 for lp in completion_lps)
|
|
|
|
|
|
@pytest.mark.gpu
|
|
class TestRealTokenAlignment:
|
|
def test_tokens_decode_to_full_text(self, proxy_client, tokenizer_instance):
|
|
resp = proxy_client.post("/sessions/create", json={})
|
|
uuid = resp.json()["uuid"]
|
|
|
|
proxy_client.post(
|
|
f"/{uuid}/v1/chat/completions",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Say exactly: test123"}],
|
|
"max_tokens": 30,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
|
|
resp = proxy_client.get(f"/{uuid}/nodes")
|
|
node = resp.json()["nodes"][0]
|
|
|
|
# Lengths must match
|
|
assert len(node["tokens"]) == len(node["masked_tokens"])
|
|
assert len(node["tokens"]) == len(node["logprobs"])
|
|
|
|
# Decode tokens and check they match full_text
|
|
decoded = tokenizer_instance.decode(node["tokens"])
|
|
# The decoded text should be close to (or contain) the full_text
|
|
# Exact match may differ due to special token handling, but content should match
|
|
assert len(decoded) > 0
|
|
|
|
|
|
@pytest.mark.gpu
|
|
class TestRealRender:
|
|
def test_render_matches_tokenizer(self, proxy_client, tokenizer_instance):
|
|
resp = proxy_client.post("/sessions/create", json={})
|
|
uuid = resp.json()["uuid"]
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are helpful."},
|
|
{"role": "user", "content": "Hello!"},
|
|
]
|
|
|
|
resp = proxy_client.post(
|
|
f"/{uuid}/v1/chat/completions/render",
|
|
json={"messages": messages},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
|
|
# Compare with direct tokenizer rendering
|
|
expected_text = tokenizer_instance.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
assert data["prompt_text"] == expected_text
|
|
|
|
|
|
@pytest.mark.gpu
|
|
class TestRealSequenceExtension:
|
|
def test_multi_turn_extends(self, proxy_client):
|
|
resp = proxy_client.post("/sessions/create", json={})
|
|
uuid = resp.json()["uuid"]
|
|
|
|
# Turn 1
|
|
resp = proxy_client.post(
|
|
f"/{uuid}/v1/chat/completions",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Say hello"}],
|
|
"max_tokens": 20,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
assert resp.status_code == 200
|
|
turn1_content = resp.json()["choices"][0]["message"]["content"]
|
|
|
|
# Turn 2 — extends turn 1
|
|
resp = proxy_client.post(
|
|
f"/{uuid}/v1/chat/completions",
|
|
json={
|
|
"messages": [
|
|
{"role": "user", "content": "Say hello"},
|
|
{"role": "assistant", "content": turn1_content},
|
|
{"role": "user", "content": "Now say goodbye"},
|
|
],
|
|
"max_tokens": 20,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
assert resp.status_code == 200
|
|
|
|
# Should have nodes (extension behavior depends on prefix matching)
|
|
resp = proxy_client.get(f"/{uuid}/nodes")
|
|
nodes = resp.json()["nodes"]
|
|
assert len(nodes) >= 1
|
|
|
|
|
|
@pytest.mark.gpu
|
|
class TestRealConcurrentSessions:
|
|
def test_sessions_independent(self, proxy_client):
|
|
"""Multiple sessions should not contaminate each other."""
|
|
uuids = []
|
|
for _ in range(3):
|
|
resp = proxy_client.post("/sessions/create", json={})
|
|
uuids.append(resp.json()["uuid"])
|
|
|
|
# Complete on each
|
|
for i, uuid in enumerate(uuids):
|
|
resp = proxy_client.post(
|
|
f"/{uuid}/v1/chat/completions",
|
|
json={
|
|
"messages": [{"role": "user", "content": f"Count to {i+1}"}],
|
|
"max_tokens": 30,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
assert resp.status_code == 200
|
|
|
|
# Each should have exactly 1 node
|
|
for uuid in uuids:
|
|
resp = proxy_client.get(f"/{uuid}/nodes")
|
|
assert len(resp.json()["nodes"]) == 1
|
|
|
|
|
|
@pytest.mark.gpu
|
|
class TestRealOpenAIClientCompat:
|
|
def test_openai_client_works(self, proxy_client):
|
|
"""Verify the standard openai Python client can talk to our proxy."""
|
|
resp = proxy_client.post("/sessions/create", json={})
|
|
uuid = resp.json()["uuid"]
|
|
|
|
# The TestClient doesn't expose a real port, so we test the
|
|
# response format is compatible by checking structure manually
|
|
resp = proxy_client.post(
|
|
f"/{uuid}/v1/chat/completions",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"max_tokens": 10,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
data = resp.json()
|
|
|
|
# Verify all fields the openai client expects
|
|
assert "id" in data
|
|
assert "object" in data
|
|
assert data["object"] == "chat.completion"
|
|
assert "created" in data
|
|
assert "model" in data
|
|
assert "choices" in data
|
|
assert isinstance(data["choices"], list)
|
|
for choice in data["choices"]:
|
|
assert "index" in choice
|
|
assert "message" in choice
|
|
assert "finish_reason" in choice
|
|
assert "role" in choice["message"]
|
|
assert "content" in choice["message"]
|