mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
adding tests
This commit is contained in:
parent
6c564799bc
commit
1b8ff075c4
3 changed files with 84 additions and 3 deletions
|
|
@ -401,6 +401,19 @@ def resolve_openai_configs(
|
||||||
raise FailedExecutionException(
|
raise FailedExecutionException(
|
||||||
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
|
f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}"
|
||||||
) from e
|
) from e
|
||||||
|
elif isinstance(default_server_configs, APIServerConfig):
|
||||||
|
# Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline
|
||||||
|
logger.info(
|
||||||
|
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
final_openai_config = APIServerConfig(**openai_config_dict)
|
||||||
|
except Exception as e:
|
||||||
|
raise FailedExecutionException(
|
||||||
|
f"Error creating final OpenAI configuration from merged settings: {e}\n"
|
||||||
|
f"Merged Dict: {openai_config_dict}"
|
||||||
|
) from e
|
||||||
|
server_configs = [final_openai_config]
|
||||||
elif isinstance(default_server_configs, ServerBaseline):
|
elif isinstance(default_server_configs, ServerBaseline):
|
||||||
logger.info("Using ServerBaseline configuration.")
|
logger.info("Using ServerBaseline configuration.")
|
||||||
server_configs = default_server_configs
|
server_configs = default_server_configs
|
||||||
|
|
@ -419,9 +432,7 @@ def resolve_openai_configs(
|
||||||
f"Merged Dict: {openai_config_dict}"
|
f"Merged Dict: {openai_config_dict}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
if isinstance(default_server_configs, APIServerConfig):
|
if isinstance(default_server_configs, list):
|
||||||
server_configs = [final_openai_config]
|
|
||||||
elif isinstance(default_server_configs, list):
|
|
||||||
server_configs = [final_openai_config]
|
server_configs = [final_openai_config]
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,13 @@
|
||||||
"""Tests for get_logprobs wrappers and server-manager routing."""
|
"""Tests for get_logprobs wrappers and server-manager routing."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
|
||||||
|
from atroposlib.envs.server_handling.vllm_server import (
|
||||||
|
resolve_openai_configs as resolve_vllm_configs,
|
||||||
|
)
|
||||||
from atroposlib.envs.server_handling.server_baseline import (
|
from atroposlib.envs.server_handling.server_baseline import (
|
||||||
APIServer,
|
APIServer,
|
||||||
APIServerConfig,
|
APIServerConfig,
|
||||||
|
|
@ -103,3 +109,49 @@ async def test_server_manager_get_logprobs_routes_to_most_available_server():
|
||||||
out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval")
|
out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval")
|
||||||
assert out_eval["server"] == "s1"
|
assert out_eval["server"] == "s1"
|
||||||
assert s1.calls == 1
|
assert s1.calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_openai_configs_wraps_single_api_server_config_in_list():
|
||||||
|
default_server_config = APIServerConfig(
|
||||||
|
model_name="test-model",
|
||||||
|
base_url="http://localhost:9001/v1",
|
||||||
|
api_key="x",
|
||||||
|
server_type="openai",
|
||||||
|
)
|
||||||
|
merged_config = default_server_config.model_dump()
|
||||||
|
|
||||||
|
server_configs = resolve_openai_configs(
|
||||||
|
default_server_configs=default_server_config,
|
||||||
|
openai_config_dict=merged_config,
|
||||||
|
yaml_config={},
|
||||||
|
cli_passed_flags={},
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(server_configs, list)
|
||||||
|
assert len(server_configs) == 1
|
||||||
|
assert isinstance(server_configs[0], APIServerConfig)
|
||||||
|
assert server_configs[0].base_url == "http://localhost:9001/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_vllm_configs_wraps_single_api_server_config_in_list():
|
||||||
|
default_server_config = APIServerConfig(
|
||||||
|
model_name="test-model",
|
||||||
|
base_url="http://localhost:9001/v1",
|
||||||
|
api_key="x",
|
||||||
|
server_type="vllm",
|
||||||
|
)
|
||||||
|
merged_config = default_server_config.model_dump()
|
||||||
|
|
||||||
|
server_configs = resolve_vllm_configs(
|
||||||
|
default_server_configs=default_server_config,
|
||||||
|
openai_config_dict=merged_config,
|
||||||
|
yaml_config={},
|
||||||
|
cli_passed_flags={},
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(server_configs, list)
|
||||||
|
assert len(server_configs) == 1
|
||||||
|
assert isinstance(server_configs[0], APIServerConfig)
|
||||||
|
assert server_configs[0].base_url == "http://localhost:9001/v1"
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,24 @@ async def test_attach_teacher_distillation_negative_topk_skips_fetch():
|
||||||
assert out["distill_logprobs"] is None
|
assert out["distill_logprobs"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_attach_teacher_distillation_zero_topk_passthrough():
|
||||||
|
env = object.__new__(_ConcreteTeacherEnv)
|
||||||
|
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
|
||||||
|
env.teacher_server = _FakeTeacherServer()
|
||||||
|
|
||||||
|
group = {
|
||||||
|
"tokens": [[1, 2, 3]],
|
||||||
|
"group_overrides": None,
|
||||||
|
"masks": [[-100, 2, 3]],
|
||||||
|
"scores": [1.0],
|
||||||
|
}
|
||||||
|
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||||
|
assert env.teacher_server.calls == 1
|
||||||
|
assert out["distill_token_ids"] is not None
|
||||||
|
assert out["distill_logprobs"] is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
|
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
|
||||||
env = object.__new__(_ConcreteTeacherEnv)
|
env = object.__new__(_ConcreteTeacherEnv)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue