mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
adding tests
This commit is contained in:
parent
6c564799bc
commit
1b8ff075c4
3 changed files with 84 additions and 3 deletions
|
|
@ -1,7 +1,13 @@
|
|||
"""Tests for get_logprobs wrappers and server-manager routing."""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
|
||||
from atroposlib.envs.server_handling.vllm_server import (
|
||||
resolve_openai_configs as resolve_vllm_configs,
|
||||
)
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
|
|
@ -103,3 +109,49 @@ async def test_server_manager_get_logprobs_routes_to_most_available_server():
|
|||
out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval")
|
||||
assert out_eval["server"] == "s1"
|
||||
assert s1.calls == 1
|
||||
|
||||
|
||||
def test_resolve_openai_configs_wraps_single_api_server_config_in_list():
|
||||
default_server_config = APIServerConfig(
|
||||
model_name="test-model",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
server_type="openai",
|
||||
)
|
||||
merged_config = default_server_config.model_dump()
|
||||
|
||||
server_configs = resolve_openai_configs(
|
||||
default_server_configs=default_server_config,
|
||||
openai_config_dict=merged_config,
|
||||
yaml_config={},
|
||||
cli_passed_flags={},
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
assert isinstance(server_configs, list)
|
||||
assert len(server_configs) == 1
|
||||
assert isinstance(server_configs[0], APIServerConfig)
|
||||
assert server_configs[0].base_url == "http://localhost:9001/v1"
|
||||
|
||||
|
||||
def test_resolve_vllm_configs_wraps_single_api_server_config_in_list():
|
||||
default_server_config = APIServerConfig(
|
||||
model_name="test-model",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
merged_config = default_server_config.model_dump()
|
||||
|
||||
server_configs = resolve_vllm_configs(
|
||||
default_server_configs=default_server_config,
|
||||
openai_config_dict=merged_config,
|
||||
yaml_config={},
|
||||
cli_passed_flags={},
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
assert isinstance(server_configs, list)
|
||||
assert len(server_configs) == 1
|
||||
assert isinstance(server_configs[0], APIServerConfig)
|
||||
assert server_configs[0].base_url == "http://localhost:9001/v1"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue