proper fallback

This commit is contained in:
Jai Suphavadeeprasit 2026-02-20 01:22:50 -05:00
parent 3910a58f9b
commit 559d649a26
3 changed files with 61 additions and 4 deletions

View file

@ -990,7 +990,12 @@ class BaseEnv(ABC):
if self.config.include_messages and group.get("messages") is None:
group["messages"] = [
self.tokenizer.decode(group["tokens"][i])
[
{
"role": "user",
"content": self.tokenizer.decode(group["tokens"][i]),
}
]
for i in range(len(group["tokens"]))
]

View file

@ -1,5 +1,5 @@
import os
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import aiohttp
@ -15,6 +15,28 @@ class TeacherClient:
self.config = config
self.tokenizer = tokenizer
self.logger = logger
if self.config.distillation_enabled:
self._validate_distillation_config()
def _validate_distillation_config(self) -> None:
if not self.config.teacher_base_url:
raise ValueError("Distillation requires `teacher_base_url` to be set.")
if self.config.teacher_top_k <= 0:
raise ValueError(
f"Distillation requires `teacher_top_k > 0`, got {self.config.teacher_top_k}."
)
student_model_name = getattr(self.config, "model_name", None)
if (
self.config.teacher_model_name
and student_model_name
and self.config.teacher_model_name != student_model_name
):
self.logger.warning(
"Cross-model distillation configured (teacher=%s, student=%s). "
"Token-level alignment quality depends on tokenizer compatibility.",
self.config.teacher_model_name,
student_model_name,
)
async def get_teacher_logprobs(
self,
@ -109,7 +131,7 @@ class TeacherClient:
pass
if messages_list and i < len(messages_list):
messages = list(messages_list[i])
messages = self._normalize_messages(messages_list[i], full_text)
if self.config.teacher_system_prompt:
messages = [
{
@ -176,6 +198,36 @@ class TeacherClient:
self.logger.error("Error fetching teacher logprobs: %s", e)
return [], []
def _normalize_messages(
self, raw_messages: Any, fallback_text: str
) -> List[Dict[str, str]]:
"""
Normalize environment message payloads for chat/completions teacher fallback.
Accepts already-structured message lists, plain strings, or unknown structures.
"""
if isinstance(raw_messages, str):
return [{"role": "user", "content": raw_messages}]
if isinstance(raw_messages, list):
normalized: List[Dict[str, str]] = []
for msg in raw_messages:
if (
isinstance(msg, dict)
and "role" in msg
and "content" in msg
and isinstance(msg["content"], str)
):
normalized.append(
{"role": str(msg["role"]), "content": msg["content"]}
)
elif isinstance(msg, str):
normalized.append({"role": "user", "content": msg})
if normalized:
return normalized
return [{"role": "user", "content": fallback_text}]
def _align_teacher_topk_to_tokens(
self,
seq_token_ids: List[List[int]],