diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index ebf27eeb..a3ac2117 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -209,8 +209,12 @@ class BaseEnvConfig(BaseModel): "no thinking prompt is injected. Use HERMES_REASONING_PROMPT from " "eval_helpers for the standard Hermes reasoning prompt.", ) - # On-policy distillation settings + distillation_enabled: bool = Field( + default=False, + description="Enable on-policy distillation. When True, automatically fetches teacher logprobs " + "after scoring and includes them in data sent to trainer.", + ) teacher_base_url: Optional[str] = Field( default=None, description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API " @@ -226,14 +230,9 @@ class BaseEnvConfig(BaseModel): description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.", ) teacher_top_k: int = Field( - default=10, + default=20, description="Number of top logprobs to fetch from teacher model per position.", ) - distillation_enabled: bool = Field( - default=False, - description="Enable on-policy distillation. When True, automatically fetches teacher logprobs " - "after scoring and includes them in data sent to trainer.", - ) class BaseEnv(ABC): @@ -1164,6 +1163,28 @@ class BaseEnv(ABC): valid_groups.append(group) if valid_groups and do_send_to_api: + # On-policy distillation: fetch teacher logprobs if enabled + if self.config.distillation_enabled and self.config.teacher_base_url: + logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups") + for group in valid_groups: + if group.get("onpolicydistill_logprobs") is None: + try: + teacher_logprobs = await self.get_teacher_logprobs( + token_sequences=group["tokens"], + messages_list=group.get("messages"), + ) + if teacher_logprobs: + group["onpolicydistill_logprobs"] = teacher_logprobs + logger.info(f"[DISTILL] Added teacher logprobs for {len(teacher_logprobs)} sequences") + else: + logger.warning("[DISTILL] get_teacher_logprobs returned empty") + except Exception as e: + logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}") + import traceback + logger.error(traceback.format_exc()) + else: + logger.debug(f"[DISTILL] Skipped - enabled={self.config.distillation_enabled}, url={self.config.teacher_base_url}") + data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]] # send single or list of scored data groups if not original_was_list and len(valid_groups) == 1: diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 1432ab4d..932df9dc 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -6,11 +6,17 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero import asyncio import random import re +import logging from concurrent.futures import ProcessPoolExecutor from typing import Dict, List, Optional, Tuple +import aiohttp import wandb from datasets import load_dataset + +# Set up logging for debug +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify from math_verify.errors import TimeoutException @@ -135,6 +141,15 @@ class MathEnv(BaseEnv): self.normal_rollouts = list() self.pass_at_groupsize = list() self.iter = 0 + + # Debug: Print distillation config + print("=" * 60) + print("[MATH_DEBUG] DISTILLATION CONFIGURATION:") + print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}") + print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}") + print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}") + print(f"[MATH_DEBUG] teacher_top_logprobs = {getattr(config, 'teacher_top_logprobs', 'N/A')}") + print("=" * 60) @classmethod def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: @@ -252,7 +267,85 @@ class MathEnv(BaseEnv): name, ) ) + + # Debug: Test teacher connectivity if distillation is enabled + if self.config.distillation_enabled and self.config.teacher_base_url: + await self._test_teacher_connectivity() + return + + async def _test_teacher_connectivity(self): + """Test if the teacher model API is reachable.""" + print("=" * 60) + print("[MATH_DEBUG] TESTING TEACHER CONNECTIVITY...") + print(f"[MATH_DEBUG] Teacher URL: {self.config.teacher_base_url}") + print(f"[MATH_DEBUG] Teacher Model: {getattr(self.config, 'teacher_model_name', 'default')}") + + try: + async with aiohttp.ClientSession() as session: + # Test 1: Health check + health_url = self.config.teacher_base_url.replace("/v1", "") + "/health" + print(f"[MATH_DEBUG] Testing health endpoint: {health_url}") + try: + async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as resp: + print(f"[MATH_DEBUG] Health check status: {resp.status}") + if resp.status == 200: + print("[MATH_DEBUG] ✓ Teacher health check PASSED") + else: + print(f"[MATH_DEBUG] ✗ Teacher health check FAILED: {await resp.text()}") + except Exception as e: + print(f"[MATH_DEBUG] ✗ Teacher health check ERROR: {e}") + + # Test 2: Models endpoint + models_url = f"{self.config.teacher_base_url}/models" + print(f"[MATH_DEBUG] Testing models endpoint: {models_url}") + try: + async with session.get(models_url, timeout=aiohttp.ClientTimeout(total=10)) as resp: + print(f"[MATH_DEBUG] Models endpoint status: {resp.status}") + if resp.status == 200: + data = await resp.json() + models = [m.get("id", m) for m in data.get("data", [])] + print(f"[MATH_DEBUG] ✓ Available models: {models}") + else: + print(f"[MATH_DEBUG] ✗ Models endpoint FAILED: {await resp.text()}") + except Exception as e: + print(f"[MATH_DEBUG] ✗ Models endpoint ERROR: {e}") + + # Test 3: Simple completion test + completions_url = f"{self.config.teacher_base_url}/completions" + teacher_model = getattr(self.config, 'teacher_model_name', 'default') + test_payload = { + "model": teacher_model, + "prompt": "Hello", + "max_tokens": 5, + "logprobs": 5, + "echo": True, + } + print(f"[MATH_DEBUG] Testing completions endpoint: {completions_url}") + print(f"[MATH_DEBUG] Test payload: {test_payload}") + try: + async with session.post( + completions_url, + json=test_payload, + headers={"Content-Type": "application/json"}, + timeout=aiohttp.ClientTimeout(total=30), + ) as resp: + print(f"[MATH_DEBUG] Completions status: {resp.status}") + resp_text = await resp.text() + if resp.status == 200: + print(f"[MATH_DEBUG] ✓ Teacher completions WORKING!") + print(f"[MATH_DEBUG] Response preview: {resp_text[:500]}") + else: + print(f"[MATH_DEBUG] ✗ Teacher completions FAILED: {resp_text[:500]}") + except Exception as e: + print(f"[MATH_DEBUG] ✗ Teacher completions ERROR: {e}") + + except Exception as e: + print(f"[MATH_DEBUG] ✗ Teacher connectivity test FAILED: {e}") + import traceback + traceback.print_exc() + + print("=" * 60) async def rollout_and_score_eval(self, question, answer, subset): async with self.server.managed_server(tokenizer=self.tokenizer) as managed: @@ -482,7 +575,52 @@ class MathEnv(BaseEnv): and (not scores["overrides"][i].get("set_advantage_to_zero", False)) ] ) + + # Debug: Log scored group creation + print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences") + print(f"[MATH_DEBUG] Scores: {scores['scores']}") + print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}") + print(f"[MATH_DEBUG] Has onpolicydistill_logprobs: {'onpolicydistill_logprobs' in scores}") + return scores + + async def handle_send_to_api( + self, + scored_data, + item=None, + do_send_to_api: bool = True, + abort_on_any_max_length_exceeded: bool = True, + ): + """Override to add debugging for distillation.""" + print(f"[MATH_DEBUG] handle_send_to_api called") + print(f"[MATH_DEBUG] distillation_enabled: {self.config.distillation_enabled}") + print(f"[MATH_DEBUG] teacher_base_url: {self.config.teacher_base_url}") + + if isinstance(scored_data, list): + for i, group in enumerate(scored_data): + if group: + has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None + print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}") + elif scored_data: + has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None + print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}") + + # Call parent implementation which does the actual distillation fetch + result = await super().handle_send_to_api( + scored_data, item, do_send_to_api, abort_on_any_max_length_exceeded + ) + + # Debug: Check if distillation was added after parent call + if isinstance(scored_data, list): + for i, group in enumerate(scored_data): + if group: + has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None + print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}") + elif scored_data: + has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None + print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}") + + return result async def get_next_item(self): while True: