mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
debug changes
This commit is contained in:
parent
0510ca9b72
commit
c89854a350
2 changed files with 166 additions and 7 deletions
|
|
@ -6,11 +6,17 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
|
|||
import asyncio
|
||||
import random
|
||||
import re
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
|
||||
# Set up logging for debug
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from math_verify.errors import TimeoutException
|
||||
|
|
@ -135,6 +141,15 @@ class MathEnv(BaseEnv):
|
|||
self.normal_rollouts = list()
|
||||
self.pass_at_groupsize = list()
|
||||
self.iter = 0
|
||||
|
||||
# Debug: Print distillation config
|
||||
print("=" * 60)
|
||||
print("[MATH_DEBUG] DISTILLATION CONFIGURATION:")
|
||||
print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}")
|
||||
print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}")
|
||||
print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}")
|
||||
print(f"[MATH_DEBUG] teacher_top_logprobs = {getattr(config, 'teacher_top_logprobs', 'N/A')}")
|
||||
print("=" * 60)
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
|
||||
|
|
@ -252,7 +267,85 @@ class MathEnv(BaseEnv):
|
|||
name,
|
||||
)
|
||||
)
|
||||
|
||||
# Debug: Test teacher connectivity if distillation is enabled
|
||||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||
await self._test_teacher_connectivity()
|
||||
|
||||
return
|
||||
|
||||
async def _test_teacher_connectivity(self):
|
||||
"""Test if the teacher model API is reachable."""
|
||||
print("=" * 60)
|
||||
print("[MATH_DEBUG] TESTING TEACHER CONNECTIVITY...")
|
||||
print(f"[MATH_DEBUG] Teacher URL: {self.config.teacher_base_url}")
|
||||
print(f"[MATH_DEBUG] Teacher Model: {getattr(self.config, 'teacher_model_name', 'default')}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test 1: Health check
|
||||
health_url = self.config.teacher_base_url.replace("/v1", "") + "/health"
|
||||
print(f"[MATH_DEBUG] Testing health endpoint: {health_url}")
|
||||
try:
|
||||
async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||
print(f"[MATH_DEBUG] Health check status: {resp.status}")
|
||||
if resp.status == 200:
|
||||
print("[MATH_DEBUG] ✓ Teacher health check PASSED")
|
||||
else:
|
||||
print(f"[MATH_DEBUG] ✗ Teacher health check FAILED: {await resp.text()}")
|
||||
except Exception as e:
|
||||
print(f"[MATH_DEBUG] ✗ Teacher health check ERROR: {e}")
|
||||
|
||||
# Test 2: Models endpoint
|
||||
models_url = f"{self.config.teacher_base_url}/models"
|
||||
print(f"[MATH_DEBUG] Testing models endpoint: {models_url}")
|
||||
try:
|
||||
async with session.get(models_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||
print(f"[MATH_DEBUG] Models endpoint status: {resp.status}")
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
models = [m.get("id", m) for m in data.get("data", [])]
|
||||
print(f"[MATH_DEBUG] ✓ Available models: {models}")
|
||||
else:
|
||||
print(f"[MATH_DEBUG] ✗ Models endpoint FAILED: {await resp.text()}")
|
||||
except Exception as e:
|
||||
print(f"[MATH_DEBUG] ✗ Models endpoint ERROR: {e}")
|
||||
|
||||
# Test 3: Simple completion test
|
||||
completions_url = f"{self.config.teacher_base_url}/completions"
|
||||
teacher_model = getattr(self.config, 'teacher_model_name', 'default')
|
||||
test_payload = {
|
||||
"model": teacher_model,
|
||||
"prompt": "Hello",
|
||||
"max_tokens": 5,
|
||||
"logprobs": 5,
|
||||
"echo": True,
|
||||
}
|
||||
print(f"[MATH_DEBUG] Testing completions endpoint: {completions_url}")
|
||||
print(f"[MATH_DEBUG] Test payload: {test_payload}")
|
||||
try:
|
||||
async with session.post(
|
||||
completions_url,
|
||||
json=test_payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
print(f"[MATH_DEBUG] Completions status: {resp.status}")
|
||||
resp_text = await resp.text()
|
||||
if resp.status == 200:
|
||||
print(f"[MATH_DEBUG] ✓ Teacher completions WORKING!")
|
||||
print(f"[MATH_DEBUG] Response preview: {resp_text[:500]}")
|
||||
else:
|
||||
print(f"[MATH_DEBUG] ✗ Teacher completions FAILED: {resp_text[:500]}")
|
||||
except Exception as e:
|
||||
print(f"[MATH_DEBUG] ✗ Teacher completions ERROR: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[MATH_DEBUG] ✗ Teacher connectivity test FAILED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
async def rollout_and_score_eval(self, question, answer, subset):
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
|
|
@ -482,7 +575,52 @@ class MathEnv(BaseEnv):
|
|||
and (not scores["overrides"][i].get("set_advantage_to_zero", False))
|
||||
]
|
||||
)
|
||||
|
||||
# Debug: Log scored group creation
|
||||
print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences")
|
||||
print(f"[MATH_DEBUG] Scores: {scores['scores']}")
|
||||
print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}")
|
||||
print(f"[MATH_DEBUG] Has onpolicydistill_logprobs: {'onpolicydistill_logprobs' in scores}")
|
||||
|
||||
return scores
|
||||
|
||||
async def handle_send_to_api(
|
||||
self,
|
||||
scored_data,
|
||||
item=None,
|
||||
do_send_to_api: bool = True,
|
||||
abort_on_any_max_length_exceeded: bool = True,
|
||||
):
|
||||
"""Override to add debugging for distillation."""
|
||||
print(f"[MATH_DEBUG] handle_send_to_api called")
|
||||
print(f"[MATH_DEBUG] distillation_enabled: {self.config.distillation_enabled}")
|
||||
print(f"[MATH_DEBUG] teacher_base_url: {self.config.teacher_base_url}")
|
||||
|
||||
if isinstance(scored_data, list):
|
||||
for i, group in enumerate(scored_data):
|
||||
if group:
|
||||
has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None
|
||||
print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
|
||||
elif scored_data:
|
||||
has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None
|
||||
print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
|
||||
|
||||
# Call parent implementation which does the actual distillation fetch
|
||||
result = await super().handle_send_to_api(
|
||||
scored_data, item, do_send_to_api, abort_on_any_max_length_exceeded
|
||||
)
|
||||
|
||||
# Debug: Check if distillation was added after parent call
|
||||
if isinstance(scored_data, list):
|
||||
for i, group in enumerate(scored_data):
|
||||
if group:
|
||||
has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None
|
||||
print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}")
|
||||
elif scored_data:
|
||||
has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None
|
||||
print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}")
|
||||
|
||||
return result
|
||||
|
||||
async def get_next_item(self):
|
||||
while True:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue