remove cross tokenization and fix location of configs

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 13:19:28 -04:00
parent 148a4fd5eb
commit a1b545c734
6 changed files with 147 additions and 302 deletions

View file

@ -298,6 +298,47 @@ curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!
- Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.).
- Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR.
### TeacherDistillationEnv follow-up
The follow-up teacher environment uses a dedicated teacher server config and
attaches teacher prompt logprobs before the group is sent to the API.
Teacher config shape:
```python
TeacherDistillationConfig(
teacher_enabled=True,
teacher_server=APIServerConfig(
base_url="http://localhost:9003/v1",
model_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
api_key="",
server_type="vllm",
),
teacher_top_k=8,
)
```
CLI shape:
```bash
--env.teacher_enabled true \
--env.teacher_server.base_url "http://localhost:9003/v1" \
--env.teacher_server.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \
--env.teacher_server.server_type vllm \
--env.teacher_top_k 8
```
Tokenizer requirement:
- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary.
- If the tokenizers do not match, `TeacherDistillationEnv` raises an error instead of attempting token conversion.
Why same-tokenizer is required:
- `distill_token_ids` are consumed as student-vocabulary IDs by the trainer.
- If the teacher uses a different vocabulary, the same integer token ID refers to different text on the teacher and student sides.
- A decode/re-tokenize/remap pipeline is not a safe drop-in fix because it changes both token positions and token identities, which breaks the exact per-position token supervision that the current distillation loss assumes.
---
## Testing and Debugging Tools

View file

@ -4,42 +4,10 @@ Teacher distillation environment layer.
This module adds teacher prompt-logprob fetching on top of BaseEnv without
modifying BaseEnv transport behavior.
Cross-tokenizer distillation
----------------------------
When student and teacher use the same tokenizer family (e.g. both Qwen3) the
student's raw token IDs can be forwarded directly to the teacher vLLM and the
returned top-k token IDs can be used as-is in the student logit lookup.
When tokenizers differ (e.g. Llama student, Qwen teacher) two problems arise:
1. Token-ID aliasing: student token 42 = " the" in Llama, but 42 = "" in
Qwen. Sending student IDs to the teacher causes it to score garbage.
2. Vocab-space mismatch: the teacher's top-k IDs live in the teacher's
vocabulary. The student logit lookup at those IDs would access random
tokens in the student vocab.
This module fixes both problems automatically:
Re-tokenization student tokens are decoded to plain text and
re-tokenized with the teacher tokenizer before being sent to the teacher
server. The teacher therefore always scores the correct text.
Character-level position alignment after re-tokenisation the teacher
has a different number of tokens than the student. A character-offset
map is built (requires a fast HuggingFace tokenizer) to project each
teacher logprob position back onto the student token it overlaps with.
Vocabulary remapping teacher top-k token IDs (teacher vocab) are
decoded to text fragments and re-encoded with the student tokenizer so
that the final distill_token_ids live in the student vocabulary and can
be looked up directly in the student logit tensor.
Same-tokenizer fast path
------------------------
When teacher_tokenizer_name resolves to the same underlying vocabulary as the
student tokenizer the original fast path (no decode / re-tokenize / remap) is
taken automatically.
This implementation supports same-tokenizer distillation only. The teacher and
student must share the same tokenizer vocabulary so the student's token IDs can
be forwarded directly to the teacher and the returned teacher top-k token IDs
can be looked up directly in the student's logits.
"""
from __future__ import annotations
@ -47,7 +15,7 @@ from __future__ import annotations
import asyncio
import logging
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
from pydantic import Field
@ -63,29 +31,9 @@ class TeacherDistillationConfig(BaseEnvConfig):
default=False,
description="Whether to fetch teacher prompt logprobs for distillation.",
)
teacher_base_url: Optional[str] = Field(
teacher_server: Optional[APIServerConfig] = Field(
default=None,
description="Teacher server base URL (OpenAI-compatible).",
)
teacher_model_name: Optional[str] = Field(
default=None,
description="Teacher model name used in teacher server requests.",
)
teacher_api_key: str = Field(
default="",
description="Teacher API key, if required by the teacher endpoint.",
)
teacher_server_type: str = Field(
default="vllm",
description="Teacher server type (e.g. vllm, sglang, trl, openai).",
)
teacher_tokenizer_name: str = Field(
default="none",
description=(
"Tokenizer name for teacher server. If 'none', teacher_model_name is used. "
"When this resolves to a different vocabulary than the student tokenizer, "
"cross-tokenizer alignment is applied automatically."
),
description="Teacher inference server configuration.",
)
teacher_top_k: int = Field(
default=1,
@ -114,266 +62,71 @@ class TeacherDistillationEnv(BaseEnv, ABC):
):
super().__init__(config, server_configs, slurm=slurm, testing=testing)
self.teacher_server: Optional[ServerManager] = None
# Teacher tokenizer (only loaded when tokenizers may differ).
self._teacher_tokenizer = None
# True when student and teacher share the same vocabulary.
self._same_tokenizer: bool = True
# LRU-style cache: teacher_token_id -> student_token_id
self._vocab_remap_cache: Dict[int, int] = {}
if config.teacher_enabled:
if not config.teacher_base_url or not config.teacher_model_name:
if config.teacher_server is None:
raise ValueError(
"teacher_enabled=True requires teacher_base_url and teacher_model_name."
"teacher_enabled=True requires a teacher_server configuration."
)
teacher_tok_name = (
config.teacher_model_name
if config.teacher_tokenizer_name in ("none", "")
else config.teacher_tokenizer_name
)
teacher_cfg = APIServerConfig(
server_type=config.teacher_server_type, # type: ignore[arg-type]
base_url=config.teacher_base_url,
api_key=config.teacher_api_key,
model_name=config.teacher_model_name,
tokenizer_name=teacher_tok_name,
timeout=1200,
teacher_cfg = config.teacher_server.model_copy(
update={
"tokenizer_name": (
config.teacher_server.model_name
if config.teacher_server.tokenizer_name in ("", "none")
else config.teacher_server.tokenizer_name
),
"timeout": 1200,
}
)
self.teacher_server = ServerManager(
[teacher_cfg],
slurm=False,
testing=False,
)
# Detect vocabulary mismatch.
# Compare by name first; if names differ, load the teacher tokenizer
# and do a vocab-size sanity check. Same-family models (e.g. Qwen3-4B
# and Qwen3-30B) share the same vocabulary, so even though the
# name_or_path strings differ they should use the fast path.
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
if (
student_tok_name
and teacher_tok_name
and student_tok_name != teacher_tok_name
):
try:
from transformers import AutoTokenizer
loaded = AutoTokenizer.from_pretrained(
teacher_tok_name, use_fast=True
)
student_vocab_size = getattr(self.tokenizer, "vocab_size", None)
teacher_vocab_size = getattr(loaded, "vocab_size", None)
if (
student_vocab_size is not None
and teacher_vocab_size is not None
and student_vocab_size == teacher_vocab_size
):
# Same vocab size — treat as same tokenizer (fast path).
# This covers same-family models (e.g. all Qwen3 variants).
self._same_tokenizer = True
logger.warning(
"TeacherDistillationEnv: names differ but vocab sizes match "
"(%d tokens). Using fast (same-tokenizer) path. "
"student=%s teacher=%s",
student_vocab_size,
student_tok_name,
teacher_tok_name,
)
else:
self._teacher_tokenizer = loaded
self._same_tokenizer = False
logger.warning(
"TeacherDistillationEnv: cross-tokenizer mode active. "
"student=%s (%s tokens) teacher=%s (%s tokens). "
"Token IDs will be decoded → re-tokenized → vocab-remapped.",
student_tok_name,
student_vocab_size,
teacher_tok_name,
teacher_vocab_size,
)
except Exception as exc:
logger.warning(
"TeacherDistillationEnv: could not load teacher tokenizer '%s' "
"(%s). Falling back to same-tokenizer (fast) path — only safe if "
"student and teacher share the same vocabulary.",
teacher_tok_name,
exc,
)
self._same_tokenizer = True
else:
self._same_tokenizer = True
logger.warning(
"TeacherDistillationEnv: teacher server configured at %s "
"(model=%s, top_k=%s, same_tokenizer=%s)",
config.teacher_base_url,
config.teacher_model_name,
config.teacher_top_k,
self._same_tokenizer,
)
# ------------------------------------------------------------------
# Cross-tokenizer helpers
# ------------------------------------------------------------------
def _build_student_teacher_alignment(
self,
text: str,
student_ids: List[int],
teacher_ids: List[int],
) -> List[List[int]]:
"""
For each student token position return the list of teacher token positions
whose character spans overlap with the student token's character span.
Requires fast (Rust-backed) HuggingFace tokenizers that support
return_offsets_mapping. Falls back to a proportional approximation
if offset mapping is unavailable.
"""
student_len = len(student_ids)
teacher_len = len(teacher_ids)
try:
s_enc = self.tokenizer(
text, return_offsets_mapping=True, add_special_tokens=False
)
t_enc = self._teacher_tokenizer(
text, return_offsets_mapping=True, add_special_tokens=False
)
s_offsets: List[Tuple[int, int]] = s_enc["offset_mapping"][:student_len]
t_offsets: List[Tuple[int, int]] = t_enc["offset_mapping"][:teacher_len]
alignment: List[List[int]] = []
for s_start, s_end in s_offsets:
overlapping = [
t_idx
for t_idx, (t_start, t_end) in enumerate(t_offsets)
if t_start < s_end and t_end > s_start and s_end > s_start
]
alignment.append(overlapping)
return alignment
except Exception as exc:
logger.warning(
"TeacherDistillationEnv: offset-mapping alignment failed (%s). "
"Using proportional fallback.",
exc,
)
ratio = teacher_len / max(student_len, 1)
return [[int(i * ratio)] for i in range(student_len)]
def _remap_teacher_token_to_student(self, teacher_token_id: int) -> int:
"""
Convert a teacher vocabulary token ID to the best-matching student
vocabulary token ID by decoding the teacher token to text then
re-encoding with the student tokenizer.
Results are cached to avoid repeated tokenizer calls.
"""
if teacher_token_id in self._vocab_remap_cache:
return self._vocab_remap_cache[teacher_token_id]
try:
text = self._teacher_tokenizer.decode(
[teacher_token_id], clean_up_tokenization_spaces=False
)
student_ids = self.tokenizer.encode(text, add_special_tokens=False)
# Use the first student token as the representative.
sid = int(student_ids[0]) if student_ids else teacher_token_id
except Exception:
sid = teacher_token_id
self._vocab_remap_cache[teacher_token_id] = sid
return sid
def _align_and_remap(
self,
student_ids: List[int],
teacher_topk_ids: List[List[int]],
teacher_topk_lps: List[List[float]],
alignment: List[List[int]],
) -> Tuple[List[List[int]], List[List[float]]]:
"""
Project teacher logprobs (teacher positions, teacher vocab) onto
student positions in student vocab.
For each student token position:
1. Collect all teacher top-k entries from overlapping teacher positions.
2. Remap each teacher token ID to the student vocab.
3. Merge duplicates by keeping the maximum logprob.
4. Return the top-k entries sorted by descending logprob.
"""
k = max(1, len(teacher_topk_ids[0]) if teacher_topk_ids else 1)
result_ids: List[List[int]] = []
result_lps: List[List[float]] = []
for s_idx in range(len(student_ids)):
t_positions = alignment[s_idx] if s_idx < len(alignment) else []
if not t_positions:
result_ids.append([])
result_lps.append([])
continue
# Merge all overlapping teacher positions, remap vocab.
merged: Dict[int, float] = {}
for t_idx in t_positions:
if t_idx >= len(teacher_topk_ids):
continue
for tid, tlp in zip(teacher_topk_ids[t_idx], teacher_topk_lps[t_idx]):
sid = self._remap_teacher_token_to_student(tid)
merged[sid] = max(merged.get(sid, -1e9), tlp)
sorted_items = sorted(merged.items(), key=lambda x: -x[1])
top = sorted_items[:k]
result_ids.append([int(sid) for sid, _ in top])
result_lps.append([float(lp) for _, lp in top])
return result_ids, result_lps
self._validate_teacher_tokenizer_compatibility(teacher_cfg.tokenizer_name)
# ------------------------------------------------------------------
# Core fetch
# ------------------------------------------------------------------
def _validate_teacher_tokenizer_compatibility(self, teacher_tokenizer_name: str) -> None:
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
if student_tok_name == teacher_tokenizer_name:
return
try:
from transformers import AutoTokenizer
teacher_tokenizer = AutoTokenizer.from_pretrained(
teacher_tokenizer_name, use_fast=True
)
except Exception as exc:
raise ValueError(
"Cross-tokenizer distillation is not supported in this PR, and the "
f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to "
f"verify compatibility: {exc}"
) from exc
student_vocab = self.tokenizer.get_vocab()
teacher_vocab = teacher_tokenizer.get_vocab()
if student_vocab != teacher_vocab:
raise ValueError(
"Cross-tokenizer distillation is not supported in this PR. "
f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' "
f"and teacher tokenizer '{teacher_tokenizer_name}' do not match."
)
async def _fetch_teacher_for_sequence(
self, token_ids: List[int], top_k: int
) -> Tuple[List[List[int]], List[List[float]]]:
assert self.teacher_server is not None
if self._same_tokenizer or self._teacher_tokenizer is None:
# Fast path: same vocabulary — send student IDs directly.
payload = await self.teacher_server.get_logprobs(
input_ids=token_ids,
top_k=top_k,
max_tokens=1,
split="train",
)
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
# Cross-tokenizer path:
# 1. Decode student tokens → plain text.
# 2. Re-tokenize with teacher tokenizer → teacher IDs.
# 3. Send teacher IDs to teacher vLLM.
# 4. Align teacher positions → student positions.
# 5. Remap teacher vocab IDs → student vocab IDs.
text = self.tokenizer.decode(token_ids, clean_up_tokenization_spaces=False)
teacher_ids: List[int] = self._teacher_tokenizer.encode(
text, add_special_tokens=False
)
payload = await self.teacher_server.get_logprobs(
input_ids=teacher_ids,
input_ids=token_ids,
top_k=top_k,
max_tokens=1,
split="train",
)
teacher_topk_ids = payload["prompt_topk_token_ids"]
teacher_topk_lps = payload["prompt_topk_logprobs"]
alignment = self._build_student_teacher_alignment(text, token_ids, teacher_ids)
return self._align_and_remap(
token_ids, teacher_topk_ids, teacher_topk_lps, alignment
)
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
# ------------------------------------------------------------------
# Group enrichment

View file

@ -67,3 +67,29 @@ async def test_attach_teacher_distillation_failure_drops_payload():
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
assert out["distill_token_ids"] is None
assert out["distill_logprobs"] is None
def test_teacher_tokenizer_mismatch_raises(monkeypatch):
env = object.__new__(_ConcreteTeacherEnv)
class _StudentTokenizer:
name_or_path = "student-model"
def get_vocab(self):
return {"a": 1}
class _TeacherTokenizer:
def get_vocab(self):
return {"b": 1}
env.tokenizer = _StudentTokenizer()
monkeypatch.setattr(
"transformers.AutoTokenizer.from_pretrained",
lambda *args, **kwargs: _TeacherTokenizer(),
)
with pytest.raises(ValueError, match="Cross-tokenizer distillation is not supported"):
TeacherDistillationEnv._validate_teacher_tokenizer_compatibility(
env,
teacher_tokenizer_name="teacher-model",
)

View file

@ -32,11 +32,12 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv):
max_token_length=2048,
wandb_name="gsm8k_teacher_distill",
teacher_enabled=True,
teacher_base_url="http://localhost:8003/v1",
teacher_model_name="mock-teacher",
teacher_api_key="",
teacher_server_type="vllm",
teacher_tokenizer_name="none",
teacher_server=APIServerConfig(
base_url="http://localhost:8003/v1",
model_name="mock-teacher",
api_key="",
server_type="vllm",
),
teacher_top_k=4,
)
server_config = APIServerConfig(

View file

@ -304,6 +304,29 @@ environment uses the `/generate` path and includes token-level
4. Trainer extracts and aligns logprobs with training labels
5. GRPO loss uses these rollout logprobs in importance-ratio terms
### 1b. Teacher distillation requires the same tokenizer
When distillation data is attached to Atropos batches, the trainer treats
`distill_token_ids` as indices into the student's logit tensor. That only works
if the teacher and student share the same tokenizer vocabulary.
What to configure on the environment side:
```bash
--env.teacher_enabled true \
--env.teacher_server.base_url "http://localhost:9003/v1" \
--env.teacher_server.model_name "$TEACHER_MODEL" \
--env.teacher_server.server_type vllm \
--env.teacher_top_k 8
```
Why cross-tokenizer conversion is not acceptable here:
- Teacher token ID `1234` and student token ID `1234` can correspond to different text.
- Re-tokenizing teacher text changes token boundaries, so teacher position `i` may no longer correspond to student position `i`.
- Remapping teacher top-k tokens back into student vocab can collapse multiple teacher candidates into one student token or expand one teacher token into multiple student tokens.
- The current distillation loss expects exact per-position supervision in student token space, so an approximate remapping would silently produce misleading targets.
### 2. Clipping
```bash

View file

@ -234,8 +234,9 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \
--env.use_wandb true \
--env.wandb_name "gsm8k-teacher-distill" \
--env.teacher_enabled true \
--env.teacher_base_url "http://localhost:${TEACHER_PORT}/v1" \
--env.teacher_model_name "$TEACHER_MODEL" \
--env.teacher_server.base_url "http://localhost:${TEACHER_PORT}/v1" \
--env.teacher_server.model_name "$TEACHER_MODEL" \
--env.teacher_server.server_type vllm \
--env.teacher_top_k "$TEACHER_TOP_K" \
--env.ensure_scores_are_not_same false \
--openai.api_key "dummy" \