From 1c90fc71b0a0d79e6041ec858a3bbf6733044a31 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 20 Feb 2026 00:35:29 -0500 Subject: [PATCH] on policy clean up --- README.md | 127 ++++++++++++++ atroposlib/envs/base.py | 12 +- .../envs/server_handling/managed_server.py | 18 ++ .../envs/server_handling/openai_server.py | 10 +- .../envs/server_handling/vllm_server.py | 25 +++ environments/gsm8k_server.py | 77 ++++++++- environments/math_server_zero.py | 161 +----------------- 7 files changed, 262 insertions(+), 168 deletions(-) diff --git a/README.md b/README.md index b7ed312e..ddb2bac7 100644 --- a/README.md +++ b/README.md @@ -256,6 +256,133 @@ Atropos repo contains an example trainer that should primarily be used as a refe To use the example trainer, see this page: [training example guide](example_trainer/README.md) +## On-Policy Distillation (Environment + API Flow) + +Atropos supports on-policy distillation by fetching teacher top-k token distributions for the same trajectories that are used for RL, then attaching that teacher data to the batches consumed by your trainer. + +### How the flow works + +1. **Student rollouts are generated by the environment** + - Example: `environments/gsm8k_server.py` samples `group_size` completions from your student inference server. +2. **Environment scores and validates the group** + - Normal Atropos filtering still applies (group structure, optional score-equality checks, max token checks, etc.). +3. **Teacher logprobs are fetched in `BaseEnv.handle_send_to_api`** + - If `distillation_enabled=true` and `teacher_base_url` is set, Atropos calls the teacher endpoint and builds: + - `distill_token_ids` with shape `[sequence][position][top_k]` + - `distill_logprobs` with shape `[sequence][position][top_k]` +4. **Distillation arrays are attached to each scored group** + - Added as `group["distill_token_ids"]` and `group["distill_logprobs"]`. +5. **Atropos API stores and serves these fields unchanged** + - `/scored_data` and `/batch` include the distillation arrays. +6. **Trainer consumes both RL and distillation signals** + - Example trainer computes GRPO + distillation loss from the same batch. + +### Configuration knobs in environments + +Distillation is configured in `BaseEnvConfig` and available via CLI under `--env.*`: + +- `--env.distillation_enabled true` +- `--env.teacher_base_url http://localhost:8003/v1` +- `--env.teacher_model_name ` +- `--env.teacher_api_key ` (or `TEACHER_API_KEY`) +- `--env.teacher_top_k 20` +- Optional steering controls: + - `--env.teacher_prefix_text "..."` + - `--env.teacher_system_prompt "..."` + +### Self-distillation vs cross-model distillation + +Both setups are supported: + +- **Self-distillation (same model family for teacher and student)** + Point `teacher_base_url` to a server running the same model (or equivalent checkpoint family) as the student. This is the most stable setup for token-level alignment. + +- **Cross-model distillation (different teacher and student models)** + Also supported, but tokenization compatibility becomes more important. If token vocabularies/template behavior differ significantly, alignment quality may degrade. + +In practice, self-distillation is usually easiest to bring up first, then cross-model can be layered in once your pipeline is stable. + +### Tokenization and alignment details + +Atropos handles tokenization in two places: + +1. **Student rollout path (`server_type=vllm`)** + - The `/generate` request is built via the vLLM server handler and uses the server-side tokenizer configured by: + - `--openai.tokenizer_name` (or falls back to `--openai.model_name`) + - Recommendation: set `--openai.tokenizer_name` explicitly to match the student serving model. + +2. **Teacher top-k parsing path** + - Teacher responses are parsed into token ids/logprobs in `BaseEnv.get_teacher_logprobs`. + - The parser maps teacher token strings into ids using the environment tokenizer (`self.tokenizer`) and then aligns to student sequence length. + +Because distillation is token-position based, keeping tokenizer families compatible is strongly recommended, especially for cross-model distillation. + +### Minimal bring-up example + +Run each command in a separate terminal. + +1. **Start Atropos API** +```bash +run-api --port 8002 +``` + +2. **Start teacher server (OpenAI-compatible endpoint)** +```bash +python -m vllm.entrypoints.openai.api_server \ + --model "$TEACHER_MODEL" \ + --host 0.0.0.0 \ + --port 8003 +``` + +3. **Start student server for environments (`/generate` endpoint)** +```bash +python -m example_trainer.vllm_api_server \ + --model "$STUDENT_MODEL" \ + --port 9001 +``` + +4. **Start environment with distillation enabled** +```bash +python environments/gsm8k_server.py serve \ + --env.rollout_server_url "http://localhost:8002" \ + --env.distillation_enabled true \ + --env.teacher_base_url "http://localhost:8003/v1" \ + --env.teacher_model_name "$TEACHER_MODEL" \ + --env.teacher_top_k 20 \ + --openai.server_type vllm \ + --openai.base_url "http://localhost:9001/v1" \ + --openai.model_name "$STUDENT_MODEL" \ + --openai.tokenizer_name "$STUDENT_MODEL" +``` + +5. **Start trainer with distillation flags** +```bash +python -m example_trainer.grpo \ + --atropos-url "http://localhost:8002" \ + --distillation-enabled \ + --distillation-coef 0.1 \ + --distillation-loss-type kl \ + --distillation-temperature 1.0 +``` + +### Verification checklist + +- Environment logs show distillation fetch: + - `[DISTILL] Fetching teacher logprobs ...` + - `[DISTILL] Added teacher distill arrays ...` +- Teacher logs show completion/chat requests from the environment. +- API contains distill fields in latest example: +```bash +curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!=null), has_lps:(.distill_logprobs!=null)}' +``` +- Trainer logs report distillation metrics (example trainer): + - `Distill: loss=...` + +### Important notes + +- For `server_type=vllm`, the environment expects a server exposing `/generate` (the custom server in `example_trainer/vllm_api_server.py`), not only `/v1/chat/completions`. +- Prefer explicitly setting `--openai.tokenizer_name` to your student tokenizer to avoid prompt token-ID mismatch. + --- ## Testing and Debugging Tools diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 25054e0c..551eb5b0 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1276,7 +1276,11 @@ class BaseEnv(ABC): import traceback logger.error(traceback.format_exc()) else: - logger.debug(f"[DISTILL] Skipped - enabled={self.config.distillation_enabled}, url={self.config.teacher_base_url}") + logger.debug( + "[DISTILL] Skipped - enabled=%s, url=%s", + self.config.distillation_enabled, + self.config.teacher_base_url, + ) data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]] # send single or list of scored data groups @@ -1826,7 +1830,11 @@ class BaseEnv(ABC): ) print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}") else: - print(f"[CLI DEBUG] Not merging: default_openai_config_ type={type(default_openai_config_)}, yaml_oai_config type={type(yaml_oai_config)}") + print( + "[CLI DEBUG] Not merging: default_openai_config_ " + f"type={type(default_openai_config_)}, " + f"yaml_oai_config type={type(yaml_oai_config)}" + ) openai_config_dict = {} # 3. Server Manager Configuration (slurm, testing - not namespaced) diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 0918c325..cb14d210 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -7,6 +7,7 @@ This wrapper maintains a tree structure of sequences, where: - Branching occurs organically from different contexts and n > 1 completions """ +import os import time import uuid import warnings @@ -131,6 +132,10 @@ class ManagedServer: # Fallback for tokenizers without chat template return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + def _debug_requests_enabled(self) -> bool: + """Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1.""" + return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1" + def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]: """ Find a node that this input extends (default mode). @@ -284,6 +289,19 @@ class ManagedServer: completion_kwargs = kwargs.copy() completion_kwargs["prompt"] = prompt completion_kwargs.pop("messages", None) + if self._debug_requests_enabled(): + msg_count = len(messages) + prompt_preview = prompt.replace("\n", "\\n")[:600] + print( + f"[ATROPOS_REQ_DEBUG] chat_completion messages={msg_count} " + f"n={completion_kwargs.get('n')} max_tokens={completion_kwargs.get('max_tokens')} " + f"temperature={completion_kwargs.get('temperature')}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}", + flush=True, + ) # Set model name if not provided if "model" not in completion_kwargs: diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index bbf0c15b..f84558c4 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -185,8 +185,14 @@ def resolve_openai_configs( and len(default_server_configs) >= 2 ) - print(f"[RESOLVE DEBUG] is_multi_server_yaml={is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}") - print(f"[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = {isinstance(default_server_configs, ServerBaseline)}") + print( + "[RESOLVE DEBUG] is_multi_server_yaml=" + f"{is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}" + ) + print( + "[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = " + f"{isinstance(default_server_configs, ServerBaseline)}" + ) if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config: raise FailedExecutionException( diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 96242754..d3e7d2ea 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -2,6 +2,7 @@ # see example_trainer/vllm_api_server.py for an example import asyncio +import os import warnings import aiohttp @@ -189,6 +190,30 @@ class VLLMServer(APIServer): # Prepare request for VLLM native API request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0} request_data.update(kwargs) + debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1" + if debug_requests: + base = self.config.base_url.replace("/v1", "") + prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n") + print( + f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate " + f"prompt_token_len={len(prompt_tokens)}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] request_meta=" + f"{{'n': {request_data.get('n')}, 'max_tokens': {request_data.get('max_tokens')}, " + f"'temperature': {request_data.get('temperature')}, 'top_p': {request_data.get('top_p')}}}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate " + '-H "Content-Type: application/json" -d \'\'', + flush=True, + ) # Make async request to VLLM /generate endpoint async with aiohttp.ClientSession() as session: diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 6ae5285b..0fa53431 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -123,7 +123,11 @@ class GSM8kEnv(BaseEnv): async def rollout_and_score_eval(self, question: str, answer: str) -> dict: """Rollout and score evaluation with detailed sample data collection.""" - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Important: use ManagedServer's default tokenizer resolution so it uses + # the underlying inference server tokenizer (e.g., Qwen) instead of the + # environment tokenizer. Passing self.tokenizer here can cause token-ID + # mismatch and gibberish generations when model/tokenizer families differ. + async with self.server.managed_server() as managed: completion = await managed.chat_completion( messages=[ {"role": "system", "content": system_prompt}, @@ -231,8 +235,16 @@ class GSM8kEnv(BaseEnv): gold_answer = ( "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" ) + question_preview = item["question"].replace("\n", " ")[:120] + print( + f"[GSM8K_DEBUG] collect_start group_size={self.config.group_size} " + f"q={question_preview!r}", + flush=True, + ) - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Important: do not force env tokenizer into ManagedServer for rollout. + # Let ManagedServer use the server's tokenizer to keep prompt token IDs aligned. + async with self.server.managed_server() as managed: chat_completions = await managed.chat_completion( messages=[{"role": "system", "content": system_prompt}, user_message], @@ -243,10 +255,24 @@ class GSM8kEnv(BaseEnv): state = managed.get_state() nodes = state["nodes"] + print( + f"[GSM8K_DEBUG] completion_batch_received choices={len(chat_completions.choices)} " + f"nodes={len(nodes)}", + flush=True, + ) to_score = list() to_backlog = list() for i, chat_completion in enumerate(chat_completions.choices): + response_text = chat_completion.message.content or "" + response_preview = response_text.replace("\n", " ")[:220] + valid_mask_count = sum(1 for m in nodes[i].masked_tokens if m != -100) + print( + f"[GSM8K_DEBUG] response_received idx={i} finish={chat_completion.finish_reason} " + f"tokens={len(nodes[i].tokens)} valid_masked={valid_mask_count} " + f"text={response_preview!r}", + flush=True, + ) messages = ( {"role": "system", "content": system_prompt}, user_message, @@ -263,6 +289,11 @@ class GSM8kEnv(BaseEnv): } ) to_postprocess = await self.score(to_score) + accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", [])) + print( + f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}", + flush=True, + ) return to_postprocess, to_backlog async def score( @@ -278,10 +309,15 @@ class GSM8kEnv(BaseEnv): extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) + print( + f"[GSM8K_DEBUG] score_start candidates={len(rollout_group_data)} " + f"gold_parsed_len={len(gold_parsed)}", + flush=True, + ) if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) random.shuffle(rollout_group_data) - for item in rollout_group_data: + for idx, item in enumerate(rollout_group_data): # print(item[0][-1]["content"]) answer_parsed = parse( item["messages"][-1]["content"].split("")[-1], @@ -310,7 +346,19 @@ class GSM8kEnv(BaseEnv): logprobs = item["logprobs"] # remove obviously bad examples - if len([1 for i in masks if i != -100]) < 10: + valid_mask_count = len([1 for i in masks if i != -100]) + print( + f"[GSM8K_DEBUG] score_candidate idx={idx} parsed_len={len(answer_parsed)} " + f"reward={bool(reward)} valid_masked={valid_mask_count} " + f"tokens={len(tokens)}", + flush=True, + ) + if valid_mask_count < 10: + print( + f"[GSM8K_DEBUG] drop_candidate idx={idx} reason=valid_masked_lt_10 " + f"value={valid_mask_count}", + flush=True, + ) continue scores["tokens"].append(tokens) scores["masks"].append(masks) @@ -323,6 +371,13 @@ class GSM8kEnv(BaseEnv): for score in scores["scores"]: self.percent_correct_buffer.append(max(score, 0)) + if len(scores["scores"]) == 0: + print( + "[GSM8K_DEBUG] drop_group reason=no_valid_candidates_after_filtering", + flush=True, + ) + return None + # check if all the same # print(scores['scores']) if all([score == 1 for score in scores["scores"]]): @@ -330,6 +385,10 @@ class GSM8kEnv(BaseEnv): token_lengths = [len(token) for token in scores["tokens"]] if max(token_lengths) == 0: # What? But don't want to crash a run so just in case... + print( + "[GSM8K_DEBUG] drop_group reason=zero_token_length_after_penalty_branch", + flush=True, + ) return None # Get max allowed token length from config @@ -353,10 +412,20 @@ class GSM8kEnv(BaseEnv): # Apply linear penalty scaling from 1.0 down to 0.0 scores["scores"].append(1.0 - percentage_of_range) if all([scores["scores"][0] == score for score in scores["scores"]]): + print( + f"[GSM8K_DEBUG] drop_group reason=all_scores_identical scores={scores['scores']}", + flush=True, + ) return None # If all the same, we return None + print( + f"[GSM8K_DEBUG] score_done accepted={len(scores['scores'])} " + f"scores={scores['scores']}", + flush=True, + ) return scores else: # If the gold solution is not parseable, we return None + print("[GSM8K_DEBUG] drop_group reason=gold_unparseable", flush=True) return None async def get_next_item(self) -> GSM8kRow: diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 93856ea3..12076dab 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -6,17 +6,12 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero import asyncio import random import re -import logging from concurrent.futures import ProcessPoolExecutor from typing import Dict, List, Optional, Tuple -import aiohttp import wandb from datasets import load_dataset -# Set up logging for debug -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify from math_verify.errors import TimeoutException @@ -129,8 +124,6 @@ class MathEnv(BaseEnv): slurm=True, testing=False, ): - print("Initializing MathEnv") - print(f"Slurm: {slurm}, Testing: {testing}") super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer = list() self.eval_metrics = list() @@ -141,17 +134,6 @@ class MathEnv(BaseEnv): self.normal_rollouts = list() self.pass_at_groupsize = list() self.iter = 0 - - # Debug: Print distillation config - print("=" * 60) - print("[MATH_DEBUG] DISTILLATION CONFIGURATION:") - print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}") - print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}") - print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}") - print(f"[MATH_DEBUG] teacher_top_k = {getattr(config, 'teacher_top_k', 'N/A')}") - print(f"[MATH_DEBUG] teacher_prefix_text set = {bool(getattr(config, 'teacher_prefix_text', None))}") - print(f"[MATH_DEBUG] teacher_system_prompt set = {bool(getattr(config, 'teacher_system_prompt', None))}") - print("=" * 60) @classmethod def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: @@ -269,85 +251,7 @@ class MathEnv(BaseEnv): name, ) ) - - # Debug: Test teacher connectivity if distillation is enabled - if self.config.distillation_enabled and self.config.teacher_base_url: - await self._test_teacher_connectivity() - return - - async def _test_teacher_connectivity(self): - """Test if the teacher model API is reachable.""" - print("=" * 60) - print("[MATH_DEBUG] TESTING TEACHER CONNECTIVITY...") - print(f"[MATH_DEBUG] Teacher URL: {self.config.teacher_base_url}") - print(f"[MATH_DEBUG] Teacher Model: {getattr(self.config, 'teacher_model_name', 'default')}") - - try: - async with aiohttp.ClientSession() as session: - # Test 1: Health check - health_url = self.config.teacher_base_url.replace("/v1", "") + "/health" - print(f"[MATH_DEBUG] Testing health endpoint: {health_url}") - try: - async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as resp: - print(f"[MATH_DEBUG] Health check status: {resp.status}") - if resp.status == 200: - print("[MATH_DEBUG] ✓ Teacher health check PASSED") - else: - print(f"[MATH_DEBUG] ✗ Teacher health check FAILED: {await resp.text()}") - except Exception as e: - print(f"[MATH_DEBUG] ✗ Teacher health check ERROR: {e}") - - # Test 2: Models endpoint - models_url = f"{self.config.teacher_base_url}/models" - print(f"[MATH_DEBUG] Testing models endpoint: {models_url}") - try: - async with session.get(models_url, timeout=aiohttp.ClientTimeout(total=10)) as resp: - print(f"[MATH_DEBUG] Models endpoint status: {resp.status}") - if resp.status == 200: - data = await resp.json() - models = [m.get("id", m) for m in data.get("data", [])] - print(f"[MATH_DEBUG] ✓ Available models: {models}") - else: - print(f"[MATH_DEBUG] ✗ Models endpoint FAILED: {await resp.text()}") - except Exception as e: - print(f"[MATH_DEBUG] ✗ Models endpoint ERROR: {e}") - - # Test 3: Simple completion test - completions_url = f"{self.config.teacher_base_url}/completions" - teacher_model = getattr(self.config, 'teacher_model_name', 'default') - test_payload = { - "model": teacher_model, - "prompt": "Hello", - "max_tokens": 5, - "logprobs": 5, - "echo": True, - } - print(f"[MATH_DEBUG] Testing completions endpoint: {completions_url}") - print(f"[MATH_DEBUG] Test payload: {test_payload}") - try: - async with session.post( - completions_url, - json=test_payload, - headers={"Content-Type": "application/json"}, - timeout=aiohttp.ClientTimeout(total=30), - ) as resp: - print(f"[MATH_DEBUG] Completions status: {resp.status}") - resp_text = await resp.text() - if resp.status == 200: - print(f"[MATH_DEBUG] ✓ Teacher completions WORKING!") - print(f"[MATH_DEBUG] Response preview: {resp_text[:500]}") - else: - print(f"[MATH_DEBUG] ✗ Teacher completions FAILED: {resp_text[:500]}") - except Exception as e: - print(f"[MATH_DEBUG] ✗ Teacher completions ERROR: {e}") - - except Exception as e: - print(f"[MATH_DEBUG] ✗ Teacher connectivity test FAILED: {e}") - import traceback - traceback.print_exc() - - print("=" * 60) async def rollout_and_score_eval(self, question, answer, subset): async with self.server.managed_server(tokenizer=self.tokenizer) as managed: @@ -492,7 +396,6 @@ class MathEnv(BaseEnv): ) if len(self.normal_rollouts) > self.config.num_rollouts_to_keep: self.normal_rollouts.pop(0) - print(f"Collected {len(to_postprocess['scores'])} trajectories") return to_postprocess, to_backlog async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: @@ -578,66 +481,7 @@ class MathEnv(BaseEnv): ] ) - # Debug: Log scored group creation - print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences") - print(f"[MATH_DEBUG] Scores: {scores['scores']}") - print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}") - has_new_distill = ( - "distill_token_ids" in scores and "distill_logprobs" in scores - ) - print(f"[MATH_DEBUG] Has distill arrays: {has_new_distill}") - return scores - - async def handle_send_to_api( - self, - scored_data, - item=None, - do_send_to_api: bool = True, - abort_on_any_max_length_exceeded: bool = True, - ): - """Override to add debugging for distillation.""" - print(f"[MATH_DEBUG] handle_send_to_api called") - print(f"[MATH_DEBUG] distillation_enabled: {self.config.distillation_enabled}") - print(f"[MATH_DEBUG] teacher_base_url: {self.config.teacher_base_url}") - - if isinstance(scored_data, list): - for i, group in enumerate(scored_data): - if group: - has_distill = ( - group.get("distill_token_ids") is not None - and group.get("distill_logprobs") is not None - ) - print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}") - elif scored_data: - has_distill = ( - scored_data.get("distill_token_ids") is not None - and scored_data.get("distill_logprobs") is not None - ) - print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}") - - # Call parent implementation which does the actual distillation fetch - result = await super().handle_send_to_api( - scored_data, item, do_send_to_api, abort_on_any_max_length_exceeded - ) - - # Debug: Check if distillation was added after parent call - if isinstance(scored_data, list): - for i, group in enumerate(scored_data): - if group: - has_distill = ( - group.get("distill_token_ids") is not None - and group.get("distill_logprobs") is not None - ) - print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}") - elif scored_data: - has_distill = ( - scored_data.get("distill_token_ids") is not None - and scored_data.get("distill_logprobs") is not None - ) - print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}") - - return result async def get_next_item(self): while True: @@ -652,10 +496,7 @@ class MathEnv(BaseEnv): ) break except TypeError: - print( - f"Error in getting next item, trying again, " - f"data: {next_item['question']} -> {next_item['final_answer']}" - ) + continue return (prompt, answer, "normal")