Merge conflict commit

This commit is contained in:
dmahan93 2026-03-09 23:13:43 -05:00
commit f198c1738e
13 changed files with 579 additions and 14 deletions

View file

@ -1,4 +1,5 @@
import subprocess
import sys
import time
import requests
@ -16,7 +17,7 @@ def launch_api_for_testing(max_wait_for_api: int = 10) -> subprocess.Popen:
# Use subprocess instead of multiprocessing to avoid inheriting pytest args
api_proc = subprocess.Popen(
[
"python",
sys.executable,
"-m",
"atroposlib.cli.run_api",
"--host",

View file

@ -5,6 +5,7 @@ import json
import os
import signal
import subprocess
import sys
import time
import pytest
@ -27,7 +28,7 @@ def wait_for_api_server(max_wait=10):
def api_server():
proc = subprocess.Popen(
[
"python",
sys.executable,
"-m",
"atroposlib.cli.run_api",
"--host",

View file

@ -5,6 +5,7 @@ Tests for API server message handling, particularly for SFT (Supervised Fine-Tun
import os
import signal
import subprocess
import sys
import time
import pytest
@ -30,7 +31,7 @@ def api_server():
# Start the API server as a subprocess
proc = subprocess.Popen(
[
"python",
sys.executable,
"-m",
"atroposlib.cli.run_api",
"--host",

View file

@ -268,6 +268,91 @@ async def test_bos_token_handling(mock_server):
assert mock_server.tokenizer.bos_token_id not in node.tokens[1:]
@pytest.mark.asyncio
async def test_get_logprobs_normalized_schema(mock_server):
"""ManagedServer.get_logprobs returns normalized prompt schema."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
prompt = "Hello"
prompt_tokens = mock_server.tokenizer.encode(prompt)
prompt_topk_token_ids = [[t, t + 1] for t in prompt_tokens]
prompt_topk_logprobs = [[-0.1, -0.2] for _ in prompt_tokens]
async def _mock_get_logprobs(**kwargs):
assert kwargs.get("prompt") == prompt
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": prompt_topk_token_ids,
"prompt_topk_logprobs": prompt_topk_logprobs,
}
mock_server.get_logprobs = _mock_get_logprobs
payload = await managed.get_logprobs(prompt=prompt, n=1)
assert payload["prompt_tokens"] == prompt_tokens
assert payload["prompt_topk_token_ids"] == prompt_topk_token_ids
assert payload["prompt_topk_logprobs"] == prompt_topk_logprobs
@pytest.mark.asyncio
async def test_get_logprobs_messages_passthrough(mock_server):
"""ManagedServer.get_logprobs converts messages and passes prompt through."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
messages = [{"role": "user", "content": "Hello"}]
expected_prompt = managed._convert_messages_to_prompt(messages)
prompt_tokens = mock_server.tokenizer.encode(expected_prompt)
async def _mock_get_logprobs(**kwargs):
assert kwargs.get("prompt") == expected_prompt
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": [[t] for t in prompt_tokens],
"prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens],
}
mock_server.get_logprobs = _mock_get_logprobs
payload = await managed.get_logprobs(messages=messages, top_k=1)
assert payload["prompt_tokens"] == prompt_tokens
assert len(payload["prompt_topk_token_ids"]) == len(prompt_tokens)
assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens)
@pytest.mark.asyncio
async def test_get_logprobs_input_ids_only_passthrough(mock_server):
"""ManagedServer.get_logprobs supports input_ids-only without requiring prompt."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
input_ids = [10, 20, 30]
async def _mock_get_logprobs(**kwargs):
assert "input_ids" in kwargs
assert kwargs["input_ids"] == input_ids
assert kwargs.get("prompt") is None
return {
"prompt_tokens": input_ids,
"prompt_topk_token_ids": [[t] for t in input_ids],
"prompt_topk_logprobs": [[-0.1] for _ in input_ids],
}
mock_server.get_logprobs = _mock_get_logprobs
payload = await managed.get_logprobs(input_ids=input_ids, top_k=1)
assert payload["prompt_tokens"] == input_ids
assert payload["prompt_topk_token_ids"] == [[10], [20], [30]]
assert payload["prompt_topk_logprobs"] == [[-0.1], [-0.1], [-0.1]]
@pytest.mark.asyncio
async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server):
"""ManagedServer.get_logprobs requires backend get_logprobs in strict mode."""
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
prompt = "Hello"
with pytest.raises(NotImplementedError, match="does not implement get_logprobs"):
await managed.get_logprobs(prompt=prompt, n=1)
@pytest.mark.asyncio
async def test_reset_clears_sequences(mock_server):
"""Test that reset() clears all tracked sequences."""

View file

@ -0,0 +1,105 @@
"""Tests for get_logprobs wrappers and server-manager routing."""
import pytest
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
APIServerConfig,
AsyncSemWithAdaptiveWeight,
)
from atroposlib.envs.server_handling.server_manager import ServerManager
class _FakeAPIServer(APIServer):
def __init__(self, config: APIServerConfig):
super().__init__(config=config, reasoning_config=None)
self.calls = 0
self.last_kwargs = None
async def check_server_status_task(self, chat_completion: bool = True):
self.server_healthy = True
async def _chat_completion_wrapper(self, **kwargs):
raise NotImplementedError
async def _completion_wrapper(self, **kwargs):
raise NotImplementedError
async def _tokens_and_logprobs_completion_wrapper(self, **kwargs):
raise NotImplementedError
async def _get_logprobs_wrapper(self, **kwargs):
self.calls += 1
self.last_kwargs = kwargs
prompt = kwargs.get("prompt", "")
prompt_tokens = [ord(c) for c in prompt]
return {
"prompt_tokens": prompt_tokens,
"prompt_topk_token_ids": [[t] for t in prompt_tokens],
"prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens],
}
class _FakeRoutedServer:
def __init__(
self, name: str, train_slots: int, eval_slots: int, healthy: bool = True
):
self.name = name
self.server_healthy = healthy
self.sem = AsyncSemWithAdaptiveWeight(4)
self.eval_sem = AsyncSemWithAdaptiveWeight(4)
self.sem._value = train_slots
self.eval_sem._value = eval_slots
self.calls = 0
async def get_logprobs(self, **kwargs):
self.calls += 1
return {
"server": self.name,
"prompt_tokens": [1],
"prompt_topk_token_ids": [[1]],
"prompt_topk_logprobs": [[-0.1]],
}
@pytest.mark.asyncio
async def test_apiserver_get_logprobs_train_eval_wrappers():
cfg = APIServerConfig(
model_name="test-model",
base_url="",
health_check=False,
)
server = _FakeAPIServer(cfg)
train_out = await server.get_logprobs(prompt="hi", split="train")
assert train_out["prompt_tokens"] == [ord("h"), ord("i")]
assert server.calls == 1
assert server.last_kwargs["model"] == "test-model"
assert len(server.request_timings) == 1
assert len(server.attempts_list) == 1
assert len(server.eval_request_timings) == 0
assert len(server.eval_attempts_list) == 0
eval_out = await server.get_logprobs(prompt="ok", split="eval")
assert eval_out["prompt_tokens"] == [ord("o"), ord("k")]
assert server.calls == 2
assert len(server.eval_request_timings) == 1
assert len(server.eval_attempts_list) == 1
@pytest.mark.asyncio
async def test_server_manager_get_logprobs_routes_to_most_available_server():
s1 = _FakeRoutedServer("s1", train_slots=1, eval_slots=4, healthy=True)
s2 = _FakeRoutedServer("s2", train_slots=3, eval_slots=1, healthy=True)
s3 = _FakeRoutedServer("s3", train_slots=4, eval_slots=4, healthy=False)
manager = ServerManager.__new__(ServerManager)
manager.servers = [s1, s2, s3]
out_train = await ServerManager.get_logprobs(manager, prompt="x", split="train")
assert out_train["server"] == "s2"
assert s2.calls == 1
out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval")
assert out_eval["server"] == "s1"
assert s1.calls == 1

View file

@ -0,0 +1,68 @@
"""Optional integration test for example_trainer.vllm_api_server /generate."""
from importlib import import_module
import pytest
from fastapi.testclient import TestClient
@pytest.mark.asyncio
async def test_vllm_api_server_generate_endpoint_optional():
"""
Validate /generate contract on the custom vLLM API server.
This test only runs when vLLM is installed.
"""
pytest.importorskip("vllm")
module = import_module("example_trainer.vllm_api_server")
class _FakeLogprob:
def __init__(self, value: float):
self.logprob = value
class _FakeOutput:
def __init__(self):
self.text = " world"
self.finish_reason = "stop"
self.logprobs = [{11: _FakeLogprob(-0.3)}]
self.token_ids = [11]
class _FakeRequestOutput:
def __init__(self):
self.prompt = "hello"
self.prompt_token_ids = [1, 2]
self.outputs = [_FakeOutput()]
class _FakeEngine:
tokenizer = type("Tok", (), {"decode": staticmethod(lambda _: "hello")})()
def generate(self, *_args, **_kwargs):
async def _gen():
yield _FakeRequestOutput()
return _gen()
old_engine = module.engine
module.engine = _FakeEngine()
try:
client = TestClient(module.app)
resp = client.post(
"/generate",
json={
"prompt": "hello",
"max_tokens": 1,
"temperature": 0.0,
"logprobs": 1,
},
)
assert resp.status_code == 200
body = resp.json()
assert "text" in body and body["text"] == [" world"]
assert body["prompt"] == "hello"
assert body["finish_reasons"] == ["stop"]
assert "logprobs" in body
assert "token_ids" in body
assert "prompt_token_ids" in body
finally:
module.engine = old_engine