mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
debug changes
This commit is contained in:
parent
0510ca9b72
commit
c89854a350
2 changed files with 166 additions and 7 deletions
|
|
@ -209,8 +209,12 @@ class BaseEnvConfig(BaseModel):
|
|||
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
|
||||
"eval_helpers for the standard Hermes reasoning prompt.",
|
||||
)
|
||||
|
||||
# On-policy distillation settings
|
||||
distillation_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Enable on-policy distillation. When True, automatically fetches teacher logprobs "
|
||||
"after scoring and includes them in data sent to trainer.",
|
||||
)
|
||||
teacher_base_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API "
|
||||
|
|
@ -226,14 +230,9 @@ class BaseEnvConfig(BaseModel):
|
|||
description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.",
|
||||
)
|
||||
teacher_top_k: int = Field(
|
||||
default=10,
|
||||
default=20,
|
||||
description="Number of top logprobs to fetch from teacher model per position.",
|
||||
)
|
||||
distillation_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Enable on-policy distillation. When True, automatically fetches teacher logprobs "
|
||||
"after scoring and includes them in data sent to trainer.",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
|
@ -1164,6 +1163,28 @@ class BaseEnv(ABC):
|
|||
valid_groups.append(group)
|
||||
|
||||
if valid_groups and do_send_to_api:
|
||||
# On-policy distillation: fetch teacher logprobs if enabled
|
||||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
|
||||
for group in valid_groups:
|
||||
if group.get("onpolicydistill_logprobs") is None:
|
||||
try:
|
||||
teacher_logprobs = await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
messages_list=group.get("messages"),
|
||||
)
|
||||
if teacher_logprobs:
|
||||
group["onpolicydistill_logprobs"] = teacher_logprobs
|
||||
logger.info(f"[DISTILL] Added teacher logprobs for {len(teacher_logprobs)} sequences")
|
||||
else:
|
||||
logger.warning("[DISTILL] get_teacher_logprobs returned empty")
|
||||
except Exception as e:
|
||||
logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.debug(f"[DISTILL] Skipped - enabled={self.config.distillation_enabled}, url={self.config.teacher_base_url}")
|
||||
|
||||
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||
# send single or list of scored data groups
|
||||
if not original_was_list and len(valid_groups) == 1:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue