adding tests

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 17:23:40 -04:00
parent 6c564799bc
commit 1b8ff075c4
3 changed files with 84 additions and 3 deletions

View file

@ -87,6 +87,24 @@ async def test_attach_teacher_distillation_negative_topk_skips_fetch():
assert out["distill_logprobs"] is None
@pytest.mark.asyncio
async def test_attach_teacher_distillation_zero_topk_passthrough():
env = object.__new__(_ConcreteTeacherEnv)
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
env.teacher_server = _FakeTeacherServer()
group = {
"tokens": [[1, 2, 3]],
"group_overrides": None,
"masks": [[-100, 2, 3]],
"scores": [1.0],
}
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
assert env.teacher_server.calls == 1
assert out["distill_token_ids"] is not None
assert out["distill_logprobs"] is not None
@pytest.mark.asyncio
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
env = object.__new__(_ConcreteTeacherEnv)