This commit is contained in:
Jai Suphavadeeprasit 2026-02-19 23:57:47 -05:00
parent c1a80205cc
commit ccdd5a1ca6
3 changed files with 19 additions and 177 deletions

View file

@ -7,17 +7,12 @@ import asyncio
import os
import random
import re
import logging
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, List, Optional, Tuple, Union
import aiohttp
import wandb
from datasets import load_dataset
# Set up logging for debug
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from math_verify.errors import TimeoutException
@ -131,8 +126,6 @@ class MathEnv(BaseEnv):
slurm=True,
testing=False,
):
print("Initializing MathEnv")
print(f"Slurm: {slurm}, Testing: {testing}")
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
@ -143,17 +136,6 @@ class MathEnv(BaseEnv):
self.normal_rollouts = list()
self.pass_at_groupsize = list()
self.iter = 0
# Debug: Print distillation config
print("=" * 60)
print("[MATH_DEBUG] DISTILLATION CONFIGURATION:")
print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}")
print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}")
print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}")
print(f"[MATH_DEBUG] teacher_top_k = {getattr(config, 'teacher_top_k', 'N/A')}")
print(f"[MATH_DEBUG] teacher_prefix_text set = {bool(getattr(config, 'teacher_prefix_text', None))}")
print(f"[MATH_DEBUG] teacher_system_prompt set = {bool(getattr(config, 'teacher_system_prompt', None))}")
print("=" * 60)
@classmethod
def config_init(cls) -> Tuple[RSConfig, List[APIServerConfig]]:
@ -286,85 +268,7 @@ class MathEnv(BaseEnv):
name,
)
)
# Debug: Test teacher connectivity if distillation is enabled
if self.config.distillation_enabled and self.config.teacher_base_url:
await self._test_teacher_connectivity()
return
async def _test_teacher_connectivity(self):
"""Test if the teacher model API is reachable."""
print("=" * 60)
print("[MATH_DEBUG] TESTING TEACHER CONNECTIVITY...")
print(f"[MATH_DEBUG] Teacher URL: {self.config.teacher_base_url}")
print(f"[MATH_DEBUG] Teacher Model: {getattr(self.config, 'teacher_model_name', 'default')}")
try:
async with aiohttp.ClientSession() as session:
# Test 1: Health check
health_url = self.config.teacher_base_url.replace("/v1", "") + "/health"
print(f"[MATH_DEBUG] Testing health endpoint: {health_url}")
try:
async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
print(f"[MATH_DEBUG] Health check status: {resp.status}")
if resp.status == 200:
print("[MATH_DEBUG] ✓ Teacher health check PASSED")
else:
print(f"[MATH_DEBUG] ✗ Teacher health check FAILED: {await resp.text()}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher health check ERROR: {e}")
# Test 2: Models endpoint
models_url = f"{self.config.teacher_base_url}/models"
print(f"[MATH_DEBUG] Testing models endpoint: {models_url}")
try:
async with session.get(models_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
print(f"[MATH_DEBUG] Models endpoint status: {resp.status}")
if resp.status == 200:
data = await resp.json()
models = [m.get("id", m) for m in data.get("data", [])]
print(f"[MATH_DEBUG] ✓ Available models: {models}")
else:
print(f"[MATH_DEBUG] ✗ Models endpoint FAILED: {await resp.text()}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Models endpoint ERROR: {e}")
# Test 3: Simple completion test
completions_url = f"{self.config.teacher_base_url}/completions"
teacher_model = getattr(self.config, 'teacher_model_name', 'default')
test_payload = {
"model": teacher_model,
"prompt": "Hello",
"max_tokens": 5,
"logprobs": 5,
"echo": True,
}
print(f"[MATH_DEBUG] Testing completions endpoint: {completions_url}")
print(f"[MATH_DEBUG] Test payload: {test_payload}")
try:
async with session.post(
completions_url,
json=test_payload,
headers={"Content-Type": "application/json"},
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
print(f"[MATH_DEBUG] Completions status: {resp.status}")
resp_text = await resp.text()
if resp.status == 200:
print(f"[MATH_DEBUG] ✓ Teacher completions WORKING!")
print(f"[MATH_DEBUG] Response preview: {resp_text[:500]}")
else:
print(f"[MATH_DEBUG] ✗ Teacher completions FAILED: {resp_text[:500]}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher completions ERROR: {e}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher connectivity test FAILED: {e}")
import traceback
traceback.print_exc()
print("=" * 60)
async def rollout_and_score_eval(self, question, answer, subset):
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
@ -432,19 +336,6 @@ class MathEnv(BaseEnv):
end_time = time.time()
# Print results to console
print("\n" + "=" * 60)
print("Math Zero Evaluation Results")
print("=" * 60)
print(
f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})"
)
print("\nPer-subset breakdown:")
for subset, scores in sorted(task_lists.items()):
acc = sum(scores) / len(scores)
print(f" {subset}: {acc:.2%} ({sum(scores)}/{len(scores)})")
print("=" * 60 + "\n")
# Save results to disk
await self.evaluate_log(
metrics=metrics,
@ -549,7 +440,6 @@ class MathEnv(BaseEnv):
)
if len(self.normal_rollouts) > self.config.num_rollouts_to_keep:
self.normal_rollouts.pop(0)
print(f"Collected {len(to_postprocess['scores'])} trajectories")
return to_postprocess, to_backlog
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
@ -635,66 +525,7 @@ class MathEnv(BaseEnv):
]
)
# Debug: Log scored group creation
print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences")
print(f"[MATH_DEBUG] Scores: {scores['scores']}")
print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}")
has_new_distill = (
"distill_token_ids" in scores and "distill_logprobs" in scores
)
print(f"[MATH_DEBUG] Has distill arrays: {has_new_distill}")
return scores
async def handle_send_to_api(
self,
scored_data,
item=None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
"""Override to add debugging for distillation."""
print(f"[MATH_DEBUG] handle_send_to_api called")
print(f"[MATH_DEBUG] distillation_enabled: {self.config.distillation_enabled}")
print(f"[MATH_DEBUG] teacher_base_url: {self.config.teacher_base_url}")
if isinstance(scored_data, list):
for i, group in enumerate(scored_data):
if group:
has_distill = (
group.get("distill_token_ids") is not None
and group.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
elif scored_data:
has_distill = (
scored_data.get("distill_token_ids") is not None
and scored_data.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
# Call parent implementation which does the actual distillation fetch
result = await super().handle_send_to_api(
scored_data, item, do_send_to_api, abort_on_any_max_length_exceeded
)
# Debug: Check if distillation was added after parent call
if isinstance(scored_data, list):
for i, group in enumerate(scored_data):
if group:
has_distill = (
group.get("distill_token_ids") is not None
and group.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}")
elif scored_data:
has_distill = (
scored_data.get("distill_token_ids") is not None
and scored_data.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}")
return result
async def get_next_item(self):
while True:
@ -709,10 +540,7 @@ class MathEnv(BaseEnv):
)
break
except TypeError:
print(
f"Error in getting next item, trying again, "
f"data: {next_item['question']} -> {next_item['final_answer']}"
)
continue
return (prompt, answer, "normal")