diff --git a/README.md b/README.md index 3b533a9b..6def2fdf 100644 --- a/README.md +++ b/README.md @@ -298,6 +298,47 @@ curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids! - Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.). - Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR. +### TeacherDistillationEnv follow-up + +The follow-up teacher environment uses a dedicated teacher server config and +attaches teacher prompt logprobs before the group is sent to the API. + +Teacher config shape: + +```python +TeacherDistillationConfig( + teacher_enabled=True, + teacher_server=APIServerConfig( + base_url="http://localhost:9003/v1", + model_name="Qwen/Qwen3-30B-A3B-Instruct-2507", + api_key="", + server_type="vllm", + ), + teacher_top_k=8, +) +``` + +CLI shape: + +```bash +--env.teacher_enabled true \ +--env.teacher_server.base_url "http://localhost:9003/v1" \ +--env.teacher_server.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \ +--env.teacher_server.server_type vllm \ +--env.teacher_top_k 8 +``` + +Tokenizer requirement: + +- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary. +- If the tokenizers do not match, `TeacherDistillationEnv` raises an error instead of attempting token conversion. + +Why same-tokenizer is required: + +- `distill_token_ids` are consumed as student-vocabulary IDs by the trainer. +- If the teacher uses a different vocabulary, the same integer token ID refers to different text on the teacher and student sides. +- A decode/re-tokenize/remap pipeline is not a safe drop-in fix because it changes both token positions and token identities, which breaks the exact per-position token supervision that the current distillation loss assumes. + --- ## Testing and Debugging Tools diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 1c88ab62..d8284335 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -4,42 +4,10 @@ Teacher distillation environment layer. This module adds teacher prompt-logprob fetching on top of BaseEnv without modifying BaseEnv transport behavior. -Cross-tokenizer distillation ----------------------------- -When student and teacher use the same tokenizer family (e.g. both Qwen3) the -student's raw token IDs can be forwarded directly to the teacher vLLM and the -returned top-k token IDs can be used as-is in the student logit lookup. - -When tokenizers differ (e.g. Llama student, Qwen teacher) two problems arise: - - 1. Token-ID aliasing: student token 42 = " the" in Llama, but 42 = "ท" in - Qwen. Sending student IDs to the teacher causes it to score garbage. - - 2. Vocab-space mismatch: the teacher's top-k IDs live in the teacher's - vocabulary. The student logit lookup at those IDs would access random - tokens in the student vocab. - -This module fixes both problems automatically: - - • Re-tokenization – student tokens are decoded to plain text and - re-tokenized with the teacher tokenizer before being sent to the teacher - server. The teacher therefore always scores the correct text. - - • Character-level position alignment – after re-tokenisation the teacher - has a different number of tokens than the student. A character-offset - map is built (requires a fast HuggingFace tokenizer) to project each - teacher logprob position back onto the student token it overlaps with. - - • Vocabulary remapping – teacher top-k token IDs (teacher vocab) are - decoded to text fragments and re-encoded with the student tokenizer so - that the final distill_token_ids live in the student vocabulary and can - be looked up directly in the student logit tensor. - -Same-tokenizer fast path ------------------------- -When teacher_tokenizer_name resolves to the same underlying vocabulary as the -student tokenizer the original fast path (no decode / re-tokenize / remap) is -taken automatically. +This implementation supports same-tokenizer distillation only. The teacher and +student must share the same tokenizer vocabulary so the student's token IDs can +be forwarded directly to the teacher and the returned teacher top-k token IDs +can be looked up directly in the student's logits. """ from __future__ import annotations @@ -47,7 +15,7 @@ from __future__ import annotations import asyncio import logging from abc import ABC -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from pydantic import Field @@ -63,29 +31,9 @@ class TeacherDistillationConfig(BaseEnvConfig): default=False, description="Whether to fetch teacher prompt logprobs for distillation.", ) - teacher_base_url: Optional[str] = Field( + teacher_server: Optional[APIServerConfig] = Field( default=None, - description="Teacher server base URL (OpenAI-compatible).", - ) - teacher_model_name: Optional[str] = Field( - default=None, - description="Teacher model name used in teacher server requests.", - ) - teacher_api_key: str = Field( - default="", - description="Teacher API key, if required by the teacher endpoint.", - ) - teacher_server_type: str = Field( - default="vllm", - description="Teacher server type (e.g. vllm, sglang, trl, openai).", - ) - teacher_tokenizer_name: str = Field( - default="none", - description=( - "Tokenizer name for teacher server. If 'none', teacher_model_name is used. " - "When this resolves to a different vocabulary than the student tokenizer, " - "cross-tokenizer alignment is applied automatically." - ), + description="Teacher inference server configuration.", ) teacher_top_k: int = Field( default=1, @@ -114,266 +62,71 @@ class TeacherDistillationEnv(BaseEnv, ABC): ): super().__init__(config, server_configs, slurm=slurm, testing=testing) self.teacher_server: Optional[ServerManager] = None - # Teacher tokenizer (only loaded when tokenizers may differ). - self._teacher_tokenizer = None - # True when student and teacher share the same vocabulary. - self._same_tokenizer: bool = True - # LRU-style cache: teacher_token_id -> student_token_id - self._vocab_remap_cache: Dict[int, int] = {} if config.teacher_enabled: - if not config.teacher_base_url or not config.teacher_model_name: + if config.teacher_server is None: raise ValueError( - "teacher_enabled=True requires teacher_base_url and teacher_model_name." + "teacher_enabled=True requires a teacher_server configuration." ) - teacher_tok_name = ( - config.teacher_model_name - if config.teacher_tokenizer_name in ("none", "") - else config.teacher_tokenizer_name - ) - teacher_cfg = APIServerConfig( - server_type=config.teacher_server_type, # type: ignore[arg-type] - base_url=config.teacher_base_url, - api_key=config.teacher_api_key, - model_name=config.teacher_model_name, - tokenizer_name=teacher_tok_name, - timeout=1200, + teacher_cfg = config.teacher_server.model_copy( + update={ + "tokenizer_name": ( + config.teacher_server.model_name + if config.teacher_server.tokenizer_name in ("", "none") + else config.teacher_server.tokenizer_name + ), + "timeout": 1200, + } ) self.teacher_server = ServerManager( [teacher_cfg], slurm=False, testing=False, ) - - # Detect vocabulary mismatch. - # Compare by name first; if names differ, load the teacher tokenizer - # and do a vocab-size sanity check. Same-family models (e.g. Qwen3-4B - # and Qwen3-30B) share the same vocabulary, so even though the - # name_or_path strings differ they should use the fast path. - student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" - if ( - student_tok_name - and teacher_tok_name - and student_tok_name != teacher_tok_name - ): - try: - from transformers import AutoTokenizer - - loaded = AutoTokenizer.from_pretrained( - teacher_tok_name, use_fast=True - ) - student_vocab_size = getattr(self.tokenizer, "vocab_size", None) - teacher_vocab_size = getattr(loaded, "vocab_size", None) - if ( - student_vocab_size is not None - and teacher_vocab_size is not None - and student_vocab_size == teacher_vocab_size - ): - # Same vocab size — treat as same tokenizer (fast path). - # This covers same-family models (e.g. all Qwen3 variants). - self._same_tokenizer = True - logger.warning( - "TeacherDistillationEnv: names differ but vocab sizes match " - "(%d tokens). Using fast (same-tokenizer) path. " - "student=%s teacher=%s", - student_vocab_size, - student_tok_name, - teacher_tok_name, - ) - else: - self._teacher_tokenizer = loaded - self._same_tokenizer = False - logger.warning( - "TeacherDistillationEnv: cross-tokenizer mode active. " - "student=%s (%s tokens) teacher=%s (%s tokens). " - "Token IDs will be decoded → re-tokenized → vocab-remapped.", - student_tok_name, - student_vocab_size, - teacher_tok_name, - teacher_vocab_size, - ) - except Exception as exc: - logger.warning( - "TeacherDistillationEnv: could not load teacher tokenizer '%s' " - "(%s). Falling back to same-tokenizer (fast) path — only safe if " - "student and teacher share the same vocabulary.", - teacher_tok_name, - exc, - ) - self._same_tokenizer = True - else: - self._same_tokenizer = True - - logger.warning( - "TeacherDistillationEnv: teacher server configured at %s " - "(model=%s, top_k=%s, same_tokenizer=%s)", - config.teacher_base_url, - config.teacher_model_name, - config.teacher_top_k, - self._same_tokenizer, - ) - - # ------------------------------------------------------------------ - # Cross-tokenizer helpers - # ------------------------------------------------------------------ - - def _build_student_teacher_alignment( - self, - text: str, - student_ids: List[int], - teacher_ids: List[int], - ) -> List[List[int]]: - """ - For each student token position return the list of teacher token positions - whose character spans overlap with the student token's character span. - - Requires fast (Rust-backed) HuggingFace tokenizers that support - return_offsets_mapping. Falls back to a proportional approximation - if offset mapping is unavailable. - """ - student_len = len(student_ids) - teacher_len = len(teacher_ids) - - try: - s_enc = self.tokenizer( - text, return_offsets_mapping=True, add_special_tokens=False - ) - t_enc = self._teacher_tokenizer( - text, return_offsets_mapping=True, add_special_tokens=False - ) - s_offsets: List[Tuple[int, int]] = s_enc["offset_mapping"][:student_len] - t_offsets: List[Tuple[int, int]] = t_enc["offset_mapping"][:teacher_len] - - alignment: List[List[int]] = [] - for s_start, s_end in s_offsets: - overlapping = [ - t_idx - for t_idx, (t_start, t_end) in enumerate(t_offsets) - if t_start < s_end and t_end > s_start and s_end > s_start - ] - alignment.append(overlapping) - return alignment - - except Exception as exc: - logger.warning( - "TeacherDistillationEnv: offset-mapping alignment failed (%s). " - "Using proportional fallback.", - exc, - ) - ratio = teacher_len / max(student_len, 1) - return [[int(i * ratio)] for i in range(student_len)] - - def _remap_teacher_token_to_student(self, teacher_token_id: int) -> int: - """ - Convert a teacher vocabulary token ID to the best-matching student - vocabulary token ID by decoding the teacher token to text then - re-encoding with the student tokenizer. - - Results are cached to avoid repeated tokenizer calls. - """ - if teacher_token_id in self._vocab_remap_cache: - return self._vocab_remap_cache[teacher_token_id] - - try: - text = self._teacher_tokenizer.decode( - [teacher_token_id], clean_up_tokenization_spaces=False - ) - student_ids = self.tokenizer.encode(text, add_special_tokens=False) - # Use the first student token as the representative. - sid = int(student_ids[0]) if student_ids else teacher_token_id - except Exception: - sid = teacher_token_id - - self._vocab_remap_cache[teacher_token_id] = sid - return sid - - def _align_and_remap( - self, - student_ids: List[int], - teacher_topk_ids: List[List[int]], - teacher_topk_lps: List[List[float]], - alignment: List[List[int]], - ) -> Tuple[List[List[int]], List[List[float]]]: - """ - Project teacher logprobs (teacher positions, teacher vocab) onto - student positions in student vocab. - - For each student token position: - 1. Collect all teacher top-k entries from overlapping teacher positions. - 2. Remap each teacher token ID to the student vocab. - 3. Merge duplicates by keeping the maximum logprob. - 4. Return the top-k entries sorted by descending logprob. - """ - k = max(1, len(teacher_topk_ids[0]) if teacher_topk_ids else 1) - result_ids: List[List[int]] = [] - result_lps: List[List[float]] = [] - - for s_idx in range(len(student_ids)): - t_positions = alignment[s_idx] if s_idx < len(alignment) else [] - if not t_positions: - result_ids.append([]) - result_lps.append([]) - continue - - # Merge all overlapping teacher positions, remap vocab. - merged: Dict[int, float] = {} - for t_idx in t_positions: - if t_idx >= len(teacher_topk_ids): - continue - for tid, tlp in zip(teacher_topk_ids[t_idx], teacher_topk_lps[t_idx]): - sid = self._remap_teacher_token_to_student(tid) - merged[sid] = max(merged.get(sid, -1e9), tlp) - - sorted_items = sorted(merged.items(), key=lambda x: -x[1]) - top = sorted_items[:k] - result_ids.append([int(sid) for sid, _ in top]) - result_lps.append([float(lp) for _, lp in top]) - - return result_ids, result_lps + self._validate_teacher_tokenizer_compatibility(teacher_cfg.tokenizer_name) # ------------------------------------------------------------------ # Core fetch # ------------------------------------------------------------------ + def _validate_teacher_tokenizer_compatibility(self, teacher_tokenizer_name: str) -> None: + student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" + if student_tok_name == teacher_tokenizer_name: + return + + try: + from transformers import AutoTokenizer + + teacher_tokenizer = AutoTokenizer.from_pretrained( + teacher_tokenizer_name, use_fast=True + ) + except Exception as exc: + raise ValueError( + "Cross-tokenizer distillation is not supported in this PR, and the " + f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to " + f"verify compatibility: {exc}" + ) from exc + + student_vocab = self.tokenizer.get_vocab() + teacher_vocab = teacher_tokenizer.get_vocab() + if student_vocab != teacher_vocab: + raise ValueError( + "Cross-tokenizer distillation is not supported in this PR. " + f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' " + f"and teacher tokenizer '{teacher_tokenizer_name}' do not match." + ) + async def _fetch_teacher_for_sequence( self, token_ids: List[int], top_k: int ) -> Tuple[List[List[int]], List[List[float]]]: assert self.teacher_server is not None - - if self._same_tokenizer or self._teacher_tokenizer is None: - # Fast path: same vocabulary — send student IDs directly. - payload = await self.teacher_server.get_logprobs( - input_ids=token_ids, - top_k=top_k, - max_tokens=1, - split="train", - ) - return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"] - - # Cross-tokenizer path: - # 1. Decode student tokens → plain text. - # 2. Re-tokenize with teacher tokenizer → teacher IDs. - # 3. Send teacher IDs to teacher vLLM. - # 4. Align teacher positions → student positions. - # 5. Remap teacher vocab IDs → student vocab IDs. - text = self.tokenizer.decode(token_ids, clean_up_tokenization_spaces=False) - teacher_ids: List[int] = self._teacher_tokenizer.encode( - text, add_special_tokens=False - ) - payload = await self.teacher_server.get_logprobs( - input_ids=teacher_ids, + input_ids=token_ids, top_k=top_k, max_tokens=1, split="train", ) - teacher_topk_ids = payload["prompt_topk_token_ids"] - teacher_topk_lps = payload["prompt_topk_logprobs"] - - alignment = self._build_student_teacher_alignment(text, token_ids, teacher_ids) - return self._align_and_remap( - token_ids, teacher_topk_ids, teacher_topk_lps, alignment - ) + return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"] # ------------------------------------------------------------------ # Group enrichment diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index 199f1453..7f5262e7 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -67,3 +67,29 @@ async def test_attach_teacher_distillation_failure_drops_payload(): out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) assert out["distill_token_ids"] is None assert out["distill_logprobs"] is None + + +def test_teacher_tokenizer_mismatch_raises(monkeypatch): + env = object.__new__(_ConcreteTeacherEnv) + + class _StudentTokenizer: + name_or_path = "student-model" + + def get_vocab(self): + return {"a": 1} + + class _TeacherTokenizer: + def get_vocab(self): + return {"b": 1} + + env.tokenizer = _StudentTokenizer() + monkeypatch.setattr( + "transformers.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: _TeacherTokenizer(), + ) + + with pytest.raises(ValueError, match="Cross-tokenizer distillation is not supported"): + TeacherDistillationEnv._validate_teacher_tokenizer_compatibility( + env, + teacher_tokenizer_name="teacher-model", + ) diff --git a/environments/gsm8k_server_teacher_distill.py b/environments/gsm8k_server_teacher_distill.py index 8276436b..49caabec 100644 --- a/environments/gsm8k_server_teacher_distill.py +++ b/environments/gsm8k_server_teacher_distill.py @@ -32,11 +32,12 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): max_token_length=2048, wandb_name="gsm8k_teacher_distill", teacher_enabled=True, - teacher_base_url="http://localhost:8003/v1", - teacher_model_name="mock-teacher", - teacher_api_key="", - teacher_server_type="vllm", - teacher_tokenizer_name="none", + teacher_server=APIServerConfig( + base_url="http://localhost:8003/v1", + model_name="mock-teacher", + api_key="", + server_type="vllm", + ), teacher_top_k=4, ) server_config = APIServerConfig( diff --git a/example_trainer/README.md b/example_trainer/README.md index ddb96b8a..b889f440 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -304,6 +304,29 @@ environment uses the `/generate` path and includes token-level 4. Trainer extracts and aligns logprobs with training labels 5. GRPO loss uses these rollout logprobs in importance-ratio terms +### 1b. Teacher distillation requires the same tokenizer + +When distillation data is attached to Atropos batches, the trainer treats +`distill_token_ids` as indices into the student's logit tensor. That only works +if the teacher and student share the same tokenizer vocabulary. + +What to configure on the environment side: + +```bash +--env.teacher_enabled true \ +--env.teacher_server.base_url "http://localhost:9003/v1" \ +--env.teacher_server.model_name "$TEACHER_MODEL" \ +--env.teacher_server.server_type vllm \ +--env.teacher_top_k 8 +``` + +Why cross-tokenizer conversion is not acceptable here: + +- Teacher token ID `1234` and student token ID `1234` can correspond to different text. +- Re-tokenizing teacher text changes token boundaries, so teacher position `i` may no longer correspond to student position `i`. +- Remapping teacher top-k tokens back into student vocab can collapse multiple teacher candidates into one student token or expand one teacher token into multiple student tokens. +- The current distillation loss expects exact per-position supervision in student token space, so an approximate remapping would silently produce misleading targets. + ### 2. Clipping ```bash diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 418a87ea..91cecf8a 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -234,8 +234,9 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \ --env.use_wandb true \ --env.wandb_name "gsm8k-teacher-distill" \ --env.teacher_enabled true \ - --env.teacher_base_url "http://localhost:${TEACHER_PORT}/v1" \ - --env.teacher_model_name "$TEACHER_MODEL" \ + --env.teacher_server.base_url "http://localhost:${TEACHER_PORT}/v1" \ + --env.teacher_server.model_name "$TEACHER_MODEL" \ + --env.teacher_server.server_type vllm \ --env.teacher_top_k "$TEACHER_TOP_K" \ --env.ensure_scores_are_not_same false \ --openai.api_key "dummy" \