diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index cf95d3f8..2a39b267 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -1124,7 +1124,17 @@ def log_metrics( if "gpu_memory_gb" in metrics: timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB" - print(f" Loss: {metrics['loss']:.4f}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}") + # Show loss with more precision since GRPO loss is often very small + loss_str = f"{metrics['loss']:.6f}" if abs(metrics['loss']) < 0.01 else f"{metrics['loss']:.4f}" + print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}") + + # Show GRPO-specific metrics if available + if "pos_count" in metrics or "neg_count" in metrics: + pos_count = metrics.get('pos_count', 0) + neg_count = metrics.get('neg_count', 0) + pos_logp = metrics.get('pos_logp', 0) + neg_logp = metrics.get('neg_logp', 0) + print(f" Advantages: +{int(pos_count)} / -{int(neg_count)}, LogP: pos={pos_logp:.3f}, neg={neg_logp:.3f}") if use_wandb: log_dict = {