mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
training kernel
This commit is contained in:
parent
62ef2fcc2e
commit
c26432b963
1 changed files with 35 additions and 9 deletions
|
|
@ -146,22 +146,48 @@ class TeacherDistillationEnv(BaseEnv, ABC):
|
|||
)
|
||||
|
||||
# Detect vocabulary mismatch.
|
||||
# Compare by name first; if names differ, load the teacher tokenizer
|
||||
# and do a vocab-size sanity check. Same-family models (e.g. Qwen3-4B
|
||||
# and Qwen3-30B) share the same vocabulary, so even though the
|
||||
# name_or_path strings differ they should use the fast path.
|
||||
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
|
||||
if student_tok_name and teacher_tok_name and student_tok_name != teacher_tok_name:
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self._teacher_tokenizer = AutoTokenizer.from_pretrained(
|
||||
loaded = AutoTokenizer.from_pretrained(
|
||||
teacher_tok_name, use_fast=True
|
||||
)
|
||||
self._same_tokenizer = False
|
||||
logger.warning(
|
||||
"TeacherDistillationEnv: cross-tokenizer mode active. "
|
||||
"student=%s teacher=%s. "
|
||||
"Token IDs will be decoded → re-tokenized → vocab-remapped.",
|
||||
student_tok_name,
|
||||
teacher_tok_name,
|
||||
)
|
||||
student_vocab_size = getattr(self.tokenizer, "vocab_size", None)
|
||||
teacher_vocab_size = getattr(loaded, "vocab_size", None)
|
||||
if (
|
||||
student_vocab_size is not None
|
||||
and teacher_vocab_size is not None
|
||||
and student_vocab_size == teacher_vocab_size
|
||||
):
|
||||
# Same vocab size — treat as same tokenizer (fast path).
|
||||
# This covers same-family models (e.g. all Qwen3 variants).
|
||||
self._same_tokenizer = True
|
||||
logger.warning(
|
||||
"TeacherDistillationEnv: names differ but vocab sizes match "
|
||||
"(%d tokens). Using fast (same-tokenizer) path. "
|
||||
"student=%s teacher=%s",
|
||||
student_vocab_size,
|
||||
student_tok_name,
|
||||
teacher_tok_name,
|
||||
)
|
||||
else:
|
||||
self._teacher_tokenizer = loaded
|
||||
self._same_tokenizer = False
|
||||
logger.warning(
|
||||
"TeacherDistillationEnv: cross-tokenizer mode active. "
|
||||
"student=%s (%s tokens) teacher=%s (%s tokens). "
|
||||
"Token IDs will be decoded → re-tokenized → vocab-remapped.",
|
||||
student_tok_name,
|
||||
student_vocab_size,
|
||||
teacher_tok_name,
|
||||
teacher_vocab_size,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"TeacherDistillationEnv: could not load teacher tokenizer '%s' "
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue