adding tests

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 17:23:40 -04:00
parent 6c564799bc
commit 1b8ff075c4
3 changed files with 84 additions and 3 deletions

View file

@ -401,6 +401,19 @@ def resolve_openai_configs(
raise FailedExecutionException(
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
) from e
elif isinstance(default_server_configs, APIServerConfig):
# Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline
logger.info(
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
)
try:
final_openai_config = APIServerConfig(**openai_config_dict)
except Exception as e:
raise FailedExecutionException(
f"Error creating final OpenAI configuration from merged settings: {e}\n"
f"Merged Dict: {openai_config_dict}"
) from e
server_configs = [final_openai_config]
elif isinstance(default_server_configs, ServerBaseline):
logger.info("Using ServerBaseline configuration.")
server_configs = default_server_configs
@ -419,9 +432,7 @@ def resolve_openai_configs(
f"Merged Dict: {openai_config_dict}"
) from e
if isinstance(default_server_configs, APIServerConfig):
server_configs = [final_openai_config]
elif isinstance(default_server_configs, list):
if isinstance(default_server_configs, list):
server_configs = [final_openai_config]
else:
logger.warning(

View file

@ -1,7 +1,13 @@
"""Tests for get_logprobs wrappers and server-manager routing."""
import logging
import pytest
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
from atroposlib.envs.server_handling.vllm_server import (
resolve_openai_configs as resolve_vllm_configs,
)
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
APIServerConfig,
@ -103,3 +109,49 @@ async def test_server_manager_get_logprobs_routes_to_most_available_server():
out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval")
assert out_eval["server"] == "s1"
assert s1.calls == 1
def test_resolve_openai_configs_wraps_single_api_server_config_in_list():
default_server_config = APIServerConfig(
model_name="test-model",
base_url="http://localhost:9001/v1",
api_key="x",
server_type="openai",
)
merged_config = default_server_config.model_dump()
server_configs = resolve_openai_configs(
default_server_configs=default_server_config,
openai_config_dict=merged_config,
yaml_config={},
cli_passed_flags={},
logger=logging.getLogger("test"),
)
assert isinstance(server_configs, list)
assert len(server_configs) == 1
assert isinstance(server_configs[0], APIServerConfig)
assert server_configs[0].base_url == "http://localhost:9001/v1"
def test_resolve_vllm_configs_wraps_single_api_server_config_in_list():
default_server_config = APIServerConfig(
model_name="test-model",
base_url="http://localhost:9001/v1",
api_key="x",
server_type="vllm",
)
merged_config = default_server_config.model_dump()
server_configs = resolve_vllm_configs(
default_server_configs=default_server_config,
openai_config_dict=merged_config,
yaml_config={},
cli_passed_flags={},
logger=logging.getLogger("test"),
)
assert isinstance(server_configs, list)
assert len(server_configs) == 1
assert isinstance(server_configs[0], APIServerConfig)
assert server_configs[0].base_url == "http://localhost:9001/v1"

View file

@ -87,6 +87,24 @@ async def test_attach_teacher_distillation_negative_topk_skips_fetch():
assert out["distill_logprobs"] is None
@pytest.mark.asyncio
async def test_attach_teacher_distillation_zero_topk_passthrough():
env = object.__new__(_ConcreteTeacherEnv)
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
env.teacher_server = _FakeTeacherServer()
group = {
"tokens": [[1, 2, 3]],
"group_overrides": None,
"masks": [[-100, 2, 3]],
"scores": [1.0],
}
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
assert env.teacher_server.calls == 1
assert out["distill_token_ids"] is not None
assert out["distill_logprobs"] is not None
@pytest.mark.asyncio
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
env = object.__new__(_ConcreteTeacherEnv)