mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
structural changes
This commit is contained in:
parent
12ba3cc3bd
commit
a171358f2e
6 changed files with 422 additions and 50 deletions
|
|
@ -4,6 +4,7 @@ from types import SimpleNamespace
|
|||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv
|
||||
|
||||
|
||||
|
|
@ -32,6 +33,20 @@ class _ConcreteTeacherEnv(TeacherDistillationEnv):
|
|||
return None
|
||||
|
||||
|
||||
class _DummyTokenizer:
|
||||
name_or_path = "student-model"
|
||||
|
||||
def get_vocab(self):
|
||||
return {"a": 1}
|
||||
|
||||
|
||||
class _CapturingServerManager:
|
||||
def __init__(self, configs, slurm=False, testing=False):
|
||||
self.configs = configs
|
||||
self.slurm = slurm
|
||||
self.testing = testing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_success():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
|
|
@ -105,6 +120,32 @@ async def test_attach_teacher_distillation_zero_topk_passthrough():
|
|||
assert out["distill_logprobs"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_group_override_topk_is_used():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
|
||||
|
||||
seen_topks = []
|
||||
|
||||
async def _fake_fetch(seq, top_k):
|
||||
seen_topks.append(top_k)
|
||||
return [[tok] for tok in seq], [[-0.1] for _ in seq]
|
||||
|
||||
env.teacher_server = object()
|
||||
env._fetch_teacher_for_sequence = _fake_fetch
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": {"teacher_top_k": 7},
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert seen_topks == [7, 7]
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
|
|
@ -149,3 +190,111 @@ def test_teacher_tokenizer_mismatch_raises(monkeypatch):
|
|||
env,
|
||||
teacher_tokenizer_name="teacher-model",
|
||||
)
|
||||
|
||||
|
||||
def test_init_requires_teacher_server_source(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
with pytest.raises(ValueError, match="teacher_enabled=True requires"):
|
||||
_ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
)
|
||||
|
||||
|
||||
def test_init_uses_explicit_teacher_server_configs(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
called = {}
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
def _fake_validate(self, teacher_tokenizer_name):
|
||||
called["teacher_tokenizer_name"] = teacher_tokenizer_name
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
|
||||
monkeypatch.setattr(
|
||||
_ConcreteTeacherEnv,
|
||||
"_validate_teacher_tokenizer_compatibility",
|
||||
_fake_validate,
|
||||
)
|
||||
|
||||
explicit_cfg = APIServerConfig(
|
||||
model_name="explicit-model",
|
||||
tokenizer_name="explicit-tokenizer",
|
||||
base_url="http://explicit/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
|
||||
env = _ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
teacher_server_configs=[explicit_cfg],
|
||||
)
|
||||
|
||||
assert isinstance(env.teacher_server, _CapturingServerManager)
|
||||
assert env.teacher_server.configs == [explicit_cfg]
|
||||
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
|
||||
|
||||
|
||||
def test_resolve_teacher_server_configs_returns_none_when_unset():
|
||||
assert (
|
||||
_ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=None,
|
||||
yaml_config={},
|
||||
cli_passed_flags={},
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_resolve(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return ["resolved"]
|
||||
|
||||
monkeypatch.setattr(module, "resolve_openai_configs", _fake_resolve)
|
||||
|
||||
default_cfg = APIServerConfig(
|
||||
model_name="teacher-model",
|
||||
base_url="http://teacher/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
|
||||
out = _ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=default_cfg,
|
||||
yaml_config={"teacher": {"tokenizer_name": "teacher-tokenizer"}},
|
||||
cli_passed_flags={"teacher.base_url": "http://override/v1"},
|
||||
)
|
||||
|
||||
assert out == ["resolved"]
|
||||
assert captured["openai_config_dict"]["base_url"] == "http://override/v1"
|
||||
assert captured["openai_config_dict"]["tokenizer_name"] == "teacher-tokenizer"
|
||||
assert captured["yaml_config"] == {
|
||||
"openai": {"tokenizer_name": "teacher-tokenizer"}
|
||||
}
|
||||
assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue