dynamic system prompts

This commit is contained in:
Jai Suphavadeeprasit 2026-02-20 03:14:05 -05:00
parent e615eb1f50
commit 55f7cbd091
3 changed files with 240 additions and 12 deletions

View file

@ -248,6 +248,13 @@ class BaseEnvConfig(BaseModel):
"this is converted to a textual prefix. For chat fallback, this is injected "
"as a leading system message.",
)
teacher_prompt_template: Optional[str] = Field(
default=None,
description="Optional template-first teacher prompt renderer. "
"Uses Python format-style variables from runtime context/overrides "
"(e.g., {question}, {answer}, {episodes}). If set, this is preferred over "
"mode-specific prompt building.",
)
class BaseEnv(ABC):
@ -360,11 +367,15 @@ class BaseEnv(ABC):
self,
token_sequences: List[List[int]],
messages_list: Optional[List[List[Dict]]] = None,
seq_overrides: Optional[List[Dict[str, Any]]] = None,
group_overrides: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
return await self.teacher_client.get_teacher_logprobs(
token_sequences=token_sequences,
messages_list=messages_list,
seq_overrides=seq_overrides,
group_overrides=group_overrides,
top_k=top_k,
)
@ -1012,6 +1023,12 @@ class BaseEnv(ABC):
if self.config.distillation_enabled and self.config.teacher_base_url:
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
for group in valid_groups:
seq_overrides = group.get("overrides") or []
group_overrides = (
group.get("group_overrides")
if isinstance(group.get("group_overrides"), dict)
else {}
)
has_new_format = (
group.get("distill_token_ids") is not None
and group.get("distill_logprobs") is not None
@ -1021,6 +1038,8 @@ class BaseEnv(ABC):
teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs(
token_sequences=group["tokens"],
messages_list=group.get("messages"),
seq_overrides=seq_overrides,
group_overrides=group_overrides,
)
if teacher_token_ids and teacher_logprobs:
group["distill_token_ids"] = teacher_token_ids