mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
proper fallback
This commit is contained in:
parent
3910a58f9b
commit
559d649a26
3 changed files with 61 additions and 4 deletions
|
|
@ -990,7 +990,12 @@ class BaseEnv(ABC):
|
|||
|
||||
if self.config.include_messages and group.get("messages") is None:
|
||||
group["messages"] = [
|
||||
self.tokenizer.decode(group["tokens"][i])
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.tokenizer.decode(group["tokens"][i]),
|
||||
}
|
||||
]
|
||||
for i in range(len(group["tokens"]))
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
|
@ -15,6 +15,28 @@ class TeacherClient:
|
|||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
self.logger = logger
|
||||
if self.config.distillation_enabled:
|
||||
self._validate_distillation_config()
|
||||
|
||||
def _validate_distillation_config(self) -> None:
|
||||
if not self.config.teacher_base_url:
|
||||
raise ValueError("Distillation requires `teacher_base_url` to be set.")
|
||||
if self.config.teacher_top_k <= 0:
|
||||
raise ValueError(
|
||||
f"Distillation requires `teacher_top_k > 0`, got {self.config.teacher_top_k}."
|
||||
)
|
||||
student_model_name = getattr(self.config, "model_name", None)
|
||||
if (
|
||||
self.config.teacher_model_name
|
||||
and student_model_name
|
||||
and self.config.teacher_model_name != student_model_name
|
||||
):
|
||||
self.logger.warning(
|
||||
"Cross-model distillation configured (teacher=%s, student=%s). "
|
||||
"Token-level alignment quality depends on tokenizer compatibility.",
|
||||
self.config.teacher_model_name,
|
||||
student_model_name,
|
||||
)
|
||||
|
||||
async def get_teacher_logprobs(
|
||||
self,
|
||||
|
|
@ -109,7 +131,7 @@ class TeacherClient:
|
|||
pass
|
||||
|
||||
if messages_list and i < len(messages_list):
|
||||
messages = list(messages_list[i])
|
||||
messages = self._normalize_messages(messages_list[i], full_text)
|
||||
if self.config.teacher_system_prompt:
|
||||
messages = [
|
||||
{
|
||||
|
|
@ -176,6 +198,36 @@ class TeacherClient:
|
|||
self.logger.error("Error fetching teacher logprobs: %s", e)
|
||||
return [], []
|
||||
|
||||
def _normalize_messages(
|
||||
self, raw_messages: Any, fallback_text: str
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Normalize environment message payloads for chat/completions teacher fallback.
|
||||
|
||||
Accepts already-structured message lists, plain strings, or unknown structures.
|
||||
"""
|
||||
if isinstance(raw_messages, str):
|
||||
return [{"role": "user", "content": raw_messages}]
|
||||
|
||||
if isinstance(raw_messages, list):
|
||||
normalized: List[Dict[str, str]] = []
|
||||
for msg in raw_messages:
|
||||
if (
|
||||
isinstance(msg, dict)
|
||||
and "role" in msg
|
||||
and "content" in msg
|
||||
and isinstance(msg["content"], str)
|
||||
):
|
||||
normalized.append(
|
||||
{"role": str(msg["role"]), "content": msg["content"]}
|
||||
)
|
||||
elif isinstance(msg, str):
|
||||
normalized.append({"role": "user", "content": msg})
|
||||
if normalized:
|
||||
return normalized
|
||||
|
||||
return [{"role": "user", "content": fallback_text}]
|
||||
|
||||
def _align_teacher_topk_to_tokens(
|
||||
self,
|
||||
seq_token_ids: List[List[int]],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue