mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
dynamic system prompts
This commit is contained in:
parent
e615eb1f50
commit
55f7cbd091
3 changed files with 240 additions and 12 deletions
|
|
@ -248,6 +248,13 @@ class BaseEnvConfig(BaseModel):
|
|||
"this is converted to a textual prefix. For chat fallback, this is injected "
|
||||
"as a leading system message.",
|
||||
)
|
||||
teacher_prompt_template: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional template-first teacher prompt renderer. "
|
||||
"Uses Python format-style variables from runtime context/overrides "
|
||||
"(e.g., {question}, {answer}, {episodes}). If set, this is preferred over "
|
||||
"mode-specific prompt building.",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
|
@ -360,11 +367,15 @@ class BaseEnv(ABC):
|
|||
self,
|
||||
token_sequences: List[List[int]],
|
||||
messages_list: Optional[List[List[Dict]]] = None,
|
||||
seq_overrides: Optional[List[Dict[str, Any]]] = None,
|
||||
group_overrides: Optional[Dict[str, Any]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||
return await self.teacher_client.get_teacher_logprobs(
|
||||
token_sequences=token_sequences,
|
||||
messages_list=messages_list,
|
||||
seq_overrides=seq_overrides,
|
||||
group_overrides=group_overrides,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
|
|
@ -1012,6 +1023,12 @@ class BaseEnv(ABC):
|
|||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
|
||||
for group in valid_groups:
|
||||
seq_overrides = group.get("overrides") or []
|
||||
group_overrides = (
|
||||
group.get("group_overrides")
|
||||
if isinstance(group.get("group_overrides"), dict)
|
||||
else {}
|
||||
)
|
||||
has_new_format = (
|
||||
group.get("distill_token_ids") is not None
|
||||
and group.get("distill_logprobs") is not None
|
||||
|
|
@ -1021,6 +1038,8 @@ class BaseEnv(ABC):
|
|||
teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
messages_list=group.get("messages"),
|
||||
seq_overrides=seq_overrides,
|
||||
group_overrides=group_overrides,
|
||||
)
|
||||
if teacher_token_ids and teacher_logprobs:
|
||||
group["distill_token_ids"] = teacher_token_ids
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue