mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
address problems
This commit is contained in:
parent
322e7e6623
commit
a8cdb53a4d
6 changed files with 99 additions and 24 deletions
|
|
@ -69,6 +69,42 @@ async def test_attach_teacher_distillation_failure_drops_payload():
|
|||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_negative_topk_skips_fetch():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=-1)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert env.teacher_server.calls == 0
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": {"skip_teacher_top_k": True},
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert env.teacher_server.calls == 0
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
def test_teacher_tokenizer_mismatch_raises(monkeypatch):
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue