diff --git a/README.md b/README.md index ddb2bac7..ef90b729 100644 --- a/README.md +++ b/README.md @@ -312,7 +312,7 @@ Atropos handles tokenization in two places: - Recommendation: set `--openai.tokenizer_name` explicitly to match the student serving model. 2. **Teacher top-k parsing path** - - Teacher responses are parsed into token ids/logprobs in `BaseEnv.get_teacher_logprobs`. + - Teacher responses are fetched/parsed in `TeacherClient.get_teacher_logprobs` (called by `BaseEnv`). - The parser maps teacher token strings into ids using the environment tokenizer (`self.tokenizer`) and then aligns to student sequence length. Because distillation is token-position based, keeping tokenizer families compatible is strongly recommended, especially for cross-model distillation. diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 74383075..b816ce66 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -990,7 +990,12 @@ class BaseEnv(ABC): if self.config.include_messages and group.get("messages") is None: group["messages"] = [ - self.tokenizer.decode(group["tokens"][i]) + [ + { + "role": "user", + "content": self.tokenizer.decode(group["tokens"][i]), + } + ] for i in range(len(group["tokens"])) ] diff --git a/atroposlib/envs/server_handling/teacher_client.py b/atroposlib/envs/server_handling/teacher_client.py index a929d7dc..af5897b4 100644 --- a/atroposlib/envs/server_handling/teacher_client.py +++ b/atroposlib/envs/server_handling/teacher_client.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import aiohttp @@ -15,6 +15,28 @@ class TeacherClient: self.config = config self.tokenizer = tokenizer self.logger = logger + if self.config.distillation_enabled: + self._validate_distillation_config() + + def _validate_distillation_config(self) -> None: + if not self.config.teacher_base_url: + raise ValueError("Distillation requires `teacher_base_url` to be set.") + if self.config.teacher_top_k <= 0: + raise ValueError( + f"Distillation requires `teacher_top_k > 0`, got {self.config.teacher_top_k}." + ) + student_model_name = getattr(self.config, "model_name", None) + if ( + self.config.teacher_model_name + and student_model_name + and self.config.teacher_model_name != student_model_name + ): + self.logger.warning( + "Cross-model distillation configured (teacher=%s, student=%s). " + "Token-level alignment quality depends on tokenizer compatibility.", + self.config.teacher_model_name, + student_model_name, + ) async def get_teacher_logprobs( self, @@ -109,7 +131,7 @@ class TeacherClient: pass if messages_list and i < len(messages_list): - messages = list(messages_list[i]) + messages = self._normalize_messages(messages_list[i], full_text) if self.config.teacher_system_prompt: messages = [ { @@ -176,6 +198,36 @@ class TeacherClient: self.logger.error("Error fetching teacher logprobs: %s", e) return [], [] + def _normalize_messages( + self, raw_messages: Any, fallback_text: str + ) -> List[Dict[str, str]]: + """ + Normalize environment message payloads for chat/completions teacher fallback. + + Accepts already-structured message lists, plain strings, or unknown structures. + """ + if isinstance(raw_messages, str): + return [{"role": "user", "content": raw_messages}] + + if isinstance(raw_messages, list): + normalized: List[Dict[str, str]] = [] + for msg in raw_messages: + if ( + isinstance(msg, dict) + and "role" in msg + and "content" in msg + and isinstance(msg["content"], str) + ): + normalized.append( + {"role": str(msg["role"]), "content": msg["content"]} + ) + elif isinstance(msg, str): + normalized.append({"role": "user", "content": msg}) + if normalized: + return normalized + + return [{"role": "user", "content": fallback_text}] + def _align_teacher_topk_to_tokens( self, seq_token_ids: List[List[int]],