mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
base env debugging
This commit is contained in:
parent
0e81c62e90
commit
b0658f6327
1 changed files with 19 additions and 3 deletions
|
|
@ -361,7 +361,11 @@ class BaseEnv(ABC):
|
|||
Structure: [batch][position][top_k] = [token_id, logprob]
|
||||
Returns empty list if teacher_base_url is not configured.
|
||||
"""
|
||||
logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences")
|
||||
logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}")
|
||||
|
||||
if not self.config.teacher_base_url:
|
||||
logger.warning("[TEACHER] No teacher_base_url configured, returning empty")
|
||||
return []
|
||||
|
||||
if top_k is None:
|
||||
|
|
@ -371,6 +375,8 @@ class BaseEnv(ABC):
|
|||
api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "")
|
||||
model_name = self.config.teacher_model_name or "default"
|
||||
|
||||
logger.info(f"[TEACHER] Using model={model_name}, top_k={top_k}")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
|
@ -380,6 +386,7 @@ class BaseEnv(ABC):
|
|||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i, tokens in enumerate(token_sequences):
|
||||
logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens")
|
||||
# Decode tokens to text
|
||||
full_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
|
||||
|
||||
|
|
@ -1124,8 +1131,11 @@ class BaseEnv(ABC):
|
|||
]
|
||||
|
||||
# Automatic on-policy distillation: fetch teacher logprobs if enabled
|
||||
logger.info(f"[DISTILL DEBUG] distillation_enabled={self.config.distillation_enabled}, teacher_base_url={self.config.teacher_base_url}")
|
||||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||
logger.info(f"[DISTILL DEBUG] Distillation is enabled! Checking for existing logprobs...")
|
||||
if group.get("onpolicydistill_logprobs") is None:
|
||||
logger.info(f"[DISTILL DEBUG] No existing logprobs, fetching from teacher...")
|
||||
try:
|
||||
teacher_logprobs = await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
|
|
@ -1133,11 +1143,17 @@ class BaseEnv(ABC):
|
|||
)
|
||||
if teacher_logprobs:
|
||||
group["onpolicydistill_logprobs"] = teacher_logprobs
|
||||
logger.debug(
|
||||
f"Added teacher logprobs for {len(teacher_logprobs)} sequences"
|
||||
logger.info(
|
||||
f"[DISTILL DEBUG] Added teacher logprobs for {len(teacher_logprobs)} sequences"
|
||||
)
|
||||
else:
|
||||
logger.warning("[DISTILL DEBUG] get_teacher_logprobs returned empty!")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch teacher logprobs: {e}")
|
||||
logger.error(f"[DISTILL DEBUG] Failed to fetch teacher logprobs: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.debug(f"[DISTILL DEBUG] Distillation skipped - not enabled or no teacher URL")
|
||||
|
||||
await self.add_rollouts_for_wandb(group, item)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue