diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 551eb5b0..74383075 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -49,6 +49,7 @@ from .server_handling.server_manager import ( ServerManager, ServerManagerConfig, ) +from .server_handling.teacher_client import TeacherClient logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -297,6 +298,9 @@ class BaseEnv(ABC): self.curr_step = 0 self.max_token_len = -1 self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) + self.teacher_client = TeacherClient( + config=self.config, tokenizer=self.tokenizer, logger=logger + ) self.completion_lengths = [] self.max_num_workers = config.max_num_workers if self.max_num_workers == -1: @@ -358,174 +362,11 @@ class BaseEnv(ABC): messages_list: Optional[List[List[Dict]]] = None, top_k: Optional[int] = None, ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: - """ - Fetch top-K logprobs from teacher model for given sequences. - - Supports any OpenAI-compatible API (vLLM, OpenAI, Together, etc.). - - Args: - token_sequences: List of token ID sequences to get logprobs for - messages_list: Optional list of message histories (for chat APIs). - If provided, uses chat/completions with logprobs. - top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k) - - Returns: - Tuple of (distill_token_ids, distill_logprobs), both shaped as: - [batch][position][top_k]. - Returns ([], []) if teacher_base_url is not configured. - """ - logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences") - logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}") - - if not self.config.teacher_base_url: - logger.warning("[TEACHER] No teacher_base_url configured, returning empty") - return [], [] - - if top_k is None: - top_k = self.config.teacher_top_k - - # Get API key from config or environment - api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "") - model_name = self.config.teacher_model_name or "default" - - logger.info(f"[TEACHER] Using model={model_name}, top_k={top_k}") - - headers = {"Content-Type": "application/json"} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - token_id_results: List[List[List[int]]] = [] - logprob_results: List[List[List[float]]] = [] - - try: - async with aiohttp.ClientSession() as session: - for i, tokens in enumerate(token_sequences): - logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens") - # Decode original sequence and optionally prepend teacher steering text. - base_text = self.tokenizer.decode(tokens, skip_special_tokens=False) - steering_prefix = "" - if self.config.teacher_system_prompt: - steering_prefix += ( - "System instruction:\n" - f"{self.config.teacher_system_prompt.strip()}\n\n" - ) - if self.config.teacher_prefix_text: - steering_prefix += self.config.teacher_prefix_text - full_text = steering_prefix + base_text - prefix_token_len = ( - len(self.tokenizer.encode(steering_prefix, add_special_tokens=False)) - if steering_prefix - else 0 - ) - - # Try vLLM-style completions first (supports prompt_logprobs) - # This is most efficient as it doesn't generate new tokens - request_data = { - "model": model_name, - "prompt": full_text, - "max_tokens": 1, - "temperature": 1.0, - "logprobs": top_k, - "echo": True, # Include prompt in response with logprobs - } - - try: - async with session.post( - f"{self.config.teacher_base_url}/completions", - json=request_data, - headers=headers, - timeout=aiohttp.ClientTimeout(total=120), - ) as response: - if response.status == 200: - data = await response.json() - seq_token_ids, seq_logprobs = self._parse_completion_logprobs( - data, top_k - ) - if seq_token_ids and seq_logprobs: - aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( - seq_token_ids, - seq_logprobs, - target_token_len=len(tokens), - prefix_token_len=prefix_token_len, - ) - token_id_results.append(aligned_ids) - logprob_results.append(aligned_lps) - continue - except Exception: - pass # Fall through to chat completions - - # Fallback: Use chat/completions with logprobs (OpenAI style) - # This requires messages format - if messages_list and i < len(messages_list): - messages = list(messages_list[i]) - if self.config.teacher_system_prompt: - messages = [ - { - "role": "system", - "content": self.config.teacher_system_prompt, - } - ] + messages - else: - # Convert text to simple message format - messages = [] - if self.config.teacher_system_prompt: - messages.append( - { - "role": "system", - "content": self.config.teacher_system_prompt, - } - ) - messages.append({"role": "user", "content": full_text}) - - chat_request = { - "model": model_name, - "messages": messages, - "max_tokens": 1, - "temperature": 1.0, - "logprobs": True, - "top_logprobs": top_k, - } - - try: - async with session.post( - f"{self.config.teacher_base_url}/chat/completions", - json=chat_request, - headers=headers, - timeout=aiohttp.ClientTimeout(total=120), - ) as response: - if response.status == 200: - data = await response.json() - seq_token_ids, seq_logprobs = self._parse_chat_logprobs( - data, top_k - ) - # Chat fallback logprobs are for generated tokens, not prompt tokens. - # To keep alignment correct for distillation, return empty per-position rows. - if seq_token_ids and len(seq_token_ids) >= len(tokens): - aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( - seq_token_ids, - seq_logprobs, - target_token_len=len(tokens), - prefix_token_len=0, - ) - else: - aligned_ids = [[] for _ in range(len(tokens))] - aligned_lps = [[] for _ in range(len(tokens))] - token_id_results.append(aligned_ids) - logprob_results.append(aligned_lps) - else: - logger.warning(f"Teacher API returned {response.status}") - token_id_results.append([[] for _ in range(len(tokens))]) - logprob_results.append([[] for _ in range(len(tokens))]) - except Exception as e: - logger.warning(f"Teacher chat request failed: {e}") - token_id_results.append([[] for _ in range(len(tokens))]) - logprob_results.append([[] for _ in range(len(tokens))]) - - return token_id_results, logprob_results - - except Exception as e: - logger.error(f"Error fetching teacher logprobs: {e}") - return [], [] + return await self.teacher_client.get_teacher_logprobs( + token_sequences=token_sequences, + messages_list=messages_list, + top_k=top_k, + ) def _align_teacher_topk_to_tokens( self, @@ -534,109 +375,22 @@ class BaseEnv(ABC): target_token_len: int, prefix_token_len: int = 0, ) -> Tuple[List[List[int]], List[List[float]]]: - """ - Trim teacher prefix positions and enforce exact length alignment with source tokens. - """ - n = min(len(seq_token_ids), len(seq_logprobs)) - aligned_ids = list(seq_token_ids[:n]) - aligned_lps = list(seq_logprobs[:n]) - - if prefix_token_len > 0: - aligned_ids = aligned_ids[prefix_token_len:] - aligned_lps = aligned_lps[prefix_token_len:] - - # Truncate any tail token (e.g., generated token when max_tokens>0 with echo=True) - aligned_ids = aligned_ids[:target_token_len] - aligned_lps = aligned_lps[:target_token_len] - - # Pad missing positions to preserve strict [seq][position][top_k] shape. - if len(aligned_ids) < target_token_len: - pad_count = target_token_len - len(aligned_ids) - aligned_ids.extend([[] for _ in range(pad_count)]) - aligned_lps.extend([[] for _ in range(pad_count)]) - - return aligned_ids, aligned_lps + return self.teacher_client._align_teacher_topk_to_tokens( + seq_token_ids=seq_token_ids, + seq_logprobs=seq_logprobs, + target_token_len=target_token_len, + prefix_token_len=prefix_token_len, + ) def _parse_completion_logprobs( self, data: Dict, top_k: int ) -> Tuple[List[List[int]], List[List[float]]]: - """Parse token ids + logprobs from vLLM-style completion response.""" - try: - choice = data.get("choices", [{}])[0] - logprobs_data = choice.get("logprobs", {}) - - # vLLM returns top_logprobs as list of dicts - top_logprobs = logprobs_data.get("top_logprobs", []) - - if not top_logprobs: - return [], [] - - seq_token_ids: List[List[int]] = [] - seq_logprobs: List[List[float]] = [] - for pos_logprobs in top_logprobs: - if pos_logprobs is None: - seq_token_ids.append([]) - seq_logprobs.append([]) - elif isinstance(pos_logprobs, dict): - # Format: {token_str: logprob, ...} - sorted_items = sorted( - pos_logprobs.items(), - key=lambda x: x[1], - reverse=True - )[:top_k] - pos_ids: List[int] = [] - pos_lps: List[float] = [] - for token_str, logprob in sorted_items: - # Convert token string to ID - token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) - if token_ids: - pos_ids.append(int(token_ids[0])) - pos_lps.append(float(logprob)) - seq_token_ids.append(pos_ids) - seq_logprobs.append(pos_lps) - else: - seq_token_ids.append([]) - seq_logprobs.append([]) - - return seq_token_ids, seq_logprobs - except Exception as e: - logger.warning(f"Error parsing completion logprobs: {e}") - return [], [] + return self.teacher_client._parse_completion_logprobs(data=data, top_k=top_k) def _parse_chat_logprobs( self, data: Dict, top_k: int ) -> Tuple[List[List[int]], List[List[float]]]: - """Parse token ids + logprobs from OpenAI-style chat completion response.""" - try: - choice = data.get("choices", [{}])[0] - logprobs_data = choice.get("logprobs", {}) - - if not logprobs_data: - return [], [] - - content = logprobs_data.get("content", []) - seq_token_ids: List[List[int]] = [] - seq_logprobs: List[List[float]] = [] - - for token_data in content: - top_logprobs = token_data.get("top_logprobs", []) - pos_ids: List[int] = [] - pos_lps: List[float] = [] - for item in top_logprobs[:top_k]: - token_str = item.get("token", "") - logprob = item.get("logprob", 0.0) - # Convert token string to ID - token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) - if token_ids: - pos_ids.append(int(token_ids[0])) - pos_lps.append(float(logprob)) - seq_token_ids.append(pos_ids) - seq_logprobs.append(pos_lps) - - return seq_token_ids, seq_logprobs - except Exception as e: - logger.warning(f"Error parsing chat logprobs: {e}") - return [], [] + return self.teacher_client._parse_chat_logprobs(data=data, top_k=top_k) @classmethod def config_init( diff --git a/atroposlib/envs/server_handling/teacher_client.py b/atroposlib/envs/server_handling/teacher_client.py new file mode 100644 index 00000000..a929d7dc --- /dev/null +++ b/atroposlib/envs/server_handling/teacher_client.py @@ -0,0 +1,273 @@ +import os +from typing import Dict, List, Optional, Tuple + +import aiohttp + + +class TeacherClient: + """ + Transport/parsing client for teacher top-k logprobs. + + This keeps distillation HTTP and parsing logic out of BaseEnv. + """ + + def __init__(self, config, tokenizer, logger): + self.config = config + self.tokenizer = tokenizer + self.logger = logger + + async def get_teacher_logprobs( + self, + token_sequences: List[List[int]], + messages_list: Optional[List[List[Dict]]] = None, + top_k: Optional[int] = None, + ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: + self.logger.info( + "[TEACHER] get_teacher_logprobs called with %s sequences", + len(token_sequences), + ) + self.logger.info("[TEACHER] teacher_base_url=%s", self.config.teacher_base_url) + + if not self.config.teacher_base_url: + self.logger.warning("[TEACHER] No teacher_base_url configured, returning empty") + return [], [] + + if top_k is None: + top_k = self.config.teacher_top_k + + api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "") + model_name = self.config.teacher_model_name or "default" + self.logger.info("[TEACHER] Using model=%s, top_k=%s", model_name, top_k) + + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + token_id_results: List[List[List[int]]] = [] + logprob_results: List[List[List[float]]] = [] + + try: + async with aiohttp.ClientSession() as session: + for i, tokens in enumerate(token_sequences): + self.logger.info( + "[TEACHER] Processing sequence %s/%s, %s tokens", + i + 1, + len(token_sequences), + len(tokens), + ) + base_text = self.tokenizer.decode(tokens, skip_special_tokens=False) + steering_prefix = "" + if self.config.teacher_system_prompt: + steering_prefix += ( + "System instruction:\n" + f"{self.config.teacher_system_prompt.strip()}\n\n" + ) + if self.config.teacher_prefix_text: + steering_prefix += self.config.teacher_prefix_text + full_text = steering_prefix + base_text + prefix_token_len = ( + len( + self.tokenizer.encode( + steering_prefix, add_special_tokens=False + ) + ) + if steering_prefix + else 0 + ) + + request_data = { + "model": model_name, + "prompt": full_text, + "max_tokens": 1, + "temperature": 1.0, + "logprobs": top_k, + "echo": True, + } + try: + async with session.post( + f"{self.config.teacher_base_url}/completions", + json=request_data, + headers=headers, + timeout=aiohttp.ClientTimeout(total=120), + ) as response: + if response.status == 200: + data = await response.json() + seq_token_ids, seq_logprobs = self._parse_completion_logprobs( + data, top_k + ) + if seq_token_ids and seq_logprobs: + aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( + seq_token_ids, + seq_logprobs, + target_token_len=len(tokens), + prefix_token_len=prefix_token_len, + ) + token_id_results.append(aligned_ids) + logprob_results.append(aligned_lps) + continue + except Exception: + pass + + if messages_list and i < len(messages_list): + messages = list(messages_list[i]) + if self.config.teacher_system_prompt: + messages = [ + { + "role": "system", + "content": self.config.teacher_system_prompt, + } + ] + messages + else: + messages = [] + if self.config.teacher_system_prompt: + messages.append( + { + "role": "system", + "content": self.config.teacher_system_prompt, + } + ) + messages.append({"role": "user", "content": full_text}) + + chat_request = { + "model": model_name, + "messages": messages, + "max_tokens": 1, + "temperature": 1.0, + "logprobs": True, + "top_logprobs": top_k, + } + try: + async with session.post( + f"{self.config.teacher_base_url}/chat/completions", + json=chat_request, + headers=headers, + timeout=aiohttp.ClientTimeout(total=120), + ) as response: + if response.status == 200: + data = await response.json() + seq_token_ids, seq_logprobs = self._parse_chat_logprobs( + data, top_k + ) + if seq_token_ids and len(seq_token_ids) >= len(tokens): + aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( + seq_token_ids, + seq_logprobs, + target_token_len=len(tokens), + prefix_token_len=0, + ) + else: + aligned_ids = [[] for _ in range(len(tokens))] + aligned_lps = [[] for _ in range(len(tokens))] + token_id_results.append(aligned_ids) + logprob_results.append(aligned_lps) + else: + self.logger.warning( + "Teacher API returned %s", response.status + ) + token_id_results.append([[] for _ in range(len(tokens))]) + logprob_results.append([[] for _ in range(len(tokens))]) + except Exception as e: + self.logger.warning("Teacher chat request failed: %s", e) + token_id_results.append([[] for _ in range(len(tokens))]) + logprob_results.append([[] for _ in range(len(tokens))]) + + return token_id_results, logprob_results + except Exception as e: + self.logger.error("Error fetching teacher logprobs: %s", e) + return [], [] + + def _align_teacher_topk_to_tokens( + self, + seq_token_ids: List[List[int]], + seq_logprobs: List[List[float]], + target_token_len: int, + prefix_token_len: int = 0, + ) -> Tuple[List[List[int]], List[List[float]]]: + n = min(len(seq_token_ids), len(seq_logprobs)) + aligned_ids = list(seq_token_ids[:n]) + aligned_lps = list(seq_logprobs[:n]) + + if prefix_token_len > 0: + aligned_ids = aligned_ids[prefix_token_len:] + aligned_lps = aligned_lps[prefix_token_len:] + + aligned_ids = aligned_ids[:target_token_len] + aligned_lps = aligned_lps[:target_token_len] + + if len(aligned_ids) < target_token_len: + pad_count = target_token_len - len(aligned_ids) + aligned_ids.extend([[] for _ in range(pad_count)]) + aligned_lps.extend([[] for _ in range(pad_count)]) + + return aligned_ids, aligned_lps + + def _parse_completion_logprobs( + self, data: Dict, top_k: int + ) -> Tuple[List[List[int]], List[List[float]]]: + try: + choice = data.get("choices", [{}])[0] + logprobs_data = choice.get("logprobs", {}) + top_logprobs = logprobs_data.get("top_logprobs", []) + if not top_logprobs: + return [], [] + + seq_token_ids: List[List[int]] = [] + seq_logprobs: List[List[float]] = [] + for pos_logprobs in top_logprobs: + if pos_logprobs is None: + seq_token_ids.append([]) + seq_logprobs.append([]) + elif isinstance(pos_logprobs, dict): + sorted_items = sorted( + pos_logprobs.items(), key=lambda x: x[1], reverse=True + )[:top_k] + pos_ids: List[int] = [] + pos_lps: List[float] = [] + for token_str, logprob in sorted_items: + token_ids = self.tokenizer.encode( + token_str, add_special_tokens=False + ) + if token_ids: + pos_ids.append(int(token_ids[0])) + pos_lps.append(float(logprob)) + seq_token_ids.append(pos_ids) + seq_logprobs.append(pos_lps) + else: + seq_token_ids.append([]) + seq_logprobs.append([]) + return seq_token_ids, seq_logprobs + except Exception as e: + self.logger.warning("Error parsing completion logprobs: %s", e) + return [], [] + + def _parse_chat_logprobs( + self, data: Dict, top_k: int + ) -> Tuple[List[List[int]], List[List[float]]]: + try: + choice = data.get("choices", [{}])[0] + logprobs_data = choice.get("logprobs", {}) + if not logprobs_data: + return [], [] + + content = logprobs_data.get("content", []) + seq_token_ids: List[List[int]] = [] + seq_logprobs: List[List[float]] = [] + for token_data in content: + top_logprobs = token_data.get("top_logprobs", []) + pos_ids: List[int] = [] + pos_lps: List[float] = [] + for item in top_logprobs[:top_k]: + token_str = item.get("token", "") + logprob = item.get("logprob", 0.0) + token_ids = self.tokenizer.encode( + token_str, add_special_tokens=False + ) + if token_ids: + pos_ids.append(int(token_ids[0])) + pos_lps.append(float(logprob)) + seq_token_ids.append(pos_ids) + seq_logprobs.append(pos_lps) + return seq_token_ids, seq_logprobs + except Exception as e: + self.logger.warning("Error parsing chat logprobs: %s", e) + return [], []