on policy clean up

This commit is contained in:
Jai Suphavadeeprasit 2026-02-20 00:35:29 -05:00
parent 79e392c446
commit 1c90fc71b0
7 changed files with 262 additions and 168 deletions

127
README.md
View file

@ -256,6 +256,133 @@ Atropos repo contains an example trainer that should primarily be used as a refe
To use the example trainer, see this page: [training example guide](example_trainer/README.md)
## On-Policy Distillation (Environment + API Flow)
Atropos supports on-policy distillation by fetching teacher top-k token distributions for the same trajectories that are used for RL, then attaching that teacher data to the batches consumed by your trainer.
### How the flow works
1. **Student rollouts are generated by the environment**
- Example: `environments/gsm8k_server.py` samples `group_size` completions from your student inference server.
2. **Environment scores and validates the group**
- Normal Atropos filtering still applies (group structure, optional score-equality checks, max token checks, etc.).
3. **Teacher logprobs are fetched in `BaseEnv.handle_send_to_api`**
- If `distillation_enabled=true` and `teacher_base_url` is set, Atropos calls the teacher endpoint and builds:
- `distill_token_ids` with shape `[sequence][position][top_k]`
- `distill_logprobs` with shape `[sequence][position][top_k]`
4. **Distillation arrays are attached to each scored group**
- Added as `group["distill_token_ids"]` and `group["distill_logprobs"]`.
5. **Atropos API stores and serves these fields unchanged**
- `/scored_data` and `/batch` include the distillation arrays.
6. **Trainer consumes both RL and distillation signals**
- Example trainer computes GRPO + distillation loss from the same batch.
### Configuration knobs in environments
Distillation is configured in `BaseEnvConfig` and available via CLI under `--env.*`:
- `--env.distillation_enabled true`
- `--env.teacher_base_url http://localhost:8003/v1`
- `--env.teacher_model_name <teacher_model>`
- `--env.teacher_api_key <optional>` (or `TEACHER_API_KEY`)
- `--env.teacher_top_k 20`
- Optional steering controls:
- `--env.teacher_prefix_text "..."`
- `--env.teacher_system_prompt "..."`
### Self-distillation vs cross-model distillation
Both setups are supported:
- **Self-distillation (same model family for teacher and student)**
Point `teacher_base_url` to a server running the same model (or equivalent checkpoint family) as the student. This is the most stable setup for token-level alignment.
- **Cross-model distillation (different teacher and student models)**
Also supported, but tokenization compatibility becomes more important. If token vocabularies/template behavior differ significantly, alignment quality may degrade.
In practice, self-distillation is usually easiest to bring up first, then cross-model can be layered in once your pipeline is stable.
### Tokenization and alignment details
Atropos handles tokenization in two places:
1. **Student rollout path (`server_type=vllm`)**
- The `/generate` request is built via the vLLM server handler and uses the server-side tokenizer configured by:
- `--openai.tokenizer_name` (or falls back to `--openai.model_name`)
- Recommendation: set `--openai.tokenizer_name` explicitly to match the student serving model.
2. **Teacher top-k parsing path**
- Teacher responses are parsed into token ids/logprobs in `BaseEnv.get_teacher_logprobs`.
- The parser maps teacher token strings into ids using the environment tokenizer (`self.tokenizer`) and then aligns to student sequence length.
Because distillation is token-position based, keeping tokenizer families compatible is strongly recommended, especially for cross-model distillation.
### Minimal bring-up example
Run each command in a separate terminal.
1. **Start Atropos API**
```bash
run-api --port 8002
```
2. **Start teacher server (OpenAI-compatible endpoint)**
```bash
python -m vllm.entrypoints.openai.api_server \
--model "$TEACHER_MODEL" \
--host 0.0.0.0 \
--port 8003
```
3. **Start student server for environments (`/generate` endpoint)**
```bash
python -m example_trainer.vllm_api_server \
--model "$STUDENT_MODEL" \
--port 9001
```
4. **Start environment with distillation enabled**
```bash
python environments/gsm8k_server.py serve \
--env.rollout_server_url "http://localhost:8002" \
--env.distillation_enabled true \
--env.teacher_base_url "http://localhost:8003/v1" \
--env.teacher_model_name "$TEACHER_MODEL" \
--env.teacher_top_k 20 \
--openai.server_type vllm \
--openai.base_url "http://localhost:9001/v1" \
--openai.model_name "$STUDENT_MODEL" \
--openai.tokenizer_name "$STUDENT_MODEL"
```
5. **Start trainer with distillation flags**
```bash
python -m example_trainer.grpo \
--atropos-url "http://localhost:8002" \
--distillation-enabled \
--distillation-coef 0.1 \
--distillation-loss-type kl \
--distillation-temperature 1.0
```
### Verification checklist
- Environment logs show distillation fetch:
- `[DISTILL] Fetching teacher logprobs ...`
- `[DISTILL] Added teacher distill arrays ...`
- Teacher logs show completion/chat requests from the environment.
- API contains distill fields in latest example:
```bash
curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!=null), has_lps:(.distill_logprobs!=null)}'
```
- Trainer logs report distillation metrics (example trainer):
- `Distill: loss=...`
### Important notes
- For `server_type=vllm`, the environment expects a server exposing `/generate` (the custom server in `example_trainer/vllm_api_server.py`), not only `/v1/chat/completions`.
- Prefer explicitly setting `--openai.tokenizer_name` to your student tokenizer to avoid prompt token-ID mismatch.
---
## Testing and Debugging Tools

View file

@ -1276,7 +1276,11 @@ class BaseEnv(ABC):
import traceback
logger.error(traceback.format_exc())
else:
logger.debug(f"[DISTILL] Skipped - enabled={self.config.distillation_enabled}, url={self.config.teacher_base_url}")
logger.debug(
"[DISTILL] Skipped - enabled=%s, url=%s",
self.config.distillation_enabled,
self.config.teacher_base_url,
)
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
# send single or list of scored data groups
@ -1826,7 +1830,11 @@ class BaseEnv(ABC):
)
print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}")
else:
print(f"[CLI DEBUG] Not merging: default_openai_config_ type={type(default_openai_config_)}, yaml_oai_config type={type(yaml_oai_config)}")
print(
"[CLI DEBUG] Not merging: default_openai_config_ "
f"type={type(default_openai_config_)}, "
f"yaml_oai_config type={type(yaml_oai_config)}"
)
openai_config_dict = {}
# 3. Server Manager Configuration (slurm, testing - not namespaced)

View file

@ -7,6 +7,7 @@ This wrapper maintains a tree structure of sequences, where:
- Branching occurs organically from different contexts and n > 1 completions
"""
import os
import time
import uuid
import warnings
@ -131,6 +132,10 @@ class ManagedServer:
# Fallback for tokenizers without chat template
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
def _debug_requests_enabled(self) -> bool:
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]:
"""
Find a node that this input extends (default mode).
@ -284,6 +289,19 @@ class ManagedServer:
completion_kwargs = kwargs.copy()
completion_kwargs["prompt"] = prompt
completion_kwargs.pop("messages", None)
if self._debug_requests_enabled():
msg_count = len(messages)
prompt_preview = prompt.replace("\n", "\\n")[:600]
print(
f"[ATROPOS_REQ_DEBUG] chat_completion messages={msg_count} "
f"n={completion_kwargs.get('n')} max_tokens={completion_kwargs.get('max_tokens')} "
f"temperature={completion_kwargs.get('temperature')}",
flush=True,
)
print(
f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}",
flush=True,
)
# Set model name if not provided
if "model" not in completion_kwargs:

View file

@ -185,8 +185,14 @@ def resolve_openai_configs(
and len(default_server_configs) >= 2
)
print(f"[RESOLVE DEBUG] is_multi_server_yaml={is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}")
print(f"[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = {isinstance(default_server_configs, ServerBaseline)}")
print(
"[RESOLVE DEBUG] is_multi_server_yaml="
f"{is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}"
)
print(
"[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = "
f"{isinstance(default_server_configs, ServerBaseline)}"
)
if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config:
raise FailedExecutionException(

View file

@ -2,6 +2,7 @@
# see example_trainer/vllm_api_server.py for an example
import asyncio
import os
import warnings
import aiohttp
@ -189,6 +190,30 @@ class VLLMServer(APIServer):
# Prepare request for VLLM native API
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
request_data.update(kwargs)
debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
if debug_requests:
base = self.config.base_url.replace("/v1", "")
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n")
print(
f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate "
f"prompt_token_len={len(prompt_tokens)}",
flush=True,
)
print(
f"[ATROPOS_REQ_DEBUG] request_meta="
f"{{'n': {request_data.get('n')}, 'max_tokens': {request_data.get('max_tokens')}, "
f"'temperature': {request_data.get('temperature')}, 'top_p': {request_data.get('top_p')}}}",
flush=True,
)
print(
f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}",
flush=True,
)
print(
f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate "
'-H "Content-Type: application/json" -d \'<JSON_PAYLOAD>\'',
flush=True,
)
# Make async request to VLLM /generate endpoint
async with aiohttp.ClientSession() as session:

View file

@ -123,7 +123,11 @@ class GSM8kEnv(BaseEnv):
async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
"""Rollout and score evaluation with detailed sample data collection."""
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Important: use ManagedServer's default tokenizer resolution so it uses
# the underlying inference server tokenizer (e.g., Qwen) instead of the
# environment tokenizer. Passing self.tokenizer here can cause token-ID
# mismatch and gibberish generations when model/tokenizer families differ.
async with self.server.managed_server() as managed:
completion = await managed.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
@ -231,8 +235,16 @@ class GSM8kEnv(BaseEnv):
gold_answer = (
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
)
question_preview = item["question"].replace("\n", " ")[:120]
print(
f"[GSM8K_DEBUG] collect_start group_size={self.config.group_size} "
f"q={question_preview!r}",
flush=True,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
# Important: do not force env tokenizer into ManagedServer for rollout.
# Let ManagedServer use the server's tokenizer to keep prompt token IDs aligned.
async with self.server.managed_server() as managed:
chat_completions = await managed.chat_completion(
messages=[{"role": "system", "content": system_prompt}, user_message],
@ -243,10 +255,24 @@ class GSM8kEnv(BaseEnv):
state = managed.get_state()
nodes = state["nodes"]
print(
f"[GSM8K_DEBUG] completion_batch_received choices={len(chat_completions.choices)} "
f"nodes={len(nodes)}",
flush=True,
)
to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
response_text = chat_completion.message.content or ""
response_preview = response_text.replace("\n", " ")[:220]
valid_mask_count = sum(1 for m in nodes[i].masked_tokens if m != -100)
print(
f"[GSM8K_DEBUG] response_received idx={i} finish={chat_completion.finish_reason} "
f"tokens={len(nodes[i].tokens)} valid_masked={valid_mask_count} "
f"text={response_preview!r}",
flush=True,
)
messages = (
{"role": "system", "content": system_prompt},
user_message,
@ -263,6 +289,11 @@ class GSM8kEnv(BaseEnv):
}
)
to_postprocess = await self.score(to_score)
accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", []))
print(
f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}",
flush=True,
)
return to_postprocess, to_backlog
async def score(
@ -278,10 +309,15 @@ class GSM8kEnv(BaseEnv):
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
print(
f"[GSM8K_DEBUG] score_start candidates={len(rollout_group_data)} "
f"gold_parsed_len={len(gold_parsed)}",
flush=True,
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
random.shuffle(rollout_group_data)
for item in rollout_group_data:
for idx, item in enumerate(rollout_group_data):
# print(item[0][-1]["content"])
answer_parsed = parse(
item["messages"][-1]["content"].split("</think>")[-1],
@ -310,7 +346,19 @@ class GSM8kEnv(BaseEnv):
logprobs = item["logprobs"]
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
valid_mask_count = len([1 for i in masks if i != -100])
print(
f"[GSM8K_DEBUG] score_candidate idx={idx} parsed_len={len(answer_parsed)} "
f"reward={bool(reward)} valid_masked={valid_mask_count} "
f"tokens={len(tokens)}",
flush=True,
)
if valid_mask_count < 10:
print(
f"[GSM8K_DEBUG] drop_candidate idx={idx} reason=valid_masked_lt_10 "
f"value={valid_mask_count}",
flush=True,
)
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
@ -323,6 +371,13 @@ class GSM8kEnv(BaseEnv):
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
if len(scores["scores"]) == 0:
print(
"[GSM8K_DEBUG] drop_group reason=no_valid_candidates_after_filtering",
flush=True,
)
return None
# check if all the same
# print(scores['scores'])
if all([score == 1 for score in scores["scores"]]):
@ -330,6 +385,10 @@ class GSM8kEnv(BaseEnv):
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0:
# What? But don't want to crash a run so just in case...
print(
"[GSM8K_DEBUG] drop_group reason=zero_token_length_after_penalty_branch",
flush=True,
)
return None
# Get max allowed token length from config
@ -353,10 +412,20 @@ class GSM8kEnv(BaseEnv):
# Apply linear penalty scaling from 1.0 down to 0.0
scores["scores"].append(1.0 - percentage_of_range)
if all([scores["scores"][0] == score for score in scores["scores"]]):
print(
f"[GSM8K_DEBUG] drop_group reason=all_scores_identical scores={scores['scores']}",
flush=True,
)
return None # If all the same, we return None
print(
f"[GSM8K_DEBUG] score_done accepted={len(scores['scores'])} "
f"scores={scores['scores']}",
flush=True,
)
return scores
else:
# If the gold solution is not parseable, we return None
print("[GSM8K_DEBUG] drop_group reason=gold_unparseable", flush=True)
return None
async def get_next_item(self) -> GSM8kRow:

View file

@ -6,17 +6,12 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
import asyncio
import random
import re
import logging
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, List, Optional, Tuple
import aiohttp
import wandb
from datasets import load_dataset
# Set up logging for debug
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from math_verify.errors import TimeoutException
@ -129,8 +124,6 @@ class MathEnv(BaseEnv):
slurm=True,
testing=False,
):
print("Initializing MathEnv")
print(f"Slurm: {slurm}, Testing: {testing}")
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
@ -142,17 +135,6 @@ class MathEnv(BaseEnv):
self.pass_at_groupsize = list()
self.iter = 0
# Debug: Print distillation config
print("=" * 60)
print("[MATH_DEBUG] DISTILLATION CONFIGURATION:")
print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}")
print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}")
print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}")
print(f"[MATH_DEBUG] teacher_top_k = {getattr(config, 'teacher_top_k', 'N/A')}")
print(f"[MATH_DEBUG] teacher_prefix_text set = {bool(getattr(config, 'teacher_prefix_text', None))}")
print(f"[MATH_DEBUG] teacher_system_prompt set = {bool(getattr(config, 'teacher_system_prompt', None))}")
print("=" * 60)
@classmethod
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
env_config = RSConfig(
@ -269,86 +251,8 @@ class MathEnv(BaseEnv):
name,
)
)
# Debug: Test teacher connectivity if distillation is enabled
if self.config.distillation_enabled and self.config.teacher_base_url:
await self._test_teacher_connectivity()
return
async def _test_teacher_connectivity(self):
"""Test if the teacher model API is reachable."""
print("=" * 60)
print("[MATH_DEBUG] TESTING TEACHER CONNECTIVITY...")
print(f"[MATH_DEBUG] Teacher URL: {self.config.teacher_base_url}")
print(f"[MATH_DEBUG] Teacher Model: {getattr(self.config, 'teacher_model_name', 'default')}")
try:
async with aiohttp.ClientSession() as session:
# Test 1: Health check
health_url = self.config.teacher_base_url.replace("/v1", "") + "/health"
print(f"[MATH_DEBUG] Testing health endpoint: {health_url}")
try:
async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
print(f"[MATH_DEBUG] Health check status: {resp.status}")
if resp.status == 200:
print("[MATH_DEBUG] ✓ Teacher health check PASSED")
else:
print(f"[MATH_DEBUG] ✗ Teacher health check FAILED: {await resp.text()}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher health check ERROR: {e}")
# Test 2: Models endpoint
models_url = f"{self.config.teacher_base_url}/models"
print(f"[MATH_DEBUG] Testing models endpoint: {models_url}")
try:
async with session.get(models_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
print(f"[MATH_DEBUG] Models endpoint status: {resp.status}")
if resp.status == 200:
data = await resp.json()
models = [m.get("id", m) for m in data.get("data", [])]
print(f"[MATH_DEBUG] ✓ Available models: {models}")
else:
print(f"[MATH_DEBUG] ✗ Models endpoint FAILED: {await resp.text()}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Models endpoint ERROR: {e}")
# Test 3: Simple completion test
completions_url = f"{self.config.teacher_base_url}/completions"
teacher_model = getattr(self.config, 'teacher_model_name', 'default')
test_payload = {
"model": teacher_model,
"prompt": "Hello",
"max_tokens": 5,
"logprobs": 5,
"echo": True,
}
print(f"[MATH_DEBUG] Testing completions endpoint: {completions_url}")
print(f"[MATH_DEBUG] Test payload: {test_payload}")
try:
async with session.post(
completions_url,
json=test_payload,
headers={"Content-Type": "application/json"},
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
print(f"[MATH_DEBUG] Completions status: {resp.status}")
resp_text = await resp.text()
if resp.status == 200:
print(f"[MATH_DEBUG] ✓ Teacher completions WORKING!")
print(f"[MATH_DEBUG] Response preview: {resp_text[:500]}")
else:
print(f"[MATH_DEBUG] ✗ Teacher completions FAILED: {resp_text[:500]}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher completions ERROR: {e}")
except Exception as e:
print(f"[MATH_DEBUG] ✗ Teacher connectivity test FAILED: {e}")
import traceback
traceback.print_exc()
print("=" * 60)
async def rollout_and_score_eval(self, question, answer, subset):
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion(
@ -492,7 +396,6 @@ class MathEnv(BaseEnv):
)
if len(self.normal_rollouts) > self.config.num_rollouts_to_keep:
self.normal_rollouts.pop(0)
print(f"Collected {len(to_postprocess['scores'])} trajectories")
return to_postprocess, to_backlog
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
@ -578,67 +481,8 @@ class MathEnv(BaseEnv):
]
)
# Debug: Log scored group creation
print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences")
print(f"[MATH_DEBUG] Scores: {scores['scores']}")
print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}")
has_new_distill = (
"distill_token_ids" in scores and "distill_logprobs" in scores
)
print(f"[MATH_DEBUG] Has distill arrays: {has_new_distill}")
return scores
async def handle_send_to_api(
self,
scored_data,
item=None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
"""Override to add debugging for distillation."""
print(f"[MATH_DEBUG] handle_send_to_api called")
print(f"[MATH_DEBUG] distillation_enabled: {self.config.distillation_enabled}")
print(f"[MATH_DEBUG] teacher_base_url: {self.config.teacher_base_url}")
if isinstance(scored_data, list):
for i, group in enumerate(scored_data):
if group:
has_distill = (
group.get("distill_token_ids") is not None
and group.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
elif scored_data:
has_distill = (
scored_data.get("distill_token_ids") is not None
and scored_data.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
# Call parent implementation which does the actual distillation fetch
result = await super().handle_send_to_api(
scored_data, item, do_send_to_api, abort_on_any_max_length_exceeded
)
# Debug: Check if distillation was added after parent call
if isinstance(scored_data, list):
for i, group in enumerate(scored_data):
if group:
has_distill = (
group.get("distill_token_ids") is not None
and group.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}")
elif scored_data:
has_distill = (
scored_data.get("distill_token_ids") is not None
and scored_data.get("distill_logprobs") is not None
)
print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}")
return result
async def get_next_item(self):
while True:
next_item = self.train[self.iter % len(self.train)]
@ -652,10 +496,7 @@ class MathEnv(BaseEnv):
)
break
except TypeError:
print(
f"Error in getting next item, trying again, "
f"data: {next_item['question']} -> {next_item['final_answer']}"
)
continue
return (prompt, answer, "normal")