mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
205 lines
7.2 KiB
Python
205 lines
7.2 KiB
Python
"""
|
|
Student/self-distillation environment layer.
|
|
|
|
This module adds prompt-logprob fetching from the student rollout server itself
|
|
before the scored group is sent to the API.
|
|
|
|
By default, the student server scores the exact token IDs already present in the
|
|
group. Override-driven prompt/message scoring is supported only when the
|
|
resulting prompt tokenization matches the original token sequence exactly.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from abc import ABC
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
from pydantic import Field
|
|
|
|
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class StudentDistillationConfig(BaseEnvConfig):
|
|
student_distill_enabled: bool = Field(
|
|
default=False,
|
|
description="Whether to fetch prompt logprobs from the student server itself.",
|
|
)
|
|
student_top_k: int = Field(
|
|
default=0,
|
|
ge=-1,
|
|
description=(
|
|
"Number of extra prompt logprobs to fetch beyond the selected token. "
|
|
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
|
|
"student distillation fetching."
|
|
),
|
|
)
|
|
|
|
|
|
class StudentDistillationEnv(BaseEnv, ABC):
|
|
"""
|
|
BaseEnv subclass that enriches scored groups with self-distillation arrays.
|
|
|
|
Distillation payload shape:
|
|
- distill_token_ids: [sequence][position][k]
|
|
- distill_logprobs: [sequence][position][k]
|
|
"""
|
|
|
|
env_config_cls = StudentDistillationConfig
|
|
|
|
def _get_student_logprob_overrides(
|
|
self, group: ScoredDataGroup, seq_idx: int
|
|
) -> Dict[str, Any]:
|
|
merged: Dict[str, Any] = {}
|
|
|
|
group_overrides = group.get("group_overrides") or {}
|
|
group_kwargs = group_overrides.get(
|
|
"student_logprob_kwargs",
|
|
group_overrides.get("student_distill_kwargs"),
|
|
)
|
|
if isinstance(group_kwargs, dict):
|
|
merged.update(group_kwargs)
|
|
|
|
overrides = group.get("overrides") or []
|
|
if seq_idx < len(overrides):
|
|
seq_overrides = overrides[seq_idx] or {}
|
|
seq_kwargs = seq_overrides.get(
|
|
"student_logprob_kwargs",
|
|
seq_overrides.get("student_distill_kwargs"),
|
|
)
|
|
if isinstance(seq_kwargs, dict):
|
|
merged.update(seq_kwargs)
|
|
|
|
return merged
|
|
|
|
async def _fetch_student_for_sequence(
|
|
self,
|
|
token_ids: List[int],
|
|
top_k: int,
|
|
extra_kwargs: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[List[List[int]], List[List[float]]]:
|
|
request_kwargs: Dict[str, Any] = {
|
|
"input_ids": token_ids,
|
|
"top_k": top_k,
|
|
"max_tokens": 1,
|
|
"split": "train",
|
|
}
|
|
if extra_kwargs:
|
|
request_kwargs.update(extra_kwargs)
|
|
if extra_kwargs.get("messages") is not None or extra_kwargs.get(
|
|
"prompt"
|
|
) is not None:
|
|
# Let message/prompt overrides drive tokenization instead of the
|
|
# original input_ids when callers explicitly request it.
|
|
request_kwargs.pop("input_ids", None)
|
|
|
|
if request_kwargs.get("messages") is not None:
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
payload = await managed.get_logprobs(**request_kwargs)
|
|
else:
|
|
payload = await self.server.get_logprobs(**request_kwargs)
|
|
|
|
if payload.get("prompt_tokens") != token_ids:
|
|
raise ValueError(
|
|
"Student distillation request did not align to the original token "
|
|
"sequence. Override-driven prompt/messages are only supported when "
|
|
"they reproduce the exact same prompt tokens."
|
|
)
|
|
|
|
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
|
|
|
|
async def _attach_student_distillation(
|
|
self, group: ScoredDataGroup
|
|
) -> ScoredDataGroup:
|
|
if not self.config.student_distill_enabled:
|
|
return group
|
|
|
|
seqs = group.get("tokens", [])
|
|
if not seqs:
|
|
group["distill_token_ids"] = None
|
|
group["distill_logprobs"] = None
|
|
return group
|
|
|
|
group_overrides = group.get("group_overrides") or {}
|
|
if group_overrides.get("skip_student_top_k", False):
|
|
group["distill_token_ids"] = None
|
|
group["distill_logprobs"] = None
|
|
return group
|
|
|
|
top_k = int(group_overrides.get("student_top_k", self.config.student_top_k))
|
|
if top_k <= -1:
|
|
group["distill_token_ids"] = None
|
|
group["distill_logprobs"] = None
|
|
return group
|
|
|
|
tasks = [
|
|
self._fetch_student_for_sequence(
|
|
seq,
|
|
top_k,
|
|
self._get_student_logprob_overrides(group, idx),
|
|
)
|
|
for idx, seq in enumerate(seqs)
|
|
]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
distill_token_ids: List[List[List[int]]] = []
|
|
distill_logprobs: List[List[List[float]]] = []
|
|
for idx, result in enumerate(results):
|
|
if isinstance(result, Exception):
|
|
logger.warning(
|
|
"Student logprob fetch failed for seq %s: %s. "
|
|
"Dropping distill payload for this group.",
|
|
idx,
|
|
result,
|
|
)
|
|
group["distill_token_ids"] = None
|
|
group["distill_logprobs"] = None
|
|
return group
|
|
token_ids_k, logprobs_k = result
|
|
if len(token_ids_k) != len(logprobs_k):
|
|
logger.warning(
|
|
"Student prompt-topk length mismatch for seq %s (%s != %s). "
|
|
"Dropping distill payload for this group.",
|
|
idx,
|
|
len(token_ids_k),
|
|
len(logprobs_k),
|
|
)
|
|
group["distill_token_ids"] = None
|
|
group["distill_logprobs"] = None
|
|
return group
|
|
distill_token_ids.append(token_ids_k)
|
|
distill_logprobs.append(logprobs_k)
|
|
|
|
group["distill_token_ids"] = distill_token_ids
|
|
group["distill_logprobs"] = distill_logprobs
|
|
return group
|
|
|
|
async def handle_send_to_api(
|
|
self,
|
|
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
|
item: Any = None,
|
|
do_send_to_api: bool = True,
|
|
abort_on_any_max_length_exceeded: bool = True,
|
|
):
|
|
groups = scored_data if isinstance(scored_data, list) else [scored_data]
|
|
enriched_groups: List[ScoredDataGroup] = []
|
|
for group in groups:
|
|
if group is None:
|
|
continue
|
|
enriched_groups.append(await self._attach_student_distillation(group))
|
|
|
|
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
|
if isinstance(scored_data, list):
|
|
payload = enriched_groups
|
|
else:
|
|
payload = enriched_groups[0] if enriched_groups else scored_data
|
|
|
|
return await super().handle_send_to_api(
|
|
payload,
|
|
item=item,
|
|
do_send_to_api=do_send_to_api,
|
|
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
|
|
)
|