on policy clean up

This commit is contained in:
Jai Suphavadeeprasit 2026-02-20 00:35:29 -05:00
parent 79e392c446
commit 1c90fc71b0
7 changed files with 262 additions and 168 deletions

127
README.md
View file

@ -256,6 +256,133 @@ Atropos repo contains an example trainer that should primarily be used as a refe
To use the example trainer, see this page: [training example guide](example_trainer/README.md) To use the example trainer, see this page: [training example guide](example_trainer/README.md)
## On-Policy Distillation (Environment + API Flow)
Atropos supports on-policy distillation by fetching teacher top-k token distributions for the same trajectories that are used for RL, then attaching that teacher data to the batches consumed by your trainer.
### How the flow works
1. **Student rollouts are generated by the environment**
- Example: `environments/gsm8k_server.py` samples `group_size` completions from your student inference server.
2. **Environment scores and validates the group**
- Normal Atropos filtering still applies (group structure, optional score-equality checks, max token checks, etc.).
3. **Teacher logprobs are fetched in `BaseEnv.handle_send_to_api`**
- If `distillation_enabled=true` and `teacher_base_url` is set, Atropos calls the teacher endpoint and builds:
- `distill_token_ids` with shape `[sequence][position][top_k]`
- `distill_logprobs` with shape `[sequence][position][top_k]`
4. **Distillation arrays are attached to each scored group**
- Added as `group["distill_token_ids"]` and `group["distill_logprobs"]`.
5. **Atropos API stores and serves these fields unchanged**
- `/scored_data` and `/batch` include the distillation arrays.
6. **Trainer consumes both RL and distillation signals**
- Example trainer computes GRPO + distillation loss from the same batch.
### Configuration knobs in environments
Distillation is configured in `BaseEnvConfig` and available via CLI under `--env.*`:
- `--env.distillation_enabled true`
- `--env.teacher_base_url http://localhost:8003/v1`
- `--env.teacher_model_name <teacher_model>`
- `--env.teacher_api_key <optional>` (or `TEACHER_API_KEY`)
- `--env.teacher_top_k 20`
- Optional steering controls:
- `--env.teacher_prefix_text "..."`
- `--env.teacher_system_prompt "..."`
### Self-distillation vs cross-model distillation
Both setups are supported:
- **Self-distillation (same model family for teacher and student)**
Point `teacher_base_url` to a server running the same model (or equivalent checkpoint family) as the student. This is the most stable setup for token-level alignment.
- **Cross-model distillation (different teacher and student models)**
Also supported, but tokenization compatibility becomes more important. If token vocabularies/template behavior differ significantly, alignment quality may degrade.
In practice, self-distillation is usually easiest to bring up first, then cross-model can be layered in once your pipeline is stable.
### Tokenization and alignment details
Atropos handles tokenization in two places:
1. **Student rollout path (`server_type=vllm`)**
- The `/generate` request is built via the vLLM server handler and uses the server-side tokenizer configured by:
- `--openai.tokenizer_name` (or falls back to `--openai.model_name`)
- Recommendation: set `--openai.tokenizer_name` explicitly to match the student serving model.
2. **Teacher top-k parsing path**
- Teacher responses are parsed into token ids/logprobs in `BaseEnv.get_teacher_logprobs`.
- The parser maps teacher token strings into ids using the environment tokenizer (`self.tokenizer`) and then aligns to student sequence length.
Because distillation is token-position based, keeping tokenizer families compatible is strongly recommended, especially for cross-model distillation.
### Minimal bring-up example
Run each command in a separate terminal.
1. **Start Atropos API**
```bash
run-api --port 8002
```
2. **Start teacher server (OpenAI-compatible endpoint)**
```bash
python -m vllm.entrypoints.openai.api_server \
--model "$TEACHER_MODEL" \
--host 0.0.0.0 \
--port 8003
```
3. **Start student server for environments (`/generate` endpoint)**
```bash
python -m example_trainer.vllm_api_server \
--model "$STUDENT_MODEL" \
--port 9001
```
4. **Start environment with distillation enabled**
```bash
python environments/gsm8k_server.py serve \
--env.rollout_server_url "http://localhost:8002" \
--env.distillation_enabled true \
--env.teacher_base_url "http://localhost:8003/v1" \
--env.teacher_model_name "$TEACHER_MODEL" \
--env.teacher_top_k 20 \
--openai.server_type vllm \
--openai.base_url "http://localhost:9001/v1" \
--openai.model_name "$STUDENT_MODEL" \
--openai.tokenizer_name "$STUDENT_MODEL"
```
5. **Start trainer with distillation flags**
```bash
python -m example_trainer.grpo \
--atropos-url "http://localhost:8002" \
--distillation-enabled \
--distillation-coef 0.1 \
--distillation-loss-type kl \
--distillation-temperature 1.0
```
### Verification checklist
- Environment logs show distillation fetch:
- `[DISTILL] Fetching teacher logprobs ...`
- `[DISTILL] Added teacher distill arrays ...`
- Teacher logs show completion/chat requests from the environment.
- API contains distill fields in latest example:
```bash
curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!=null), has_lps:(.distill_logprobs!=null)}'
```
- Trainer logs report distillation metrics (example trainer):
- `Distill: loss=...`
### Important notes
- For `server_type=vllm`, the environment expects a server exposing `/generate` (the custom server in `example_trainer/vllm_api_server.py`), not only `/v1/chat/completions`.
- Prefer explicitly setting `--openai.tokenizer_name` to your student tokenizer to avoid prompt token-ID mismatch.
--- ---
## Testing and Debugging Tools ## Testing and Debugging Tools

View file

@ -1276,7 +1276,11 @@ class BaseEnv(ABC):
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
else: else:
logger.debug(f"[DISTILL] Skipped - enabled={self.config.distillation_enabled}, url={self.config.teacher_base_url}") logger.debug(
"[DISTILL] Skipped - enabled=%s, url=%s",
self.config.distillation_enabled,
self.config.teacher_base_url,
)
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]] data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
# send single or list of scored data groups # send single or list of scored data groups
@ -1826,7 +1830,11 @@ class BaseEnv(ABC):
) )
print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}") print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}")
else: else:
print(f"[CLI DEBUG] Not merging: default_openai_config_ type={type(default_openai_config_)}, yaml_oai_config type={type(yaml_oai_config)}") print(
"[CLI DEBUG] Not merging: default_openai_config_ "
f"type={type(default_openai_config_)}, "
f"yaml_oai_config type={type(yaml_oai_config)}"
)
openai_config_dict = {} openai_config_dict = {}
# 3. Server Manager Configuration (slurm, testing - not namespaced) # 3. Server Manager Configuration (slurm, testing - not namespaced)

View file

@ -7,6 +7,7 @@ This wrapper maintains a tree structure of sequences, where:
- Branching occurs organically from different contexts and n > 1 completions - Branching occurs organically from different contexts and n > 1 completions
""" """
import os
import time import time
import uuid import uuid
import warnings import warnings
@ -131,6 +132,10 @@ class ManagedServer:
# Fallback for tokenizers without chat template # Fallback for tokenizers without chat template
return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
def _debug_requests_enabled(self) -> bool:
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]: def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]:
""" """
Find a node that this input extends (default mode). Find a node that this input extends (default mode).
@ -284,6 +289,19 @@ class ManagedServer:
completion_kwargs = kwargs.copy() completion_kwargs = kwargs.copy()
completion_kwargs["prompt"] = prompt completion_kwargs["prompt"] = prompt
completion_kwargs.pop("messages", None) completion_kwargs.pop("messages", None)
if self._debug_requests_enabled():
msg_count = len(messages)
prompt_preview = prompt.replace("\n", "\\n")[:600]
print(
f"[ATROPOS_REQ_DEBUG] chat_completion messages={msg_count} "
f"n={completion_kwargs.get('n')} max_tokens={completion_kwargs.get('max_tokens')} "
f"temperature={completion_kwargs.get('temperature')}",
flush=True,
)
print(
f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}",
flush=True,
)
# Set model name if not provided # Set model name if not provided
if "model" not in completion_kwargs: if "model" not in completion_kwargs:

View file

@ -185,8 +185,14 @@ def resolve_openai_configs(
and len(default_server_configs) >= 2 and len(default_server_configs) >= 2
) )
print(f"[RESOLVE DEBUG] is_multi_server_yaml={is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}") print(
print(f"[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = {isinstance(default_server_configs, ServerBaseline)}") "[RESOLVE DEBUG] is_multi_server_yaml="
f"{is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}"
)
print(
"[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = "
f"{isinstance(default_server_configs, ServerBaseline)}"
)
if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config: if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config:
raise FailedExecutionException( raise FailedExecutionException(

View file

@ -2,6 +2,7 @@
# see example_trainer/vllm_api_server.py for an example # see example_trainer/vllm_api_server.py for an example
import asyncio import asyncio
import os
import warnings import warnings
import aiohttp import aiohttp
@ -189,6 +190,30 @@ class VLLMServer(APIServer):
# Prepare request for VLLM native API # Prepare request for VLLM native API
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0} request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
request_data.update(kwargs) request_data.update(kwargs)
debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
if debug_requests:
base = self.config.base_url.replace("/v1", "")
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n")
print(
f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate "
f"prompt_token_len={len(prompt_tokens)}",
flush=True,
)
print(
f"[ATROPOS_REQ_DEBUG] request_meta="
f"{{'n': {request_data.get('n')}, 'max_tokens': {request_data.get('max_tokens')}, "
f"'temperature': {request_data.get('temperature')}, 'top_p': {request_data.get('top_p')}}}",
flush=True,
)
print(
f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}",
flush=True,
)
print(
f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate "
'-H "Content-Type: application/json" -d \'<JSON_PAYLOAD>\'',
flush=True,
)
# Make async request to VLLM /generate endpoint # Make async request to VLLM /generate endpoint
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:

View file

@ -123,7 +123,11 @@ class GSM8kEnv(BaseEnv):
async def rollout_and_score_eval(self, question: str, answer: str) -> dict: async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
"""Rollout and score evaluation with detailed sample data collection.""" """Rollout and score evaluation with detailed sample data collection."""
async with self.server.managed_server(tokenizer=self.tokenizer) as managed: # Important: use ManagedServer's default tokenizer resolution so it uses
# the underlying inference server tokenizer (e.g., Qwen) instead of the
# environment tokenizer. Passing self.tokenizer here can cause token-ID
# mismatch and gibberish generations when model/tokenizer families differ.
async with self.server.managed_server() as managed:
completion = await managed.chat_completion( completion = await managed.chat_completion(
messages=[ messages=[
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
@ -231,8 +235,16 @@ class GSM8kEnv(BaseEnv):
gold_answer = ( gold_answer = (
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
) )
question_preview = item["question"].replace("\n", " ")[:120]
print(
f"[GSM8K_DEBUG] collect_start group_size={self.config.group_size} "
f"q={question_preview!r}",
flush=True,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed: # Important: do not force env tokenizer into ManagedServer for rollout.
# Let ManagedServer use the server's tokenizer to keep prompt token IDs aligned.
async with self.server.managed_server() as managed:
chat_completions = await managed.chat_completion( chat_completions = await managed.chat_completion(
messages=[{"role": "system", "content": system_prompt}, user_message], messages=[{"role": "system", "content": system_prompt}, user_message],
@ -243,10 +255,24 @@ class GSM8kEnv(BaseEnv):
state = managed.get_state() state = managed.get_state()
nodes = state["nodes"] nodes = state["nodes"]
print(
f"[GSM8K_DEBUG] completion_batch_received choices={len(chat_completions.choices)} "
f"nodes={len(nodes)}",
flush=True,
)
to_score = list() to_score = list()
to_backlog = list() to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices): for i, chat_completion in enumerate(chat_completions.choices):
response_text = chat_completion.message.content or ""
response_preview = response_text.replace("\n", " ")[:220]
valid_mask_count = sum(1 for m in nodes[i].masked_tokens if m != -100)
print(
f"[GSM8K_DEBUG] response_received idx={i} finish={chat_completion.finish_reason} "
f"tokens={len(nodes[i].tokens)} valid_masked={valid_mask_count} "
f"text={response_preview!r}",
flush=True,
)
messages = ( messages = (
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
user_message, user_message,
@ -263,6 +289,11 @@ class GSM8kEnv(BaseEnv):
} }
) )
to_postprocess = await self.score(to_score) to_postprocess = await self.score(to_score)
accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", []))
print(
f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}",
flush=True,
)
return to_postprocess, to_backlog return to_postprocess, to_backlog
async def score( async def score(
@ -278,10 +309,15 @@ class GSM8kEnv(BaseEnv):
extraction_mode="first_match", extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()], extraction_config=[LatexExtractionConfig()],
) )
print(
f"[GSM8K_DEBUG] score_start candidates={len(rollout_group_data)} "
f"gold_parsed_len={len(gold_parsed)}",
flush=True,
)
if len(gold_parsed) != 0: if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators) # We require the answer to be provided in correct latex (no malformed operators)
random.shuffle(rollout_group_data) random.shuffle(rollout_group_data)
for item in rollout_group_data: for idx, item in enumerate(rollout_group_data):
# print(item[0][-1]["content"]) # print(item[0][-1]["content"])
answer_parsed = parse( answer_parsed = parse(
item["messages"][-1]["content"].split("</think>")[-1], item["messages"][-1]["content"].split("</think>")[-1],
@ -310,7 +346,19 @@ class GSM8kEnv(BaseEnv):
logprobs = item["logprobs"] logprobs = item["logprobs"]
# remove obviously bad examples # remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10: valid_mask_count = len([1 for i in masks if i != -100])
print(
f"[GSM8K_DEBUG] score_candidate idx={idx} parsed_len={len(answer_parsed)} "
f"reward={bool(reward)} valid_masked={valid_mask_count} "
f"tokens={len(tokens)}",
flush=True,
)
if valid_mask_count < 10:
print(
f"[GSM8K_DEBUG] drop_candidate idx={idx} reason=valid_masked_lt_10 "
f"value={valid_mask_count}",
flush=True,
)
continue continue
scores["tokens"].append(tokens) scores["tokens"].append(tokens)
scores["masks"].append(masks) scores["masks"].append(masks)
@ -323,6 +371,13 @@ class GSM8kEnv(BaseEnv):
for score in scores["scores"]: for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0)) self.percent_correct_buffer.append(max(score, 0))
if len(scores["scores"]) == 0:
print(
"[GSM8K_DEBUG] drop_group reason=no_valid_candidates_after_filtering",
flush=True,
)
return None
# check if all the same # check if all the same
# print(scores['scores']) # print(scores['scores'])
if all([score == 1 for score in scores["scores"]]): if all([score == 1 for score in scores["scores"]]):
@ -330,6 +385,10 @@ class GSM8kEnv(BaseEnv):
token_lengths = [len(token) for token in scores["tokens"]] token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0: if max(token_lengths) == 0:
# What? But don't want to crash a run so just in case... # What? But don't want to crash a run so just in case...
print(
"[GSM8K_DEBUG] drop_group reason=zero_token_length_after_penalty_branch",
flush=True,
)
return None return None
# Get max allowed token length from config # Get max allowed token length from config
@ -353,10 +412,20 @@ class GSM8kEnv(BaseEnv):
# Apply linear penalty scaling from 1.0 down to 0.0 # Apply linear penalty scaling from 1.0 down to 0.0
scores["scores"].append(1.0 - percentage_of_range) scores["scores"].append(1.0 - percentage_of_range)
if all([scores["scores"][0] == score for score in scores["scores"]]): if all([scores["scores"][0] == score for score in scores["scores"]]):
print(
f"[GSM8K_DEBUG] drop_group reason=all_scores_identical scores={scores['scores']}",
flush=True,
)
return None # If all the same, we return None return None # If all the same, we return None
print(
f"[GSM8K_DEBUG] score_done accepted={len(scores['scores'])} "
f"scores={scores['scores']}",
flush=True,
)
return scores return scores
else: else:
# If the gold solution is not parseable, we return None # If the gold solution is not parseable, we return None
print("[GSM8K_DEBUG] drop_group reason=gold_unparseable", flush=True)
return None return None
async def get_next_item(self) -> GSM8kRow: async def get_next_item(self) -> GSM8kRow:

View file

@ -6,17 +6,12 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
import asyncio import asyncio
import random import random
import re import re
import logging
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import aiohttp
import wandb import wandb
from datasets import load_dataset from datasets import load_dataset
# Set up logging for debug
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from latex2sympy2_extended import NormalizationConfig from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify from math_verify import LatexExtractionConfig, parse, verify
from math_verify.errors import TimeoutException from math_verify.errors import TimeoutException
@ -129,8 +124,6 @@ class MathEnv(BaseEnv):
slurm=True, slurm=True,
testing=False, testing=False,
): ):
print("Initializing MathEnv")
print(f"Slurm: {slurm}, Testing: {testing}")
super().__init__(config, server_configs, slurm, testing) super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list() self.percent_correct_buffer = list()
self.eval_metrics = list() self.eval_metrics = list()
@ -142,17 +135,6 @@ class MathEnv(BaseEnv):
self.pass_at_groupsize = list() self.pass_at_groupsize = list()
self.iter = 0 self.iter = 0
# Debug: Print distillation config
print("=" * 60)
print("[MATH_DEBUG] DISTILLATION CONFIGURATION:")
print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}")
print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}")
print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}")
print(f"[MATH_DEBUG] teacher_top_k = {getattr(config, 'teacher_top_k', 'N/A')}")
print(f"[MATH_DEBUG] teacher_prefix_text set = {bool(getattr(config, 'teacher_prefix_text', None))}")
print(f"[MATH_DEBUG] teacher_system_prompt set = {bool(getattr(config, 'teacher_system_prompt', None))}")
print("=" * 60)
@classmethod @classmethod
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
env_config = RSConfig( env_config = RSConfig(
@ -269,86 +251,8 @@ class MathEnv(BaseEnv):
name, name,
) )
) )
# Debug: Test teacher connectivity if distillation is enabled
if self.config.distillation_enabled and self.config.teacher_base_url:
await self._test_teacher_connectivity()
return return
async def _test_teacher_connectivity(self):
"""Test if the teacher model API is reachable."""
print("=" * 60)
print("[MATH_DEBUG] TESTING TEACHER CONNECTIVITY...")
print(f"[MATH_DEBUG] Teacher URL: {self.config.teacher_base_url}")
print(f"[MATH_DEBUG] Teacher Model: {getattr(self.config, 'teacher_model_name', 'default')}")
try:
async with aiohttp.ClientSession() as session:
# Test 1: Health check
health_url = self.config.teacher_base_url.replace("/v1", "") + "/health"
print(f"[MATH_DEBUG] Testing health endpoint: {health_url}")
try:
async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
print(f"[MATH_DEBUG] Health check status: {resp.status}")
if resp.status == 200:
print("[MATH_DEBUG] ✓ Teacher health check PASSED")
else:
print(f"[MATH_DEBUG] ✗ Teacher health check FAILED: {await resp.text()}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher health check ERROR: {e}")
# Test 2: Models endpoint
models_url = f"{self.config.teacher_base_url}/models"
print(f"[MATH_DEBUG] Testing models endpoint: {models_url}")
try:
async with session.get(models_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
print(f"[MATH_DEBUG] Models endpoint status: {resp.status}")
if resp.status == 200:
data = await resp.json()
models = [m.get("id", m) for m in data.get("data", [])]
print(f"[MATH_DEBUG] ✓ Available models: {models}")
else:
print(f"[MATH_DEBUG] ✗ Models endpoint FAILED: {await resp.text()}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Models endpoint ERROR: {e}")
# Test 3: Simple completion test
completions_url = f"{self.config.teacher_base_url}/completions"
teacher_model = getattr(self.config, 'teacher_model_name', 'default')
test_payload = {
"model": teacher_model,
"prompt": "Hello",
"max_tokens": 5,
"logprobs": 5,
"echo": True,
}
print(f"[MATH_DEBUG] Testing completions endpoint: {completions_url}")
print(f"[MATH_DEBUG] Test payload: {test_payload}")
try:
async with session.post(
completions_url,
json=test_payload,
headers={"Content-Type": "application/json"},
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
print(f"[MATH_DEBUG] Completions status: {resp.status}")
resp_text = await resp.text()
if resp.status == 200:
print(f"[MATH_DEBUG] ✓ Teacher completions WORKING!")
print(f"[MATH_DEBUG] Response preview: {resp_text[:500]}")
else:
print(f"[MATH_DEBUG] ✗ Teacher completions FAILED: {resp_text[:500]}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher completions ERROR: {e}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher connectivity test FAILED: {e}")
import traceback
traceback.print_exc()
print("=" * 60)
async def rollout_and_score_eval(self, question, answer, subset): async def rollout_and_score_eval(self, question, answer, subset):
async with self.server.managed_server(tokenizer=self.tokenizer) as managed: async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion( completion = await managed.completion(
@ -492,7 +396,6 @@ class MathEnv(BaseEnv):
) )
if len(self.normal_rollouts) > self.config.num_rollouts_to_keep: if len(self.normal_rollouts) > self.config.num_rollouts_to_keep:
self.normal_rollouts.pop(0) self.normal_rollouts.pop(0)
print(f"Collected {len(to_postprocess['scores'])} trajectories")
return to_postprocess, to_backlog return to_postprocess, to_backlog
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
@ -578,67 +481,8 @@ class MathEnv(BaseEnv):
] ]
) )
# Debug: Log scored group creation
print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences")
print(f"[MATH_DEBUG] Scores: {scores['scores']}")
print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}")
has_new_distill = (
"distill_token_ids" in scores and "distill_logprobs" in scores
)
print(f"[MATH_DEBUG] Has distill arrays: {has_new_distill}")
return scores return scores
async def handle_send_to_api(
self,
scored_data,
item=None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
"""Override to add debugging for distillation."""
print(f"[MATH_DEBUG] handle_send_to_api called")
print(f"[MATH_DEBUG] distillation_enabled: {self.config.distillation_enabled}")
print(f"[MATH_DEBUG] teacher_base_url: {self.config.teacher_base_url}")
if isinstance(scored_data, list):
for i, group in enumerate(scored_data):
if group:
has_distill = (
group.get("distill_token_ids") is not None
and group.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
elif scored_data:
has_distill = (
scored_data.get("distill_token_ids") is not None
and scored_data.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
# Call parent implementation which does the actual distillation fetch
result = await super().handle_send_to_api(
scored_data, item, do_send_to_api, abort_on_any_max_length_exceeded
)
# Debug: Check if distillation was added after parent call
if isinstance(scored_data, list):
for i, group in enumerate(scored_data):
if group:
has_distill = (
group.get("distill_token_ids") is not None
and group.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}")
elif scored_data:
has_distill = (
scored_data.get("distill_token_ids") is not None
and scored_data.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}")
return result
async def get_next_item(self): async def get_next_item(self):
while True: while True:
next_item = self.train[self.iter % len(self.train)] next_item = self.train[self.iter % len(self.train)]
@ -652,10 +496,7 @@ class MathEnv(BaseEnv):
) )
break break
except TypeError: except TypeError:
print( continue
f"Error in getting next item, trying again, "
f"data: {next_item['question']} -> {next_item['final_answer']}"
)
return (prompt, answer, "normal") return (prompt, answer, "normal")