training kernel

This commit is contained in:
Jai Suphavadeeprasit 2026-03-12 12:31:09 -04:00
parent 62ef2fcc2e
commit c26432b963

View file

@ -146,22 +146,48 @@ class TeacherDistillationEnv(BaseEnv, ABC):
)
# Detect vocabulary mismatch.
# Compare by name first; if names differ, load the teacher tokenizer
# and do a vocab-size sanity check. Same-family models (e.g. Qwen3-4B
# and Qwen3-30B) share the same vocabulary, so even though the
# name_or_path strings differ they should use the fast path.
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
if student_tok_name and teacher_tok_name and student_tok_name != teacher_tok_name:
try:
from transformers import AutoTokenizer
self._teacher_tokenizer = AutoTokenizer.from_pretrained(
loaded = AutoTokenizer.from_pretrained(
teacher_tok_name, use_fast=True
)
self._same_tokenizer = False
logger.warning(
"TeacherDistillationEnv: cross-tokenizer mode active. "
"student=%s teacher=%s. "
"Token IDs will be decoded → re-tokenized → vocab-remapped.",
student_tok_name,
teacher_tok_name,
)
student_vocab_size = getattr(self.tokenizer, "vocab_size", None)
teacher_vocab_size = getattr(loaded, "vocab_size", None)
if (
student_vocab_size is not None
and teacher_vocab_size is not None
and student_vocab_size == teacher_vocab_size
):
# Same vocab size — treat as same tokenizer (fast path).
# This covers same-family models (e.g. all Qwen3 variants).
self._same_tokenizer = True
logger.warning(
"TeacherDistillationEnv: names differ but vocab sizes match "
"(%d tokens). Using fast (same-tokenizer) path. "
"student=%s teacher=%s",
student_vocab_size,
student_tok_name,
teacher_tok_name,
)
else:
self._teacher_tokenizer = loaded
self._same_tokenizer = False
logger.warning(
"TeacherDistillationEnv: cross-tokenizer mode active. "
"student=%s (%s tokens) teacher=%s (%s tokens). "
"Token IDs will be decoded → re-tokenized → vocab-remapped.",
student_tok_name,
student_vocab_size,
teacher_tok_name,
teacher_vocab_size,
)
except Exception as exc:
logger.warning(
"TeacherDistillationEnv: could not load teacher tokenizer '%s' "