base env debugging

This commit is contained in:
Jai Suphavadeeprasit 2026-02-16 21:05:57 -05:00
parent 0e81c62e90
commit b0658f6327

View file

@ -361,7 +361,11 @@ class BaseEnv(ABC):
Structure: [batch][position][top_k] = [token_id, logprob]
Returns empty list if teacher_base_url is not configured.
"""
logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences")
logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}")
if not self.config.teacher_base_url:
logger.warning("[TEACHER] No teacher_base_url configured, returning empty")
return []
if top_k is None:
@ -371,6 +375,8 @@ class BaseEnv(ABC):
api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "")
model_name = self.config.teacher_model_name or "default"
logger.info(f"[TEACHER] Using model={model_name}, top_k={top_k}")
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
@ -380,6 +386,7 @@ class BaseEnv(ABC):
try:
async with aiohttp.ClientSession() as session:
for i, tokens in enumerate(token_sequences):
logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens")
# Decode tokens to text
full_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
@ -1124,8 +1131,11 @@ class BaseEnv(ABC):
]
# Automatic on-policy distillation: fetch teacher logprobs if enabled
logger.info(f"[DISTILL DEBUG] distillation_enabled={self.config.distillation_enabled}, teacher_base_url={self.config.teacher_base_url}")
if self.config.distillation_enabled and self.config.teacher_base_url:
logger.info(f"[DISTILL DEBUG] Distillation is enabled! Checking for existing logprobs...")
if group.get("onpolicydistill_logprobs") is None:
logger.info(f"[DISTILL DEBUG] No existing logprobs, fetching from teacher...")
try:
teacher_logprobs = await self.get_teacher_logprobs(
token_sequences=group["tokens"],
@ -1133,11 +1143,17 @@ class BaseEnv(ABC):
)
if teacher_logprobs:
group["onpolicydistill_logprobs"] = teacher_logprobs
logger.debug(
f"Added teacher logprobs for {len(teacher_logprobs)} sequences"
logger.info(
f"[DISTILL DEBUG] Added teacher logprobs for {len(teacher_logprobs)} sequences"
)
else:
logger.warning("[DISTILL DEBUG] get_teacher_logprobs returned empty!")
except Exception as e:
logger.warning(f"Failed to fetch teacher logprobs: {e}")
logger.error(f"[DISTILL DEBUG] Failed to fetch teacher logprobs: {e}")
import traceback
logger.error(traceback.format_exc())
else:
logger.debug(f"[DISTILL DEBUG] Distillation skipped - not enabled or no teacher URL")
await self.add_rollouts_for_wandb(group, item)