debug changes

This commit is contained in:
Jai Suphavadeeprasit 2026-02-17 08:15:07 -05:00
parent 0510ca9b72
commit c89854a350
2 changed files with 166 additions and 7 deletions

View file

@ -6,11 +6,17 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
import asyncio
import random
import re
import logging
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, List, Optional, Tuple
import aiohttp
import wandb
from datasets import load_dataset
# Set up logging for debug
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from math_verify.errors import TimeoutException
@ -135,6 +141,15 @@ class MathEnv(BaseEnv):
self.normal_rollouts = list()
self.pass_at_groupsize = list()
self.iter = 0
# Debug: Print distillation config
print("=" * 60)
print("[MATH_DEBUG] DISTILLATION CONFIGURATION:")
print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}")
print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}")
print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}")
print(f"[MATH_DEBUG] teacher_top_logprobs = {getattr(config, 'teacher_top_logprobs', 'N/A')}")
print("=" * 60)
@classmethod
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
@ -252,7 +267,85 @@ class MathEnv(BaseEnv):
name,
)
)
# Debug: Test teacher connectivity if distillation is enabled
if self.config.distillation_enabled and self.config.teacher_base_url:
await self._test_teacher_connectivity()
return
async def _test_teacher_connectivity(self):
"""Test if the teacher model API is reachable."""
print("=" * 60)
print("[MATH_DEBUG] TESTING TEACHER CONNECTIVITY...")
print(f"[MATH_DEBUG] Teacher URL: {self.config.teacher_base_url}")
print(f"[MATH_DEBUG] Teacher Model: {getattr(self.config, 'teacher_model_name', 'default')}")
try:
async with aiohttp.ClientSession() as session:
# Test 1: Health check
health_url = self.config.teacher_base_url.replace("/v1", "") + "/health"
print(f"[MATH_DEBUG] Testing health endpoint: {health_url}")
try:
async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
print(f"[MATH_DEBUG] Health check status: {resp.status}")
if resp.status == 200:
print("[MATH_DEBUG] ✓ Teacher health check PASSED")
else:
print(f"[MATH_DEBUG] ✗ Teacher health check FAILED: {await resp.text()}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher health check ERROR: {e}")
# Test 2: Models endpoint
models_url = f"{self.config.teacher_base_url}/models"
print(f"[MATH_DEBUG] Testing models endpoint: {models_url}")
try:
async with session.get(models_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
print(f"[MATH_DEBUG] Models endpoint status: {resp.status}")
if resp.status == 200:
data = await resp.json()
models = [m.get("id", m) for m in data.get("data", [])]
print(f"[MATH_DEBUG] ✓ Available models: {models}")
else:
print(f"[MATH_DEBUG] ✗ Models endpoint FAILED: {await resp.text()}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Models endpoint ERROR: {e}")
# Test 3: Simple completion test
completions_url = f"{self.config.teacher_base_url}/completions"
teacher_model = getattr(self.config, 'teacher_model_name', 'default')
test_payload = {
"model": teacher_model,
"prompt": "Hello",
"max_tokens": 5,
"logprobs": 5,
"echo": True,
}
print(f"[MATH_DEBUG] Testing completions endpoint: {completions_url}")
print(f"[MATH_DEBUG] Test payload: {test_payload}")
try:
async with session.post(
completions_url,
json=test_payload,
headers={"Content-Type": "application/json"},
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
print(f"[MATH_DEBUG] Completions status: {resp.status}")
resp_text = await resp.text()
if resp.status == 200:
print(f"[MATH_DEBUG] ✓ Teacher completions WORKING!")
print(f"[MATH_DEBUG] Response preview: {resp_text[:500]}")
else:
print(f"[MATH_DEBUG] ✗ Teacher completions FAILED: {resp_text[:500]}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher completions ERROR: {e}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher connectivity test FAILED: {e}")
import traceback
traceback.print_exc()
print("=" * 60)
async def rollout_and_score_eval(self, question, answer, subset):
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
@ -482,7 +575,52 @@ class MathEnv(BaseEnv):
and (not scores["overrides"][i].get("set_advantage_to_zero", False))
]
)
# Debug: Log scored group creation
print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences")
print(f"[MATH_DEBUG] Scores: {scores['scores']}")
print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}")
print(f"[MATH_DEBUG] Has onpolicydistill_logprobs: {'onpolicydistill_logprobs' in scores}")
return scores
async def handle_send_to_api(
self,
scored_data,
item=None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
"""Override to add debugging for distillation."""
print(f"[MATH_DEBUG] handle_send_to_api called")
print(f"[MATH_DEBUG] distillation_enabled: {self.config.distillation_enabled}")
print(f"[MATH_DEBUG] teacher_base_url: {self.config.teacher_base_url}")
if isinstance(scored_data, list):
for i, group in enumerate(scored_data):
if group:
has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None
print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
elif scored_data:
has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None
print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
# Call parent implementation which does the actual distillation fetch
result = await super().handle_send_to_api(
scored_data, item, do_send_to_api, abort_on_any_max_length_exceeded
)
# Debug: Check if distillation was added after parent call
if isinstance(scored_data, list):
for i, group in enumerate(scored_data):
if group:
has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None
print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}")
elif scored_data:
has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None
print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}")
return result
async def get_next_item(self):
while True: