diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 7874f466..ee2ab8b4 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio import logging import time +import uuid from abc import ABC from typing import Any, Dict, List, Optional, Tuple, Union @@ -362,21 +363,65 @@ class TeacherDistillationEnv(BaseEnv, ABC): ) -> Tuple[List[List[int]], List[List[float]]]: assert self.teacher_server is not None start = time.time() + request_id = uuid.uuid4().hex[:8] + teacher_target = "unknown" + teacher_model = "unknown" + if getattr(self.teacher_server, "servers", None): + teacher_config = self.teacher_server.servers[0].config + teacher_target = getattr(teacher_config, "base_url", None) or "local" + teacher_model = getattr(teacher_config, "model_name", None) or "unknown" + + logger.info( + "[TeacherDistill][%s] teacher request start seq_len=%s top_k=%s " + "teacher_model=%s teacher_target=%s", + request_id, + len(token_ids), + top_k, + teacher_model, + teacher_target, + ) print( - f"[TeacherDistill] requesting teacher logprobs: " - f"seq_len={len(token_ids)} top_k={top_k}", + f"[TeacherDistill][{request_id}] teacher request start " + f"seq_len={len(token_ids)} top_k={top_k} " + f"teacher_model={teacher_model} teacher_target={teacher_target}", flush=True, ) - payload = await self.teacher_server.get_logprobs( - input_ids=token_ids, - top_k=top_k, - max_tokens=1, - split="train", - ) + try: + payload = await self.teacher_server.get_logprobs( + input_ids=token_ids, + top_k=top_k, + max_tokens=1, + split="train", + ) + except Exception: + elapsed = time.time() - start + logger.exception( + "[TeacherDistill][%s] teacher request failed after %.2fs", + request_id, + elapsed, + ) + print( + f"[TeacherDistill][{request_id}] teacher request failed " + f"after {elapsed:.2f}s", + flush=True, + ) + raise + elapsed = time.time() - start + prompt_positions = len(payload.get("prompt_topk_token_ids", [])) + logger.info( + "[TeacherDistill][%s] teacher response received seq_len=%s top_k=%s " + "prompt_positions=%s elapsed=%.2fs", + request_id, + len(token_ids), + top_k, + prompt_positions, + elapsed, + ) print( - f"[TeacherDistill] received teacher logprobs: " - f"seq_len={len(token_ids)} top_k={top_k} elapsed={elapsed:.2f}s", + f"[TeacherDistill][{request_id}] teacher response received " + f"seq_len={len(token_ids)} top_k={top_k} " + f"prompt_positions={prompt_positions} elapsed={elapsed:.2f}s", flush=True, ) return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]