post merge changes

This commit is contained in:
Jai Suphavadeeprasit 2026-02-20 00:32:47 -05:00
parent c89854a350
commit 79e392c446
3 changed files with 200 additions and 88 deletions

View file

@ -148,7 +148,9 @@ class MathEnv(BaseEnv):
print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}")
print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}")
print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}")
print(f"[MATH_DEBUG] teacher_top_logprobs = {getattr(config, 'teacher_top_logprobs', 'N/A')}")
print(f"[MATH_DEBUG] teacher_top_k = {getattr(config, 'teacher_top_k', 'N/A')}")
print(f"[MATH_DEBUG] teacher_prefix_text set = {bool(getattr(config, 'teacher_prefix_text', None))}")
print(f"[MATH_DEBUG] teacher_system_prompt set = {bool(getattr(config, 'teacher_system_prompt', None))}")
print("=" * 60)
@classmethod
@ -580,7 +582,10 @@ class MathEnv(BaseEnv):
print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences")
print(f"[MATH_DEBUG] Scores: {scores['scores']}")
print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}")
print(f"[MATH_DEBUG] Has onpolicydistill_logprobs: {'onpolicydistill_logprobs' in scores}")
has_new_distill = (
"distill_token_ids" in scores and "distill_logprobs" in scores
)
print(f"[MATH_DEBUG] Has distill arrays: {has_new_distill}")
return scores
@ -599,10 +604,16 @@ class MathEnv(BaseEnv):
if isinstance(scored_data, list):
for i, group in enumerate(scored_data):
if group:
has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None
has_distill = (
group.get("distill_token_ids") is not None
and group.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
elif scored_data:
has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None
has_distill = (
scored_data.get("distill_token_ids") is not None
and scored_data.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
# Call parent implementation which does the actual distillation fetch
@ -614,10 +625,16 @@ class MathEnv(BaseEnv):
if isinstance(scored_data, list):
for i, group in enumerate(scored_data):
if group:
has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None
has_distill = (
group.get("distill_token_ids") is not None
and group.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}")
elif scored_data:
has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None
has_distill = (
scored_data.get("distill_token_ids") is not None
and scored_data.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}")
return result