mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
post merge changes
This commit is contained in:
parent
c89854a350
commit
79e392c446
3 changed files with 200 additions and 88 deletions
|
|
@ -148,7 +148,9 @@ class MathEnv(BaseEnv):
|
|||
print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}")
|
||||
print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}")
|
||||
print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}")
|
||||
print(f"[MATH_DEBUG] teacher_top_logprobs = {getattr(config, 'teacher_top_logprobs', 'N/A')}")
|
||||
print(f"[MATH_DEBUG] teacher_top_k = {getattr(config, 'teacher_top_k', 'N/A')}")
|
||||
print(f"[MATH_DEBUG] teacher_prefix_text set = {bool(getattr(config, 'teacher_prefix_text', None))}")
|
||||
print(f"[MATH_DEBUG] teacher_system_prompt set = {bool(getattr(config, 'teacher_system_prompt', None))}")
|
||||
print("=" * 60)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -580,7 +582,10 @@ class MathEnv(BaseEnv):
|
|||
print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences")
|
||||
print(f"[MATH_DEBUG] Scores: {scores['scores']}")
|
||||
print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}")
|
||||
print(f"[MATH_DEBUG] Has onpolicydistill_logprobs: {'onpolicydistill_logprobs' in scores}")
|
||||
has_new_distill = (
|
||||
"distill_token_ids" in scores and "distill_logprobs" in scores
|
||||
)
|
||||
print(f"[MATH_DEBUG] Has distill arrays: {has_new_distill}")
|
||||
|
||||
return scores
|
||||
|
||||
|
|
@ -599,10 +604,16 @@ class MathEnv(BaseEnv):
|
|||
if isinstance(scored_data, list):
|
||||
for i, group in enumerate(scored_data):
|
||||
if group:
|
||||
has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None
|
||||
has_distill = (
|
||||
group.get("distill_token_ids") is not None
|
||||
and group.get("distill_logprobs") is not None
|
||||
)
|
||||
print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
|
||||
elif scored_data:
|
||||
has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None
|
||||
has_distill = (
|
||||
scored_data.get("distill_token_ids") is not None
|
||||
and scored_data.get("distill_logprobs") is not None
|
||||
)
|
||||
print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
|
||||
|
||||
# Call parent implementation which does the actual distillation fetch
|
||||
|
|
@ -614,10 +625,16 @@ class MathEnv(BaseEnv):
|
|||
if isinstance(scored_data, list):
|
||||
for i, group in enumerate(scored_data):
|
||||
if group:
|
||||
has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None
|
||||
has_distill = (
|
||||
group.get("distill_token_ids") is not None
|
||||
and group.get("distill_logprobs") is not None
|
||||
)
|
||||
print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}")
|
||||
elif scored_data:
|
||||
has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None
|
||||
has_distill = (
|
||||
scored_data.get("distill_token_ids") is not None
|
||||
and scored_data.get("distill_logprobs") is not None
|
||||
)
|
||||
print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}")
|
||||
|
||||
return result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue