atropos/atroposlib/envs/teacher_distillation_env.py
Jai Suphavadeeprasit a8cdb53a4d address problems
2026-03-13 16:12:05 -04:00

246 lines
9.1 KiB
Python

"""
Teacher distillation environment layer.
This module adds teacher prompt-logprob fetching on top of BaseEnv without
modifying BaseEnv transport behavior.
This implementation supports same-tokenizer distillation only. The teacher and
student must share the same tokenizer vocabulary so the student's token IDs can
be forwarded directly to the teacher and the returned teacher top-k token IDs
can be looked up directly in the student's logits.
"""
from __future__ import annotations
import asyncio
import logging
from abc import ABC
from typing import Any, List, Optional, Tuple, Union
from pydantic import Field
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
from .server_handling.server_baseline import APIServerConfig, ServerBaseline
from .server_handling.server_manager import ServerManager
logger = logging.getLogger(__name__)
class TeacherDistillationConfig(BaseEnvConfig):
teacher_enabled: bool = Field(
default=False,
description="Whether to fetch teacher prompt logprobs for distillation.",
)
teacher_server: Optional[APIServerConfig] = Field(
default=None,
description="Fallback teacher server configuration when not provided at init.",
)
teacher_top_k: int = Field(
default=0,
ge=-1,
description=(
"Number of extra prompt logprobs to fetch beyond the selected token. "
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
"teacher fetching."
),
)
class TeacherDistillationEnv(BaseEnv, ABC):
"""
BaseEnv subclass that enriches scored groups with teacher distillation arrays.
Distillation payload shape:
- distill_token_ids: [sequence][position][k] (student vocab IDs)
- distill_logprobs: [sequence][position][k]
"""
env_config_cls = TeacherDistillationConfig
def __init__(
self,
config: TeacherDistillationConfig,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
teacher_server_configs: Optional[
Union[ServerBaseline, List[APIServerConfig]]
] = None,
slurm: bool = False,
testing: bool = False,
):
super().__init__(config, server_configs, slurm=slurm, testing=testing)
self.teacher_server: Optional[ServerManager] = None
if config.teacher_enabled:
teacher_config_source = teacher_server_configs
if teacher_config_source is None and config.teacher_server is not None:
teacher_config_source = [
config.teacher_server.model_copy(
update={
"tokenizer_name": (
config.teacher_server.model_name
if config.teacher_server.tokenizer_name in ("", "none")
else config.teacher_server.tokenizer_name
),
"timeout": 1200,
}
)
]
if teacher_config_source is None:
raise ValueError(
"teacher_enabled=True requires teacher_server_configs at init "
"or a fallback teacher_server config."
)
self.teacher_server = ServerManager(
teacher_config_source,
slurm=False,
testing=False,
)
if isinstance(teacher_config_source, list):
teacher_cfg = teacher_config_source[0]
else:
teacher_cfg = teacher_config_source
teacher_tokenizer_name = (
teacher_cfg.model_name
if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none")
else teacher_cfg.tokenizer_name
)
self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name)
# ------------------------------------------------------------------
# Core fetch
# ------------------------------------------------------------------
def _validate_teacher_tokenizer_compatibility(
self, teacher_tokenizer_name: str
) -> None:
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
if student_tok_name == teacher_tokenizer_name:
return
try:
from transformers import AutoTokenizer
teacher_tokenizer = AutoTokenizer.from_pretrained(
teacher_tokenizer_name, use_fast=True
)
except Exception as exc:
raise ValueError(
"Cross-tokenizer distillation is not supported in this PR, and the "
f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to "
f"verify compatibility: {exc}"
) from exc
student_vocab = self.tokenizer.get_vocab()
teacher_vocab = teacher_tokenizer.get_vocab()
if student_vocab != teacher_vocab:
raise ValueError(
"Cross-tokenizer distillation is not supported in this PR. "
f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' "
f"and teacher tokenizer '{teacher_tokenizer_name}' do not match."
)
async def _fetch_teacher_for_sequence(
self, token_ids: List[int], top_k: int
) -> Tuple[List[List[int]], List[List[float]]]:
assert self.teacher_server is not None
payload = await self.teacher_server.get_logprobs(
input_ids=token_ids,
top_k=top_k,
max_tokens=1,
split="train",
)
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
# ------------------------------------------------------------------
# Group enrichment
# ------------------------------------------------------------------
async def _attach_teacher_distillation(
self, group: ScoredDataGroup
) -> ScoredDataGroup:
if not self.config.teacher_enabled or self.teacher_server is None:
return group
seqs = group.get("tokens", [])
if not seqs:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
group_overrides = group.get("group_overrides") or {}
if group_overrides.get("skip_teacher_top_k", False):
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
top_k = int(
group_overrides.get("teacher_top_k", self.config.teacher_top_k)
)
if top_k <= -1:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs]
results = await asyncio.gather(*tasks, return_exceptions=True)
distill_token_ids: List[List[List[int]]] = []
distill_logprobs: List[List[List[float]]] = []
for idx, result in enumerate(results):
if isinstance(result, Exception):
logger.warning(
"Teacher logprob fetch failed for seq %s: %s. "
"Dropping distill payload for this group.",
idx,
result,
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
token_ids_k, logprobs_k = result
if len(token_ids_k) != len(logprobs_k):
logger.warning(
"Teacher prompt-topk length mismatch for seq %s (%s != %s). "
"Dropping distill payload for this group.",
idx,
len(token_ids_k),
len(logprobs_k),
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
distill_token_ids.append(token_ids_k)
distill_logprobs.append(logprobs_k)
group["distill_token_ids"] = distill_token_ids
group["distill_logprobs"] = distill_logprobs
return group
async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Any = None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
groups = scored_data if isinstance(scored_data, list) else [scored_data]
enriched_groups: List[ScoredDataGroup] = []
for group in groups:
if group is None:
continue
enriched_groups.append(await self._attach_teacher_distillation(group))
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
if isinstance(scored_data, list):
payload = enriched_groups
else:
payload = enriched_groups[0] if enriched_groups else scored_data
return await super().handle_send_to_api(
payload,
item=item,
do_send_to_api=do_send_to_api,
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
)