mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
on policy clean up
This commit is contained in:
parent
79e392c446
commit
1c90fc71b0
7 changed files with 262 additions and 168 deletions
127
README.md
127
README.md
|
|
@ -256,6 +256,133 @@ Atropos repo contains an example trainer that should primarily be used as a refe
|
||||||
|
|
||||||
To use the example trainer, see this page: [training example guide](example_trainer/README.md)
|
To use the example trainer, see this page: [training example guide](example_trainer/README.md)
|
||||||
|
|
||||||
|
## On-Policy Distillation (Environment + API Flow)
|
||||||
|
|
||||||
|
Atropos supports on-policy distillation by fetching teacher top-k token distributions for the same trajectories that are used for RL, then attaching that teacher data to the batches consumed by your trainer.
|
||||||
|
|
||||||
|
### How the flow works
|
||||||
|
|
||||||
|
1. **Student rollouts are generated by the environment**
|
||||||
|
- Example: `environments/gsm8k_server.py` samples `group_size` completions from your student inference server.
|
||||||
|
2. **Environment scores and validates the group**
|
||||||
|
- Normal Atropos filtering still applies (group structure, optional score-equality checks, max token checks, etc.).
|
||||||
|
3. **Teacher logprobs are fetched in `BaseEnv.handle_send_to_api`**
|
||||||
|
- If `distillation_enabled=true` and `teacher_base_url` is set, Atropos calls the teacher endpoint and builds:
|
||||||
|
- `distill_token_ids` with shape `[sequence][position][top_k]`
|
||||||
|
- `distill_logprobs` with shape `[sequence][position][top_k]`
|
||||||
|
4. **Distillation arrays are attached to each scored group**
|
||||||
|
- Added as `group["distill_token_ids"]` and `group["distill_logprobs"]`.
|
||||||
|
5. **Atropos API stores and serves these fields unchanged**
|
||||||
|
- `/scored_data` and `/batch` include the distillation arrays.
|
||||||
|
6. **Trainer consumes both RL and distillation signals**
|
||||||
|
- Example trainer computes GRPO + distillation loss from the same batch.
|
||||||
|
|
||||||
|
### Configuration knobs in environments
|
||||||
|
|
||||||
|
Distillation is configured in `BaseEnvConfig` and available via CLI under `--env.*`:
|
||||||
|
|
||||||
|
- `--env.distillation_enabled true`
|
||||||
|
- `--env.teacher_base_url http://localhost:8003/v1`
|
||||||
|
- `--env.teacher_model_name <teacher_model>`
|
||||||
|
- `--env.teacher_api_key <optional>` (or `TEACHER_API_KEY`)
|
||||||
|
- `--env.teacher_top_k 20`
|
||||||
|
- Optional steering controls:
|
||||||
|
- `--env.teacher_prefix_text "..."`
|
||||||
|
- `--env.teacher_system_prompt "..."`
|
||||||
|
|
||||||
|
### Self-distillation vs cross-model distillation
|
||||||
|
|
||||||
|
Both setups are supported:
|
||||||
|
|
||||||
|
- **Self-distillation (same model family for teacher and student)**
|
||||||
|
Point `teacher_base_url` to a server running the same model (or equivalent checkpoint family) as the student. This is the most stable setup for token-level alignment.
|
||||||
|
|
||||||
|
- **Cross-model distillation (different teacher and student models)**
|
||||||
|
Also supported, but tokenization compatibility becomes more important. If token vocabularies/template behavior differ significantly, alignment quality may degrade.
|
||||||
|
|
||||||
|
In practice, self-distillation is usually easiest to bring up first, then cross-model can be layered in once your pipeline is stable.
|
||||||
|
|
||||||
|
### Tokenization and alignment details
|
||||||
|
|
||||||
|
Atropos handles tokenization in two places:
|
||||||
|
|
||||||
|
1. **Student rollout path (`server_type=vllm`)**
|
||||||
|
- The `/generate` request is built via the vLLM server handler and uses the server-side tokenizer configured by:
|
||||||
|
- `--openai.tokenizer_name` (or falls back to `--openai.model_name`)
|
||||||
|
- Recommendation: set `--openai.tokenizer_name` explicitly to match the student serving model.
|
||||||
|
|
||||||
|
2. **Teacher top-k parsing path**
|
||||||
|
- Teacher responses are parsed into token ids/logprobs in `BaseEnv.get_teacher_logprobs`.
|
||||||
|
- The parser maps teacher token strings into ids using the environment tokenizer (`self.tokenizer`) and then aligns to student sequence length.
|
||||||
|
|
||||||
|
Because distillation is token-position based, keeping tokenizer families compatible is strongly recommended, especially for cross-model distillation.
|
||||||
|
|
||||||
|
### Minimal bring-up example
|
||||||
|
|
||||||
|
Run each command in a separate terminal.
|
||||||
|
|
||||||
|
1. **Start Atropos API**
|
||||||
|
```bash
|
||||||
|
run-api --port 8002
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Start teacher server (OpenAI-compatible endpoint)**
|
||||||
|
```bash
|
||||||
|
python -m vllm.entrypoints.openai.api_server \
|
||||||
|
--model "$TEACHER_MODEL" \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 8003
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Start student server for environments (`/generate` endpoint)**
|
||||||
|
```bash
|
||||||
|
python -m example_trainer.vllm_api_server \
|
||||||
|
--model "$STUDENT_MODEL" \
|
||||||
|
--port 9001
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Start environment with distillation enabled**
|
||||||
|
```bash
|
||||||
|
python environments/gsm8k_server.py serve \
|
||||||
|
--env.rollout_server_url "http://localhost:8002" \
|
||||||
|
--env.distillation_enabled true \
|
||||||
|
--env.teacher_base_url "http://localhost:8003/v1" \
|
||||||
|
--env.teacher_model_name "$TEACHER_MODEL" \
|
||||||
|
--env.teacher_top_k 20 \
|
||||||
|
--openai.server_type vllm \
|
||||||
|
--openai.base_url "http://localhost:9001/v1" \
|
||||||
|
--openai.model_name "$STUDENT_MODEL" \
|
||||||
|
--openai.tokenizer_name "$STUDENT_MODEL"
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Start trainer with distillation flags**
|
||||||
|
```bash
|
||||||
|
python -m example_trainer.grpo \
|
||||||
|
--atropos-url "http://localhost:8002" \
|
||||||
|
--distillation-enabled \
|
||||||
|
--distillation-coef 0.1 \
|
||||||
|
--distillation-loss-type kl \
|
||||||
|
--distillation-temperature 1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Verification checklist
|
||||||
|
|
||||||
|
- Environment logs show distillation fetch:
|
||||||
|
- `[DISTILL] Fetching teacher logprobs ...`
|
||||||
|
- `[DISTILL] Added teacher distill arrays ...`
|
||||||
|
- Teacher logs show completion/chat requests from the environment.
|
||||||
|
- API contains distill fields in latest example:
|
||||||
|
```bash
|
||||||
|
curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!=null), has_lps:(.distill_logprobs!=null)}'
|
||||||
|
```
|
||||||
|
- Trainer logs report distillation metrics (example trainer):
|
||||||
|
- `Distill: loss=...`
|
||||||
|
|
||||||
|
### Important notes
|
||||||
|
|
||||||
|
- For `server_type=vllm`, the environment expects a server exposing `/generate` (the custom server in `example_trainer/vllm_api_server.py`), not only `/v1/chat/completions`.
|
||||||
|
- Prefer explicitly setting `--openai.tokenizer_name` to your student tokenizer to avoid prompt token-ID mismatch.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Testing and Debugging Tools
|
## Testing and Debugging Tools
|
||||||
|
|
|
||||||
|
|
@ -1276,7 +1276,11 @@ class BaseEnv(ABC):
|
||||||
import traceback
|
import traceback
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
else:
|
else:
|
||||||
logger.debug(f"[DISTILL] Skipped - enabled={self.config.distillation_enabled}, url={self.config.teacher_base_url}")
|
logger.debug(
|
||||||
|
"[DISTILL] Skipped - enabled=%s, url=%s",
|
||||||
|
self.config.distillation_enabled,
|
||||||
|
self.config.teacher_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||||
# send single or list of scored data groups
|
# send single or list of scored data groups
|
||||||
|
|
@ -1826,7 +1830,11 @@ class BaseEnv(ABC):
|
||||||
)
|
)
|
||||||
print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}")
|
print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}")
|
||||||
else:
|
else:
|
||||||
print(f"[CLI DEBUG] Not merging: default_openai_config_ type={type(default_openai_config_)}, yaml_oai_config type={type(yaml_oai_config)}")
|
print(
|
||||||
|
"[CLI DEBUG] Not merging: default_openai_config_ "
|
||||||
|
f"type={type(default_openai_config_)}, "
|
||||||
|
f"yaml_oai_config type={type(yaml_oai_config)}"
|
||||||
|
)
|
||||||
openai_config_dict = {}
|
openai_config_dict = {}
|
||||||
|
|
||||||
# 3. Server Manager Configuration (slurm, testing - not namespaced)
|
# 3. Server Manager Configuration (slurm, testing - not namespaced)
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ This wrapper maintains a tree structure of sequences, where:
|
||||||
- Branching occurs organically from different contexts and n > 1 completions
|
- Branching occurs organically from different contexts and n > 1 completions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -131,6 +132,10 @@ class ManagedServer:
|
||||||
# Fallback for tokenizers without chat template
|
# Fallback for tokenizers without chat template
|
||||||
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
||||||
|
|
||||||
|
def _debug_requests_enabled(self) -> bool:
|
||||||
|
"""Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1."""
|
||||||
|
return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
|
||||||
|
|
||||||
def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]:
|
def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]:
|
||||||
"""
|
"""
|
||||||
Find a node that this input extends (default mode).
|
Find a node that this input extends (default mode).
|
||||||
|
|
@ -284,6 +289,19 @@ class ManagedServer:
|
||||||
completion_kwargs = kwargs.copy()
|
completion_kwargs = kwargs.copy()
|
||||||
completion_kwargs["prompt"] = prompt
|
completion_kwargs["prompt"] = prompt
|
||||||
completion_kwargs.pop("messages", None)
|
completion_kwargs.pop("messages", None)
|
||||||
|
if self._debug_requests_enabled():
|
||||||
|
msg_count = len(messages)
|
||||||
|
prompt_preview = prompt.replace("\n", "\\n")[:600]
|
||||||
|
print(
|
||||||
|
f"[ATROPOS_REQ_DEBUG] chat_completion messages={msg_count} "
|
||||||
|
f"n={completion_kwargs.get('n')} max_tokens={completion_kwargs.get('max_tokens')} "
|
||||||
|
f"temperature={completion_kwargs.get('temperature')}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Set model name if not provided
|
# Set model name if not provided
|
||||||
if "model" not in completion_kwargs:
|
if "model" not in completion_kwargs:
|
||||||
|
|
|
||||||
|
|
@ -185,8 +185,14 @@ def resolve_openai_configs(
|
||||||
and len(default_server_configs) >= 2
|
and len(default_server_configs) >= 2
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[RESOLVE DEBUG] is_multi_server_yaml={is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}")
|
print(
|
||||||
print(f"[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = {isinstance(default_server_configs, ServerBaseline)}")
|
"[RESOLVE DEBUG] is_multi_server_yaml="
|
||||||
|
f"{is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = "
|
||||||
|
f"{isinstance(default_server_configs, ServerBaseline)}"
|
||||||
|
)
|
||||||
|
|
||||||
if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config:
|
if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config:
|
||||||
raise FailedExecutionException(
|
raise FailedExecutionException(
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
# see example_trainer/vllm_api_server.py for an example
|
# see example_trainer/vllm_api_server.py for an example
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
@ -189,6 +190,30 @@ class VLLMServer(APIServer):
|
||||||
# Prepare request for VLLM native API
|
# Prepare request for VLLM native API
|
||||||
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
|
request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0}
|
||||||
request_data.update(kwargs)
|
request_data.update(kwargs)
|
||||||
|
debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
|
||||||
|
if debug_requests:
|
||||||
|
base = self.config.base_url.replace("/v1", "")
|
||||||
|
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n")
|
||||||
|
print(
|
||||||
|
f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate "
|
||||||
|
f"prompt_token_len={len(prompt_tokens)}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[ATROPOS_REQ_DEBUG] request_meta="
|
||||||
|
f"{{'n': {request_data.get('n')}, 'max_tokens': {request_data.get('max_tokens')}, "
|
||||||
|
f"'temperature': {request_data.get('temperature')}, 'top_p': {request_data.get('top_p')}}}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate "
|
||||||
|
'-H "Content-Type: application/json" -d \'<JSON_PAYLOAD>\'',
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Make async request to VLLM /generate endpoint
|
# Make async request to VLLM /generate endpoint
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,11 @@ class GSM8kEnv(BaseEnv):
|
||||||
async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
|
async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
|
||||||
"""Rollout and score evaluation with detailed sample data collection."""
|
"""Rollout and score evaluation with detailed sample data collection."""
|
||||||
|
|
||||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
# Important: use ManagedServer's default tokenizer resolution so it uses
|
||||||
|
# the underlying inference server tokenizer (e.g., Qwen) instead of the
|
||||||
|
# environment tokenizer. Passing self.tokenizer here can cause token-ID
|
||||||
|
# mismatch and gibberish generations when model/tokenizer families differ.
|
||||||
|
async with self.server.managed_server() as managed:
|
||||||
completion = await managed.chat_completion(
|
completion = await managed.chat_completion(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
|
|
@ -231,8 +235,16 @@ class GSM8kEnv(BaseEnv):
|
||||||
gold_answer = (
|
gold_answer = (
|
||||||
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
|
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
|
||||||
)
|
)
|
||||||
|
question_preview = item["question"].replace("\n", " ")[:120]
|
||||||
|
print(
|
||||||
|
f"[GSM8K_DEBUG] collect_start group_size={self.config.group_size} "
|
||||||
|
f"q={question_preview!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
# Important: do not force env tokenizer into ManagedServer for rollout.
|
||||||
|
# Let ManagedServer use the server's tokenizer to keep prompt token IDs aligned.
|
||||||
|
async with self.server.managed_server() as managed:
|
||||||
|
|
||||||
chat_completions = await managed.chat_completion(
|
chat_completions = await managed.chat_completion(
|
||||||
messages=[{"role": "system", "content": system_prompt}, user_message],
|
messages=[{"role": "system", "content": system_prompt}, user_message],
|
||||||
|
|
@ -243,10 +255,24 @@ class GSM8kEnv(BaseEnv):
|
||||||
|
|
||||||
state = managed.get_state()
|
state = managed.get_state()
|
||||||
nodes = state["nodes"]
|
nodes = state["nodes"]
|
||||||
|
print(
|
||||||
|
f"[GSM8K_DEBUG] completion_batch_received choices={len(chat_completions.choices)} "
|
||||||
|
f"nodes={len(nodes)}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
to_score = list()
|
to_score = list()
|
||||||
to_backlog = list()
|
to_backlog = list()
|
||||||
for i, chat_completion in enumerate(chat_completions.choices):
|
for i, chat_completion in enumerate(chat_completions.choices):
|
||||||
|
response_text = chat_completion.message.content or ""
|
||||||
|
response_preview = response_text.replace("\n", " ")[:220]
|
||||||
|
valid_mask_count = sum(1 for m in nodes[i].masked_tokens if m != -100)
|
||||||
|
print(
|
||||||
|
f"[GSM8K_DEBUG] response_received idx={i} finish={chat_completion.finish_reason} "
|
||||||
|
f"tokens={len(nodes[i].tokens)} valid_masked={valid_mask_count} "
|
||||||
|
f"text={response_preview!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
messages = (
|
messages = (
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
user_message,
|
user_message,
|
||||||
|
|
@ -263,6 +289,11 @@ class GSM8kEnv(BaseEnv):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
to_postprocess = await self.score(to_score)
|
to_postprocess = await self.score(to_score)
|
||||||
|
accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", []))
|
||||||
|
print(
|
||||||
|
f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
return to_postprocess, to_backlog
|
return to_postprocess, to_backlog
|
||||||
|
|
||||||
async def score(
|
async def score(
|
||||||
|
|
@ -278,10 +309,15 @@ class GSM8kEnv(BaseEnv):
|
||||||
extraction_mode="first_match",
|
extraction_mode="first_match",
|
||||||
extraction_config=[LatexExtractionConfig()],
|
extraction_config=[LatexExtractionConfig()],
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f"[GSM8K_DEBUG] score_start candidates={len(rollout_group_data)} "
|
||||||
|
f"gold_parsed_len={len(gold_parsed)}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
if len(gold_parsed) != 0:
|
if len(gold_parsed) != 0:
|
||||||
# We require the answer to be provided in correct latex (no malformed operators)
|
# We require the answer to be provided in correct latex (no malformed operators)
|
||||||
random.shuffle(rollout_group_data)
|
random.shuffle(rollout_group_data)
|
||||||
for item in rollout_group_data:
|
for idx, item in enumerate(rollout_group_data):
|
||||||
# print(item[0][-1]["content"])
|
# print(item[0][-1]["content"])
|
||||||
answer_parsed = parse(
|
answer_parsed = parse(
|
||||||
item["messages"][-1]["content"].split("</think>")[-1],
|
item["messages"][-1]["content"].split("</think>")[-1],
|
||||||
|
|
@ -310,7 +346,19 @@ class GSM8kEnv(BaseEnv):
|
||||||
logprobs = item["logprobs"]
|
logprobs = item["logprobs"]
|
||||||
|
|
||||||
# remove obviously bad examples
|
# remove obviously bad examples
|
||||||
if len([1 for i in masks if i != -100]) < 10:
|
valid_mask_count = len([1 for i in masks if i != -100])
|
||||||
|
print(
|
||||||
|
f"[GSM8K_DEBUG] score_candidate idx={idx} parsed_len={len(answer_parsed)} "
|
||||||
|
f"reward={bool(reward)} valid_masked={valid_mask_count} "
|
||||||
|
f"tokens={len(tokens)}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
if valid_mask_count < 10:
|
||||||
|
print(
|
||||||
|
f"[GSM8K_DEBUG] drop_candidate idx={idx} reason=valid_masked_lt_10 "
|
||||||
|
f"value={valid_mask_count}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
scores["tokens"].append(tokens)
|
scores["tokens"].append(tokens)
|
||||||
scores["masks"].append(masks)
|
scores["masks"].append(masks)
|
||||||
|
|
@ -323,6 +371,13 @@ class GSM8kEnv(BaseEnv):
|
||||||
for score in scores["scores"]:
|
for score in scores["scores"]:
|
||||||
self.percent_correct_buffer.append(max(score, 0))
|
self.percent_correct_buffer.append(max(score, 0))
|
||||||
|
|
||||||
|
if len(scores["scores"]) == 0:
|
||||||
|
print(
|
||||||
|
"[GSM8K_DEBUG] drop_group reason=no_valid_candidates_after_filtering",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
# check if all the same
|
# check if all the same
|
||||||
# print(scores['scores'])
|
# print(scores['scores'])
|
||||||
if all([score == 1 for score in scores["scores"]]):
|
if all([score == 1 for score in scores["scores"]]):
|
||||||
|
|
@ -330,6 +385,10 @@ class GSM8kEnv(BaseEnv):
|
||||||
token_lengths = [len(token) for token in scores["tokens"]]
|
token_lengths = [len(token) for token in scores["tokens"]]
|
||||||
if max(token_lengths) == 0:
|
if max(token_lengths) == 0:
|
||||||
# What? But don't want to crash a run so just in case...
|
# What? But don't want to crash a run so just in case...
|
||||||
|
print(
|
||||||
|
"[GSM8K_DEBUG] drop_group reason=zero_token_length_after_penalty_branch",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get max allowed token length from config
|
# Get max allowed token length from config
|
||||||
|
|
@ -353,10 +412,20 @@ class GSM8kEnv(BaseEnv):
|
||||||
# Apply linear penalty scaling from 1.0 down to 0.0
|
# Apply linear penalty scaling from 1.0 down to 0.0
|
||||||
scores["scores"].append(1.0 - percentage_of_range)
|
scores["scores"].append(1.0 - percentage_of_range)
|
||||||
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||||
|
print(
|
||||||
|
f"[GSM8K_DEBUG] drop_group reason=all_scores_identical scores={scores['scores']}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
return None # If all the same, we return None
|
return None # If all the same, we return None
|
||||||
|
print(
|
||||||
|
f"[GSM8K_DEBUG] score_done accepted={len(scores['scores'])} "
|
||||||
|
f"scores={scores['scores']}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
return scores
|
return scores
|
||||||
else:
|
else:
|
||||||
# If the gold solution is not parseable, we return None
|
# If the gold solution is not parseable, we return None
|
||||||
|
print("[GSM8K_DEBUG] drop_group reason=gold_unparseable", flush=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_next_item(self) -> GSM8kRow:
|
async def get_next_item(self) -> GSM8kRow:
|
||||||
|
|
|
||||||
|
|
@ -6,17 +6,12 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import logging
|
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import wandb
|
import wandb
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
# Set up logging for debug
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
from latex2sympy2_extended import NormalizationConfig
|
||||||
from math_verify import LatexExtractionConfig, parse, verify
|
from math_verify import LatexExtractionConfig, parse, verify
|
||||||
from math_verify.errors import TimeoutException
|
from math_verify.errors import TimeoutException
|
||||||
|
|
@ -129,8 +124,6 @@ class MathEnv(BaseEnv):
|
||||||
slurm=True,
|
slurm=True,
|
||||||
testing=False,
|
testing=False,
|
||||||
):
|
):
|
||||||
print("Initializing MathEnv")
|
|
||||||
print(f"Slurm: {slurm}, Testing: {testing}")
|
|
||||||
super().__init__(config, server_configs, slurm, testing)
|
super().__init__(config, server_configs, slurm, testing)
|
||||||
self.percent_correct_buffer = list()
|
self.percent_correct_buffer = list()
|
||||||
self.eval_metrics = list()
|
self.eval_metrics = list()
|
||||||
|
|
@ -141,17 +134,6 @@ class MathEnv(BaseEnv):
|
||||||
self.normal_rollouts = list()
|
self.normal_rollouts = list()
|
||||||
self.pass_at_groupsize = list()
|
self.pass_at_groupsize = list()
|
||||||
self.iter = 0
|
self.iter = 0
|
||||||
|
|
||||||
# Debug: Print distillation config
|
|
||||||
print("=" * 60)
|
|
||||||
print("[MATH_DEBUG] DISTILLATION CONFIGURATION:")
|
|
||||||
print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}")
|
|
||||||
print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}")
|
|
||||||
print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}")
|
|
||||||
print(f"[MATH_DEBUG] teacher_top_k = {getattr(config, 'teacher_top_k', 'N/A')}")
|
|
||||||
print(f"[MATH_DEBUG] teacher_prefix_text set = {bool(getattr(config, 'teacher_prefix_text', None))}")
|
|
||||||
print(f"[MATH_DEBUG] teacher_system_prompt set = {bool(getattr(config, 'teacher_system_prompt', None))}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
|
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
|
||||||
|
|
@ -269,85 +251,7 @@ class MathEnv(BaseEnv):
|
||||||
name,
|
name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Debug: Test teacher connectivity if distillation is enabled
|
|
||||||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
|
||||||
await self._test_teacher_connectivity()
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
async def _test_teacher_connectivity(self):
|
|
||||||
"""Test if the teacher model API is reachable."""
|
|
||||||
print("=" * 60)
|
|
||||||
print("[MATH_DEBUG] TESTING TEACHER CONNECTIVITY...")
|
|
||||||
print(f"[MATH_DEBUG] Teacher URL: {self.config.teacher_base_url}")
|
|
||||||
print(f"[MATH_DEBUG] Teacher Model: {getattr(self.config, 'teacher_model_name', 'default')}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
# Test 1: Health check
|
|
||||||
health_url = self.config.teacher_base_url.replace("/v1", "") + "/health"
|
|
||||||
print(f"[MATH_DEBUG] Testing health endpoint: {health_url}")
|
|
||||||
try:
|
|
||||||
async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
|
||||||
print(f"[MATH_DEBUG] Health check status: {resp.status}")
|
|
||||||
if resp.status == 200:
|
|
||||||
print("[MATH_DEBUG] ✓ Teacher health check PASSED")
|
|
||||||
else:
|
|
||||||
print(f"[MATH_DEBUG] ✗ Teacher health check FAILED: {await resp.text()}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[MATH_DEBUG] ✗ Teacher health check ERROR: {e}")
|
|
||||||
|
|
||||||
# Test 2: Models endpoint
|
|
||||||
models_url = f"{self.config.teacher_base_url}/models"
|
|
||||||
print(f"[MATH_DEBUG] Testing models endpoint: {models_url}")
|
|
||||||
try:
|
|
||||||
async with session.get(models_url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
|
||||||
print(f"[MATH_DEBUG] Models endpoint status: {resp.status}")
|
|
||||||
if resp.status == 200:
|
|
||||||
data = await resp.json()
|
|
||||||
models = [m.get("id", m) for m in data.get("data", [])]
|
|
||||||
print(f"[MATH_DEBUG] ✓ Available models: {models}")
|
|
||||||
else:
|
|
||||||
print(f"[MATH_DEBUG] ✗ Models endpoint FAILED: {await resp.text()}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[MATH_DEBUG] ✗ Models endpoint ERROR: {e}")
|
|
||||||
|
|
||||||
# Test 3: Simple completion test
|
|
||||||
completions_url = f"{self.config.teacher_base_url}/completions"
|
|
||||||
teacher_model = getattr(self.config, 'teacher_model_name', 'default')
|
|
||||||
test_payload = {
|
|
||||||
"model": teacher_model,
|
|
||||||
"prompt": "Hello",
|
|
||||||
"max_tokens": 5,
|
|
||||||
"logprobs": 5,
|
|
||||||
"echo": True,
|
|
||||||
}
|
|
||||||
print(f"[MATH_DEBUG] Testing completions endpoint: {completions_url}")
|
|
||||||
print(f"[MATH_DEBUG] Test payload: {test_payload}")
|
|
||||||
try:
|
|
||||||
async with session.post(
|
|
||||||
completions_url,
|
|
||||||
json=test_payload,
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
timeout=aiohttp.ClientTimeout(total=30),
|
|
||||||
) as resp:
|
|
||||||
print(f"[MATH_DEBUG] Completions status: {resp.status}")
|
|
||||||
resp_text = await resp.text()
|
|
||||||
if resp.status == 200:
|
|
||||||
print(f"[MATH_DEBUG] ✓ Teacher completions WORKING!")
|
|
||||||
print(f"[MATH_DEBUG] Response preview: {resp_text[:500]}")
|
|
||||||
else:
|
|
||||||
print(f"[MATH_DEBUG] ✗ Teacher completions FAILED: {resp_text[:500]}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[MATH_DEBUG] ✗ Teacher completions ERROR: {e}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[MATH_DEBUG] ✗ Teacher connectivity test FAILED: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
async def rollout_and_score_eval(self, question, answer, subset):
|
async def rollout_and_score_eval(self, question, answer, subset):
|
||||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||||
|
|
@ -492,7 +396,6 @@ class MathEnv(BaseEnv):
|
||||||
)
|
)
|
||||||
if len(self.normal_rollouts) > self.config.num_rollouts_to_keep:
|
if len(self.normal_rollouts) > self.config.num_rollouts_to_keep:
|
||||||
self.normal_rollouts.pop(0)
|
self.normal_rollouts.pop(0)
|
||||||
print(f"Collected {len(to_postprocess['scores'])} trajectories")
|
|
||||||
return to_postprocess, to_backlog
|
return to_postprocess, to_backlog
|
||||||
|
|
||||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||||
|
|
@ -578,66 +481,7 @@ class MathEnv(BaseEnv):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Debug: Log scored group creation
|
|
||||||
print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences")
|
|
||||||
print(f"[MATH_DEBUG] Scores: {scores['scores']}")
|
|
||||||
print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}")
|
|
||||||
has_new_distill = (
|
|
||||||
"distill_token_ids" in scores and "distill_logprobs" in scores
|
|
||||||
)
|
|
||||||
print(f"[MATH_DEBUG] Has distill arrays: {has_new_distill}")
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
async def handle_send_to_api(
|
|
||||||
self,
|
|
||||||
scored_data,
|
|
||||||
item=None,
|
|
||||||
do_send_to_api: bool = True,
|
|
||||||
abort_on_any_max_length_exceeded: bool = True,
|
|
||||||
):
|
|
||||||
"""Override to add debugging for distillation."""
|
|
||||||
print(f"[MATH_DEBUG] handle_send_to_api called")
|
|
||||||
print(f"[MATH_DEBUG] distillation_enabled: {self.config.distillation_enabled}")
|
|
||||||
print(f"[MATH_DEBUG] teacher_base_url: {self.config.teacher_base_url}")
|
|
||||||
|
|
||||||
if isinstance(scored_data, list):
|
|
||||||
for i, group in enumerate(scored_data):
|
|
||||||
if group:
|
|
||||||
has_distill = (
|
|
||||||
group.get("distill_token_ids") is not None
|
|
||||||
and group.get("distill_logprobs") is not None
|
|
||||||
)
|
|
||||||
print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
|
|
||||||
elif scored_data:
|
|
||||||
has_distill = (
|
|
||||||
scored_data.get("distill_token_ids") is not None
|
|
||||||
and scored_data.get("distill_logprobs") is not None
|
|
||||||
)
|
|
||||||
print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}")
|
|
||||||
|
|
||||||
# Call parent implementation which does the actual distillation fetch
|
|
||||||
result = await super().handle_send_to_api(
|
|
||||||
scored_data, item, do_send_to_api, abort_on_any_max_length_exceeded
|
|
||||||
)
|
|
||||||
|
|
||||||
# Debug: Check if distillation was added after parent call
|
|
||||||
if isinstance(scored_data, list):
|
|
||||||
for i, group in enumerate(scored_data):
|
|
||||||
if group:
|
|
||||||
has_distill = (
|
|
||||||
group.get("distill_token_ids") is not None
|
|
||||||
and group.get("distill_logprobs") is not None
|
|
||||||
)
|
|
||||||
print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}")
|
|
||||||
elif scored_data:
|
|
||||||
has_distill = (
|
|
||||||
scored_data.get("distill_token_ids") is not None
|
|
||||||
and scored_data.get("distill_logprobs") is not None
|
|
||||||
)
|
|
||||||
print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def get_next_item(self):
|
async def get_next_item(self):
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -652,10 +496,7 @@ class MathEnv(BaseEnv):
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except TypeError:
|
except TypeError:
|
||||||
print(
|
continue
|
||||||
f"Error in getting next item, trying again, "
|
|
||||||
f"data: {next_item['question']} -> {next_item['final_answer']}"
|
|
||||||
)
|
|
||||||
return (prompt, answer, "normal")
|
return (prompt, answer, "normal")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue