mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
logging the teacher step
This commit is contained in:
parent
ee0cc6eeac
commit
7ed3cc6d1b
1 changed files with 55 additions and 10 deletions
|
|
@ -15,6 +15,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
|
@ -362,21 +363,65 @@ class TeacherDistillationEnv(BaseEnv, ABC):
|
|||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
assert self.teacher_server is not None
|
||||
start = time.time()
|
||||
request_id = uuid.uuid4().hex[:8]
|
||||
teacher_target = "unknown"
|
||||
teacher_model = "unknown"
|
||||
if getattr(self.teacher_server, "servers", None):
|
||||
teacher_config = self.teacher_server.servers[0].config
|
||||
teacher_target = getattr(teacher_config, "base_url", None) or "local"
|
||||
teacher_model = getattr(teacher_config, "model_name", None) or "unknown"
|
||||
|
||||
logger.info(
|
||||
"[TeacherDistill][%s] teacher request start seq_len=%s top_k=%s "
|
||||
"teacher_model=%s teacher_target=%s",
|
||||
request_id,
|
||||
len(token_ids),
|
||||
top_k,
|
||||
teacher_model,
|
||||
teacher_target,
|
||||
)
|
||||
print(
|
||||
f"[TeacherDistill] requesting teacher logprobs: "
|
||||
f"seq_len={len(token_ids)} top_k={top_k}",
|
||||
f"[TeacherDistill][{request_id}] teacher request start "
|
||||
f"seq_len={len(token_ids)} top_k={top_k} "
|
||||
f"teacher_model={teacher_model} teacher_target={teacher_target}",
|
||||
flush=True,
|
||||
)
|
||||
payload = await self.teacher_server.get_logprobs(
|
||||
input_ids=token_ids,
|
||||
top_k=top_k,
|
||||
max_tokens=1,
|
||||
split="train",
|
||||
)
|
||||
try:
|
||||
payload = await self.teacher_server.get_logprobs(
|
||||
input_ids=token_ids,
|
||||
top_k=top_k,
|
||||
max_tokens=1,
|
||||
split="train",
|
||||
)
|
||||
except Exception:
|
||||
elapsed = time.time() - start
|
||||
logger.exception(
|
||||
"[TeacherDistill][%s] teacher request failed after %.2fs",
|
||||
request_id,
|
||||
elapsed,
|
||||
)
|
||||
print(
|
||||
f"[TeacherDistill][{request_id}] teacher request failed "
|
||||
f"after {elapsed:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
raise
|
||||
|
||||
elapsed = time.time() - start
|
||||
prompt_positions = len(payload.get("prompt_topk_token_ids", []))
|
||||
logger.info(
|
||||
"[TeacherDistill][%s] teacher response received seq_len=%s top_k=%s "
|
||||
"prompt_positions=%s elapsed=%.2fs",
|
||||
request_id,
|
||||
len(token_ids),
|
||||
top_k,
|
||||
prompt_positions,
|
||||
elapsed,
|
||||
)
|
||||
print(
|
||||
f"[TeacherDistill] received teacher logprobs: "
|
||||
f"seq_len={len(token_ids)} top_k={top_k} elapsed={elapsed:.2f}s",
|
||||
f"[TeacherDistill][{request_id}] teacher response received "
|
||||
f"seq_len={len(token_ids)} top_k={top_k} "
|
||||
f"prompt_positions={prompt_positions} elapsed={elapsed:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue