mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
234 lines
7.7 KiB
Python
234 lines
7.7 KiB
Python
"""Tests for StudentDistillationEnv distillation enrichment."""
|
|
|
|
from contextlib import asynccontextmanager
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from atroposlib.envs.student_distillation_env import StudentDistillationEnv
|
|
|
|
|
|
class _FakeManagedServer:
|
|
def __init__(self, prompt_tokens):
|
|
self.prompt_tokens = prompt_tokens
|
|
self.calls = 0
|
|
self.kwargs = []
|
|
|
|
async def get_logprobs(self, **kwargs):
|
|
self.calls += 1
|
|
self.kwargs.append(kwargs)
|
|
return {
|
|
"prompt_tokens": self.prompt_tokens,
|
|
"prompt_topk_token_ids": [[tok, tok + 1] for tok in self.prompt_tokens],
|
|
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in self.prompt_tokens],
|
|
}
|
|
|
|
|
|
class _FakeStudentServer:
|
|
def __init__(self, fail_on_call: int = -1, managed_prompt_tokens=None):
|
|
self.calls = 0
|
|
self.fail_on_call = fail_on_call
|
|
self.kwargs = []
|
|
self.managed_calls = 0
|
|
self.managed = _FakeManagedServer(
|
|
prompt_tokens=managed_prompt_tokens if managed_prompt_tokens is not None else []
|
|
)
|
|
|
|
async def get_logprobs(self, **kwargs):
|
|
self.calls += 1
|
|
self.kwargs.append(kwargs)
|
|
if self.calls == self.fail_on_call:
|
|
raise RuntimeError("student backend failure")
|
|
seq = kwargs["input_ids"]
|
|
return {
|
|
"prompt_tokens": seq,
|
|
"prompt_topk_token_ids": [[tok, tok + 1] for tok in seq],
|
|
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq],
|
|
}
|
|
|
|
@asynccontextmanager
|
|
async def managed_server(self, tokenizer=None):
|
|
self.managed_calls += 1
|
|
yield self.managed
|
|
|
|
|
|
class _ConcreteStudentEnv(StudentDistillationEnv):
|
|
async def get_next_item(self):
|
|
return None
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
return None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_attach_student_distillation_success():
|
|
env = object.__new__(_ConcreteStudentEnv)
|
|
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
|
|
env.server = _FakeStudentServer()
|
|
|
|
group = {
|
|
"tokens": [[1, 2, 3], [4, 5]],
|
|
"group_overrides": None,
|
|
"overrides": None,
|
|
"masks": [[-100, 2, 3], [-100, 5]],
|
|
"scores": [1.0, 0.0],
|
|
}
|
|
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
|
assert out["distill_token_ids"] is not None
|
|
assert out["distill_logprobs"] is not None
|
|
assert len(out["distill_token_ids"]) == 2
|
|
assert len(out["distill_token_ids"][0]) == 3
|
|
assert len(out["distill_logprobs"][1]) == 2
|
|
assert env.server.calls == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_attach_student_distillation_failure_drops_payload():
|
|
env = object.__new__(_ConcreteStudentEnv)
|
|
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
|
|
env.server = _FakeStudentServer(fail_on_call=2)
|
|
|
|
group = {
|
|
"tokens": [[1, 2, 3], [4, 5]],
|
|
"group_overrides": None,
|
|
"overrides": None,
|
|
"masks": [[-100, 2, 3], [-100, 5]],
|
|
"scores": [1.0, 0.0],
|
|
}
|
|
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
|
assert out["distill_token_ids"] is None
|
|
assert out["distill_logprobs"] is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_attach_student_distillation_negative_topk_skips_fetch():
|
|
env = object.__new__(_ConcreteStudentEnv)
|
|
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=-1)
|
|
env.server = _FakeStudentServer()
|
|
|
|
group = {
|
|
"tokens": [[1, 2, 3]],
|
|
"group_overrides": None,
|
|
"overrides": None,
|
|
"masks": [[-100, 2, 3]],
|
|
"scores": [1.0],
|
|
}
|
|
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
|
assert env.server.calls == 0
|
|
assert out["distill_token_ids"] is None
|
|
assert out["distill_logprobs"] is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_attach_student_distillation_zero_topk_passthrough():
|
|
env = object.__new__(_ConcreteStudentEnv)
|
|
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=0)
|
|
env.server = _FakeStudentServer()
|
|
|
|
group = {
|
|
"tokens": [[1, 2, 3]],
|
|
"group_overrides": None,
|
|
"overrides": None,
|
|
"masks": [[-100, 2, 3]],
|
|
"scores": [1.0],
|
|
}
|
|
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
|
assert env.server.calls == 1
|
|
assert out["distill_token_ids"] is not None
|
|
assert out["distill_logprobs"] is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_attach_student_distillation_group_override_topk_is_used():
|
|
env = object.__new__(_ConcreteStudentEnv)
|
|
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=0)
|
|
env.server = _FakeStudentServer()
|
|
|
|
group = {
|
|
"tokens": [[1, 2, 3], [4, 5]],
|
|
"group_overrides": {"student_top_k": 7},
|
|
"overrides": None,
|
|
"masks": [[-100, 2, 3], [-100, 5]],
|
|
"scores": [1.0, 0.0],
|
|
}
|
|
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
|
assert env.server.kwargs[0]["top_k"] == 7
|
|
assert env.server.kwargs[1]["top_k"] == 7
|
|
assert out["distill_token_ids"] is not None
|
|
assert out["distill_logprobs"] is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_attach_student_distillation_group_override_can_skip_fetch():
|
|
env = object.__new__(_ConcreteStudentEnv)
|
|
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
|
|
env.server = _FakeStudentServer()
|
|
|
|
group = {
|
|
"tokens": [[1, 2, 3]],
|
|
"group_overrides": {"skip_student_top_k": True},
|
|
"overrides": None,
|
|
"masks": [[-100, 2, 3]],
|
|
"scores": [1.0],
|
|
}
|
|
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
|
assert env.server.calls == 0
|
|
assert out["distill_token_ids"] is None
|
|
assert out["distill_logprobs"] is None
|
|
|
|
|
|
def test_get_student_logprob_overrides_merges_group_and_sequence():
|
|
env = object.__new__(_ConcreteStudentEnv)
|
|
group = {
|
|
"group_overrides": {
|
|
"student_logprob_kwargs": {"temperature": 0.0, "prompt": "group"}
|
|
},
|
|
"overrides": [
|
|
{"student_logprob_kwargs": {"prompt": "seq", "top_p": 1.0}},
|
|
],
|
|
}
|
|
|
|
out = StudentDistillationEnv._get_student_logprob_overrides(env, group, 0)
|
|
assert out == {"temperature": 0.0, "prompt": "seq", "top_p": 1.0}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_student_for_sequence_uses_managed_server_for_messages():
|
|
env = object.__new__(_ConcreteStudentEnv)
|
|
env.server = _FakeStudentServer(managed_prompt_tokens=[1, 2, 3])
|
|
env.tokenizer = object()
|
|
|
|
out = await StudentDistillationEnv._fetch_student_for_sequence(
|
|
env,
|
|
token_ids=[1, 2, 3],
|
|
top_k=2,
|
|
extra_kwargs={"messages": [{"role": "user", "content": "hi"}]},
|
|
)
|
|
|
|
assert env.server.calls == 0
|
|
assert env.server.managed_calls == 1
|
|
assert env.server.managed.calls == 1
|
|
assert "input_ids" not in env.server.managed.kwargs[0]
|
|
assert out[0] == [[1, 2], [2, 3], [3, 4]]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_student_for_sequence_mismatch_drops_payload():
|
|
env = object.__new__(_ConcreteStudentEnv)
|
|
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
|
|
env.server = _FakeStudentServer(managed_prompt_tokens=[9, 9])
|
|
env.tokenizer = object()
|
|
|
|
group = {
|
|
"tokens": [[1, 2, 3]],
|
|
"group_overrides": {
|
|
"student_logprob_kwargs": {"messages": [{"role": "user", "content": "hi"}]}
|
|
},
|
|
"overrides": None,
|
|
"masks": [[-100, 2, 3]],
|
|
"scores": [1.0],
|
|
}
|
|
|
|
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
|
assert out["distill_token_ids"] is None
|
|
assert out["distill_logprobs"] is None
|