atropos/atroposlib/envs/teacher_distillation_env.py
Jai Suphavadeeprasit 78c0a6d082 tokenizer bug
2026-03-13 11:06:02 -04:00

429 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Teacher distillation environment layer.
This module adds teacher prompt-logprob fetching on top of BaseEnv without
modifying BaseEnv transport behavior.
Cross-tokenizer distillation
----------------------------
When student and teacher use the same tokenizer family (e.g. both Qwen3) the
student's raw token IDs can be forwarded directly to the teacher vLLM and the
returned top-k token IDs can be used as-is in the student logit lookup.
When tokenizers differ (e.g. Llama student, Qwen teacher) two problems arise:
1. Token-ID aliasing: student token 42 = " the" in Llama, but 42 = "" in
Qwen. Sending student IDs to the teacher causes it to score garbage.
2. Vocab-space mismatch: the teacher's top-k IDs live in the teacher's
vocabulary. The student logit lookup at those IDs would access random
tokens in the student vocab.
This module fixes both problems automatically:
• Re-tokenization student tokens are decoded to plain text and
re-tokenized with the teacher tokenizer before being sent to the teacher
server. The teacher therefore always scores the correct text.
• Character-level position alignment after re-tokenisation the teacher
has a different number of tokens than the student. A character-offset
map is built (requires a fast HuggingFace tokenizer) to project each
teacher logprob position back onto the student token it overlaps with.
• Vocabulary remapping teacher top-k token IDs (teacher vocab) are
decoded to text fragments and re-encoded with the student tokenizer so
that the final distill_token_ids live in the student vocabulary and can
be looked up directly in the student logit tensor.
Same-tokenizer fast path
------------------------
When teacher_tokenizer_name resolves to the same underlying vocabulary as the
student tokenizer the original fast path (no decode / re-tokenize / remap) is
taken automatically.
"""
from __future__ import annotations
import asyncio
import logging
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import Field
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
from .server_handling.server_baseline import APIServerConfig, ServerBaseline
from .server_handling.server_manager import ServerManager
logger = logging.getLogger(__name__)
class TeacherDistillationConfig(BaseEnvConfig):
teacher_enabled: bool = Field(
default=False,
description="Whether to fetch teacher prompt logprobs for distillation.",
)
teacher_base_url: Optional[str] = Field(
default=None,
description="Teacher server base URL (OpenAI-compatible).",
)
teacher_model_name: Optional[str] = Field(
default=None,
description="Teacher model name used in teacher server requests.",
)
teacher_api_key: str = Field(
default="",
description="Teacher API key, if required by the teacher endpoint.",
)
teacher_server_type: str = Field(
default="vllm",
description="Teacher server type (e.g. vllm, sglang, trl, openai).",
)
teacher_tokenizer_name: str = Field(
default="none",
description=(
"Tokenizer name for teacher server. If 'none', teacher_model_name is used. "
"When this resolves to a different vocabulary than the student tokenizer, "
"cross-tokenizer alignment is applied automatically."
),
)
teacher_top_k: int = Field(
default=1,
ge=1,
description="Top-k prompt logprobs to fetch per token position.",
)
class TeacherDistillationEnv(BaseEnv, ABC):
"""
BaseEnv subclass that enriches scored groups with teacher distillation arrays.
Distillation payload shape:
- distill_token_ids: [sequence][position][k] (student vocab IDs)
- distill_logprobs: [sequence][position][k]
"""
env_config_cls = TeacherDistillationConfig
def __init__(
self,
config: TeacherDistillationConfig,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
slurm: bool = False,
testing: bool = False,
):
super().__init__(config, server_configs, slurm=slurm, testing=testing)
self.teacher_server: Optional[ServerManager] = None
# Teacher tokenizer (only loaded when tokenizers may differ).
self._teacher_tokenizer = None
# True when student and teacher share the same vocabulary.
self._same_tokenizer: bool = True
# LRU-style cache: teacher_token_id -> student_token_id
self._vocab_remap_cache: Dict[int, int] = {}
if config.teacher_enabled:
if not config.teacher_base_url or not config.teacher_model_name:
raise ValueError(
"teacher_enabled=True requires teacher_base_url and teacher_model_name."
)
teacher_tok_name = (
config.teacher_model_name
if config.teacher_tokenizer_name in ("none", "")
else config.teacher_tokenizer_name
)
teacher_cfg = APIServerConfig(
server_type=config.teacher_server_type, # type: ignore[arg-type]
base_url=config.teacher_base_url,
api_key=config.teacher_api_key,
model_name=config.teacher_model_name,
tokenizer_name=teacher_tok_name,
timeout=1200,
)
self.teacher_server = ServerManager(
[teacher_cfg],
slurm=False,
testing=False,
)
# Detect vocabulary mismatch.
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
if student_tok_name and teacher_tok_name and student_tok_name != teacher_tok_name:
try:
from transformers import AutoTokenizer
self._teacher_tokenizer = AutoTokenizer.from_pretrained(
teacher_tok_name, use_fast=True
)
self._same_tokenizer = False
logger.warning(
"TeacherDistillationEnv: cross-tokenizer mode active. "
"student=%s teacher=%s. "
"Token IDs will be decoded → re-tokenized → vocab-remapped.",
student_tok_name,
teacher_tok_name,
)
except Exception as exc:
logger.warning(
"TeacherDistillationEnv: could not load teacher tokenizer '%s' "
"(%s). Falling back to same-tokenizer (fast) path — only safe if "
"student and teacher share the same vocabulary.",
teacher_tok_name,
exc,
)
self._same_tokenizer = True
else:
self._same_tokenizer = True
logger.warning(
"TeacherDistillationEnv: teacher server configured at %s "
"(model=%s, top_k=%s, same_tokenizer=%s)",
config.teacher_base_url,
config.teacher_model_name,
config.teacher_top_k,
self._same_tokenizer,
)
# ------------------------------------------------------------------
# Cross-tokenizer helpers
# ------------------------------------------------------------------
def _build_student_teacher_alignment(
self,
text: str,
student_ids: List[int],
teacher_ids: List[int],
) -> List[List[int]]:
"""
For each student token position return the list of teacher token positions
whose character spans overlap with the student token's character span.
Requires fast (Rust-backed) HuggingFace tokenizers that support
return_offsets_mapping. Falls back to a proportional approximation
if offset mapping is unavailable.
"""
student_len = len(student_ids)
teacher_len = len(teacher_ids)
try:
s_enc = self.tokenizer(
text, return_offsets_mapping=True, add_special_tokens=False
)
t_enc = self._teacher_tokenizer(
text, return_offsets_mapping=True, add_special_tokens=False
)
s_offsets: List[Tuple[int, int]] = s_enc["offset_mapping"][:student_len]
t_offsets: List[Tuple[int, int]] = t_enc["offset_mapping"][:teacher_len]
alignment: List[List[int]] = []
for s_start, s_end in s_offsets:
overlapping = [
t_idx
for t_idx, (t_start, t_end) in enumerate(t_offsets)
if t_start < s_end and t_end > s_start and s_end > s_start
]
alignment.append(overlapping)
return alignment
except Exception as exc:
logger.warning(
"TeacherDistillationEnv: offset-mapping alignment failed (%s). "
"Using proportional fallback.",
exc,
)
ratio = teacher_len / max(student_len, 1)
return [[int(i * ratio)] for i in range(student_len)]
def _remap_teacher_token_to_student(self, teacher_token_id: int) -> int:
"""
Convert a teacher vocabulary token ID to the best-matching student
vocabulary token ID by decoding the teacher token to text then
re-encoding with the student tokenizer.
Results are cached to avoid repeated tokenizer calls.
"""
if teacher_token_id in self._vocab_remap_cache:
return self._vocab_remap_cache[teacher_token_id]
try:
text = self._teacher_tokenizer.decode(
[teacher_token_id], clean_up_tokenization_spaces=False
)
student_ids = self.tokenizer.encode(text, add_special_tokens=False)
# Use the first student token as the representative.
sid = int(student_ids[0]) if student_ids else teacher_token_id
except Exception:
sid = teacher_token_id
self._vocab_remap_cache[teacher_token_id] = sid
return sid
def _align_and_remap(
self,
student_ids: List[int],
teacher_topk_ids: List[List[int]],
teacher_topk_lps: List[List[float]],
alignment: List[List[int]],
) -> Tuple[List[List[int]], List[List[float]]]:
"""
Project teacher logprobs (teacher positions, teacher vocab) onto
student positions in student vocab.
For each student token position:
1. Collect all teacher top-k entries from overlapping teacher positions.
2. Remap each teacher token ID to the student vocab.
3. Merge duplicates by keeping the maximum logprob.
4. Return the top-k entries sorted by descending logprob.
"""
k = max(1, len(teacher_topk_ids[0]) if teacher_topk_ids else 1)
result_ids: List[List[int]] = []
result_lps: List[List[float]] = []
for s_idx in range(len(student_ids)):
t_positions = alignment[s_idx] if s_idx < len(alignment) else []
if not t_positions:
result_ids.append([])
result_lps.append([])
continue
# Merge all overlapping teacher positions, remap vocab.
merged: Dict[int, float] = {}
for t_idx in t_positions:
if t_idx >= len(teacher_topk_ids):
continue
for tid, tlp in zip(teacher_topk_ids[t_idx], teacher_topk_lps[t_idx]):
sid = self._remap_teacher_token_to_student(tid)
merged[sid] = max(merged.get(sid, -1e9), tlp)
sorted_items = sorted(merged.items(), key=lambda x: -x[1])
top = sorted_items[:k]
result_ids.append([int(sid) for sid, _ in top])
result_lps.append([float(lp) for _, lp in top])
return result_ids, result_lps
# ------------------------------------------------------------------
# Core fetch
# ------------------------------------------------------------------
async def _fetch_teacher_for_sequence(
self, token_ids: List[int], top_k: int
) -> Tuple[List[List[int]], List[List[float]]]:
assert self.teacher_server is not None
if self._same_tokenizer or self._teacher_tokenizer is None:
# Fast path: same vocabulary — send student IDs directly.
payload = await self.teacher_server.get_logprobs(
input_ids=token_ids,
top_k=top_k,
max_tokens=1,
split="train",
)
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
# Cross-tokenizer path:
# 1. Decode student tokens → plain text.
# 2. Re-tokenize with teacher tokenizer → teacher IDs.
# 3. Send teacher IDs to teacher vLLM.
# 4. Align teacher positions → student positions.
# 5. Remap teacher vocab IDs → student vocab IDs.
text = self.tokenizer.decode(token_ids, clean_up_tokenization_spaces=False)
teacher_ids: List[int] = self._teacher_tokenizer.encode(
text, add_special_tokens=False
)
payload = await self.teacher_server.get_logprobs(
input_ids=teacher_ids,
top_k=top_k,
max_tokens=1,
split="train",
)
teacher_topk_ids = payload["prompt_topk_token_ids"]
teacher_topk_lps = payload["prompt_topk_logprobs"]
alignment = self._build_student_teacher_alignment(text, token_ids, teacher_ids)
return self._align_and_remap(token_ids, teacher_topk_ids, teacher_topk_lps, alignment)
# ------------------------------------------------------------------
# Group enrichment
# ------------------------------------------------------------------
async def _attach_teacher_distillation(
self, group: ScoredDataGroup
) -> ScoredDataGroup:
if not self.config.teacher_enabled or self.teacher_server is None:
return group
seqs = group.get("tokens", [])
if not seqs:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
top_k = int(
(group.get("group_overrides") or {}).get(
"teacher_top_k", self.config.teacher_top_k
)
)
top_k = max(1, top_k)
tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs]
results = await asyncio.gather(*tasks, return_exceptions=True)
distill_token_ids: List[List[List[int]]] = []
distill_logprobs: List[List[List[float]]] = []
for idx, result in enumerate(results):
if isinstance(result, Exception):
logger.warning(
"Teacher logprob fetch failed for seq %s: %s. "
"Dropping distill payload for this group.",
idx,
result,
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
token_ids_k, logprobs_k = result
if len(token_ids_k) != len(logprobs_k):
logger.warning(
"Teacher prompt-topk length mismatch for seq %s (%s != %s). "
"Dropping distill payload for this group.",
idx,
len(token_ids_k),
len(logprobs_k),
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
distill_token_ids.append(token_ids_k)
distill_logprobs.append(logprobs_k)
group["distill_token_ids"] = distill_token_ids
group["distill_logprobs"] = distill_logprobs
return group
async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Any = None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
groups = scored_data if isinstance(scored_data, list) else [scored_data]
enriched_groups: List[ScoredDataGroup] = []
for group in groups:
if group is None:
continue
enriched_groups.append(await self._attach_teacher_distillation(group))
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
if isinstance(scored_data, list):
payload = enriched_groups
else:
payload = enriched_groups[0] if enriched_groups else scored_data
return await super().handle_send_to_api(
payload,
item=item,
do_send_to_api=do_send_to_api,
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
)