mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
student
This commit is contained in:
parent
c421582b6f
commit
01e25707b0
3 changed files with 480 additions and 0 deletions
205
atroposlib/envs/student_distillation_env.py
Normal file
205
atroposlib/envs/student_distillation_env.py
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
"""
|
||||
Student/self-distillation environment layer.
|
||||
|
||||
This module adds prompt-logprob fetching from the student rollout server itself
|
||||
before the scored group is sent to the API.
|
||||
|
||||
By default, the student server scores the exact token IDs already present in the
|
||||
group. Override-driven prompt/message scoring is supported only when the
|
||||
resulting prompt tokenization matches the original token sequence exactly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StudentDistillationConfig(BaseEnvConfig):
|
||||
student_distill_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to fetch prompt logprobs from the student server itself.",
|
||||
)
|
||||
student_top_k: int = Field(
|
||||
default=0,
|
||||
ge=-1,
|
||||
description=(
|
||||
"Number of extra prompt logprobs to fetch beyond the selected token. "
|
||||
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
|
||||
"student distillation fetching."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class StudentDistillationEnv(BaseEnv, ABC):
|
||||
"""
|
||||
BaseEnv subclass that enriches scored groups with self-distillation arrays.
|
||||
|
||||
Distillation payload shape:
|
||||
- distill_token_ids: [sequence][position][k]
|
||||
- distill_logprobs: [sequence][position][k]
|
||||
"""
|
||||
|
||||
env_config_cls = StudentDistillationConfig
|
||||
|
||||
def _get_student_logprob_overrides(
|
||||
self, group: ScoredDataGroup, seq_idx: int
|
||||
) -> Dict[str, Any]:
|
||||
merged: Dict[str, Any] = {}
|
||||
|
||||
group_overrides = group.get("group_overrides") or {}
|
||||
group_kwargs = group_overrides.get(
|
||||
"student_logprob_kwargs",
|
||||
group_overrides.get("student_distill_kwargs"),
|
||||
)
|
||||
if isinstance(group_kwargs, dict):
|
||||
merged.update(group_kwargs)
|
||||
|
||||
overrides = group.get("overrides") or []
|
||||
if seq_idx < len(overrides):
|
||||
seq_overrides = overrides[seq_idx] or {}
|
||||
seq_kwargs = seq_overrides.get(
|
||||
"student_logprob_kwargs",
|
||||
seq_overrides.get("student_distill_kwargs"),
|
||||
)
|
||||
if isinstance(seq_kwargs, dict):
|
||||
merged.update(seq_kwargs)
|
||||
|
||||
return merged
|
||||
|
||||
async def _fetch_student_for_sequence(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
top_k: int,
|
||||
extra_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
request_kwargs: Dict[str, Any] = {
|
||||
"input_ids": token_ids,
|
||||
"top_k": top_k,
|
||||
"max_tokens": 1,
|
||||
"split": "train",
|
||||
}
|
||||
if extra_kwargs:
|
||||
request_kwargs.update(extra_kwargs)
|
||||
if extra_kwargs.get("messages") is not None or extra_kwargs.get(
|
||||
"prompt"
|
||||
) is not None:
|
||||
# Let message/prompt overrides drive tokenization instead of the
|
||||
# original input_ids when callers explicitly request it.
|
||||
request_kwargs.pop("input_ids", None)
|
||||
|
||||
if request_kwargs.get("messages") is not None:
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
payload = await managed.get_logprobs(**request_kwargs)
|
||||
else:
|
||||
payload = await self.server.get_logprobs(**request_kwargs)
|
||||
|
||||
if payload.get("prompt_tokens") != token_ids:
|
||||
raise ValueError(
|
||||
"Student distillation request did not align to the original token "
|
||||
"sequence. Override-driven prompt/messages are only supported when "
|
||||
"they reproduce the exact same prompt tokens."
|
||||
)
|
||||
|
||||
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
|
||||
|
||||
async def _attach_student_distillation(
|
||||
self, group: ScoredDataGroup
|
||||
) -> ScoredDataGroup:
|
||||
if not self.config.student_distill_enabled:
|
||||
return group
|
||||
|
||||
seqs = group.get("tokens", [])
|
||||
if not seqs:
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
group_overrides = group.get("group_overrides") or {}
|
||||
if group_overrides.get("skip_student_top_k", False):
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
top_k = int(group_overrides.get("student_top_k", self.config.student_top_k))
|
||||
if top_k <= -1:
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
tasks = [
|
||||
self._fetch_student_for_sequence(
|
||||
seq,
|
||||
top_k,
|
||||
self._get_student_logprob_overrides(group, idx),
|
||||
)
|
||||
for idx, seq in enumerate(seqs)
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
distill_token_ids: List[List[List[int]]] = []
|
||||
distill_logprobs: List[List[List[float]]] = []
|
||||
for idx, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(
|
||||
"Student logprob fetch failed for seq %s: %s. "
|
||||
"Dropping distill payload for this group.",
|
||||
idx,
|
||||
result,
|
||||
)
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
token_ids_k, logprobs_k = result
|
||||
if len(token_ids_k) != len(logprobs_k):
|
||||
logger.warning(
|
||||
"Student prompt-topk length mismatch for seq %s (%s != %s). "
|
||||
"Dropping distill payload for this group.",
|
||||
idx,
|
||||
len(token_ids_k),
|
||||
len(logprobs_k),
|
||||
)
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
distill_token_ids.append(token_ids_k)
|
||||
distill_logprobs.append(logprobs_k)
|
||||
|
||||
group["distill_token_ids"] = distill_token_ids
|
||||
group["distill_logprobs"] = distill_logprobs
|
||||
return group
|
||||
|
||||
async def handle_send_to_api(
|
||||
self,
|
||||
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||
item: Any = None,
|
||||
do_send_to_api: bool = True,
|
||||
abort_on_any_max_length_exceeded: bool = True,
|
||||
):
|
||||
groups = scored_data if isinstance(scored_data, list) else [scored_data]
|
||||
enriched_groups: List[ScoredDataGroup] = []
|
||||
for group in groups:
|
||||
if group is None:
|
||||
continue
|
||||
enriched_groups.append(await self._attach_student_distillation(group))
|
||||
|
||||
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||
if isinstance(scored_data, list):
|
||||
payload = enriched_groups
|
||||
else:
|
||||
payload = enriched_groups[0] if enriched_groups else scored_data
|
||||
|
||||
return await super().handle_send_to_api(
|
||||
payload,
|
||||
item=item,
|
||||
do_send_to_api=do_send_to_api,
|
||||
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
|
||||
)
|
||||
234
atroposlib/tests/test_student_distillation_env.py
Normal file
234
atroposlib/tests/test_student_distillation_env.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
"""Tests for StudentDistillationEnv distillation enrichment."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.student_distillation_env import StudentDistillationEnv
|
||||
|
||||
|
||||
class _FakeManagedServer:
|
||||
def __init__(self, prompt_tokens):
|
||||
self.prompt_tokens = prompt_tokens
|
||||
self.calls = 0
|
||||
self.kwargs = []
|
||||
|
||||
async def get_logprobs(self, **kwargs):
|
||||
self.calls += 1
|
||||
self.kwargs.append(kwargs)
|
||||
return {
|
||||
"prompt_tokens": self.prompt_tokens,
|
||||
"prompt_topk_token_ids": [[tok, tok + 1] for tok in self.prompt_tokens],
|
||||
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in self.prompt_tokens],
|
||||
}
|
||||
|
||||
|
||||
class _FakeStudentServer:
|
||||
def __init__(self, fail_on_call: int = -1, managed_prompt_tokens=None):
|
||||
self.calls = 0
|
||||
self.fail_on_call = fail_on_call
|
||||
self.kwargs = []
|
||||
self.managed_calls = 0
|
||||
self.managed = _FakeManagedServer(
|
||||
prompt_tokens=managed_prompt_tokens if managed_prompt_tokens is not None else []
|
||||
)
|
||||
|
||||
async def get_logprobs(self, **kwargs):
|
||||
self.calls += 1
|
||||
self.kwargs.append(kwargs)
|
||||
if self.calls == self.fail_on_call:
|
||||
raise RuntimeError("student backend failure")
|
||||
seq = kwargs["input_ids"]
|
||||
return {
|
||||
"prompt_tokens": seq,
|
||||
"prompt_topk_token_ids": [[tok, tok + 1] for tok in seq],
|
||||
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq],
|
||||
}
|
||||
|
||||
@asynccontextmanager
|
||||
async def managed_server(self, tokenizer=None):
|
||||
self.managed_calls += 1
|
||||
yield self.managed
|
||||
|
||||
|
||||
class _ConcreteStudentEnv(StudentDistillationEnv):
|
||||
async def get_next_item(self):
|
||||
return None
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_student_distillation_success():
|
||||
env = object.__new__(_ConcreteStudentEnv)
|
||||
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
|
||||
env.server = _FakeStudentServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": None,
|
||||
"overrides": None,
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
assert len(out["distill_token_ids"]) == 2
|
||||
assert len(out["distill_token_ids"][0]) == 3
|
||||
assert len(out["distill_logprobs"][1]) == 2
|
||||
assert env.server.calls == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_student_distillation_failure_drops_payload():
|
||||
env = object.__new__(_ConcreteStudentEnv)
|
||||
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
|
||||
env.server = _FakeStudentServer(fail_on_call=2)
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": None,
|
||||
"overrides": None,
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_student_distillation_negative_topk_skips_fetch():
|
||||
env = object.__new__(_ConcreteStudentEnv)
|
||||
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=-1)
|
||||
env.server = _FakeStudentServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": None,
|
||||
"overrides": None,
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
||||
assert env.server.calls == 0
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_student_distillation_zero_topk_passthrough():
|
||||
env = object.__new__(_ConcreteStudentEnv)
|
||||
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=0)
|
||||
env.server = _FakeStudentServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": None,
|
||||
"overrides": None,
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
||||
assert env.server.calls == 1
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_student_distillation_group_override_topk_is_used():
|
||||
env = object.__new__(_ConcreteStudentEnv)
|
||||
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=0)
|
||||
env.server = _FakeStudentServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": {"student_top_k": 7},
|
||||
"overrides": None,
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
||||
assert env.server.kwargs[0]["top_k"] == 7
|
||||
assert env.server.kwargs[1]["top_k"] == 7
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_student_distillation_group_override_can_skip_fetch():
|
||||
env = object.__new__(_ConcreteStudentEnv)
|
||||
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
|
||||
env.server = _FakeStudentServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": {"skip_student_top_k": True},
|
||||
"overrides": None,
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
||||
assert env.server.calls == 0
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
def test_get_student_logprob_overrides_merges_group_and_sequence():
|
||||
env = object.__new__(_ConcreteStudentEnv)
|
||||
group = {
|
||||
"group_overrides": {
|
||||
"student_logprob_kwargs": {"temperature": 0.0, "prompt": "group"}
|
||||
},
|
||||
"overrides": [
|
||||
{"student_logprob_kwargs": {"prompt": "seq", "top_p": 1.0}},
|
||||
],
|
||||
}
|
||||
|
||||
out = StudentDistillationEnv._get_student_logprob_overrides(env, group, 0)
|
||||
assert out == {"temperature": 0.0, "prompt": "seq", "top_p": 1.0}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_student_for_sequence_uses_managed_server_for_messages():
|
||||
env = object.__new__(_ConcreteStudentEnv)
|
||||
env.server = _FakeStudentServer(managed_prompt_tokens=[1, 2, 3])
|
||||
env.tokenizer = object()
|
||||
|
||||
out = await StudentDistillationEnv._fetch_student_for_sequence(
|
||||
env,
|
||||
token_ids=[1, 2, 3],
|
||||
top_k=2,
|
||||
extra_kwargs={"messages": [{"role": "user", "content": "hi"}]},
|
||||
)
|
||||
|
||||
assert env.server.calls == 0
|
||||
assert env.server.managed_calls == 1
|
||||
assert env.server.managed.calls == 1
|
||||
assert "input_ids" not in env.server.managed.kwargs[0]
|
||||
assert out[0] == [[1, 2], [2, 3], [3, 4]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_student_for_sequence_mismatch_drops_payload():
|
||||
env = object.__new__(_ConcreteStudentEnv)
|
||||
env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2)
|
||||
env.server = _FakeStudentServer(managed_prompt_tokens=[9, 9])
|
||||
env.tokenizer = object()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": {
|
||||
"student_logprob_kwargs": {"messages": [{"role": "user", "content": "hi"}]}
|
||||
},
|
||||
"overrides": None,
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
|
||||
out = await StudentDistillationEnv._attach_student_distillation(env, group)
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue