logging the teacher step

This commit is contained in:
Jai Suphavadeeprasit 2026-03-23 11:52:51 -07:00
parent ee0cc6eeac
commit 7ed3cc6d1b

View file

@ -15,6 +15,7 @@ from __future__ import annotations
import asyncio
import logging
import time
import uuid
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union
@ -362,21 +363,65 @@ class TeacherDistillationEnv(BaseEnv, ABC):
) -> Tuple[List[List[int]], List[List[float]]]:
assert self.teacher_server is not None
start = time.time()
request_id = uuid.uuid4().hex[:8]
teacher_target = "unknown"
teacher_model = "unknown"
if getattr(self.teacher_server, "servers", None):
teacher_config = self.teacher_server.servers[0].config
teacher_target = getattr(teacher_config, "base_url", None) or "local"
teacher_model = getattr(teacher_config, "model_name", None) or "unknown"
logger.info(
"[TeacherDistill][%s] teacher request start seq_len=%s top_k=%s "
"teacher_model=%s teacher_target=%s",
request_id,
len(token_ids),
top_k,
teacher_model,
teacher_target,
)
print(
f"[TeacherDistill] requesting teacher logprobs: "
f"seq_len={len(token_ids)} top_k={top_k}",
f"[TeacherDistill][{request_id}] teacher request start "
f"seq_len={len(token_ids)} top_k={top_k} "
f"teacher_model={teacher_model} teacher_target={teacher_target}",
flush=True,
)
payload = await self.teacher_server.get_logprobs(
input_ids=token_ids,
top_k=top_k,
max_tokens=1,
split="train",
)
try:
payload = await self.teacher_server.get_logprobs(
input_ids=token_ids,
top_k=top_k,
max_tokens=1,
split="train",
)
except Exception:
elapsed = time.time() - start
logger.exception(
"[TeacherDistill][%s] teacher request failed after %.2fs",
request_id,
elapsed,
)
print(
f"[TeacherDistill][{request_id}] teacher request failed "
f"after {elapsed:.2f}s",
flush=True,
)
raise
elapsed = time.time() - start
prompt_positions = len(payload.get("prompt_topk_token_ids", []))
logger.info(
"[TeacherDistill][%s] teacher response received seq_len=%s top_k=%s "
"prompt_positions=%s elapsed=%.2fs",
request_id,
len(token_ids),
top_k,
prompt_positions,
elapsed,
)
print(
f"[TeacherDistill] received teacher logprobs: "
f"seq_len={len(token_ids)} top_k={top_k} elapsed={elapsed:.2f}s",
f"[TeacherDistill][{request_id}] teacher response received "
f"seq_len={len(token_ids)} top_k={top_k} "
f"prompt_positions={prompt_positions} elapsed={elapsed:.2f}s",
flush=True,
)
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]