mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
remove cross tokenization and fix location of configs
This commit is contained in:
parent
148a4fd5eb
commit
a1b545c734
6 changed files with 147 additions and 302 deletions
|
|
@ -67,3 +67,29 @@ async def test_attach_teacher_distillation_failure_drops_payload():
|
|||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
def test_teacher_tokenizer_mismatch_raises(monkeypatch):
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
|
||||
class _StudentTokenizer:
|
||||
name_or_path = "student-model"
|
||||
|
||||
def get_vocab(self):
|
||||
return {"a": 1}
|
||||
|
||||
class _TeacherTokenizer:
|
||||
def get_vocab(self):
|
||||
return {"b": 1}
|
||||
|
||||
env.tokenizer = _StudentTokenizer()
|
||||
monkeypatch.setattr(
|
||||
"transformers.AutoTokenizer.from_pretrained",
|
||||
lambda *args, **kwargs: _TeacherTokenizer(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Cross-tokenizer distillation is not supported"):
|
||||
TeacherDistillationEnv._validate_teacher_tokenizer_compatibility(
|
||||
env,
|
||||
teacher_tokenizer_name="teacher-model",
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue