mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
testing
This commit is contained in:
parent
efc90bfb1b
commit
1a3d9ee664
1 changed files with 68 additions and 0 deletions
68
atroposlib/tests/test_vllm_api_server_generate.py
Normal file
68
atroposlib/tests/test_vllm_api_server_generate.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
"""Optional integration test for example_trainer.vllm_api_server /generate."""
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_api_server_generate_endpoint_optional():
|
||||
"""
|
||||
Validate /generate contract on the custom vLLM API server.
|
||||
|
||||
This test only runs when vLLM is installed.
|
||||
"""
|
||||
pytest.importorskip("vllm")
|
||||
|
||||
module = import_module("example_trainer.vllm_api_server")
|
||||
|
||||
class _FakeLogprob:
|
||||
def __init__(self, value: float):
|
||||
self.logprob = value
|
||||
|
||||
class _FakeOutput:
|
||||
def __init__(self):
|
||||
self.text = " world"
|
||||
self.finish_reason = "stop"
|
||||
self.logprobs = [{11: _FakeLogprob(-0.3)}]
|
||||
self.token_ids = [11]
|
||||
|
||||
class _FakeRequestOutput:
|
||||
def __init__(self):
|
||||
self.prompt = "hello"
|
||||
self.prompt_token_ids = [1, 2]
|
||||
self.outputs = [_FakeOutput()]
|
||||
|
||||
class _FakeEngine:
|
||||
tokenizer = type("Tok", (), {"decode": staticmethod(lambda _: "hello")})()
|
||||
|
||||
def generate(self, *_args, **_kwargs):
|
||||
async def _gen():
|
||||
yield _FakeRequestOutput()
|
||||
|
||||
return _gen()
|
||||
|
||||
old_engine = module.engine
|
||||
module.engine = _FakeEngine()
|
||||
try:
|
||||
client = TestClient(module.app)
|
||||
resp = client.post(
|
||||
"/generate",
|
||||
json={
|
||||
"prompt": "hello",
|
||||
"max_tokens": 1,
|
||||
"temperature": 0.0,
|
||||
"logprobs": 1,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "text" in body and body["text"] == [" world"]
|
||||
assert body["prompt"] == "hello"
|
||||
assert body["finish_reasons"] == ["stop"]
|
||||
assert "logprobs" in body
|
||||
assert "token_ids" in body
|
||||
assert "prompt_token_ids" in body
|
||||
finally:
|
||||
module.engine = old_engine
|
||||
Loading…
Add table
Add a link
Reference in a new issue