diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 521c3762..12c16079 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -146,22 +146,48 @@ class TeacherDistillationEnv(BaseEnv, ABC): ) # Detect vocabulary mismatch. + # Compare by name first; if names differ, load the teacher tokenizer + # and do a vocab-size sanity check. Same-family models (e.g. Qwen3-4B + # and Qwen3-30B) share the same vocabulary, so even though the + # name_or_path strings differ they should use the fast path. student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" if student_tok_name and teacher_tok_name and student_tok_name != teacher_tok_name: try: from transformers import AutoTokenizer - self._teacher_tokenizer = AutoTokenizer.from_pretrained( + loaded = AutoTokenizer.from_pretrained( teacher_tok_name, use_fast=True ) - self._same_tokenizer = False - logger.warning( - "TeacherDistillationEnv: cross-tokenizer mode active. " - "student=%s teacher=%s. " - "Token IDs will be decoded → re-tokenized → vocab-remapped.", - student_tok_name, - teacher_tok_name, - ) + student_vocab_size = getattr(self.tokenizer, "vocab_size", None) + teacher_vocab_size = getattr(loaded, "vocab_size", None) + if ( + student_vocab_size is not None + and teacher_vocab_size is not None + and student_vocab_size == teacher_vocab_size + ): + # Same vocab size — treat as same tokenizer (fast path). + # This covers same-family models (e.g. all Qwen3 variants). + self._same_tokenizer = True + logger.warning( + "TeacherDistillationEnv: names differ but vocab sizes match " + "(%d tokens). Using fast (same-tokenizer) path. " + "student=%s teacher=%s", + student_vocab_size, + student_tok_name, + teacher_tok_name, + ) + else: + self._teacher_tokenizer = loaded + self._same_tokenizer = False + logger.warning( + "TeacherDistillationEnv: cross-tokenizer mode active. " + "student=%s (%s tokens) teacher=%s (%s tokens). " + "Token IDs will be decoded → re-tokenized → vocab-remapped.", + student_tok_name, + student_vocab_size, + teacher_tok_name, + teacher_vocab_size, + ) except Exception as exc: logger.warning( "TeacherDistillationEnv: could not load teacher tokenizer '%s' "