debug changes

This commit is contained in:
Jai Suphavadeeprasit 2026-02-17 08:15:07 -05:00
parent 0510ca9b72
commit c89854a350
2 changed files with 166 additions and 7 deletions

View file

@ -209,8 +209,12 @@ class BaseEnvConfig(BaseModel):
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
"eval_helpers for the standard Hermes reasoning prompt.",
)
# On-policy distillation settings
distillation_enabled: bool = Field(
default=False,
description="Enable on-policy distillation. When True, automatically fetches teacher logprobs "
"after scoring and includes them in data sent to trainer.",
)
teacher_base_url: Optional[str] = Field(
default=None,
description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API "
@ -226,14 +230,9 @@ class BaseEnvConfig(BaseModel):
description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.",
)
teacher_top_k: int = Field(
default=10,
default=20,
description="Number of top logprobs to fetch from teacher model per position.",
)
distillation_enabled: bool = Field(
default=False,
description="Enable on-policy distillation. When True, automatically fetches teacher logprobs "
"after scoring and includes them in data sent to trainer.",
)
class BaseEnv(ABC):
@ -1164,6 +1163,28 @@ class BaseEnv(ABC):
valid_groups.append(group)
if valid_groups and do_send_to_api:
# On-policy distillation: fetch teacher logprobs if enabled
if self.config.distillation_enabled and self.config.teacher_base_url:
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
for group in valid_groups:
if group.get("onpolicydistill_logprobs") is None:
try:
teacher_logprobs = await self.get_teacher_logprobs(
token_sequences=group["tokens"],
messages_list=group.get("messages"),
)
if teacher_logprobs:
group["onpolicydistill_logprobs"] = teacher_logprobs
logger.info(f"[DISTILL] Added teacher logprobs for {len(teacher_logprobs)} sequences")
else:
logger.warning("[DISTILL] get_teacher_logprobs returned empty")
except Exception as e:
logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}")
import traceback
logger.error(traceback.format_exc())
else:
logger.debug(f"[DISTILL] Skipped - enabled={self.config.distillation_enabled}, url={self.config.teacher_base_url}")
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
# send single or list of scored data groups
if not original_was_list and len(valid_groups) == 1: