atropos/atroposlib/envs/student_distillation_env.py
Jai Suphavadeeprasit 01e25707b0 student
2026-03-17 12:02:48 -04:00

205 lines
7.2 KiB
Python

"""
Student/self-distillation environment layer.
This module adds prompt-logprob fetching from the student rollout server itself
before the scored group is sent to the API.
By default, the student server scores the exact token IDs already present in the
group. Override-driven prompt/message scoring is supported only when the
resulting prompt tokenization matches the original token sequence exactly.
"""
from __future__ import annotations
import asyncio
import logging
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import Field
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
logger = logging.getLogger(__name__)
class StudentDistillationConfig(BaseEnvConfig):
student_distill_enabled: bool = Field(
default=False,
description="Whether to fetch prompt logprobs from the student server itself.",
)
student_top_k: int = Field(
default=0,
ge=-1,
description=(
"Number of extra prompt logprobs to fetch beyond the selected token. "
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
"student distillation fetching."
),
)
class StudentDistillationEnv(BaseEnv, ABC):
"""
BaseEnv subclass that enriches scored groups with self-distillation arrays.
Distillation payload shape:
- distill_token_ids: [sequence][position][k]
- distill_logprobs: [sequence][position][k]
"""
env_config_cls = StudentDistillationConfig
def _get_student_logprob_overrides(
self, group: ScoredDataGroup, seq_idx: int
) -> Dict[str, Any]:
merged: Dict[str, Any] = {}
group_overrides = group.get("group_overrides") or {}
group_kwargs = group_overrides.get(
"student_logprob_kwargs",
group_overrides.get("student_distill_kwargs"),
)
if isinstance(group_kwargs, dict):
merged.update(group_kwargs)
overrides = group.get("overrides") or []
if seq_idx < len(overrides):
seq_overrides = overrides[seq_idx] or {}
seq_kwargs = seq_overrides.get(
"student_logprob_kwargs",
seq_overrides.get("student_distill_kwargs"),
)
if isinstance(seq_kwargs, dict):
merged.update(seq_kwargs)
return merged
async def _fetch_student_for_sequence(
self,
token_ids: List[int],
top_k: int,
extra_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[List[int]], List[List[float]]]:
request_kwargs: Dict[str, Any] = {
"input_ids": token_ids,
"top_k": top_k,
"max_tokens": 1,
"split": "train",
}
if extra_kwargs:
request_kwargs.update(extra_kwargs)
if extra_kwargs.get("messages") is not None or extra_kwargs.get(
"prompt"
) is not None:
# Let message/prompt overrides drive tokenization instead of the
# original input_ids when callers explicitly request it.
request_kwargs.pop("input_ids", None)
if request_kwargs.get("messages") is not None:
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
payload = await managed.get_logprobs(**request_kwargs)
else:
payload = await self.server.get_logprobs(**request_kwargs)
if payload.get("prompt_tokens") != token_ids:
raise ValueError(
"Student distillation request did not align to the original token "
"sequence. Override-driven prompt/messages are only supported when "
"they reproduce the exact same prompt tokens."
)
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
async def _attach_student_distillation(
self, group: ScoredDataGroup
) -> ScoredDataGroup:
if not self.config.student_distill_enabled:
return group
seqs = group.get("tokens", [])
if not seqs:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
group_overrides = group.get("group_overrides") or {}
if group_overrides.get("skip_student_top_k", False):
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
top_k = int(group_overrides.get("student_top_k", self.config.student_top_k))
if top_k <= -1:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
tasks = [
self._fetch_student_for_sequence(
seq,
top_k,
self._get_student_logprob_overrides(group, idx),
)
for idx, seq in enumerate(seqs)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
distill_token_ids: List[List[List[int]]] = []
distill_logprobs: List[List[List[float]]] = []
for idx, result in enumerate(results):
if isinstance(result, Exception):
logger.warning(
"Student logprob fetch failed for seq %s: %s. "
"Dropping distill payload for this group.",
idx,
result,
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
token_ids_k, logprobs_k = result
if len(token_ids_k) != len(logprobs_k):
logger.warning(
"Student prompt-topk length mismatch for seq %s (%s != %s). "
"Dropping distill payload for this group.",
idx,
len(token_ids_k),
len(logprobs_k),
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
distill_token_ids.append(token_ids_k)
distill_logprobs.append(logprobs_k)
group["distill_token_ids"] = distill_token_ids
group["distill_logprobs"] = distill_logprobs
return group
async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Any = None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
groups = scored_data if isinstance(scored_data, list) else [scored_data]
enriched_groups: List[ScoredDataGroup] = []
for group in groups:
if group is None:
continue
enriched_groups.append(await self._attach_student_distillation(group))
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
if isinstance(scored_data, list):
payload = enriched_groups
else:
payload = enriched_groups[0] if enriched_groups else scored_data
return await super().handle_send_to_api(
payload,
item=item,
do_send_to_api=do_send_to_api,
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
)