diff --git a/README.md b/README.md index 6def2fdf..076f886a 100644 --- a/README.md +++ b/README.md @@ -318,6 +318,10 @@ TeacherDistillationConfig( ) ``` +If `teacher_server.model_name` is a deployment alias rather than a tokenizer +identifier, set `teacher_server.tokenizer_name` explicitly so the env can +validate tokenizer compatibility. + CLI shape: ```bash diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index aaee28d7..cc5bf9a5 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -281,7 +281,7 @@ class VLLMServer(APIServer): ), "Prompt or input_ids is required for get_logprobs!" top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1))) - top_k = max(1, top_k) + top_k = max(0, top_k) # Use input_ids if provided (from ManagedServer), otherwise tokenize prompt from_prompt_text = False @@ -420,7 +420,7 @@ def resolve_openai_configs( ) from e if isinstance(default_server_configs, APIServerConfig): - server_configs = final_openai_config + server_configs = [final_openai_config] elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 85e040c7..4c8bfa75 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -33,12 +33,16 @@ class TeacherDistillationConfig(BaseEnvConfig): ) teacher_server: Optional[APIServerConfig] = Field( default=None, - description="Teacher inference server configuration.", + description="Fallback teacher server configuration when not provided at init.", ) teacher_top_k: int = Field( - default=1, - ge=1, - description="Top-k prompt logprobs to fetch per token position.", + default=0, + ge=-1, + description=( + "Number of extra prompt logprobs to fetch beyond the selected token. " + "Use 0 for selected-token-only prompt logprobs and <= -1 to disable " + "teacher fetching." + ), ) @@ -57,6 +61,9 @@ class TeacherDistillationEnv(BaseEnv, ABC): self, config: TeacherDistillationConfig, server_configs: Union[ServerBaseline, List[APIServerConfig]], + teacher_server_configs: Optional[ + Union[ServerBaseline, List[APIServerConfig]] + ] = None, slurm: bool = False, testing: bool = False, ): @@ -64,26 +71,42 @@ class TeacherDistillationEnv(BaseEnv, ABC): self.teacher_server: Optional[ServerManager] = None if config.teacher_enabled: - if config.teacher_server is None: + teacher_config_source = teacher_server_configs + if teacher_config_source is None and config.teacher_server is not None: + teacher_config_source = [ + config.teacher_server.model_copy( + update={ + "tokenizer_name": ( + config.teacher_server.model_name + if config.teacher_server.tokenizer_name in ("", "none") + else config.teacher_server.tokenizer_name + ), + "timeout": 1200, + } + ) + ] + + if teacher_config_source is None: raise ValueError( - "teacher_enabled=True requires a teacher_server configuration." + "teacher_enabled=True requires teacher_server_configs at init " + "or a fallback teacher_server config." ) - teacher_cfg = config.teacher_server.model_copy( - update={ - "tokenizer_name": ( - config.teacher_server.model_name - if config.teacher_server.tokenizer_name in ("", "none") - else config.teacher_server.tokenizer_name - ), - "timeout": 1200, - } - ) self.teacher_server = ServerManager( - [teacher_cfg], + teacher_config_source, slurm=False, testing=False, ) - self._validate_teacher_tokenizer_compatibility(teacher_cfg.tokenizer_name) + if isinstance(teacher_config_source, list): + teacher_cfg = teacher_config_source[0] + else: + teacher_cfg = teacher_config_source + + teacher_tokenizer_name = ( + teacher_cfg.model_name + if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none") + else teacher_cfg.tokenizer_name + ) + self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name) # ------------------------------------------------------------------ # Core fetch @@ -146,12 +169,19 @@ class TeacherDistillationEnv(BaseEnv, ABC): group["distill_logprobs"] = None return group + group_overrides = group.get("group_overrides") or {} + if group_overrides.get("skip_teacher_top_k", False): + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + top_k = int( - (group.get("group_overrides") or {}).get( - "teacher_top_k", self.config.teacher_top_k - ) + group_overrides.get("teacher_top_k", self.config.teacher_top_k) ) - top_k = max(1, top_k) + if top_k <= -1: + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs] results = await asyncio.gather(*tasks, return_exceptions=True) diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index 65262984..2c0ddf17 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -69,6 +69,42 @@ async def test_attach_teacher_distillation_failure_drops_payload(): assert out["distill_logprobs"] is None +@pytest.mark.asyncio +async def test_attach_teacher_distillation_negative_topk_skips_fetch(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=-1) + env.teacher_server = _FakeTeacherServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": None, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert env.teacher_server.calls == 0 + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None + + +@pytest.mark.asyncio +async def test_attach_teacher_distillation_group_override_can_skip_fetch(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2) + env.teacher_server = _FakeTeacherServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": {"skip_teacher_top_k": True}, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert env.teacher_server.calls == 0 + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None + + def test_teacher_tokenizer_mismatch_raises(monkeypatch): env = object.__new__(_ConcreteTeacherEnv) diff --git a/environments/gsm8k_server_teacher_distill.py b/environments/gsm8k_server_teacher_distill.py index 49caabec..5aa33a01 100644 --- a/environments/gsm8k_server_teacher_distill.py +++ b/environments/gsm8k_server_teacher_distill.py @@ -37,6 +37,7 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): model_name="mock-teacher", api_key="", server_type="vllm", + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", ), teacher_top_k=4, ) diff --git a/example_trainer/README.md b/example_trainer/README.md index 8023ab73..8596a849 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -320,6 +320,10 @@ What to configure on the environment side: --env.teacher_top_k 8 ``` +If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier, +also set `--env.teacher_server.tokenizer_name ...` so the env can validate +tokenizer compatibility. + Why cross-tokenizer conversion is not acceptable here: - Teacher token ID `1234` and student token ID `1234` can correspond to different text.