"""Tests for get_logprobs wrappers and server-manager routing.""" import logging import pytest from atroposlib.envs.server_handling.openai_server import resolve_openai_configs from atroposlib.envs.server_handling.vllm_server import ( resolve_openai_configs as resolve_vllm_configs, ) from atroposlib.envs.server_handling.server_baseline import ( APIServer, APIServerConfig, AsyncSemWithAdaptiveWeight, ) from atroposlib.envs.server_handling.server_manager import ServerManager class _FakeAPIServer(APIServer): def __init__(self, config: APIServerConfig): super().__init__(config=config, reasoning_config=None) self.calls = 0 self.last_kwargs = None async def check_server_status_task(self, chat_completion: bool = True): self.server_healthy = True async def _chat_completion_wrapper(self, **kwargs): raise NotImplementedError async def _completion_wrapper(self, **kwargs): raise NotImplementedError async def _tokens_and_logprobs_completion_wrapper(self, **kwargs): raise NotImplementedError async def _get_logprobs_wrapper(self, **kwargs): self.calls += 1 self.last_kwargs = kwargs prompt = kwargs.get("prompt", "") prompt_tokens = [ord(c) for c in prompt] return { "prompt_tokens": prompt_tokens, "prompt_topk_token_ids": [[t] for t in prompt_tokens], "prompt_topk_logprobs": [[-0.1] for _ in prompt_tokens], } class _FakeRoutedServer: def __init__( self, name: str, train_slots: int, eval_slots: int, healthy: bool = True ): self.name = name self.server_healthy = healthy self.sem = AsyncSemWithAdaptiveWeight(4) self.eval_sem = AsyncSemWithAdaptiveWeight(4) self.sem._value = train_slots self.eval_sem._value = eval_slots self.calls = 0 async def get_logprobs(self, **kwargs): self.calls += 1 return { "server": self.name, "prompt_tokens": [1], "prompt_topk_token_ids": [[1]], "prompt_topk_logprobs": [[-0.1]], } @pytest.mark.asyncio async def test_apiserver_get_logprobs_train_eval_wrappers(): cfg = APIServerConfig( model_name="test-model", base_url="", health_check=False, ) server = _FakeAPIServer(cfg) train_out = await server.get_logprobs(prompt="hi", split="train") assert train_out["prompt_tokens"] == [ord("h"), ord("i")] assert server.calls == 1 assert server.last_kwargs["model"] == "test-model" assert len(server.request_timings) == 1 assert len(server.attempts_list) == 1 assert len(server.eval_request_timings) == 0 assert len(server.eval_attempts_list) == 0 eval_out = await server.get_logprobs(prompt="ok", split="eval") assert eval_out["prompt_tokens"] == [ord("o"), ord("k")] assert server.calls == 2 assert len(server.eval_request_timings) == 1 assert len(server.eval_attempts_list) == 1 @pytest.mark.asyncio async def test_server_manager_get_logprobs_routes_to_most_available_server(): s1 = _FakeRoutedServer("s1", train_slots=1, eval_slots=4, healthy=True) s2 = _FakeRoutedServer("s2", train_slots=3, eval_slots=1, healthy=True) s3 = _FakeRoutedServer("s3", train_slots=4, eval_slots=4, healthy=False) manager = ServerManager.__new__(ServerManager) manager.servers = [s1, s2, s3] out_train = await ServerManager.get_logprobs(manager, prompt="x", split="train") assert out_train["server"] == "s2" assert s2.calls == 1 out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval") assert out_eval["server"] == "s1" assert s1.calls == 1 def test_resolve_openai_configs_wraps_single_api_server_config_in_list(): default_server_config = APIServerConfig( model_name="test-model", base_url="http://localhost:9001/v1", api_key="x", server_type="openai", ) merged_config = default_server_config.model_dump() server_configs = resolve_openai_configs( default_server_configs=default_server_config, openai_config_dict=merged_config, yaml_config={}, cli_passed_flags={}, logger=logging.getLogger("test"), ) assert isinstance(server_configs, list) assert len(server_configs) == 1 assert isinstance(server_configs[0], APIServerConfig) assert server_configs[0].base_url == "http://localhost:9001/v1" def test_resolve_vllm_configs_wraps_single_api_server_config_in_list(): default_server_config = APIServerConfig( model_name="test-model", base_url="http://localhost:9001/v1", api_key="x", server_type="vllm", ) merged_config = default_server_config.model_dump() server_configs = resolve_vllm_configs( default_server_configs=default_server_config, openai_config_dict=merged_config, yaml_config={}, cli_passed_flags={}, logger=logging.getLogger("test"), ) assert isinstance(server_configs, list) assert len(server_configs) == 1 assert isinstance(server_configs[0], APIServerConfig) assert server_configs[0].base_url == "http://localhost:9001/v1"