mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-26 17:13:09 +00:00
address problems
This commit is contained in:
parent
322e7e6623
commit
a8cdb53a4d
6 changed files with 99 additions and 24 deletions
|
|
@ -281,7 +281,7 @@ class VLLMServer(APIServer):
|
|||
), "Prompt or input_ids is required for get_logprobs!"
|
||||
|
||||
top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1)))
|
||||
top_k = max(1, top_k)
|
||||
top_k = max(0, top_k)
|
||||
|
||||
# Use input_ids if provided (from ManagedServer), otherwise tokenize prompt
|
||||
from_prompt_text = False
|
||||
|
|
@ -420,7 +420,7 @@ def resolve_openai_configs(
|
|||
) from e
|
||||
|
||||
if isinstance(default_server_configs, APIServerConfig):
|
||||
server_configs = final_openai_config
|
||||
server_configs = [final_openai_config]
|
||||
elif isinstance(default_server_configs, list):
|
||||
server_configs = [final_openai_config]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -33,12 +33,16 @@ class TeacherDistillationConfig(BaseEnvConfig):
|
|||
)
|
||||
teacher_server: Optional[APIServerConfig] = Field(
|
||||
default=None,
|
||||
description="Teacher inference server configuration.",
|
||||
description="Fallback teacher server configuration when not provided at init.",
|
||||
)
|
||||
teacher_top_k: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Top-k prompt logprobs to fetch per token position.",
|
||||
default=0,
|
||||
ge=-1,
|
||||
description=(
|
||||
"Number of extra prompt logprobs to fetch beyond the selected token. "
|
||||
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
|
||||
"teacher fetching."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -57,6 +61,9 @@ class TeacherDistillationEnv(BaseEnv, ABC):
|
|||
self,
|
||||
config: TeacherDistillationConfig,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
teacher_server_configs: Optional[
|
||||
Union[ServerBaseline, List[APIServerConfig]]
|
||||
] = None,
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
|
|
@ -64,26 +71,42 @@ class TeacherDistillationEnv(BaseEnv, ABC):
|
|||
self.teacher_server: Optional[ServerManager] = None
|
||||
|
||||
if config.teacher_enabled:
|
||||
if config.teacher_server is None:
|
||||
teacher_config_source = teacher_server_configs
|
||||
if teacher_config_source is None and config.teacher_server is not None:
|
||||
teacher_config_source = [
|
||||
config.teacher_server.model_copy(
|
||||
update={
|
||||
"tokenizer_name": (
|
||||
config.teacher_server.model_name
|
||||
if config.teacher_server.tokenizer_name in ("", "none")
|
||||
else config.teacher_server.tokenizer_name
|
||||
),
|
||||
"timeout": 1200,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
if teacher_config_source is None:
|
||||
raise ValueError(
|
||||
"teacher_enabled=True requires a teacher_server configuration."
|
||||
"teacher_enabled=True requires teacher_server_configs at init "
|
||||
"or a fallback teacher_server config."
|
||||
)
|
||||
teacher_cfg = config.teacher_server.model_copy(
|
||||
update={
|
||||
"tokenizer_name": (
|
||||
config.teacher_server.model_name
|
||||
if config.teacher_server.tokenizer_name in ("", "none")
|
||||
else config.teacher_server.tokenizer_name
|
||||
),
|
||||
"timeout": 1200,
|
||||
}
|
||||
)
|
||||
self.teacher_server = ServerManager(
|
||||
[teacher_cfg],
|
||||
teacher_config_source,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
)
|
||||
self._validate_teacher_tokenizer_compatibility(teacher_cfg.tokenizer_name)
|
||||
if isinstance(teacher_config_source, list):
|
||||
teacher_cfg = teacher_config_source[0]
|
||||
else:
|
||||
teacher_cfg = teacher_config_source
|
||||
|
||||
teacher_tokenizer_name = (
|
||||
teacher_cfg.model_name
|
||||
if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none")
|
||||
else teacher_cfg.tokenizer_name
|
||||
)
|
||||
self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core fetch
|
||||
|
|
@ -146,12 +169,19 @@ class TeacherDistillationEnv(BaseEnv, ABC):
|
|||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
group_overrides = group.get("group_overrides") or {}
|
||||
if group_overrides.get("skip_teacher_top_k", False):
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
top_k = int(
|
||||
(group.get("group_overrides") or {}).get(
|
||||
"teacher_top_k", self.config.teacher_top_k
|
||||
)
|
||||
group_overrides.get("teacher_top_k", self.config.teacher_top_k)
|
||||
)
|
||||
top_k = max(1, top_k)
|
||||
if top_k <= -1:
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue