adding tests

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 17:23:40 -04:00
parent 6c564799bc
commit 1b8ff075c4
3 changed files with 84 additions and 3 deletions

View file

@ -1,7 +1,13 @@
"""Tests for get_logprobs wrappers and server-manager routing."""
import logging
import pytest
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
from atroposlib.envs.server_handling.vllm_server import (
resolve_openai_configs as resolve_vllm_configs,
)
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
APIServerConfig,
@ -103,3 +109,49 @@ async def test_server_manager_get_logprobs_routes_to_most_available_server():
out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval")
assert out_eval["server"] == "s1"
assert s1.calls == 1
def test_resolve_openai_configs_wraps_single_api_server_config_in_list():
default_server_config = APIServerConfig(
model_name="test-model",
base_url="http://localhost:9001/v1",
api_key="x",
server_type="openai",
)
merged_config = default_server_config.model_dump()
server_configs = resolve_openai_configs(
default_server_configs=default_server_config,
openai_config_dict=merged_config,
yaml_config={},
cli_passed_flags={},
logger=logging.getLogger("test"),
)
assert isinstance(server_configs, list)
assert len(server_configs) == 1
assert isinstance(server_configs[0], APIServerConfig)
assert server_configs[0].base_url == "http://localhost:9001/v1"
def test_resolve_vllm_configs_wraps_single_api_server_config_in_list():
default_server_config = APIServerConfig(
model_name="test-model",
base_url="http://localhost:9001/v1",
api_key="x",
server_type="vllm",
)
merged_config = default_server_config.model_dump()
server_configs = resolve_vllm_configs(
default_server_configs=default_server_config,
openai_config_dict=merged_config,
yaml_config={},
cli_passed_flags={},
logger=logging.getLogger("test"),
)
assert isinstance(server_configs, list)
assert len(server_configs) == 1
assert isinstance(server_configs[0], APIServerConfig)
assert server_configs[0].base_url == "http://localhost:9001/v1"