diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 4a94c6d8..f6b1d4d1 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -145,6 +145,9 @@ class ScoredData(BaseModel): group_overrides: Optional[dict] = None images: Optional[Any] = None env_id: Optional[int] = None # ID of the environment that generated this data + # On-policy distillation: top-K logprobs from teacher model + # Structure: [sequence][position][top_k] = [token_id, logprob] + onpolicydistill_logprobs: Optional[List[List[List[List]]]] = None @field_validator("messages", mode="before") @classmethod @@ -182,6 +185,7 @@ def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]: "group_overrides": scored_data.group_overrides, "images": scored_data.images, "env_id": scored_data.env_id, + "onpolicydistill_logprobs": scored_data.onpolicydistill_logprobs, } diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index e9b672a6..81569f0d 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -66,6 +66,9 @@ class ScoredDataGroup(TypedDict): group_overrides: Optional[Dict] overrides: Optional[List[Dict]] images: Optional[Any] + # On-policy distillation: top-K logprobs from teacher model + # Structure: List[List[List[List[Union[int, float]]]]] = [sequence][position][top_k] = [token_id, logprob] + onpolicydistill_logprobs: Optional[List[List[List[List]]]] class ScoredDataItem(TypedDict): @@ -78,6 +81,8 @@ class ScoredDataItem(TypedDict): group_overrides: Optional[Dict] overrides: Optional[Dict] images: Optional[Any] + # On-policy distillation: top-K logprobs from teacher model per position + onpolicydistill_logprobs: Optional[List[List[List]]] class EvalHandlingEnum(Enum): @@ -205,6 +210,31 @@ class BaseEnvConfig(BaseModel): "eval_helpers for the standard Hermes reasoning prompt.", ) + # On-policy distillation settings + teacher_base_url: Optional[str] = Field( + default=None, + description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API " + "(vLLM, OpenAI, Together, etc.). Examples: 'http://localhost:8001/v1', 'https://api.openai.com/v1'", + ) + teacher_model_name: Optional[str] = Field( + default=None, + description="Model name for teacher API calls (e.g., 'gpt-4o', 'meta-llama/Llama-3-70b'). " + "If None, uses 'default' which works for single-model vLLM servers.", + ) + teacher_api_key: Optional[str] = Field( + default=None, + description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.", + ) + teacher_top_k: int = Field( + default=10, + description="Number of top logprobs to fetch from teacher model per position.", + ) + distillation_enabled: bool = Field( + default=False, + description="Enable on-policy distillation. When True, automatically fetches teacher logprobs " + "after scoring and includes them in data sent to trainer.", + ) + class BaseEnv(ABC): name: Optional[str] = None @@ -309,6 +339,190 @@ class BaseEnv(ABC): # Calculate derived batch size return int(self.config.batch_size * effective_fraction) + async def get_teacher_logprobs( + self, + token_sequences: List[List[int]], + messages_list: Optional[List[List[Dict]]] = None, + top_k: Optional[int] = None, + ) -> List[List[List[List]]]: + """ + Fetch top-K logprobs from teacher model for given sequences. + + Supports any OpenAI-compatible API (vLLM, OpenAI, Together, etc.). + + Args: + token_sequences: List of token ID sequences to get logprobs for + messages_list: Optional list of message histories (for chat APIs). + If provided, uses chat/completions with logprobs. + top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k) + + Returns: + List of top-K logprobs per position per sequence + Structure: [batch][position][top_k] = [token_id, logprob] + Returns empty list if teacher_base_url is not configured. + """ + if not self.config.teacher_base_url: + return [] + + if top_k is None: + top_k = self.config.teacher_top_k + + # Get API key from config or environment + api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "") + model_name = self.config.teacher_model_name or "default" + + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + results = [] + + try: + async with aiohttp.ClientSession() as session: + for i, tokens in enumerate(token_sequences): + # Decode tokens to text + full_text = self.tokenizer.decode(tokens, skip_special_tokens=False) + + # Try vLLM-style completions first (supports prompt_logprobs) + # This is most efficient as it doesn't generate new tokens + request_data = { + "model": model_name, + "prompt": full_text, + "max_tokens": 1, + "temperature": 1.0, + "logprobs": top_k, + "echo": True, # Include prompt in response with logprobs + } + + try: + async with session.post( + f"{self.config.teacher_base_url}/completions", + json=request_data, + headers=headers, + timeout=aiohttp.ClientTimeout(total=120), + ) as response: + if response.status == 200: + data = await response.json() + seq_result = self._parse_completion_logprobs(data, top_k) + if seq_result: + results.append(seq_result) + continue + except Exception: + pass # Fall through to chat completions + + # Fallback: Use chat/completions with logprobs (OpenAI style) + # This requires messages format + if messages_list and i < len(messages_list): + messages = messages_list[i] + else: + # Convert text to simple message format + messages = [{"role": "user", "content": full_text}] + + chat_request = { + "model": model_name, + "messages": messages, + "max_tokens": 1, + "temperature": 1.0, + "logprobs": True, + "top_logprobs": top_k, + } + + try: + async with session.post( + f"{self.config.teacher_base_url}/chat/completions", + json=chat_request, + headers=headers, + timeout=aiohttp.ClientTimeout(total=120), + ) as response: + if response.status == 200: + data = await response.json() + seq_result = self._parse_chat_logprobs(data, top_k) + results.append(seq_result) + else: + logger.warning(f"Teacher API returned {response.status}") + results.append([]) + except Exception as e: + logger.warning(f"Teacher chat request failed: {e}") + results.append([]) + + return results + + except Exception as e: + logger.error(f"Error fetching teacher logprobs: {e}") + return [] + + def _parse_completion_logprobs( + self, data: Dict, top_k: int + ) -> List[List[List]]: + """Parse logprobs from vLLM-style completion response.""" + try: + choice = data.get("choices", [{}])[0] + logprobs_data = choice.get("logprobs", {}) + + # vLLM returns top_logprobs as list of dicts + top_logprobs = logprobs_data.get("top_logprobs", []) + tokens = logprobs_data.get("tokens", []) + + if not top_logprobs: + return [] + + seq_result = [] + for pos_logprobs in top_logprobs: + if pos_logprobs is None: + seq_result.append([]) + elif isinstance(pos_logprobs, dict): + # Format: {token_str: logprob, ...} + sorted_items = sorted( + pos_logprobs.items(), + key=lambda x: x[1], + reverse=True + )[:top_k] + pos_result = [] + for token_str, logprob in sorted_items: + # Convert token string to ID + token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) + if token_ids: + pos_result.append([token_ids[0], float(logprob)]) + seq_result.append(pos_result) + else: + seq_result.append([]) + + return seq_result + except Exception as e: + logger.warning(f"Error parsing completion logprobs: {e}") + return [] + + def _parse_chat_logprobs( + self, data: Dict, top_k: int + ) -> List[List[List]]: + """Parse logprobs from OpenAI-style chat completion response.""" + try: + choice = data.get("choices", [{}])[0] + logprobs_data = choice.get("logprobs", {}) + + if not logprobs_data: + return [] + + content = logprobs_data.get("content", []) + seq_result = [] + + for token_data in content: + top_logprobs = token_data.get("top_logprobs", []) + pos_result = [] + for item in top_logprobs[:top_k]: + token_str = item.get("token", "") + logprob = item.get("logprob", 0.0) + # Convert token string to ID + token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) + if token_ids: + pos_result.append([token_ids[0], float(logprob)]) + seq_result.append(pos_result) + + return seq_result + except Exception as e: + logger.warning(f"Error parsing chat logprobs: {e}") + return [] + @classmethod def config_init( cls, @@ -909,6 +1123,22 @@ class BaseEnv(ABC): for i in range(len(group["tokens"])) ] + # Automatic on-policy distillation: fetch teacher logprobs if enabled + if self.config.distillation_enabled and self.config.teacher_base_url: + if group.get("onpolicydistill_logprobs") is None: + try: + teacher_logprobs = await self.get_teacher_logprobs( + token_sequences=group["tokens"], + messages_list=group.get("messages"), + ) + if teacher_logprobs: + group["onpolicydistill_logprobs"] = teacher_logprobs + logger.debug( + f"Added teacher logprobs for {len(teacher_logprobs)} sequences" + ) + except Exception as e: + logger.warning(f"Failed to fetch teacher logprobs: {e}") + await self.add_rollouts_for_wandb(group, item) if self.jsonl_writer is not None: