mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
teacher env init
This commit is contained in:
parent
c421582b6f
commit
f44eb810bf
4 changed files with 258 additions and 27 deletions
|
|
@ -319,30 +319,6 @@ async def test_get_logprobs_messages_passthrough(mock_server):
|
|||
assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_input_ids_only_passthrough(mock_server):
|
||||
"""ManagedServer.get_logprobs supports input_ids-only without requiring prompt."""
|
||||
managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer)
|
||||
input_ids = [10, 20, 30]
|
||||
|
||||
async def _mock_get_logprobs(**kwargs):
|
||||
assert "input_ids" in kwargs
|
||||
assert kwargs["input_ids"] == input_ids
|
||||
assert kwargs.get("prompt") is None
|
||||
return {
|
||||
"prompt_tokens": input_ids,
|
||||
"prompt_topk_token_ids": [[t] for t in input_ids],
|
||||
"prompt_topk_logprobs": [[-0.1] for _ in input_ids],
|
||||
}
|
||||
|
||||
mock_server.get_logprobs = _mock_get_logprobs
|
||||
payload = await managed.get_logprobs(input_ids=input_ids, top_k=1)
|
||||
|
||||
assert payload["prompt_tokens"] == input_ids
|
||||
assert payload["prompt_topk_token_ids"] == [[10], [20], [30]]
|
||||
assert payload["prompt_topk_logprobs"] == [[-0.1], [-0.1], [-0.1]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server):
|
||||
"""ManagedServer.get_logprobs requires backend get_logprobs in strict mode."""
|
||||
|
|
|
|||
|
|
@ -41,9 +41,7 @@ class _FakeAPIServer(APIServer):
|
|||
|
||||
|
||||
class _FakeRoutedServer:
|
||||
def __init__(
|
||||
self, name: str, train_slots: int, eval_slots: int, healthy: bool = True
|
||||
):
|
||||
def __init__(self, name: str, train_slots: int, eval_slots: int, healthy: bool = True):
|
||||
self.name = name
|
||||
self.server_healthy = healthy
|
||||
self.sem = AsyncSemWithAdaptiveWeight(4)
|
||||
|
|
|
|||
69
atroposlib/tests/test_teacher_distillation_env.py
Normal file
69
atroposlib/tests/test_teacher_distillation_env.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""Tests for TeacherDistillationEnv distillation enrichment."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv
|
||||
|
||||
|
||||
class _FakeTeacherServer:
|
||||
def __init__(self, fail_on_call: int = -1):
|
||||
self.calls = 0
|
||||
self.fail_on_call = fail_on_call
|
||||
|
||||
async def get_logprobs(self, **kwargs):
|
||||
self.calls += 1
|
||||
if self.calls == self.fail_on_call:
|
||||
raise RuntimeError("teacher backend failure")
|
||||
seq = kwargs["input_ids"]
|
||||
return {
|
||||
"prompt_tokens": seq,
|
||||
"prompt_topk_token_ids": [[tok, tok + 1] for tok in seq],
|
||||
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq],
|
||||
}
|
||||
|
||||
|
||||
class _ConcreteTeacherEnv(TeacherDistillationEnv):
|
||||
async def get_next_item(self):
|
||||
return None
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_success():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
assert len(out["distill_token_ids"]) == 2
|
||||
assert len(out["distill_token_ids"][0]) == 3
|
||||
assert len(out["distill_logprobs"][1]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_failure_drops_payload():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||
env.teacher_server = _FakeTeacherServer(fail_on_call=2)
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue