mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
linting
This commit is contained in:
parent
2eaff54351
commit
e932369777
7 changed files with 32 additions and 38 deletions
|
|
@ -140,10 +140,10 @@ def compute_grpo_loss(
|
|||
# Check if ref logprobs are negative (as they should be for generated tokens)
|
||||
# If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used
|
||||
if ref_at_generated > 0.5:
|
||||
print(f" [WARNING] ref_logprobs at generated positions avg {ref_at_generated:.3f} (should be negative!)")
|
||||
print(f" [WARNING] This suggests inference_logprobs alignment is still wrong")
|
||||
print(f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)")
|
||||
print(" [WARNING] This suggests inference_logprobs alignment is wrong")
|
||||
elif abs(ref_at_generated - train_at_generated) > 2.0:
|
||||
print(f" [DEBUG] Logprob gap (may be OK for first step): ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
|
||||
print(f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}")
|
||||
|
||||
# Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old)
|
||||
log_ratio = logp_per_token - ref_logprobs
|
||||
|
|
@ -277,8 +277,6 @@ def run_training_step(
|
|||
Returns:
|
||||
Dict of training metrics for this step
|
||||
"""
|
||||
global _logprob_alignment_stats
|
||||
|
||||
total_loss = 0.0
|
||||
total_pos_logp = 0.0
|
||||
total_neg_logp = 0.0
|
||||
|
|
@ -399,8 +397,6 @@ def log_metrics(
|
|||
extra_metrics: Optional additional metrics to log
|
||||
benchmark: Whether to show timing/benchmark info
|
||||
"""
|
||||
global _logprob_alignment_stats
|
||||
|
||||
# Build timing string (only if benchmark enabled)
|
||||
timing_str = ""
|
||||
if benchmark:
|
||||
|
|
@ -462,7 +458,7 @@ def log_metrics(
|
|||
if key in metrics:
|
||||
log_dict[f"train/{key}"] = metrics[key]
|
||||
|
||||
# Add logprob alignment stats (key for shared_vllm validation!)
|
||||
# Add logprob alignment stats
|
||||
if _logprob_alignment_stats:
|
||||
log_dict.update(_logprob_alignment_stats)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue