This commit is contained in:
Jai Suphavadeeprasit 2026-03-17 12:02:48 -04:00
parent c421582b6f
commit 01e25707b0
3 changed files with 480 additions and 0 deletions

View file

@ -0,0 +1,205 @@
"""
Student/self-distillation environment layer.
This module adds prompt-logprob fetching from the student rollout server itself
before the scored group is sent to the API.
By default, the student server scores the exact token IDs already present in the
group. Override-driven prompt/message scoring is supported only when the
resulting prompt tokenization matches the original token sequence exactly.
"""
from __future__ import annotations
import asyncio
import logging
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import Field
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
logger = logging.getLogger(__name__)
class StudentDistillationConfig(BaseEnvConfig):
student_distill_enabled: bool = Field(
default=False,
description="Whether to fetch prompt logprobs from the student server itself.",
)
student_top_k: int = Field(
default=0,
ge=-1,
description=(
"Number of extra prompt logprobs to fetch beyond the selected token. "
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
"student distillation fetching."
),
)
class StudentDistillationEnv(BaseEnv, ABC):
"""
BaseEnv subclass that enriches scored groups with self-distillation arrays.
Distillation payload shape:
- distill_token_ids: [sequence][position][k]
- distill_logprobs: [sequence][position][k]
"""
env_config_cls = StudentDistillationConfig
def _get_student_logprob_overrides(
self, group: ScoredDataGroup, seq_idx: int
) -> Dict[str, Any]:
merged: Dict[str, Any] = {}
group_overrides = group.get("group_overrides") or {}
group_kwargs = group_overrides.get(
"student_logprob_kwargs",
group_overrides.get("student_distill_kwargs"),
)
if isinstance(group_kwargs, dict):
merged.update(group_kwargs)
overrides = group.get("overrides") or []
if seq_idx < len(overrides):
seq_overrides = overrides[seq_idx] or {}
seq_kwargs = seq_overrides.get(
"student_logprob_kwargs",
seq_overrides.get("student_distill_kwargs"),
)
if isinstance(seq_kwargs, dict):
merged.update(seq_kwargs)
return merged
async def _fetch_student_for_sequence(
self,
token_ids: List[int],
top_k: int,
extra_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[List[int]], List[List[float]]]:
request_kwargs: Dict[str, Any] = {
"input_ids": token_ids,
"top_k": top_k,
"max_tokens": 1,
"split": "train",
}
if extra_kwargs:
request_kwargs.update(extra_kwargs)
if extra_kwargs.get("messages") is not None or extra_kwargs.get(
"prompt"
) is not None:
# Let message/prompt overrides drive tokenization instead of the
# original input_ids when callers explicitly request it.
request_kwargs.pop("input_ids", None)
if request_kwargs.get("messages") is not None:
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
payload = await managed.get_logprobs(**request_kwargs)
else:
payload = await self.server.get_logprobs(**request_kwargs)
if payload.get("prompt_tokens") != token_ids:
raise ValueError(
"Student distillation request did not align to the original token "
"sequence. Override-driven prompt/messages are only supported when "
"they reproduce the exact same prompt tokens."
)
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
async def _attach_student_distillation(
self, group: ScoredDataGroup
) -> ScoredDataGroup:
if not self.config.student_distill_enabled:
return group
seqs = group.get("tokens", [])
if not seqs:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
group_overrides = group.get("group_overrides") or {}
if group_overrides.get("skip_student_top_k", False):
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
top_k = int(group_overrides.get("student_top_k", self.config.student_top_k))
if top_k <= -1:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
tasks = [
self._fetch_student_for_sequence(
seq,
top_k,
self._get_student_logprob_overrides(group, idx),
)
for idx, seq in enumerate(seqs)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
distill_token_ids: List[List[List[int]]] = []
distill_logprobs: List[List[List[float]]] = []
for idx, result in enumerate(results):
if isinstance(result, Exception):
logger.warning(
"Student logprob fetch failed for seq %s: %s. "
"Dropping distill payload for this group.",
idx,
result,
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
token_ids_k, logprobs_k = result
if len(token_ids_k) != len(logprobs_k):
logger.warning(
"Student prompt-topk length mismatch for seq %s (%s != %s). "
"Dropping distill payload for this group.",
idx,
len(token_ids_k),
len(logprobs_k),
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
distill_token_ids.append(token_ids_k)
distill_logprobs.append(logprobs_k)
group["distill_token_ids"] = distill_token_ids
group["distill_logprobs"] = distill_logprobs
return group
async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Any = None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
groups = scored_data if isinstance(scored_data, list) else [scored_data]
enriched_groups: List[ScoredDataGroup] = []
for group in groups:
if group is None:
continue
enriched_groups.append(await self._attach_student_distillation(group))
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
if isinstance(scored_data, list):
payload = enriched_groups
else:
payload = enriched_groups[0] if enriched_groups else scored_data
return await super().handle_send_to_api(
payload,
item=item,
do_send_to_api=do_send_to_api,
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
)

View file

@ -0,0 +1,234 @@
"""Tests for StudentDistillationEnv distillation enrichment."""
from contextlib import asynccontextmanager
from types import SimpleNamespace
import pytest
from atroposlib.envs.student_distillation_env import StudentDistillationEnv
class _FakeManagedServer:
def __init__(self, prompt_tokens):
self.prompt_tokens = prompt_tokens
self.calls = 0
self.kwargs = []
async def get_logprobs(self, **kwargs):
self.calls += 1
self.kwargs.append(kwargs)
return {
"prompt_tokens": self.prompt_tokens,
"prompt_topk_token_ids": [[tok, tok + 1] for tok in self.prompt_tokens],
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in self.prompt_tokens],
}
class _FakeStudentServer:
def __init__(self, fail_on_call: int = -1, managed_prompt_tokens=None):
self.calls = 0
self.fail_on_call = fail_on_call
self.kwargs = []
self.managed_calls = 0
self.managed = _FakeManagedServer(
prompt_tokens=managed_prompt_tokens if managed_prompt_tokens is not None else []
)
async def get_logprobs(self, **kwargs):
self.calls += 1
self.kwargs.append(kwargs)
if self.calls == self.fail_on_call:
raise RuntimeError("student backend failure")
seq = kwargs["input_ids"]
return {
"prompt_tokens": seq,
"prompt_topk_token_ids": [[tok, tok + 1] for tok in seq],
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq],
}
@asynccontextmanager
async def managed_server(self, tokenizer=None):
self.managed_calls += 1
yield self.managed
class _ConcreteStudentEnv(StudentDistillationEnv):
async def get_next_item(self):
return None
async def evaluate(self, *args, **kwargs):
return None
@pytest.mark.asyncio
async def test_attach_student_distillation_success():
env = object.__new__(_ConcreteStudentEnv)
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
env.server = _FakeStudentServer()
group = {
"tokens": [[1, 2, 3], [4, 5]],
"group_overrides": None,
"overrides": None,
"masks": [[-100, 2, 3], [-100, 5]],
"scores": [1.0, 0.0],
}
out = await StudentDistillationEnv._attach_student_distillation(env, group)
assert out["distill_token_ids"] is not None
assert out["distill_logprobs"] is not None
assert len(out["distill_token_ids"]) == 2
assert len(out["distill_token_ids"][0]) == 3
assert len(out["distill_logprobs"][1]) == 2
assert env.server.calls == 2
@pytest.mark.asyncio
async def test_attach_student_distillation_failure_drops_payload():
env = object.__new__(_ConcreteStudentEnv)
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
env.server = _FakeStudentServer(fail_on_call=2)
group = {
"tokens": [[1, 2, 3], [4, 5]],
"group_overrides": None,
"overrides": None,
"masks": [[-100, 2, 3], [-100, 5]],
"scores": [1.0, 0.0],
}
out = await StudentDistillationEnv._attach_student_distillation(env, group)
assert out["distill_token_ids"] is None
assert out["distill_logprobs"] is None
@pytest.mark.asyncio
async def test_attach_student_distillation_negative_topk_skips_fetch():
env = object.__new__(_ConcreteStudentEnv)
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=-1)
env.server = _FakeStudentServer()
group = {
"tokens": [[1, 2, 3]],
"group_overrides": None,
"overrides": None,
"masks": [[-100, 2, 3]],
"scores": [1.0],
}
out = await StudentDistillationEnv._attach_student_distillation(env, group)
assert env.server.calls == 0
assert out["distill_token_ids"] is None
assert out["distill_logprobs"] is None
@pytest.mark.asyncio
async def test_attach_student_distillation_zero_topk_passthrough():
env = object.__new__(_ConcreteStudentEnv)
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=0)
env.server = _FakeStudentServer()
group = {
"tokens": [[1, 2, 3]],
"group_overrides": None,
"overrides": None,
"masks": [[-100, 2, 3]],
"scores": [1.0],
}
out = await StudentDistillationEnv._attach_student_distillation(env, group)
assert env.server.calls == 1
assert out["distill_token_ids"] is not None
assert out["distill_logprobs"] is not None
@pytest.mark.asyncio
async def test_attach_student_distillation_group_override_topk_is_used():
env = object.__new__(_ConcreteStudentEnv)
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=0)
env.server = _FakeStudentServer()
group = {
"tokens": [[1, 2, 3], [4, 5]],
"group_overrides": {"student_top_k": 7},
"overrides": None,
"masks": [[-100, 2, 3], [-100, 5]],
"scores": [1.0, 0.0],
}
out = await StudentDistillationEnv._attach_student_distillation(env, group)
assert env.server.kwargs[0]["top_k"] == 7
assert env.server.kwargs[1]["top_k"] == 7
assert out["distill_token_ids"] is not None
assert out["distill_logprobs"] is not None
@pytest.mark.asyncio
async def test_attach_student_distillation_group_override_can_skip_fetch():
env = object.__new__(_ConcreteStudentEnv)
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
env.server = _FakeStudentServer()
group = {
"tokens": [[1, 2, 3]],
"group_overrides": {"skip_student_top_k": True},
"overrides": None,
"masks": [[-100, 2, 3]],
"scores": [1.0],
}
out = await StudentDistillationEnv._attach_student_distillation(env, group)
assert env.server.calls == 0
assert out["distill_token_ids"] is None
assert out["distill_logprobs"] is None
def test_get_student_logprob_overrides_merges_group_and_sequence():
env = object.__new__(_ConcreteStudentEnv)
group = {
"group_overrides": {
"student_logprob_kwargs": {"temperature": 0.0, "prompt": "group"}
},
"overrides": [
{"student_logprob_kwargs": {"prompt": "seq", "top_p": 1.0}},
],
}
out = StudentDistillationEnv._get_student_logprob_overrides(env, group, 0)
assert out == {"temperature": 0.0, "prompt": "seq", "top_p": 1.0}
@pytest.mark.asyncio
async def test_fetch_student_for_sequence_uses_managed_server_for_messages():
env = object.__new__(_ConcreteStudentEnv)
env.server = _FakeStudentServer(managed_prompt_tokens=[1, 2, 3])
env.tokenizer = object()
out = await StudentDistillationEnv._fetch_student_for_sequence(
env,
token_ids=[1, 2, 3],
top_k=2,
extra_kwargs={"messages": [{"role": "user", "content": "hi"}]},
)
assert env.server.calls == 0
assert env.server.managed_calls == 1
assert env.server.managed.calls == 1
assert "input_ids" not in env.server.managed.kwargs[0]
assert out[0] == [[1, 2], [2, 3], [3, 4]]
@pytest.mark.asyncio
async def test_fetch_student_for_sequence_mismatch_drops_payload():
env = object.__new__(_ConcreteStudentEnv)
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
env.server = _FakeStudentServer(managed_prompt_tokens=[9, 9])
env.tokenizer = object()
group = {
"tokens": [[1, 2, 3]],
"group_overrides": {
"student_logprob_kwargs": {"messages": [{"role": "user", "content": "hi"}]}
},
"overrides": None,
"masks": [[-100, 2, 3]],
"scores": [1.0],
}
out = await StudentDistillationEnv._attach_student_distillation(env, group)
assert out["distill_token_ids"] is None
assert out["distill_logprobs"] is None