mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-26 17:13:09 +00:00
change OPD style
This commit is contained in:
parent
33f5696171
commit
527433b5bc
10 changed files with 452 additions and 90 deletions
|
|
@ -150,7 +150,9 @@ class MathEnv(BaseEnv):
|
|||
print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}")
|
||||
print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}")
|
||||
print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}")
|
||||
print(f"[MATH_DEBUG] teacher_top_logprobs = {getattr(config, 'teacher_top_logprobs', 'N/A')}")
|
||||
print(f"[MATH_DEBUG] teacher_top_k = {getattr(config, 'teacher_top_k', 'N/A')}")
|
||||
print(f"[MATH_DEBUG] teacher_prefix_text set = {bool(getattr(config, 'teacher_prefix_text', None))}")
|
||||
print(f"[MATH_DEBUG] teacher_system_prompt set = {bool(getattr(config, 'teacher_system_prompt', None))}")
|
||||
print("=" * 60)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -637,7 +639,10 @@ class MathEnv(BaseEnv):
|
|||
print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences")
|
||||
print(f"[MATH_DEBUG] Scores: {scores['scores']}")
|
||||
print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}")
|
||||
print(f"[MATH_DEBUG] Has onpolicydistill_logprobs: {'onpolicydistill_logprobs' in scores}")
|
||||
has_new_distill = (
|
||||
"distill_token_ids" in scores and "distill_logprobs" in scores
|
||||
)
|
||||
print(f"[MATH_DEBUG] Has distill arrays: {has_new_distill}")
|
||||
|
||||
return scores
|
||||
|
||||
|
|
@ -656,10 +661,16 @@ class MathEnv(BaseEnv):
|
|||
if isinstance(scored_data, list):
|
||||
for i, group in enumerate(scored_data):
|
||||
if group:
|
||||
has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None
|
||||
has_distill = (
|
||||
group.get("distill_token_ids") is not None
|
||||
and group.get("distill_logprobs") is not None
|
||||
)
|
||||
print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
|
||||
elif scored_data:
|
||||
has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None
|
||||
has_distill = (
|
||||
scored_data.get("distill_token_ids") is not None
|
||||
and scored_data.get("distill_logprobs") is not None
|
||||
)
|
||||
print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
|
||||
|
||||
# Call parent implementation which does the actual distillation fetch
|
||||
|
|
@ -671,10 +682,16 @@ class MathEnv(BaseEnv):
|
|||
if isinstance(scored_data, list):
|
||||
for i, group in enumerate(scored_data):
|
||||
if group:
|
||||
has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None
|
||||
has_distill = (
|
||||
group.get("distill_token_ids") is not None
|
||||
and group.get("distill_logprobs") is not None
|
||||
)
|
||||
print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}")
|
||||
elif scored_data:
|
||||
has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None
|
||||
has_distill = (
|
||||
scored_data.get("distill_token_ids") is not None
|
||||
and scored_data.get("distill_logprobs") is not None
|
||||
)
|
||||
print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}")
|
||||
|
||||
return result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue