address problems

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 16:12:05 -04:00
parent 322e7e6623
commit a8cdb53a4d
6 changed files with 99 additions and 24 deletions

View file

@ -281,7 +281,7 @@ class VLLMServer(APIServer):
), "Prompt or input_ids is required for get_logprobs!"
top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1)))
top_k = max(1, top_k)
top_k = max(0, top_k)
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
from_prompt_text = False
@ -420,7 +420,7 @@ def resolve_openai_configs(
) from e
if isinstance(default_server_configs, APIServerConfig):
server_configs = final_openai_config
server_configs = [final_openai_config]
elif isinstance(default_server_configs, list):
server_configs = [final_openai_config]
else:

View file

@ -33,12 +33,16 @@ class TeacherDistillationConfig(BaseEnvConfig):
)
teacher_server: Optional[APIServerConfig] = Field(
default=None,
description="Teacher inference server configuration.",
description="Fallback teacher server configuration when not provided at init.",
)
teacher_top_k: int = Field(
default=1,
ge=1,
description="Top-k prompt logprobs to fetch per token position.",
default=0,
ge=-1,
description=(
"Number of extra prompt logprobs to fetch beyond the selected token. "
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
"teacher fetching."
),
)
@ -57,6 +61,9 @@ class TeacherDistillationEnv(BaseEnv, ABC):
self,
config: TeacherDistillationConfig,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
teacher_server_configs: Optional[
Union[ServerBaseline, List[APIServerConfig]]
] = None,
slurm: bool = False,
testing: bool = False,
):
@ -64,26 +71,42 @@ class TeacherDistillationEnv(BaseEnv, ABC):
self.teacher_server: Optional[ServerManager] = None
if config.teacher_enabled:
if config.teacher_server is None:
teacher_config_source = teacher_server_configs
if teacher_config_source is None and config.teacher_server is not None:
teacher_config_source = [
config.teacher_server.model_copy(
update={
"tokenizer_name": (
config.teacher_server.model_name
if config.teacher_server.tokenizer_name in ("", "none")
else config.teacher_server.tokenizer_name
),
"timeout": 1200,
}
)
]
if teacher_config_source is None:
raise ValueError(
"teacher_enabled=True requires a teacher_server configuration."
"teacher_enabled=True requires teacher_server_configs at init "
"or a fallback teacher_server config."
)
teacher_cfg = config.teacher_server.model_copy(
update={
"tokenizer_name": (
config.teacher_server.model_name
if config.teacher_server.tokenizer_name in ("", "none")
else config.teacher_server.tokenizer_name
),
"timeout": 1200,
}
)
self.teacher_server = ServerManager(
[teacher_cfg],
teacher_config_source,
slurm=False,
testing=False,
)
self._validate_teacher_tokenizer_compatibility(teacher_cfg.tokenizer_name)
if isinstance(teacher_config_source, list):
teacher_cfg = teacher_config_source[0]
else:
teacher_cfg = teacher_config_source
teacher_tokenizer_name = (
teacher_cfg.model_name
if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none")
else teacher_cfg.tokenizer_name
)
self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name)
# ------------------------------------------------------------------
# Core fetch
@ -146,12 +169,19 @@ class TeacherDistillationEnv(BaseEnv, ABC):
group["distill_logprobs"] = None
return group
group_overrides = group.get("group_overrides") or {}
if group_overrides.get("skip_teacher_top_k", False):
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
top_k = int(
(group.get("group_overrides") or {}).get(
"teacher_top_k", self.config.teacher_top_k
)
group_overrides.get("teacher_top_k", self.config.teacher_top_k)
)
top_k = max(1, top_k)
if top_k <= -1:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs]
results = await asyncio.gather(*tasks, return_exceptions=True)