From f44eb810bf605d277953d13036b3244b5811c1e4 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 6 Mar 2026 13:58:38 -0500 Subject: [PATCH 01/64] teacher env init --- atroposlib/envs/teacher_distillation_env.py | 188 ++++++++++++++++++ atroposlib/tests/test_managed_server.py | 24 --- atroposlib/tests/test_server_logprobs.py | 4 +- .../tests/test_teacher_distillation_env.py | 69 +++++++ 4 files changed, 258 insertions(+), 27 deletions(-) create mode 100644 atroposlib/envs/teacher_distillation_env.py create mode 100644 atroposlib/tests/test_teacher_distillation_env.py diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py new file mode 100644 index 00000000..5ff96bc7 --- /dev/null +++ b/atroposlib/envs/teacher_distillation_env.py @@ -0,0 +1,188 @@ +""" +Teacher distillation environment layer. + +This module adds teacher prompt-logprob fetching on top of BaseEnv without +modifying BaseEnv transport behavior. +""" + +from __future__ import annotations + +import asyncio +import logging +from abc import ABC +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import Field + +from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup +from .server_handling.server_baseline import APIServerConfig, ServerBaseline +from .server_handling.server_manager import ServerManager + +logger = logging.getLogger(__name__) + + +class TeacherDistillationConfig(BaseEnvConfig): + teacher_enabled: bool = Field( + default=False, + description="Whether to fetch teacher prompt logprobs for distillation.", + ) + teacher_base_url: Optional[str] = Field( + default=None, + description="Teacher server base URL (OpenAI-compatible).", + ) + teacher_model_name: Optional[str] = Field( + default=None, + description="Teacher model name used in teacher server requests.", + ) + teacher_api_key: str = Field( + default="", + description="Teacher API key, if required by the teacher endpoint.", + ) + teacher_server_type: str = Field( + default="vllm", + description="Teacher server type (e.g. vllm, sglang, trl, openai).", + ) + teacher_tokenizer_name: str = Field( + default="none", + description=( + "Tokenizer name for teacher server. If 'none', teacher_model_name is used." + ), + ) + teacher_top_k: int = Field( + default=1, + ge=1, + description="Top-k prompt logprobs to fetch per token position.", + ) + + +class TeacherDistillationEnv(BaseEnv, ABC): + """ + BaseEnv subclass that enriches scored groups with teacher distillation arrays. + + Distillation payload shape: + - distill_token_ids: [sequence][position][k] + - distill_logprobs: [sequence][position][k] + """ + + env_config_cls = TeacherDistillationConfig + + def __init__( + self, + config: TeacherDistillationConfig, + server_configs: Union[ServerBaseline, List[APIServerConfig]], + slurm: bool = False, + testing: bool = False, + ): + super().__init__(config, server_configs, slurm=slurm, testing=testing) + self.teacher_server: Optional[ServerManager] = None + if config.teacher_enabled: + if not config.teacher_base_url or not config.teacher_model_name: + raise ValueError( + "teacher_enabled=True requires teacher_base_url and teacher_model_name." + ) + teacher_cfg = APIServerConfig( + server_type=config.teacher_server_type, # type: ignore[arg-type] + base_url=config.teacher_base_url, + api_key=config.teacher_api_key, + model_name=config.teacher_model_name, + tokenizer_name=config.teacher_tokenizer_name, + timeout=1200, + ) + self.teacher_server = ServerManager( + [teacher_cfg], + slurm=False, + testing=False, + ) + + async def _fetch_teacher_for_sequence( + self, token_ids: List[int], top_k: int + ) -> Tuple[List[List[int]], List[List[float]]]: + assert self.teacher_server is not None + payload = await self.teacher_server.get_logprobs( + input_ids=token_ids, + top_k=top_k, + max_tokens=1, + split="train", + ) + return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"] + + async def _attach_teacher_distillation( + self, group: ScoredDataGroup + ) -> ScoredDataGroup: + if not self.config.teacher_enabled or self.teacher_server is None: + return group + + seqs = group.get("tokens", []) + if not seqs: + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + + top_k = int( + (group.get("group_overrides") or {}).get( + "teacher_top_k", self.config.teacher_top_k + ) + ) + top_k = max(1, top_k) + + tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs] + results = await asyncio.gather(*tasks, return_exceptions=True) + + distill_token_ids: List[List[List[int]]] = [] + distill_logprobs: List[List[List[float]]] = [] + for idx, result in enumerate(results): + if isinstance(result, Exception): + logger.warning( + "Teacher logprob fetch failed for seq %s: %s. " + "Dropping distill payload for this group.", + idx, + result, + ) + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + token_ids_k, logprobs_k = result + if len(token_ids_k) != len(logprobs_k): + logger.warning( + "Teacher prompt-topk length mismatch for seq %s (%s != %s). " + "Dropping distill payload for this group.", + idx, + len(token_ids_k), + len(logprobs_k), + ) + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + distill_token_ids.append(token_ids_k) + distill_logprobs.append(logprobs_k) + + group["distill_token_ids"] = distill_token_ids + group["distill_logprobs"] = distill_logprobs + return group + + async def handle_send_to_api( + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Any = None, + do_send_to_api: bool = True, + abort_on_any_max_length_exceeded: bool = True, + ): + groups = scored_data if isinstance(scored_data, list) else [scored_data] + enriched_groups: List[ScoredDataGroup] = [] + for group in groups: + if group is None: + continue + enriched_groups.append(await self._attach_teacher_distillation(group)) + + payload: Union[ScoredDataGroup, List[ScoredDataGroup]] + if isinstance(scored_data, list): + payload = enriched_groups + else: + payload = enriched_groups[0] if enriched_groups else scored_data + + return await super().handle_send_to_api( + payload, + item=item, + do_send_to_api=do_send_to_api, + abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded, + ) diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index 1524aaf7..6f18be08 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -319,30 +319,6 @@ async def test_get_logprobs_messages_passthrough(mock_server): assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens) -@pytest.mark.asyncio -async def test_get_logprobs_input_ids_only_passthrough(mock_server): - """ManagedServer.get_logprobs supports input_ids-only without requiring prompt.""" - managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) - input_ids = [10, 20, 30] - - async def _mock_get_logprobs(**kwargs): - assert "input_ids" in kwargs - assert kwargs["input_ids"] == input_ids - assert kwargs.get("prompt") is None - return { - "prompt_tokens": input_ids, - "prompt_topk_token_ids": [[t] for t in input_ids], - "prompt_topk_logprobs": [[-0.1] for _ in input_ids], - } - - mock_server.get_logprobs = _mock_get_logprobs - payload = await managed.get_logprobs(input_ids=input_ids, top_k=1) - - assert payload["prompt_tokens"] == input_ids - assert payload["prompt_topk_token_ids"] == [[10], [20], [30]] - assert payload["prompt_topk_logprobs"] == [[-0.1], [-0.1], [-0.1]] - - @pytest.mark.asyncio async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server): """ManagedServer.get_logprobs requires backend get_logprobs in strict mode.""" diff --git a/atroposlib/tests/test_server_logprobs.py b/atroposlib/tests/test_server_logprobs.py index 8cbd84ad..2da50b42 100644 --- a/atroposlib/tests/test_server_logprobs.py +++ b/atroposlib/tests/test_server_logprobs.py @@ -41,9 +41,7 @@ class _FakeAPIServer(APIServer): class _FakeRoutedServer: - def __init__( - self, name: str, train_slots: int, eval_slots: int, healthy: bool = True - ): + def __init__(self, name: str, train_slots: int, eval_slots: int, healthy: bool = True): self.name = name self.server_healthy = healthy self.sem = AsyncSemWithAdaptiveWeight(4) diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py new file mode 100644 index 00000000..199f1453 --- /dev/null +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -0,0 +1,69 @@ +"""Tests for TeacherDistillationEnv distillation enrichment.""" + +from types import SimpleNamespace + +import pytest + +from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv + + +class _FakeTeacherServer: + def __init__(self, fail_on_call: int = -1): + self.calls = 0 + self.fail_on_call = fail_on_call + + async def get_logprobs(self, **kwargs): + self.calls += 1 + if self.calls == self.fail_on_call: + raise RuntimeError("teacher backend failure") + seq = kwargs["input_ids"] + return { + "prompt_tokens": seq, + "prompt_topk_token_ids": [[tok, tok + 1] for tok in seq], + "prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq], + } + + +class _ConcreteTeacherEnv(TeacherDistillationEnv): + async def get_next_item(self): + return None + + async def evaluate(self, *args, **kwargs): + return None + + +@pytest.mark.asyncio +async def test_attach_teacher_distillation_success(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2) + env.teacher_server = _FakeTeacherServer() + + group = { + "tokens": [[1, 2, 3], [4, 5]], + "group_overrides": None, + "masks": [[-100, 2, 3], [-100, 5]], + "scores": [1.0, 0.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert out["distill_token_ids"] is not None + assert out["distill_logprobs"] is not None + assert len(out["distill_token_ids"]) == 2 + assert len(out["distill_token_ids"][0]) == 3 + assert len(out["distill_logprobs"][1]) == 2 + + +@pytest.mark.asyncio +async def test_attach_teacher_distillation_failure_drops_payload(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2) + env.teacher_server = _FakeTeacherServer(fail_on_call=2) + + group = { + "tokens": [[1, 2, 3], [4, 5]], + "group_overrides": None, + "masks": [[-100, 2, 3], [-100, 5]], + "scores": [1.0, 0.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None From 530fed2877e9cd985005e28bfc54a399a03f4db1 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 6 Mar 2026 14:49:32 -0500 Subject: [PATCH 02/64] testing set up --- environments/gsm8k_server_teacher_distill.py | 51 ++++ example_trainer/cli.py | 20 ++ example_trainer/config.py | 12 + example_trainer/data.py | 120 +++++++- example_trainer/run.py | 3 + ...n_gsm8k_teacher_distill_single_terminal.sh | 267 ++++++++++++++++++ example_trainer/trainers.py | 16 ++ example_trainer/training.py | 112 +++++++- 8 files changed, 599 insertions(+), 2 deletions(-) create mode 100644 environments/gsm8k_server_teacher_distill.py create mode 100755 example_trainer/run_gsm8k_teacher_distill_single_terminal.sh diff --git a/environments/gsm8k_server_teacher_distill.py b/environments/gsm8k_server_teacher_distill.py new file mode 100644 index 00000000..159fa4d3 --- /dev/null +++ b/environments/gsm8k_server_teacher_distill.py @@ -0,0 +1,51 @@ +from typing import Tuple + +from atroposlib.envs.base import APIServerConfig, ServerBaseline +from atroposlib.envs.teacher_distillation_env import ( + TeacherDistillationConfig, + TeacherDistillationEnv, +) +from environments.gsm8k_server import GSM8kEnv + + +class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): + """ + GSM8K environment variant that enables TeacherDistillationEnv config fields. + + This preserves the original `gsm8k_server.py` while providing a separate entrypoint + for teacher-distillation data collection. + """ + + name = "gsm8k_teacher_distill" + env_config_cls = TeacherDistillationConfig + + @classmethod + def config_init(cls) -> Tuple[TeacherDistillationConfig, ServerBaseline]: + env_config = TeacherDistillationConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, + wandb_name="gsm8k_teacher_distill", + teacher_enabled=True, + teacher_base_url="http://localhost:8003/v1", + teacher_model_name="mock-teacher", + teacher_api_key="", + teacher_server_type="vllm", + teacher_tokenizer_name="none", + teacher_top_k=4, + ) + server_config = APIServerConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, + ) + return env_config, server_config + +if __name__ == "__main__": + GSM8kTeacherDistillEnv.cli() diff --git a/example_trainer/cli.py b/example_trainer/cli.py index 1e46bfc9..93946d51 100644 --- a/example_trainer/cli.py +++ b/example_trainer/cli.py @@ -163,6 +163,23 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None: default=0.2, help="PPO-style clipping epsilon. Clips ratio to [1-eps, 1+eps].", ) + group.add_argument( + "--distill-enabled", + action="store_true", + help="Enable teacher distillation loss (requires distill payload in Atropos batch).", + ) + group.add_argument( + "--distill-coef", + type=float, + default=0.0, + help="Coefficient for distillation loss term.", + ) + group.add_argument( + "--distill-temperature", + type=float, + default=1.0, + help="Temperature for teacher top-k distribution in distillation loss.", + ) def add_vllm_args(parser: argparse.ArgumentParser) -> None: @@ -424,6 +441,9 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: checkpoint_interval=getattr(args, "checkpoint_interval", 3), # GRPO/PPO hyperparameters clip_eps=getattr(args, "clip_eps", 0.2), + distill_enabled=getattr(args, "distill_enabled", False), + distill_coef=getattr(args, "distill_coef", 0.0), + distill_temperature=getattr(args, "distill_temperature", 1.0), adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False), adafactor_relative_step=getattr(args, "adafactor_relative_step", False), # vLLM settings diff --git a/example_trainer/config.py b/example_trainer/config.py index 4ddeddb5..03fd80a8 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -69,6 +69,18 @@ class TrainingConfig(BaseModel): "Prevents large policy updates that could destabilize training." ), ) + distill_enabled: bool = Field( + False, + description="Enable teacher distillation loss when distill tensors are present.", + ) + distill_coef: float = Field( + 0.0, + description="Weight for distillation loss in total loss.", + ) + distill_temperature: float = Field( + 1.0, + description="Temperature applied when converting teacher top-k logprobs.", + ) # === Device & Storage === device: str = Field( "cuda" if torch.cuda.is_available() else "cpu", description="Device to train on" diff --git a/example_trainer/data.py b/example_trainer/data.py index 16a38564..770d68fa 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -29,6 +29,8 @@ def pad_data_to_good_offset( List[torch.Tensor], # advantage_batches List[torch.Tensor], # temperature_batches Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels) + Optional[List[torch.Tensor]], # distill_token_id_batches [batch, seq, k] + Optional[List[torch.Tensor]], # distill_logprob_batches [batch, seq, k] ]: """ Pad and batch data from the Atropos API. @@ -45,7 +47,8 @@ def pad_data_to_good_offset( extract_inference_logprobs: Whether to extract inference logprobs Returns: - Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches) + Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, + inference_logprob_batches, distill_token_id_batches, distill_logprob_batches) inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data Note: @@ -73,6 +76,10 @@ def pad_data_to_good_offset( temperatures = [] inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape has_any_logprobs = False + distill_token_ids_padded: List[np.ndarray] = [] + distill_logprobs_padded: List[np.ndarray] = [] + has_any_distill = False + max_distill_k = 1 for item in data["batch"]: # Normalize advantage scores @@ -153,6 +160,77 @@ def pad_data_to_good_offset( np.full(token_setup_len - 1, 1.0, dtype=np.float32) ) + # Extract teacher distillation top-k arrays if available. + # Expected shape in incoming payload: [sequence][position][k]. + if "distill_token_ids" in item and "distill_logprobs" in item: + seq_token_ids = item["distill_token_ids"] + seq_logprobs = item["distill_logprobs"] + if ( + isinstance(seq_token_ids, list) + and isinstance(seq_logprobs, list) + and i < len(seq_token_ids) + and i < len(seq_logprobs) + and seq_token_ids[i] is not None + and seq_logprobs[i] is not None + ): + per_pos_token_ids = seq_token_ids[i] + per_pos_logprobs = seq_logprobs[i] + if ( + isinstance(per_pos_token_ids, list) + and isinstance(per_pos_logprobs, list) + and len(per_pos_token_ids) == len(per_pos_logprobs) + ): + local_k = 1 + for row_ids in per_pos_token_ids: + if isinstance(row_ids, list): + local_k = max(local_k, len(row_ids)) + max_distill_k = max(max_distill_k, local_k) + has_any_distill = True + + rows = max(0, token_setup_len - 1) + token_mat = np.full((rows, local_k), -1, dtype=np.int64) + logprob_mat = np.full( + (rows, local_k), -1e9, dtype=np.float32 + ) + + # Shift by one to align with causal labels like inference_logprobs. + copy_positions = min( + len(per_pos_token_ids), len(per_pos_logprobs), token_setup_len + ) + for pos in range(1, copy_positions): + src_ids = per_pos_token_ids[pos] + src_lps = per_pos_logprobs[pos] + if not isinstance(src_ids, list) or not isinstance(src_lps, list): + continue + topk = min(local_k, len(src_ids), len(src_lps)) + if topk <= 0: + continue + token_mat[pos - 1, :topk] = np.array(src_ids[:topk], dtype=np.int64) + logprob_mat[pos - 1, :topk] = np.array( + src_lps[:topk], dtype=np.float32 + ) + + distill_token_ids_padded.append(token_mat) + distill_logprobs_padded.append(logprob_mat) + else: + rows = max(0, token_setup_len - 1) + distill_token_ids_padded.append( + np.full((rows, 1), -1, dtype=np.int64) + ) + distill_logprobs_padded.append( + np.full((rows, 1), -1e9, dtype=np.float32) + ) + else: + rows = max(0, token_setup_len - 1) + distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64)) + distill_logprobs_padded.append( + np.full((rows, 1), -1e9, dtype=np.float32) + ) + else: + rows = max(0, token_setup_len - 1) + distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64)) + distill_logprobs_padded.append(np.full((rows, 1), -1e9, dtype=np.float32)) + # Extract temperature (priority: override > generation_params > group_overrides > 1.0) t = 1.0 if ( @@ -178,6 +256,8 @@ def pad_data_to_good_offset( advantage_batches = [] temperature_batches = [] inference_logprob_batches = [] + distill_token_id_batches = [] + distill_logprob_batches = [] for start in range(0, len(input_ids), batch_size): end = min(start + batch_size, len(input_ids)) @@ -199,12 +279,42 @@ def pad_data_to_good_offset( torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0)) ) + if distill_token_ids_padded and distill_logprobs_padded: + seq_slice_ids = distill_token_ids_padded[start:end] + seq_slice_lps = distill_logprobs_padded[start:end] + normalized_ids = [] + normalized_lps = [] + for ids_mat, lps_mat in zip(seq_slice_ids, seq_slice_lps): + if ids_mat.shape[1] < max_distill_k: + pad_cols = max_distill_k - ids_mat.shape[1] + ids_mat = np.pad( + ids_mat, ((0, 0), (0, pad_cols)), constant_values=-1 + ) + lps_mat = np.pad( + lps_mat, ((0, 0), (0, pad_cols)), constant_values=-1e9 + ) + normalized_ids.append(ids_mat) + normalized_lps.append(lps_mat) + + distill_token_id_batches.append( + torch.tensor(np.stack(normalized_ids, axis=0), dtype=torch.long) + ) + distill_logprob_batches.append( + torch.tensor(np.stack(normalized_lps, axis=0), dtype=torch.float32) + ) + # Return inference logprob batches if we have any real logprobs final_logprob_batches = ( inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None ) + final_distill_token_id_batches = ( + distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None + ) + final_distill_logprob_batches = ( + distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None + ) return ( token_batches, @@ -212,6 +322,8 @@ def pad_data_to_good_offset( advantage_batches, temperature_batches, final_logprob_batches, + final_distill_token_id_batches, + final_distill_logprob_batches, ) @@ -228,6 +340,8 @@ def get_data( List[torch.Tensor], # advantage_batches List[torch.Tensor], # temperature_batches Optional[List[torch.Tensor]], # inference_logprob_batches + Optional[List[torch.Tensor]], # distill_token_id_batches + Optional[List[torch.Tensor]], # distill_logprob_batches ] ], None, # Legacy return (no longer used) @@ -299,6 +413,8 @@ def get_data( adv_batches, temp_batches, inf_logprob_batches, + distill_token_id_batches, + distill_logprob_batches, ) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) # Include inference logprob batches in the tuple @@ -309,6 +425,8 @@ def get_data( adv_batches, temp_batches, inf_logprob_batches, + distill_token_id_batches, + distill_logprob_batches, ) ) diff --git a/example_trainer/run.py b/example_trainer/run.py index b9b5f88f..d1cf37b2 100644 --- a/example_trainer/run.py +++ b/example_trainer/run.py @@ -201,6 +201,9 @@ def main(): checkpoint_interval=args.checkpoint_interval, # GRPO hyperparameters clip_eps=args.clip_eps, + distill_enabled=getattr(args, "distill_enabled", False), + distill_coef=getattr(args, "distill_coef", 0.0), + distill_temperature=getattr(args, "distill_temperature", 1.0), # vLLM settings vllm_port=args.vllm_port, vllm_gpu_memory_utilization=args.gpu_memory_utilization, diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh new file mode 100755 index 00000000..797f2cb5 --- /dev/null +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -0,0 +1,267 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Single-terminal teacher-distillation runner. +# Starts everything in the background from ONE shell that has GPU access: +# 1) Atropos API +# 2) Student vLLM server +# 3) Teacher vLLM server +# 4) GSM8K teacher-distill environment +# 5) Example trainer (foreground) +# +# Usage: +# chmod +x example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +# ./example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +# +# Optional overrides: +# STUDENT_MODEL="Qwen/Qwen3-4B-Instruct-2507-FP8" +# TEACHER_MODEL="Qwen/Qwen3-30B-A3B-Instruct-2507" +# STUDENT_GPUS="0,1" +# TEACHER_GPUS="4,5,6,7" +# TRAINER_GPU="2" +# STUDENT_TP=2 +# TEACHER_TP=4 +# API_PORT=8002 +# STUDENT_PORT=9001 +# TEACHER_PORT=9003 +# TRAINING_STEPS=100 +# DISTILL_COEF=0.2 +# DISTILL_TEMPERATURE=1.0 +# TEACHER_TOP_K=8 +# DRY_RUN=1 + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +LAUNCH_DIR="$PWD" +cd "$ROOT_DIR" + +STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B-Instruct-2507-FP8}" +TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}" + +STUDENT_GPUS="${STUDENT_GPUS:-0,1}" +TEACHER_GPUS="${TEACHER_GPUS:-4,5,6,7}" +TRAINER_GPU="${TRAINER_GPU:-2}" + +STUDENT_TP="${STUDENT_TP:-2}" +TEACHER_TP="${TEACHER_TP:-4}" + +API_PORT="${API_PORT:-8002}" +STUDENT_PORT="${STUDENT_PORT:-9001}" +TEACHER_PORT="${TEACHER_PORT:-9003}" + +TRAINING_STEPS="${TRAINING_STEPS:-100}" +BATCH_SIZE="${BATCH_SIZE:-2}" +GRAD_ACCUM="${GRAD_ACCUM:-8}" +LR="${LR:-1e-5}" +WARMUP_STEPS="${WARMUP_STEPS:-0}" +CLIP_EPS="${CLIP_EPS:-0.2}" +MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}" +DISTILL_COEF="${DISTILL_COEF:-0.2}" +DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" +TEACHER_TOP_K="${TEACHER_TOP_K:-8}" + +STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.90}" +TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.92}" +DTYPE="${DTYPE:-bfloat16}" +SAVE_DIR="${SAVE_DIR:-${LAUNCH_DIR}/saves/gsm8k_teacher_distill}" +LOG_DIR="${LOG_DIR:-${LAUNCH_DIR}/logs/gsm8k_teacher_distill}" +DRY_RUN="${DRY_RUN:-0}" + +ENV_GROUP_SIZE="${ENV_GROUP_SIZE:-4}" +ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-16}" +ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}" +ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}" +ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-8}" + +RUN_PIDS=() +RUN_PORTS=() + +log() { + echo "[$(date '+%H:%M:%S')] $*" +} + +kill_port() { + local port="$1" + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] skip port cleanup for :${port}" + return 0 + fi + if lsof -i ":${port}" -sTCP:LISTEN >/dev/null 2>&1; then + lsof -ti ":${port}" | xargs -r kill -9 || true + fi +} + +wait_for_http() { + local url="$1" + local timeout="${2:-240}" + local name="${3:-endpoint}" + local start + start="$(date +%s)" + while true; do + if curl -fsS "$url" >/dev/null 2>&1; then + log "Ready: ${name} (${url})" + return 0 + fi + if (( "$(date +%s)" - start > timeout )); then + log "Timeout waiting for ${name}: ${url}" + return 1 + fi + sleep 2 + done +} + +start_process() { + local name="$1" + local logfile="$2" + shift 2 + if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] start ${name} (log: ${logfile})" + printf ' ' + printf '%q ' "$@" + printf '\n' + return 0 + fi + log "Starting ${name} (log: ${logfile})" + "$@" >"$logfile" 2>&1 & + local pid=$! + RUN_PIDS+=("$pid") + log "${name} PID=${pid}" +} + +cleanup_all() { + log "Cleaning up processes..." + for pid in "${RUN_PIDS[@]:-}"; do + kill "$pid" >/dev/null 2>&1 || true + done + sleep 1 + for pid in "${RUN_PIDS[@]:-}"; do + kill -9 "$pid" >/dev/null 2>&1 || true + done + for port in "${RUN_PORTS[@]:-}"; do + kill_port "$port" + done +} + +trap cleanup_all EXIT INT TERM + +mkdir -p "$LOG_DIR" "$SAVE_DIR" +RUN_PORTS+=("$API_PORT" "$STUDENT_PORT" "$TEACHER_PORT") +kill_port "$API_PORT" +kill_port "$STUDENT_PORT" +kill_port "$TEACHER_PORT" + +log "Config:" +log " student=${STUDENT_MODEL}" +log " teacher=${TEACHER_MODEL}" +log " gpus student=${STUDENT_GPUS}, teacher=${TEACHER_GPUS}, trainer=${TRAINER_GPU}" +log " ports api=${API_PORT}, student=${STUDENT_PORT}, teacher=${TEACHER_PORT}" +log " logs=${LOG_DIR}" +log " saves=${SAVE_DIR}" + +# 1) Atropos API +start_process "run_api" "${LOG_DIR}/run_api.log" \ + uv run python -m atroposlib.cli.run_api --port "$API_PORT" +if [[ "$DRY_RUN" == "0" ]]; then + wait_for_http "http://localhost:${API_PORT}/info" 60 "run-api" +fi + +# 2) Student vLLM server +start_process "student_vllm" "${LOG_DIR}/student_vllm.log" \ + env CUDA_VISIBLE_DEVICES="$STUDENT_GPUS" \ + uv run python -m example_trainer.vllm_api_server \ + --model "$STUDENT_MODEL" \ + --port "$STUDENT_PORT" \ + --tensor-parallel-size "$STUDENT_TP" \ + --gpu-memory-utilization "$STUDENT_GPU_MEMORY_UTILIZATION" \ + --max-model-len "$MAX_MODEL_LEN" \ + --dtype "$DTYPE" +if [[ "$DRY_RUN" == "0" ]]; then + wait_for_http "http://localhost:${STUDENT_PORT}/health" 420 "student vLLM" +fi + +# 3) Teacher vLLM server +start_process "teacher_vllm" "${LOG_DIR}/teacher_vllm.log" \ + env CUDA_VISIBLE_DEVICES="$TEACHER_GPUS" \ + uv run python -m example_trainer.vllm_api_server \ + --model "$TEACHER_MODEL" \ + --port "$TEACHER_PORT" \ + --tensor-parallel-size "$TEACHER_TP" \ + --gpu-memory-utilization "$TEACHER_GPU_MEMORY_UTILIZATION" \ + --max-model-len "$MAX_MODEL_LEN" \ + --dtype "$DTYPE" +if [[ "$DRY_RUN" == "0" ]]; then + wait_for_http "http://localhost:${TEACHER_PORT}/health" 600 "teacher vLLM" +fi + +# 4) Teacher-distill GSM8K env +start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \ + uv run python environments/gsm8k_server_teacher_distill.py serve \ + --env.group_size "$ENV_GROUP_SIZE" \ + --env.batch_size "$ENV_BATCH_SIZE" \ + --env.total_steps "$ENV_TOTAL_STEPS" \ + --env.steps_per_eval "$ENV_STEPS_PER_EVAL" \ + --env.max_num_workers_per_node "$ENV_MAX_WORKERS_PER_NODE" \ + --env.max_token_length "$MAX_MODEL_LEN" \ + --env.rollout_server_url "http://localhost:${API_PORT}" \ + --env.use_wandb true \ + --env.wandb_name "gsm8k-teacher-distill" \ + --env.distillation_enabled true \ + --env.teacher_enabled true \ + --env.teacher_base_url "http://localhost:${TEACHER_PORT}/v1" \ + --env.teacher_model_name "$TEACHER_MODEL" \ + --env.teacher_top_k "$TEACHER_TOP_K" \ + --openai.api_key "dummy" \ + --openai.base_url "http://localhost:${STUDENT_PORT}/v1" \ + --openai.model_name "$STUDENT_MODEL" \ + --openai.tokenizer_name "$STUDENT_MODEL" \ + --openai.server_type vllm + +log "All services launched." +log "Run logs:" +log " ${LOG_DIR}/run_api.log" +log " ${LOG_DIR}/student_vllm.log" +log " ${LOG_DIR}/teacher_vllm.log" +log " ${LOG_DIR}/env.log" + +# 5) Trainer (foreground, primary output) +if [[ "$DRY_RUN" == "1" ]]; then + log "[DRY RUN] trainer command:" + printf ' ' + printf '%q ' env CUDA_VISIBLE_DEVICES="$TRAINER_GPU" \ + uv run python -m example_trainer.grpo \ + --model-name "$STUDENT_MODEL" \ + --weight-bridge-mode none \ + --device cuda:0 \ + --save-path "$SAVE_DIR" \ + --atropos-url "http://localhost:${API_PORT}" \ + --training-steps "$TRAINING_STEPS" \ + --batch-size "$BATCH_SIZE" \ + --gradient-accumulation-steps "$GRAD_ACCUM" \ + --warmup-steps "$WARMUP_STEPS" \ + --lr "$LR" \ + --clip-eps "$CLIP_EPS" \ + --distill-enabled \ + --distill-coef "$DISTILL_COEF" \ + --distill-temperature "$DISTILL_TEMPERATURE" + printf '\n' + exit 0 +fi + +log "Starting trainer in foreground..." +env CUDA_VISIBLE_DEVICES="$TRAINER_GPU" \ + uv run python -m example_trainer.grpo \ + --model-name "$STUDENT_MODEL" \ + --weight-bridge-mode none \ + --device cuda:0 \ + --save-path "$SAVE_DIR" \ + --atropos-url "http://localhost:${API_PORT}" \ + --training-steps "$TRAINING_STEPS" \ + --batch-size "$BATCH_SIZE" \ + --gradient-accumulation-steps "$GRAD_ACCUM" \ + --warmup-steps "$WARMUP_STEPS" \ + --lr "$LR" \ + --clip-eps "$CLIP_EPS" \ + --distill-enabled \ + --distill-coef "$DISTILL_COEF" \ + --distill-temperature "$DISTILL_TEMPERATURE" | tee "${LOG_DIR}/trainer.log" + +log "Training finished." diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index 4c9e2893..cc96cee5 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -170,6 +170,8 @@ def train_legacy(config: TrainingConfig): batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None + distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None + distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -192,6 +194,8 @@ def train_legacy(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, + distill_token_id_batches=distill_token_id_batches, + distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) @@ -324,6 +328,8 @@ def train_shared_vllm(config: TrainingConfig): batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None + distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None + distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -339,6 +345,8 @@ def train_shared_vllm(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation + distill_token_id_batches=distill_token_id_batches, + distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) @@ -484,6 +492,8 @@ def train_lora(config: TrainingConfig): batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None + distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None + distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -499,6 +509,8 @@ def train_lora(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, + distill_token_id_batches=distill_token_id_batches, + distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) @@ -706,6 +718,8 @@ def train_lora_restart(config: TrainingConfig): batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None + distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None + distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -721,6 +735,8 @@ def train_lora_restart(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, + distill_token_id_batches=distill_token_id_batches, + distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) diff --git a/example_trainer/training.py b/example_trainer/training.py index 035d45c7..c5b739e9 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -70,6 +70,11 @@ def compute_grpo_loss( gradient_accumulation_steps: int, inference_logprobs: Optional[torch.Tensor] = None, clip_eps: float = 0.2, + distill_token_ids: Optional[torch.Tensor] = None, + distill_logprobs: Optional[torch.Tensor] = None, + distill_enabled: bool = False, + distill_coef: float = 0.0, + distill_temperature: float = 1.0, ) -> Tuple[torch.Tensor, dict]: """ Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch. @@ -125,6 +130,9 @@ def compute_grpo_loss( logprob_diff_abs_mean = 0.0 logprob_diff_max = 0.0 + distill_loss_value = torch.tensor(0.0, device=logp_per_token.device) + distill_token_count = 0.0 + # === GRPO/PPO Loss Computation === if inference_logprobs is not None: # Move inference logprobs to correct device/dtype @@ -187,7 +195,23 @@ def compute_grpo_loss( # Average over tokens, then over batch policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean() - total_loss = policy_loss / gradient_accumulation_steps + if ( + distill_enabled + and distill_coef > 0 + and distill_token_ids is not None + and distill_logprobs is not None + ): + distill_loss_value, distill_token_count = compute_distillation_loss( + logits=scaled_logits, + labels=labels, + distill_token_ids=distill_token_ids.to(logits.device), + distill_logprobs=distill_logprobs.to(logits.device, logits.dtype), + temperature=max(1e-6, float(distill_temperature)), + ) + + total_loss = (policy_loss + distill_coef * distill_loss_value) / ( + gradient_accumulation_steps + ) # Compute metrics for logging with torch.no_grad(): @@ -253,11 +277,66 @@ def compute_grpo_loss( "logprob_diff_mean": logprob_diff_mean, "logprob_diff_abs_mean": logprob_diff_abs_mean, "logprob_diff_max": logprob_diff_max, + "distill_loss": ( + distill_loss_value.item() + if torch.is_tensor(distill_loss_value) + else float(distill_loss_value) + ), + "distill_token_count": distill_token_count, } return total_loss, metrics +def compute_distillation_loss( + logits: torch.Tensor, + labels: torch.Tensor, + distill_token_ids: torch.Tensor, + distill_logprobs: torch.Tensor, + temperature: float = 1.0, +) -> Tuple[torch.Tensor, float]: + """ + Compute token-level distillation loss from teacher top-k prompt logprobs. + + Args: + logits: Student logits [batch, seq_len, vocab] + labels: Labels [batch, seq_len], -100 for masked positions + distill_token_ids: Teacher top-k token IDs [batch, seq_len, k], -1 padded + distill_logprobs: Teacher top-k logprobs [batch, seq_len, k], very negative padded + temperature: Distillation temperature + + Returns: + Tuple of (distillation loss scalar, valid token count) + """ + if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3: + return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 + + if distill_token_ids.shape[:2] != labels.shape or distill_logprobs.shape != distill_token_ids.shape: + return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 + + temp = max(1e-6, float(temperature)) + student_log_probs = F.log_softmax(logits / temp, dim=-1) + + valid_ids = distill_token_ids >= 0 + label_mask = labels != -100 + valid_pos = label_mask & valid_ids.any(dim=-1) + if not valid_pos.any(): + return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 + + gather_ids = distill_token_ids.clamp_min(0).long() + student_logp_topk = torch.gather(student_log_probs, dim=-1, index=gather_ids) + + masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9) + teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1) + + per_token_loss = -(teacher_probs * student_logp_topk).sum(dim=-1) + per_token_loss = per_token_loss * valid_pos.to(per_token_loss.dtype) + + token_count = valid_pos.sum().item() + loss = per_token_loss.sum() / valid_pos.sum().clamp_min(1).to(per_token_loss.dtype) + return loss, float(token_count) + + def run_training_step( model: torch.nn.Module, optimizer: torch.optim.Optimizer, @@ -268,6 +347,8 @@ def run_training_step( config: TrainingConfig, step_idx: int, inference_logprob_batches: Optional[List[torch.Tensor]] = None, + distill_token_id_batches: Optional[List[torch.Tensor]] = None, + distill_logprob_batches: Optional[List[torch.Tensor]] = None, ) -> dict: """ Run a single training step with gradient accumulation. @@ -302,6 +383,8 @@ def run_training_step( total_logprob_diff_mean = 0.0 total_logprob_diff_abs_mean = 0.0 total_logprob_diff_max = 0.0 + total_distill_loss = 0.0 + total_distill_tokens = 0.0 grad_norm = 0.0 all_training_logprobs: List[torch.Tensor] = [] all_inference_logprobs: List[torch.Tensor] = [] @@ -335,6 +418,16 @@ def run_training_step( inference_logprob_batches ): inf_logprobs = inference_logprob_batches[batch_idx] + distill_ids = None + if distill_token_id_batches is not None and batch_idx < len( + distill_token_id_batches + ): + distill_ids = distill_token_id_batches[batch_idx] + distill_lps = None + if distill_logprob_batches is not None and batch_idx < len( + distill_logprob_batches + ): + distill_lps = distill_logprob_batches[batch_idx] loss, metrics = compute_grpo_loss( model, @@ -345,6 +438,11 @@ def run_training_step( config.gradient_accumulation_steps, inference_logprobs=inf_logprobs, clip_eps=clip_eps, + distill_token_ids=distill_ids, + distill_logprobs=distill_lps, + distill_enabled=bool(getattr(config, "distill_enabled", False)), + distill_coef=float(getattr(config, "distill_coef", 0.0)), + distill_temperature=float(getattr(config, "distill_temperature", 1.0)), ) loss.backward() @@ -364,6 +462,8 @@ def run_training_step( total_logprob_diff_max = max( total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0) ) + total_distill_loss += metrics.get("distill_loss", 0.0) + total_distill_tokens += metrics.get("distill_token_count", 0.0) # Collect logprobs for alignment monitoring if "training_logprobs" in metrics and metrics["training_logprobs"] is not None: @@ -399,6 +499,8 @@ def run_training_step( # GRPO-specific metrics (averaged over batches) "mean_ratio": total_mean_ratio / num_batches, "clipped_fraction": total_clipped_fraction / num_batches, + "distill_loss": total_distill_loss / num_batches, + "distill_token_count": total_distill_tokens, } # Compute logprob alignment stats for monitoring @@ -472,6 +574,12 @@ def log_metrics( clipped_frac = metrics.get("clipped_fraction", 0) print(f" GRPO: ratio={mean_ratio:.3f}, clipped={clipped_frac*100:.1f}%") + if metrics.get("distill_token_count", 0) > 0: + print( + " Distill: " + f"loss={metrics.get('distill_loss', 0.0):.4f}, " + f"tokens={int(metrics.get('distill_token_count', 0))}" + ) # Advantage distribution if "pos_count" in metrics or "neg_count" in metrics: @@ -494,6 +602,8 @@ def log_metrics( # GRPO-specific metrics "grpo/mean_ratio": mean_ratio, "grpo/clipped_fraction": clipped_frac, + "distill/loss": metrics.get("distill_loss", 0.0), + "distill/token_count": metrics.get("distill_token_count", 0.0), } # Add timing metrics if present for key in [ From d5ca760f367a3af32e30aa0b7c4cea775f2b3755 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Sat, 7 Mar 2026 17:50:14 -0500 Subject: [PATCH 03/64] command change --- example_trainer/README.md | 2 +- .../run_gsm8k_teacher_distill_single_terminal.sh | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/example_trainer/README.md b/example_trainer/README.md index a6820614..ddb96b8a 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -550,7 +550,7 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands ### WandB Logging -```bash + ```bash --use-wandb \ --wandb-project "my-grpo-training" \ --wandb-group "hermes-8b-gsm8k" diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 797f2cb5..dac797a3 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -34,6 +34,7 @@ ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" LAUNCH_DIR="$PWD" cd "$ROOT_DIR" +PYTHON_BIN="${PYTHON_BIN:-python3}" STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B-Instruct-2507-FP8}" TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}" @@ -159,15 +160,15 @@ log " saves=${SAVE_DIR}" # 1) Atropos API start_process "run_api" "${LOG_DIR}/run_api.log" \ - uv run python -m atroposlib.cli.run_api --port "$API_PORT" + run-api --port "$API_PORT" if [[ "$DRY_RUN" == "0" ]]; then - wait_for_http "http://localhost:${API_PORT}/info" 60 "run-api" + wait_for_http "http://localhost:${API_PORT}/info" 180 "run-api" fi # 2) Student vLLM server start_process "student_vllm" "${LOG_DIR}/student_vllm.log" \ env CUDA_VISIBLE_DEVICES="$STUDENT_GPUS" \ - uv run python -m example_trainer.vllm_api_server \ + "$PYTHON_BIN" -m example_trainer.vllm_api_server \ --model "$STUDENT_MODEL" \ --port "$STUDENT_PORT" \ --tensor-parallel-size "$STUDENT_TP" \ @@ -181,7 +182,7 @@ fi # 3) Teacher vLLM server start_process "teacher_vllm" "${LOG_DIR}/teacher_vllm.log" \ env CUDA_VISIBLE_DEVICES="$TEACHER_GPUS" \ - uv run python -m example_trainer.vllm_api_server \ + "$PYTHON_BIN" -m example_trainer.vllm_api_server \ --model "$TEACHER_MODEL" \ --port "$TEACHER_PORT" \ --tensor-parallel-size "$TEACHER_TP" \ @@ -194,7 +195,7 @@ fi # 4) Teacher-distill GSM8K env start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \ - uv run python environments/gsm8k_server_teacher_distill.py serve \ + "$PYTHON_BIN" environments/gsm8k_server_teacher_distill.py serve \ --env.group_size "$ENV_GROUP_SIZE" \ --env.batch_size "$ENV_BATCH_SIZE" \ --env.total_steps "$ENV_TOTAL_STEPS" \ @@ -227,7 +228,7 @@ if [[ "$DRY_RUN" == "1" ]]; then log "[DRY RUN] trainer command:" printf ' ' printf '%q ' env CUDA_VISIBLE_DEVICES="$TRAINER_GPU" \ - uv run python -m example_trainer.grpo \ + "$PYTHON_BIN" -m example_trainer.grpo \ --model-name "$STUDENT_MODEL" \ --weight-bridge-mode none \ --device cuda:0 \ @@ -248,7 +249,7 @@ fi log "Starting trainer in foreground..." env CUDA_VISIBLE_DEVICES="$TRAINER_GPU" \ - uv run python -m example_trainer.grpo \ + "$PYTHON_BIN" -m example_trainer.grpo \ --model-name "$STUDENT_MODEL" \ --weight-bridge-mode none \ --device cuda:0 \ From ad364ac77171fbced7531948b3dd1baf8cf601f8 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Sun, 8 Mar 2026 00:06:20 -0500 Subject: [PATCH 04/64] increase timeout cause vllm is super slow all of a sudden --- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index dac797a3..566c05b9 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -190,7 +190,7 @@ start_process "teacher_vllm" "${LOG_DIR}/teacher_vllm.log" \ --max-model-len "$MAX_MODEL_LEN" \ --dtype "$DTYPE" if [[ "$DRY_RUN" == "0" ]]; then - wait_for_http "http://localhost:${TEACHER_PORT}/health" 600 "teacher vLLM" + wait_for_http "http://localhost:${TEACHER_PORT}/health" 1800 "teacher vLLM" fi # 4) Teacher-distill GSM8K env From 985311eb946eda547d3b2d1d2d3205200e4a0585 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Sun, 8 Mar 2026 16:31:09 -0400 Subject: [PATCH 05/64] trial --- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 566c05b9..f06b62fa 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -205,7 +205,6 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \ --env.rollout_server_url "http://localhost:${API_PORT}" \ --env.use_wandb true \ --env.wandb_name "gsm8k-teacher-distill" \ - --env.distillation_enabled true \ --env.teacher_enabled true \ --env.teacher_base_url "http://localhost:${TEACHER_PORT}/v1" \ --env.teacher_model_name "$TEACHER_MODEL" \ From e5633527ba065630dbc2497d5d0cfa16f4dc3d4f Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Sun, 8 Mar 2026 17:12:29 -0400 Subject: [PATCH 06/64] quicker training --- ...n_gsm8k_teacher_distill_single_terminal.sh | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index f06b62fa..9d8b1729 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -18,7 +18,7 @@ set -euo pipefail # TEACHER_MODEL="Qwen/Qwen3-30B-A3B-Instruct-2507" # STUDENT_GPUS="0,1" # TEACHER_GPUS="4,5,6,7" -# TRAINER_GPU="2" +# TRAINER_GPUS="0,1" # STUDENT_TP=2 # TEACHER_TP=4 # API_PORT=8002 @@ -40,7 +40,7 @@ TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}" STUDENT_GPUS="${STUDENT_GPUS:-0,1}" TEACHER_GPUS="${TEACHER_GPUS:-4,5,6,7}" -TRAINER_GPU="${TRAINER_GPU:-2}" +TRAINER_GPUS="${TRAINER_GPUS:-$STUDENT_GPUS}" STUDENT_TP="${STUDENT_TP:-2}" TEACHER_TP="${TEACHER_TP:-4}" @@ -65,6 +65,7 @@ TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.92}" DTYPE="${DTYPE:-bfloat16}" SAVE_DIR="${SAVE_DIR:-${LAUNCH_DIR}/saves/gsm8k_teacher_distill}" LOG_DIR="${LOG_DIR:-${LAUNCH_DIR}/logs/gsm8k_teacher_distill}" +BRIDGE_DIR="${BRIDGE_DIR:-${LOG_DIR}/bridge}" DRY_RUN="${DRY_RUN:-0}" ENV_GROUP_SIZE="${ENV_GROUP_SIZE:-4}" @@ -144,7 +145,7 @@ cleanup_all() { trap cleanup_all EXIT INT TERM -mkdir -p "$LOG_DIR" "$SAVE_DIR" +mkdir -p "$LOG_DIR" "$SAVE_DIR" "$BRIDGE_DIR" RUN_PORTS+=("$API_PORT" "$STUDENT_PORT" "$TEACHER_PORT") kill_port "$API_PORT" kill_port "$STUDENT_PORT" @@ -153,10 +154,11 @@ kill_port "$TEACHER_PORT" log "Config:" log " student=${STUDENT_MODEL}" log " teacher=${TEACHER_MODEL}" -log " gpus student=${STUDENT_GPUS}, teacher=${TEACHER_GPUS}, trainer=${TRAINER_GPU}" +log " gpus student=${STUDENT_GPUS}, teacher=${TEACHER_GPUS}, trainer=${TRAINER_GPUS}" log " ports api=${API_PORT}, student=${STUDENT_PORT}, teacher=${TEACHER_PORT}" log " logs=${LOG_DIR}" log " saves=${SAVE_DIR}" +log " bridge=${BRIDGE_DIR}" # 1) Atropos API start_process "run_api" "${LOG_DIR}/run_api.log" \ @@ -167,14 +169,15 @@ fi # 2) Student vLLM server start_process "student_vllm" "${LOG_DIR}/student_vllm.log" \ - env CUDA_VISIBLE_DEVICES="$STUDENT_GPUS" \ + env CUDA_VISIBLE_DEVICES="$STUDENT_GPUS" VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR="$BRIDGE_DIR" \ "$PYTHON_BIN" -m example_trainer.vllm_api_server \ --model "$STUDENT_MODEL" \ --port "$STUDENT_PORT" \ --tensor-parallel-size "$STUDENT_TP" \ --gpu-memory-utilization "$STUDENT_GPU_MEMORY_UTILIZATION" \ --max-model-len "$MAX_MODEL_LEN" \ - --dtype "$DTYPE" + --dtype "$DTYPE" \ + --enforce-eager if [[ "$DRY_RUN" == "0" ]]; then wait_for_http "http://localhost:${STUDENT_PORT}/health" 420 "student vLLM" fi @@ -226,13 +229,15 @@ log " ${LOG_DIR}/env.log" if [[ "$DRY_RUN" == "1" ]]; then log "[DRY RUN] trainer command:" printf ' ' - printf '%q ' env CUDA_VISIBLE_DEVICES="$TRAINER_GPU" \ + printf '%q ' env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" \ "$PYTHON_BIN" -m example_trainer.grpo \ --model-name "$STUDENT_MODEL" \ - --weight-bridge-mode none \ + --weight-bridge-mode shared_vllm \ --device cuda:0 \ --save-path "$SAVE_DIR" \ --atropos-url "http://localhost:${API_PORT}" \ + --vllm-port "$STUDENT_PORT" \ + --vllm-config-path "${BRIDGE_DIR}/vllm_bridge_config.json" \ --training-steps "$TRAINING_STEPS" \ --batch-size "$BATCH_SIZE" \ --gradient-accumulation-steps "$GRAD_ACCUM" \ @@ -247,13 +252,15 @@ if [[ "$DRY_RUN" == "1" ]]; then fi log "Starting trainer in foreground..." -env CUDA_VISIBLE_DEVICES="$TRAINER_GPU" \ +env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" \ "$PYTHON_BIN" -m example_trainer.grpo \ --model-name "$STUDENT_MODEL" \ - --weight-bridge-mode none \ + --weight-bridge-mode shared_vllm \ --device cuda:0 \ --save-path "$SAVE_DIR" \ --atropos-url "http://localhost:${API_PORT}" \ + --vllm-port "$STUDENT_PORT" \ + --vllm-config-path "${BRIDGE_DIR}/vllm_bridge_config.json" \ --training-steps "$TRAINING_STEPS" \ --batch-size "$BATCH_SIZE" \ --gradient-accumulation-steps "$GRAD_ACCUM" \ From 81f90a67b5a35bd2e50095369669064000788d9e Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Sun, 8 Mar 2026 18:09:26 -0400 Subject: [PATCH 07/64] forgot something easy --- ...n_gsm8k_teacher_distill_single_terminal.sh | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 9d8b1729..5170a45a 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -16,10 +16,10 @@ set -euo pipefail # Optional overrides: # STUDENT_MODEL="Qwen/Qwen3-4B-Instruct-2507-FP8" # TEACHER_MODEL="Qwen/Qwen3-30B-A3B-Instruct-2507" -# STUDENT_GPUS="0,1" +# STUDENT_GPUS="0" # TEACHER_GPUS="4,5,6,7" -# TRAINER_GPUS="0,1" -# STUDENT_TP=2 +# TRAINER_GPUS="0" +# STUDENT_TP=1 # TEACHER_TP=4 # API_PORT=8002 # STUDENT_PORT=9001 @@ -38,11 +38,11 @@ PYTHON_BIN="${PYTHON_BIN:-python3}" STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B-Instruct-2507-FP8}" TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}" -STUDENT_GPUS="${STUDENT_GPUS:-0,1}" +STUDENT_GPUS="${STUDENT_GPUS:-0}" TEACHER_GPUS="${TEACHER_GPUS:-4,5,6,7}" TRAINER_GPUS="${TRAINER_GPUS:-$STUDENT_GPUS}" -STUDENT_TP="${STUDENT_TP:-2}" +STUDENT_TP="${STUDENT_TP:-1}" TEACHER_TP="${TEACHER_TP:-4}" API_PORT="${API_PORT:-8002}" @@ -160,6 +160,20 @@ log " logs=${LOG_DIR}" log " saves=${SAVE_DIR}" log " bridge=${BRIDGE_DIR}" +# Shared-vLLM attach path currently expects the student server to expose +# unsharded weights. Keep the student on TP=1 and the trainer on the same GPU set. +if [[ "$STUDENT_TP" != "1" ]]; then + log "ERROR: shared_vllm teacher-distill runner currently requires STUDENT_TP=1." + log " The current attach path does not support TP-sharded student bridge weights." + exit 2 +fi + +if [[ "$TRAINER_GPUS" != "$STUDENT_GPUS" ]]; then + log "ERROR: TRAINER_GPUS must match STUDENT_GPUS for shared_vllm mode." + log " Got student=${STUDENT_GPUS}, trainer=${TRAINER_GPUS}" + exit 2 +fi + # 1) Atropos API start_process "run_api" "${LOG_DIR}/run_api.log" \ run-api --port "$API_PORT" From 4f33ab8bf4d69c49aa430a67e05352467897cb92 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 9 Mar 2026 10:18:13 -0400 Subject: [PATCH 08/64] apparently not so easy --- environments/gsm8k_server.py | 17 +++++++++++++++++ ...n_gsm8k_teacher_distill_single_terminal.sh | 8 ++++++-- example_trainer/vllm_api_server.py | 19 +++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 87823526..295dd8df 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -1,3 +1,4 @@ +import logging import random import time from typing import Dict, List, Optional, Tuple, TypedDict, Union @@ -31,6 +32,8 @@ It is important that you provide your answer in the correct format. If you do not, you will not receive credit for your answer. So please end your answer with \\boxed{your answer here}""" +logger = logging.getLogger(__name__) + class GSM8kRow(TypedDict): question: str @@ -232,6 +235,12 @@ class GSM8kEnv(BaseEnv): ) async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + logger.info( + "gsm8k collect_trajectories start group_size=%s max_tokens=%s question_chars=%s", + self.config.group_size, + self.config.max_token_length, + len(item["question"]), + ) chat_completions = await managed.chat_completion( messages=[{"role": "system", "content": system_prompt}, user_message], @@ -239,9 +248,17 @@ class GSM8kEnv(BaseEnv): max_tokens=self.config.max_token_length, temperature=1.0, ) + logger.info( + "gsm8k collect_trajectories completion_received choices=%s", + len(chat_completions.choices), + ) state = managed.get_state() nodes = state["nodes"] + logger.info( + "gsm8k collect_trajectories managed_state_nodes=%s", + len(nodes), + ) to_score = list() to_backlog = list() diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 5170a45a..db9402f9 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -56,6 +56,7 @@ LR="${LR:-1e-5}" WARMUP_STEPS="${WARMUP_STEPS:-0}" CLIP_EPS="${CLIP_EPS:-0.2}" MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}" +ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-4096}" DISTILL_COEF="${DISTILL_COEF:-0.2}" DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" TEACHER_TOP_K="${TEACHER_TOP_K:-8}" @@ -72,7 +73,8 @@ ENV_GROUP_SIZE="${ENV_GROUP_SIZE:-4}" ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-16}" ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}" ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}" -ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-8}" +ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-1}" +ENV_WORKER_TIMEOUT="${ENV_WORKER_TIMEOUT:-1800}" RUN_PIDS=() RUN_PORTS=() @@ -159,6 +161,7 @@ log " ports api=${API_PORT}, student=${STUDENT_PORT}, teacher=${TEACHER_PORT}" log " logs=${LOG_DIR}" log " saves=${SAVE_DIR}" log " bridge=${BRIDGE_DIR}" +log " env max_token_length=${ENV_MAX_TOKEN_LENGTH}, env workers=${ENV_MAX_WORKERS_PER_NODE}, env worker_timeout=${ENV_WORKER_TIMEOUT}" # Shared-vLLM attach path currently expects the student server to expose # unsharded weights. Keep the student on TP=1 and the trainer on the same GPU set. @@ -218,7 +221,8 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \ --env.total_steps "$ENV_TOTAL_STEPS" \ --env.steps_per_eval "$ENV_STEPS_PER_EVAL" \ --env.max_num_workers_per_node "$ENV_MAX_WORKERS_PER_NODE" \ - --env.max_token_length "$MAX_MODEL_LEN" \ + --env.max_token_length "$ENV_MAX_TOKEN_LENGTH" \ + --env.worker_timeout "$ENV_WORKER_TIMEOUT" \ --env.rollout_server_url "http://localhost:${API_PORT}" \ --env.use_wandb true \ --env.wandb_name "gsm8k-teacher-distill" \ diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 2846f14f..da5b6608 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -296,6 +296,17 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: if engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") + request_preview = { + "has_prompt": "prompt" in request_dict, + "n": request_dict.get("n"), + "max_tokens": request_dict.get("max_tokens"), + "temperature": request_dict.get("temperature"), + "top_p": request_dict.get("top_p"), + "logprobs": request_dict.get("logprobs"), + "prompt_logprobs": request_dict.get("prompt_logprobs"), + } + logger.info("POST /generate received %s", request_preview) + prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY @@ -325,6 +336,7 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: async for request_output in results_generator: final_output = request_output except asyncio.CancelledError: + logger.warning("POST /generate cancelled request_id=%s", request_id) return Response(status_code=499) assert final_output is not None @@ -348,6 +360,13 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: ret["prompt_token_ids"] = final_output.prompt_token_ids ret["token_ids"] = [x.token_ids for x in final_output.outputs] + logger.info( + "POST /generate completed request_id=%s outputs=%s finish_reasons=%s", + request_id, + len(text_outputs), + finish_reasons, + ) + return JSONResponse(ret) From bb2736db4ebc5f4da96ea5891aa48721b72b9dc5 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 9 Mar 2026 21:25:58 -0400 Subject: [PATCH 09/64] next --- .../envs/server_handling/managed_server.py | 17 ++++++++++++ .../envs/server_handling/vllm_server.py | 27 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 9d46f265..a85ac0d3 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -447,14 +447,31 @@ class ManagedServer: if not self.track_tree and self.tokenizer is not None: input_ids = self._compute_input_ids(prompt, extending_node) completion_kwargs["input_ids"] = input_ids + logger.info( + "managed_server chat_completion prepared input_ids=%s extending=%s", + len(input_ids), + extending_node is not None, + ) + else: + logger.info( + "managed_server chat_completion using prompt passthrough track_tree=%s tokenizer=%s", + self.track_tree, + self.tokenizer is not None, + ) # Call the tokens and logprobs wrapper directly + logger.info("managed_server chat_completion calling backend completion wrapper") ( prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons, ) = await self.server.tokens_and_logprobs_completion(**completion_kwargs) + logger.info( + "managed_server chat_completion backend returned prompt_tokens=%s outputs=%s", + len(prompt_tokens), + len(output_tokens_list), + ) # Track each completion and build choices n = len(output_tokens_list) diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 3c35bebb..107ad2d1 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -2,6 +2,7 @@ # see example_trainer/vllm_api_server.py for an example import asyncio +import logging import warnings from typing import Any, Dict, List, Tuple @@ -19,6 +20,8 @@ from atroposlib.envs.server_handling.server_baseline import ( ReasoningConfig, ) +logger = logging.getLogger(__name__) + class VLLMServer(APIServer): """ @@ -190,6 +193,14 @@ class VLLMServer(APIServer): # Prepare request for VLLM native API request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0} request_data.update(kwargs) + logger.info( + "vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s", + self.config.base_url, + len(prompt_tokens), + request_data.get("n"), + request_data.get("max_tokens"), + request_data.get("temperature"), + ) # Make async request to VLLM /generate endpoint async with aiohttp.ClientSession() as session: @@ -205,6 +216,11 @@ class VLLMServer(APIServer): ) as response: response.raise_for_status() results = await response.json() + logger.info( + "vllm_server completion POST done outputs=%s finish_reasons=%s", + len(results.get("logprobs", [])), + len(results.get("finish_reasons", [])), + ) output_tokens_list = [] output_logprobs_list = [] finish_reasons_list = [] @@ -314,6 +330,13 @@ class VLLMServer(APIServer): request_data["temperature"] = 0.0 request_data["top_p"] = 1.0 request_data.setdefault("max_tokens", 1) + logger.info( + "vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s", + self.config.base_url, + len(prompt_tokens), + top_k, + request_data.get("max_tokens"), + ) async with aiohttp.ClientSession() as session: async with session.post( @@ -328,6 +351,10 @@ class VLLMServer(APIServer): ) as response: response.raise_for_status() results = await response.json() + logger.info( + "vllm_server get_logprobs POST done prompt_logprobs_present=%s", + results.get("prompt_logprobs") is not None, + ) raw_prompt_logprobs = results.get("prompt_logprobs") if raw_prompt_logprobs is None: From 64794e7c721502b24c0e4ebbe10ef329147f0d45 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 10 Mar 2026 11:12:48 -0400 Subject: [PATCH 10/64] sneaky bug --- atroposlib/envs/server_handling/managed_server.py | 8 ++++---- atroposlib/envs/server_handling/openai_server.py | 4 ++-- atroposlib/envs/server_handling/vllm_server.py | 8 ++++---- environments/gsm8k_server.py | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index a85ac0d3..af74fa55 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -447,27 +447,27 @@ class ManagedServer: if not self.track_tree and self.tokenizer is not None: input_ids = self._compute_input_ids(prompt, extending_node) completion_kwargs["input_ids"] = input_ids - logger.info( + logger.warning( "managed_server chat_completion prepared input_ids=%s extending=%s", len(input_ids), extending_node is not None, ) else: - logger.info( + logger.warning( "managed_server chat_completion using prompt passthrough track_tree=%s tokenizer=%s", self.track_tree, self.tokenizer is not None, ) # Call the tokens and logprobs wrapper directly - logger.info("managed_server chat_completion calling backend completion wrapper") + logger.warning("managed_server chat_completion calling backend completion wrapper") ( prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons, ) = await self.server.tokens_and_logprobs_completion(**completion_kwargs) - logger.info( + logger.warning( "managed_server chat_completion backend returned prompt_tokens=%s outputs=%s", len(prompt_tokens), len(output_tokens_list), diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index fecc5828..871ba9fb 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -210,7 +210,7 @@ def resolve_openai_configs( f"Error creating final OpenAI configuration from merged settings: {e}\n" f"Merged Dict: {openai_config_dict}" ) from e - server_configs = final_openai_config + server_configs = [final_openai_config] elif isinstance(default_server_configs, ServerBaseline): # Pure ServerBaseline (not APIServerConfig) - no CLI overrides possible logger.info("Using ServerBaseline configuration.") @@ -231,7 +231,7 @@ def resolve_openai_configs( ) from e if isinstance(default_server_configs, APIServerConfig): - server_configs = final_openai_config + server_configs = [final_openai_config] elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 107ad2d1..18b8333e 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -193,7 +193,7 @@ class VLLMServer(APIServer): # Prepare request for VLLM native API request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0} request_data.update(kwargs) - logger.info( + logger.warning( "vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s", self.config.base_url, len(prompt_tokens), @@ -216,7 +216,7 @@ class VLLMServer(APIServer): ) as response: response.raise_for_status() results = await response.json() - logger.info( + logger.warning( "vllm_server completion POST done outputs=%s finish_reasons=%s", len(results.get("logprobs", [])), len(results.get("finish_reasons", [])), @@ -330,7 +330,7 @@ class VLLMServer(APIServer): request_data["temperature"] = 0.0 request_data["top_p"] = 1.0 request_data.setdefault("max_tokens", 1) - logger.info( + logger.warning( "vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s", self.config.base_url, len(prompt_tokens), @@ -351,7 +351,7 @@ class VLLMServer(APIServer): ) as response: response.raise_for_status() results = await response.json() - logger.info( + logger.warning( "vllm_server get_logprobs POST done prompt_logprobs_present=%s", results.get("prompt_logprobs") is not None, ) diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 295dd8df..5112d1a4 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -235,7 +235,7 @@ class GSM8kEnv(BaseEnv): ) async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - logger.info( + logger.warning( "gsm8k collect_trajectories start group_size=%s max_tokens=%s question_chars=%s", self.config.group_size, self.config.max_token_length, @@ -248,14 +248,14 @@ class GSM8kEnv(BaseEnv): max_tokens=self.config.max_token_length, temperature=1.0, ) - logger.info( + logger.warning( "gsm8k collect_trajectories completion_received choices=%s", len(chat_completions.choices), ) state = managed.get_state() nodes = state["nodes"] - logger.info( + logger.warning( "gsm8k collect_trajectories managed_state_nodes=%s", len(nodes), ) From 09ad401995a8e4688a6561ed09847e948a2b2d8e Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 10 Mar 2026 11:18:05 -0400 Subject: [PATCH 11/64] sneaky bug logging --- .../envs/server_handling/openai_server.py | 13 ++++++++++++ .../envs/server_handling/server_manager.py | 20 +++++++++++++++++++ atroposlib/envs/teacher_distillation_env.py | 7 +++++++ 3 files changed, 40 insertions(+) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index 871ba9fb..40a993fe 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -241,4 +241,17 @@ def resolve_openai_configs( ) server_configs = [final_openai_config] + if isinstance(server_configs, list): + logger.warning( + "resolve_openai_configs: returning list of %s config(s), URLs: %s", + len(server_configs), + [c.base_url for c in server_configs], + ) + else: + logger.warning( + "resolve_openai_configs: returning single %s (base_url=%s) — " + "ServerManager will use template mode!", + type(server_configs).__name__, + getattr(server_configs, "base_url", "N/A"), + ) return server_configs diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index b9c493f9..b24698a6 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -1,5 +1,6 @@ import asyncio import inspect +import logging import os import warnings from contextlib import asynccontextmanager @@ -25,6 +26,8 @@ from atroposlib.envs.server_handling.sglang_server import SGLangServer from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer from atroposlib.envs.server_handling.vllm_server import VLLMServer +logger = logging.getLogger(__name__) + class ServerManagerConfig(BaseModel): slurm: bool = Field( @@ -103,6 +106,13 @@ class ServerManager: self.servers = [ServerHarness()] return if not isinstance(configs, list): + logger.warning( + "ServerManager: configs is NOT a list (type=%s). " + "Using auto-generated URLs (template mode). " + "Passed base_url=%s will be IGNORED.", + type(configs).__name__, + getattr(configs, "base_url", "N/A"), + ) urls = [] if os.environ.get("SLURM_JOB_NODELIST", None) is not None: nodelist = ( @@ -145,11 +155,21 @@ class ServerManager: server_class(config, reasoning_config=reasoning_config) for config in openai_configs ] + logger.warning( + "ServerManager: auto-generated %s server(s) at URLs: %s", + len(self.servers), + [c.base_url for c in openai_configs], + ) elif not slurm: self.servers = [ server_class(config, reasoning_config=reasoning_config) for config in configs ] + logger.warning( + "ServerManager: using %s explicit config(s) at URLs: %s", + len(self.servers), + [c.base_url for c in configs], + ) else: nodelist = ( os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}') diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 5ff96bc7..54a66ff0 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -93,6 +93,13 @@ class TeacherDistillationEnv(BaseEnv, ABC): slurm=False, testing=False, ) + logger.warning( + "TeacherDistillationEnv: teacher server configured at %s " + "(model=%s, top_k=%s)", + config.teacher_base_url, + config.teacher_model_name, + config.teacher_top_k, + ) async def _fetch_teacher_for_sequence( self, token_ids: List[int], top_k: int From d1fd89f99296a57c17cae924b2d1d09a16c7ec7e Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 10 Mar 2026 11:43:40 -0400 Subject: [PATCH 12/64] non blocking test --- ...n_gsm8k_teacher_distill_single_terminal.sh | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index db9402f9..76acd6f4 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -61,8 +61,8 @@ DISTILL_COEF="${DISTILL_COEF:-0.2}" DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" TEACHER_TOP_K="${TEACHER_TOP_K:-8}" -STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.90}" -TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.92}" +STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.95}" +TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.95}" DTYPE="${DTYPE:-bfloat16}" SAVE_DIR="${SAVE_DIR:-${LAUNCH_DIR}/saves/gsm8k_teacher_distill}" LOG_DIR="${LOG_DIR:-${LAUNCH_DIR}/logs/gsm8k_teacher_distill}" @@ -269,8 +269,8 @@ if [[ "$DRY_RUN" == "1" ]]; then exit 0 fi -log "Starting trainer in foreground..." -env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" \ +start_process "trainer" "${LOG_DIR}/trainer.log" \ + env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" \ "$PYTHON_BIN" -m example_trainer.grpo \ --model-name "$STUDENT_MODEL" \ --weight-bridge-mode shared_vllm \ @@ -287,6 +287,20 @@ env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" \ --clip-eps "$CLIP_EPS" \ --distill-enabled \ --distill-coef "$DISTILL_COEF" \ - --distill-temperature "$DISTILL_TEMPERATURE" | tee "${LOG_DIR}/trainer.log" + --distill-temperature "$DISTILL_TEMPERATURE" -log "Training finished." +log "All processes running in background." +log "" +log "Monitor with:" +log " tail -f ${LOG_DIR}/trainer.log" +log " tail -f ${LOG_DIR}/env.log" +log " tail -f ${LOG_DIR}/student_vllm.log" +log " tail -f ${LOG_DIR}/teacher_vllm.log" +log "" +log "Test endpoints:" +log " curl -s http://localhost:${STUDENT_PORT}/health" +log " curl -s http://localhost:${TEACHER_PORT}/health" +log " curl -s http://localhost:${STUDENT_PORT}/bridge/is_paused | jq ." +log "" +log "Press Ctrl+C to stop all processes." +wait From 057c9fe870e111c720a7246e583dce9216c81126 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 10 Mar 2026 11:58:48 -0400 Subject: [PATCH 13/64] shorten worker timeout --- .../run_gsm8k_teacher_distill_single_terminal.sh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 76acd6f4..50fc9081 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -74,7 +74,7 @@ ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-16}" ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}" ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}" ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-1}" -ENV_WORKER_TIMEOUT="${ENV_WORKER_TIMEOUT:-1800}" +ENV_WORKER_TIMEOUT="${ENV_WORKER_TIMEOUT:-300}" RUN_PIDS=() RUN_PORTS=() @@ -302,5 +302,6 @@ log " curl -s http://localhost:${STUDENT_PORT}/health" log " curl -s http://localhost:${TEACHER_PORT}/health" log " curl -s http://localhost:${STUDENT_PORT}/bridge/is_paused | jq ." log "" -log "Press Ctrl+C to stop all processes." -wait +log "To stop all processes:" +log " kill ${RUN_PIDS[*]:-} 2>/dev/null; sleep 1; kill -9 ${RUN_PIDS[*]:-} 2>/dev/null" +trap - EXIT INT TERM From e84686b4fdffacd13060ab931ec63ac0c672cfe5 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 10 Mar 2026 12:30:41 -0400 Subject: [PATCH 14/64] remove enforce eager --- example_trainer/run_gsm8k_lora_matrix.sh | 3 +-- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/example_trainer/run_gsm8k_lora_matrix.sh b/example_trainer/run_gsm8k_lora_matrix.sh index 121106e0..48bad3ce 100755 --- a/example_trainer/run_gsm8k_lora_matrix.sh +++ b/example_trainer/run_gsm8k_lora_matrix.sh @@ -248,8 +248,7 @@ run_shared_vllm() { --port "$vllm_port" \ --gpu-memory-utilization "$SHARED_GPU_MEMORY_UTILIZATION" \ --max-model-len "$MAX_MODEL_LEN" \ - --dtype "$DTYPE" \ - --enforce-eager + --dtype "$DTYPE" if [[ "$DRY_RUN" == "1" ]]; then log "[DRY RUN] wait for http://localhost:${vllm_port}/health" else diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 50fc9081..6b22d767 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -193,8 +193,7 @@ start_process "student_vllm" "${LOG_DIR}/student_vllm.log" \ --tensor-parallel-size "$STUDENT_TP" \ --gpu-memory-utilization "$STUDENT_GPU_MEMORY_UTILIZATION" \ --max-model-len "$MAX_MODEL_LEN" \ - --dtype "$DTYPE" \ - --enforce-eager + --dtype "$DTYPE" if [[ "$DRY_RUN" == "0" ]]; then wait_for_http "http://localhost:${STUDENT_PORT}/health" 420 "student vLLM" fi From e79af5ff694ab975ea243030867121f682ee9d9d Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 10 Mar 2026 23:39:55 -0400 Subject: [PATCH 15/64] testing config --- environments/gsm8k_server.py | 5 +++-- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 5112d1a4..2697ef30 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -368,8 +368,9 @@ class GSM8kEnv(BaseEnv): percentage_of_range = min(percentage_of_range, 1.0) # Apply linear penalty scaling from 1.0 down to 0.0 scores["scores"].append(1.0 - percentage_of_range) - if all([scores["scores"][0] == score for score in scores["scores"]]): - return None # If all the same, we return None + # NOTE: identical-score filter disabled for testing. + # if all([scores["scores"][0] == score for score in scores["scores"]]): + # return None return scores else: # If the gold solution is not parseable, we return None diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 6b22d767..1a8b66e2 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -56,7 +56,7 @@ LR="${LR:-1e-5}" WARMUP_STEPS="${WARMUP_STEPS:-0}" CLIP_EPS="${CLIP_EPS:-0.2}" MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}" -ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-4096}" +ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-16384}" DISTILL_COEF="${DISTILL_COEF:-0.2}" DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" TEACHER_TOP_K="${TEACHER_TOP_K:-8}" From abba562d4a8124aeb631ab66883b3a88ce9a6e44 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 00:26:22 -0400 Subject: [PATCH 16/64] testing config --- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 1a8b66e2..5680f679 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -56,7 +56,7 @@ LR="${LR:-1e-5}" WARMUP_STEPS="${WARMUP_STEPS:-0}" CLIP_EPS="${CLIP_EPS:-0.2}" MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}" -ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-16384}" +ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-8192}" DISTILL_COEF="${DISTILL_COEF:-0.2}" DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" TEACHER_TOP_K="${TEACHER_TOP_K:-8}" From 82be8719790c3a3cc72739b3168cd4b34db96c9a Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 00:53:34 -0400 Subject: [PATCH 17/64] testing config --- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 5680f679..6b22d767 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -56,7 +56,7 @@ LR="${LR:-1e-5}" WARMUP_STEPS="${WARMUP_STEPS:-0}" CLIP_EPS="${CLIP_EPS:-0.2}" MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}" -ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-8192}" +ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-4096}" DISTILL_COEF="${DISTILL_COEF:-0.2}" DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" TEACHER_TOP_K="${TEACHER_TOP_K:-8}" From 98a5d3b334a6036b1e8530dc82375be9708ee941 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 01:23:58 -0400 Subject: [PATCH 18/64] testing config --- .../run_gsm8k_teacher_distill_single_terminal.sh | 1 + example_trainer/vllm_api_server.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 6b22d767..7fdcdccd 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -229,6 +229,7 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \ --env.teacher_base_url "http://localhost:${TEACHER_PORT}/v1" \ --env.teacher_model_name "$TEACHER_MODEL" \ --env.teacher_top_k "$TEACHER_TOP_K" \ + --env.ensure_scores_are_not_same false \ --openai.api_key "dummy" \ --openai.base_url "http://localhost:${STUDENT_PORT}/v1" \ --openai.model_name "$STUDENT_MODEL" \ diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index da5b6608..131861d5 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -360,6 +360,13 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: ret["prompt_token_ids"] = final_output.prompt_token_ids ret["token_ids"] = [x.token_ids for x in final_output.outputs] + if sampling_params.prompt_logprobs is not None and final_output.prompt_logprobs is not None: + ret["prompt_logprobs"] = [ + {int(tok_id): lp.logprob for tok_id, lp in pos.items()} + if pos is not None else None + for pos in final_output.prompt_logprobs + ] + logger.info( "POST /generate completed request_id=%s outputs=%s finish_reasons=%s", request_id, From 78c0a6d08250e38283457174a2140a07dd39c05e Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 01:49:49 -0400 Subject: [PATCH 19/64] tokenizer bug --- atroposlib/envs/teacher_distillation_env.py | 248 +++++++++++++++++++- 1 file changed, 241 insertions(+), 7 deletions(-) diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 54a66ff0..521c3762 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -3,6 +3,43 @@ Teacher distillation environment layer. This module adds teacher prompt-logprob fetching on top of BaseEnv without modifying BaseEnv transport behavior. + +Cross-tokenizer distillation +---------------------------- +When student and teacher use the same tokenizer family (e.g. both Qwen3) the +student's raw token IDs can be forwarded directly to the teacher vLLM and the +returned top-k token IDs can be used as-is in the student logit lookup. + +When tokenizers differ (e.g. Llama student, Qwen teacher) two problems arise: + + 1. Token-ID aliasing: student token 42 = " the" in Llama, but 42 = "ท" in + Qwen. Sending student IDs to the teacher causes it to score garbage. + + 2. Vocab-space mismatch: the teacher's top-k IDs live in the teacher's + vocabulary. The student logit lookup at those IDs would access random + tokens in the student vocab. + +This module fixes both problems automatically: + + • Re-tokenization – student tokens are decoded to plain text and + re-tokenized with the teacher tokenizer before being sent to the teacher + server. The teacher therefore always scores the correct text. + + • Character-level position alignment – after re-tokenisation the teacher + has a different number of tokens than the student. A character-offset + map is built (requires a fast HuggingFace tokenizer) to project each + teacher logprob position back onto the student token it overlaps with. + + • Vocabulary remapping – teacher top-k token IDs (teacher vocab) are + decoded to text fragments and re-encoded with the student tokenizer so + that the final distill_token_ids live in the student vocabulary and can + be looked up directly in the student logit tensor. + +Same-tokenizer fast path +------------------------ +When teacher_tokenizer_name resolves to the same underlying vocabulary as the +student tokenizer the original fast path (no decode / re-tokenize / remap) is +taken automatically. """ from __future__ import annotations @@ -45,7 +82,9 @@ class TeacherDistillationConfig(BaseEnvConfig): teacher_tokenizer_name: str = Field( default="none", description=( - "Tokenizer name for teacher server. If 'none', teacher_model_name is used." + "Tokenizer name for teacher server. If 'none', teacher_model_name is used. " + "When this resolves to a different vocabulary than the student tokenizer, " + "cross-tokenizer alignment is applied automatically." ), ) teacher_top_k: int = Field( @@ -60,8 +99,8 @@ class TeacherDistillationEnv(BaseEnv, ABC): BaseEnv subclass that enriches scored groups with teacher distillation arrays. Distillation payload shape: - - distill_token_ids: [sequence][position][k] - - distill_logprobs: [sequence][position][k] + - distill_token_ids: [sequence][position][k] (student vocab IDs) + - distill_logprobs: [sequence][position][k] """ env_config_cls = TeacherDistillationConfig @@ -75,17 +114,29 @@ class TeacherDistillationEnv(BaseEnv, ABC): ): super().__init__(config, server_configs, slurm=slurm, testing=testing) self.teacher_server: Optional[ServerManager] = None + # Teacher tokenizer (only loaded when tokenizers may differ). + self._teacher_tokenizer = None + # True when student and teacher share the same vocabulary. + self._same_tokenizer: bool = True + # LRU-style cache: teacher_token_id -> student_token_id + self._vocab_remap_cache: Dict[int, int] = {} + if config.teacher_enabled: if not config.teacher_base_url or not config.teacher_model_name: raise ValueError( "teacher_enabled=True requires teacher_base_url and teacher_model_name." ) + teacher_tok_name = ( + config.teacher_model_name + if config.teacher_tokenizer_name in ("none", "") + else config.teacher_tokenizer_name + ) teacher_cfg = APIServerConfig( server_type=config.teacher_server_type, # type: ignore[arg-type] base_url=config.teacher_base_url, api_key=config.teacher_api_key, model_name=config.teacher_model_name, - tokenizer_name=config.teacher_tokenizer_name, + tokenizer_name=teacher_tok_name, timeout=1200, ) self.teacher_server = ServerManager( @@ -93,25 +144,208 @@ class TeacherDistillationEnv(BaseEnv, ABC): slurm=False, testing=False, ) + + # Detect vocabulary mismatch. + student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" + if student_tok_name and teacher_tok_name and student_tok_name != teacher_tok_name: + try: + from transformers import AutoTokenizer + + self._teacher_tokenizer = AutoTokenizer.from_pretrained( + teacher_tok_name, use_fast=True + ) + self._same_tokenizer = False + logger.warning( + "TeacherDistillationEnv: cross-tokenizer mode active. " + "student=%s teacher=%s. " + "Token IDs will be decoded → re-tokenized → vocab-remapped.", + student_tok_name, + teacher_tok_name, + ) + except Exception as exc: + logger.warning( + "TeacherDistillationEnv: could not load teacher tokenizer '%s' " + "(%s). Falling back to same-tokenizer (fast) path — only safe if " + "student and teacher share the same vocabulary.", + teacher_tok_name, + exc, + ) + self._same_tokenizer = True + else: + self._same_tokenizer = True + logger.warning( "TeacherDistillationEnv: teacher server configured at %s " - "(model=%s, top_k=%s)", + "(model=%s, top_k=%s, same_tokenizer=%s)", config.teacher_base_url, config.teacher_model_name, config.teacher_top_k, + self._same_tokenizer, ) + # ------------------------------------------------------------------ + # Cross-tokenizer helpers + # ------------------------------------------------------------------ + + def _build_student_teacher_alignment( + self, + text: str, + student_ids: List[int], + teacher_ids: List[int], + ) -> List[List[int]]: + """ + For each student token position return the list of teacher token positions + whose character spans overlap with the student token's character span. + + Requires fast (Rust-backed) HuggingFace tokenizers that support + return_offsets_mapping. Falls back to a proportional approximation + if offset mapping is unavailable. + """ + student_len = len(student_ids) + teacher_len = len(teacher_ids) + + try: + s_enc = self.tokenizer( + text, return_offsets_mapping=True, add_special_tokens=False + ) + t_enc = self._teacher_tokenizer( + text, return_offsets_mapping=True, add_special_tokens=False + ) + s_offsets: List[Tuple[int, int]] = s_enc["offset_mapping"][:student_len] + t_offsets: List[Tuple[int, int]] = t_enc["offset_mapping"][:teacher_len] + + alignment: List[List[int]] = [] + for s_start, s_end in s_offsets: + overlapping = [ + t_idx + for t_idx, (t_start, t_end) in enumerate(t_offsets) + if t_start < s_end and t_end > s_start and s_end > s_start + ] + alignment.append(overlapping) + return alignment + + except Exception as exc: + logger.warning( + "TeacherDistillationEnv: offset-mapping alignment failed (%s). " + "Using proportional fallback.", + exc, + ) + ratio = teacher_len / max(student_len, 1) + return [[int(i * ratio)] for i in range(student_len)] + + def _remap_teacher_token_to_student(self, teacher_token_id: int) -> int: + """ + Convert a teacher vocabulary token ID to the best-matching student + vocabulary token ID by decoding the teacher token to text then + re-encoding with the student tokenizer. + + Results are cached to avoid repeated tokenizer calls. + """ + if teacher_token_id in self._vocab_remap_cache: + return self._vocab_remap_cache[teacher_token_id] + + try: + text = self._teacher_tokenizer.decode( + [teacher_token_id], clean_up_tokenization_spaces=False + ) + student_ids = self.tokenizer.encode(text, add_special_tokens=False) + # Use the first student token as the representative. + sid = int(student_ids[0]) if student_ids else teacher_token_id + except Exception: + sid = teacher_token_id + + self._vocab_remap_cache[teacher_token_id] = sid + return sid + + def _align_and_remap( + self, + student_ids: List[int], + teacher_topk_ids: List[List[int]], + teacher_topk_lps: List[List[float]], + alignment: List[List[int]], + ) -> Tuple[List[List[int]], List[List[float]]]: + """ + Project teacher logprobs (teacher positions, teacher vocab) onto + student positions in student vocab. + + For each student token position: + 1. Collect all teacher top-k entries from overlapping teacher positions. + 2. Remap each teacher token ID to the student vocab. + 3. Merge duplicates by keeping the maximum logprob. + 4. Return the top-k entries sorted by descending logprob. + """ + k = max(1, len(teacher_topk_ids[0]) if teacher_topk_ids else 1) + result_ids: List[List[int]] = [] + result_lps: List[List[float]] = [] + + for s_idx in range(len(student_ids)): + t_positions = alignment[s_idx] if s_idx < len(alignment) else [] + if not t_positions: + result_ids.append([]) + result_lps.append([]) + continue + + # Merge all overlapping teacher positions, remap vocab. + merged: Dict[int, float] = {} + for t_idx in t_positions: + if t_idx >= len(teacher_topk_ids): + continue + for tid, tlp in zip(teacher_topk_ids[t_idx], teacher_topk_lps[t_idx]): + sid = self._remap_teacher_token_to_student(tid) + merged[sid] = max(merged.get(sid, -1e9), tlp) + + sorted_items = sorted(merged.items(), key=lambda x: -x[1]) + top = sorted_items[:k] + result_ids.append([int(sid) for sid, _ in top]) + result_lps.append([float(lp) for _, lp in top]) + + return result_ids, result_lps + + # ------------------------------------------------------------------ + # Core fetch + # ------------------------------------------------------------------ + async def _fetch_teacher_for_sequence( self, token_ids: List[int], top_k: int ) -> Tuple[List[List[int]], List[List[float]]]: assert self.teacher_server is not None + + if self._same_tokenizer or self._teacher_tokenizer is None: + # Fast path: same vocabulary — send student IDs directly. + payload = await self.teacher_server.get_logprobs( + input_ids=token_ids, + top_k=top_k, + max_tokens=1, + split="train", + ) + return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"] + + # Cross-tokenizer path: + # 1. Decode student tokens → plain text. + # 2. Re-tokenize with teacher tokenizer → teacher IDs. + # 3. Send teacher IDs to teacher vLLM. + # 4. Align teacher positions → student positions. + # 5. Remap teacher vocab IDs → student vocab IDs. + text = self.tokenizer.decode(token_ids, clean_up_tokenization_spaces=False) + teacher_ids: List[int] = self._teacher_tokenizer.encode( + text, add_special_tokens=False + ) + payload = await self.teacher_server.get_logprobs( - input_ids=token_ids, + input_ids=teacher_ids, top_k=top_k, max_tokens=1, split="train", ) - return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"] + teacher_topk_ids = payload["prompt_topk_token_ids"] + teacher_topk_lps = payload["prompt_topk_logprobs"] + + alignment = self._build_student_teacher_alignment(text, token_ids, teacher_ids) + return self._align_and_remap(token_ids, teacher_topk_ids, teacher_topk_lps, alignment) + + # ------------------------------------------------------------------ + # Group enrichment + # ------------------------------------------------------------------ async def _attach_teacher_distillation( self, group: ScoredDataGroup From f1cfc137eca10dc2b195a78b2338121f560de84b Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 11:05:44 -0400 Subject: [PATCH 20/64] tokenizer bug --- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 7fdcdccd..7f54a646 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -56,7 +56,7 @@ LR="${LR:-1e-5}" WARMUP_STEPS="${WARMUP_STEPS:-0}" CLIP_EPS="${CLIP_EPS:-0.2}" MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}" -ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-4096}" +ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-4608}" DISTILL_COEF="${DISTILL_COEF:-0.2}" DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" TEACHER_TOP_K="${TEACHER_TOP_K:-8}" From c275687fba1ff51947744cf3630755cea0a171df Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 11:38:57 -0400 Subject: [PATCH 21/64] tokenizer bug --- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 7f54a646..3c064580 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -56,7 +56,7 @@ LR="${LR:-1e-5}" WARMUP_STEPS="${WARMUP_STEPS:-0}" CLIP_EPS="${CLIP_EPS:-0.2}" MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}" -ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-4608}" +ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-16384}" DISTILL_COEF="${DISTILL_COEF:-0.2}" DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" TEACHER_TOP_K="${TEACHER_TOP_K:-8}" @@ -74,7 +74,7 @@ ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-16}" ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}" ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}" ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-1}" -ENV_WORKER_TIMEOUT="${ENV_WORKER_TIMEOUT:-300}" +ENV_WORKER_TIMEOUT="${ENV_WORKER_TIMEOUT:-1800}" RUN_PIDS=() RUN_PORTS=() From 3a440f847c5d9e84f23d2caf69283b9487b0ef32 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 12:38:03 -0400 Subject: [PATCH 22/64] tokenizer bug --- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 3c064580..6cda7d1b 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -262,6 +262,7 @@ if [[ "$DRY_RUN" == "1" ]]; then --warmup-steps "$WARMUP_STEPS" \ --lr "$LR" \ --clip-eps "$CLIP_EPS" \ + --seq-len "$ENV_MAX_TOKEN_LENGTH" \ --distill-enabled \ --distill-coef "$DISTILL_COEF" \ --distill-temperature "$DISTILL_TEMPERATURE" @@ -285,6 +286,7 @@ start_process "trainer" "${LOG_DIR}/trainer.log" \ --warmup-steps "$WARMUP_STEPS" \ --lr "$LR" \ --clip-eps "$CLIP_EPS" \ + --seq-len "$ENV_MAX_TOKEN_LENGTH" \ --distill-enabled \ --distill-coef "$DISTILL_COEF" \ --distill-temperature "$DISTILL_TEMPERATURE" From b457a678ce420fad84d6cca2a34933a7cdeb1514 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 16:56:51 -0400 Subject: [PATCH 23/64] tokenizer bug --- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 6cda7d1b..1b58b738 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -56,6 +56,7 @@ LR="${LR:-1e-5}" WARMUP_STEPS="${WARMUP_STEPS:-0}" CLIP_EPS="${CLIP_EPS:-0.2}" MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}" +TEACHER_MAX_MODEL_LEN="${TEACHER_MAX_MODEL_LEN:-32768}" ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-16384}" DISTILL_COEF="${DISTILL_COEF:-0.2}" DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" @@ -206,7 +207,7 @@ start_process "teacher_vllm" "${LOG_DIR}/teacher_vllm.log" \ --port "$TEACHER_PORT" \ --tensor-parallel-size "$TEACHER_TP" \ --gpu-memory-utilization "$TEACHER_GPU_MEMORY_UTILIZATION" \ - --max-model-len "$MAX_MODEL_LEN" \ + --max-model-len "$TEACHER_MAX_MODEL_LEN" \ --dtype "$DTYPE" if [[ "$DRY_RUN" == "0" ]]; then wait_for_http "http://localhost:${TEACHER_PORT}/health" 1800 "teacher vLLM" From 2f371e03fc0bbea0d4d97f513c80d76a021e1eb6 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 19:35:12 -0400 Subject: [PATCH 24/64] tokenizer bug --- atroposlib/envs/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 3d3b6c20..7aa391ba 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -907,7 +907,7 @@ class BaseEnv(ABC): "ensure your trainer handles this appropriately." ) elif abort_on_any_max_length_exceeded and any( - [len(x) >= self.max_token_len for x in group["tokens"]] + [len(x) > self.max_token_len for x in group["tokens"]] ): logger.warning("Token length is too long in a group, skipping...") continue From 8a348beccd9dac4c9a9d9566144149adfb92ed03 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 19:37:17 -0400 Subject: [PATCH 25/64] tokenizer bug --- .../run_gsm8k_teacher_distill_single_terminal.sh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 1b58b738..311f668d 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -57,6 +57,9 @@ WARMUP_STEPS="${WARMUP_STEPS:-0}" CLIP_EPS="${CLIP_EPS:-0.2}" MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}" TEACHER_MAX_MODEL_LEN="${TEACHER_MAX_MODEL_LEN:-32768}" +# Trainer seq_len must be larger than ENV_MAX_TOKEN_LENGTH to accommodate +# chat template overhead (~400-800 tokens for Qwen3 thinking format). +TRAINER_SEQ_LEN="${TRAINER_SEQ_LEN:-20480}" ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-16384}" DISTILL_COEF="${DISTILL_COEF:-0.2}" DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" @@ -263,7 +266,7 @@ if [[ "$DRY_RUN" == "1" ]]; then --warmup-steps "$WARMUP_STEPS" \ --lr "$LR" \ --clip-eps "$CLIP_EPS" \ - --seq-len "$ENV_MAX_TOKEN_LENGTH" \ + --seq-len "$TRAINER_SEQ_LEN" \ --distill-enabled \ --distill-coef "$DISTILL_COEF" \ --distill-temperature "$DISTILL_TEMPERATURE" @@ -287,7 +290,7 @@ start_process "trainer" "${LOG_DIR}/trainer.log" \ --warmup-steps "$WARMUP_STEPS" \ --lr "$LR" \ --clip-eps "$CLIP_EPS" \ - --seq-len "$ENV_MAX_TOKEN_LENGTH" \ + --seq-len "$TRAINER_SEQ_LEN" \ --distill-enabled \ --distill-coef "$DISTILL_COEF" \ --distill-temperature "$DISTILL_TEMPERATURE" From 34a39367dc7c3a735d8b583ca50b488574da1ee8 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 22:15:38 -0400 Subject: [PATCH 26/64] tokenizer bug --- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 311f668d..db9f137f 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -35,7 +35,7 @@ LAUNCH_DIR="$PWD" cd "$ROOT_DIR" PYTHON_BIN="${PYTHON_BIN:-python3}" -STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B-Instruct-2507-FP8}" +STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B-Instruct}" TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}" STUDENT_GPUS="${STUDENT_GPUS:-0}" From fd5b426f9fd24eca526bb2c6947240c7097a106f Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 11 Mar 2026 23:58:36 -0400 Subject: [PATCH 27/64] tokenizer bug --- example_trainer/run_gsm8k_teacher_distill_single_terminal.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index db9f137f..cde59a3c 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -35,7 +35,7 @@ LAUNCH_DIR="$PWD" cd "$ROOT_DIR" PYTHON_BIN="${PYTHON_BIN:-python3}" -STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B-Instruct}" +STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B}" TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}" STUDENT_GPUS="${STUDENT_GPUS:-0}" From c37516b289d7086a29a105548932294844d3e0c3 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Thu, 12 Mar 2026 09:06:50 -0400 Subject: [PATCH 28/64] tokenizer bug --- .../run_gsm8k_teacher_distill_single_terminal.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index cde59a3c..94021717 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -251,8 +251,8 @@ log " ${LOG_DIR}/env.log" if [[ "$DRY_RUN" == "1" ]]; then log "[DRY RUN] trainer command:" printf ' ' - printf '%q ' env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" \ - "$PYTHON_BIN" -m example_trainer.grpo \ + printf '%q ' env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" PYTHONUNBUFFERED=1 \ + "$PYTHON_BIN" -u -m example_trainer.grpo \ --model-name "$STUDENT_MODEL" \ --weight-bridge-mode shared_vllm \ --device cuda:0 \ @@ -275,8 +275,8 @@ if [[ "$DRY_RUN" == "1" ]]; then fi start_process "trainer" "${LOG_DIR}/trainer.log" \ - env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" \ - "$PYTHON_BIN" -m example_trainer.grpo \ + env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" PYTHONUNBUFFERED=1 \ + "$PYTHON_BIN" -u -m example_trainer.grpo \ --model-name "$STUDENT_MODEL" \ --weight-bridge-mode shared_vllm \ --device cuda:0 \ From a54dfe7a135b383143529d22d1f7c3f7ed3dd45c Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Thu, 12 Mar 2026 09:53:33 -0400 Subject: [PATCH 29/64] tokenizer bug --- .../run_gsm8k_teacher_distill_single_terminal.sh | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 94021717..197599d0 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -65,6 +65,9 @@ DISTILL_COEF="${DISTILL_COEF:-0.2}" DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" TEACHER_TOP_K="${TEACHER_TOP_K:-8}" +WANDB_PROJECT="${WANDB_PROJECT:-gsm8k-teacher-distill}" +WANDB_GROUP="${WANDB_GROUP:-}" + STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.95}" TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.95}" DTYPE="${DTYPE:-bfloat16}" @@ -166,6 +169,7 @@ log " logs=${LOG_DIR}" log " saves=${SAVE_DIR}" log " bridge=${BRIDGE_DIR}" log " env max_token_length=${ENV_MAX_TOKEN_LENGTH}, env workers=${ENV_MAX_WORKERS_PER_NODE}, env worker_timeout=${ENV_WORKER_TIMEOUT}" +log " wandb project=${WANDB_PROJECT}${WANDB_GROUP:+, group=${WANDB_GROUP}}" # Shared-vLLM attach path currently expects the student server to expose # unsharded weights. Keep the student on TP=1 and the trainer on the same GPU set. @@ -269,7 +273,10 @@ if [[ "$DRY_RUN" == "1" ]]; then --seq-len "$TRAINER_SEQ_LEN" \ --distill-enabled \ --distill-coef "$DISTILL_COEF" \ - --distill-temperature "$DISTILL_TEMPERATURE" + --distill-temperature "$DISTILL_TEMPERATURE" \ + --use-wandb \ + --wandb-project "$WANDB_PROJECT" \ + ${WANDB_GROUP:+--wandb-group "$WANDB_GROUP"} printf '\n' exit 0 fi @@ -293,7 +300,10 @@ start_process "trainer" "${LOG_DIR}/trainer.log" \ --seq-len "$TRAINER_SEQ_LEN" \ --distill-enabled \ --distill-coef "$DISTILL_COEF" \ - --distill-temperature "$DISTILL_TEMPERATURE" + --distill-temperature "$DISTILL_TEMPERATURE" \ + --use-wandb \ + --wandb-project "$WANDB_PROJECT" \ + ${WANDB_GROUP:+--wandb-group "$WANDB_GROUP"} log "All processes running in background." log "" From 62ef2fcc2efb60d1ddb9a03b9c60099ee935a483 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Thu, 12 Mar 2026 12:20:54 -0400 Subject: [PATCH 30/64] training kernel --- example_trainer/training.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/example_trainer/training.py b/example_trainer/training.py index c5b739e9..673ed795 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -315,7 +315,6 @@ def compute_distillation_loss( return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 temp = max(1e-6, float(temperature)) - student_log_probs = F.log_softmax(logits / temp, dim=-1) valid_ids = distill_token_ids >= 0 label_mask = labels != -100 @@ -324,7 +323,14 @@ def compute_distillation_loss( return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 gather_ids = distill_token_ids.clamp_min(0).long() - student_logp_topk = torch.gather(student_log_probs, dim=-1, index=gather_ids) + + # Avoid materializing the full [batch, seq_len, vocab] log_softmax tensor + # (e.g. [2, 20480, 151936] = ~12.5 GB) which is the main cause of OOM/hangs. + # Instead: gather raw logits at top-k positions, then subtract logsumexp. + # Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab]. + scaled_logits = logits / temp + log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1] + student_logp_topk = torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9) teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1) @@ -408,6 +414,13 @@ def run_training_step( for batch_idx, (tokens, labels, advantages, temperatures) in enumerate( zip(token_batches, label_batches, advantage_batches, temperature_batches) ): + print( + f" [Step] micro-batch {batch_idx+1}/{num_batches} " + f"tokens={tokens.shape} " + f"gpu_mem={torch.cuda.memory_allocated()/1e9:.1f}GB " + f"gpu_reserved={torch.cuda.memory_reserved()/1e9:.1f}GB", + flush=True, + ) tokens = tokens.to(config.device) labels = labels.to(config.device) advantages = advantages.to(config.device) @@ -429,6 +442,7 @@ def run_training_step( ): distill_lps = distill_logprob_batches[batch_idx] + print(f" [Step] micro-batch {batch_idx+1} forward pass...", flush=True) loss, metrics = compute_grpo_loss( model, tokens, @@ -445,7 +459,13 @@ def run_training_step( distill_temperature=float(getattr(config, "distill_temperature", 1.0)), ) + print( + f" [Step] micro-batch {batch_idx+1} loss={loss.item():.4f} " + f"backward...", + flush=True, + ) loss.backward() + print(f" [Step] micro-batch {batch_idx+1} backward done", flush=True) total_loss += loss.item() total_pos_logp += metrics["pos_logp"] total_neg_logp += metrics["neg_logp"] From c26432b963999e7df78b75597e03dede351011d7 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Thu, 12 Mar 2026 12:31:09 -0400 Subject: [PATCH 31/64] training kernel --- atroposlib/envs/teacher_distillation_env.py | 44 ++++++++++++++++----- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 521c3762..12c16079 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -146,22 +146,48 @@ class TeacherDistillationEnv(BaseEnv, ABC): ) # Detect vocabulary mismatch. + # Compare by name first; if names differ, load the teacher tokenizer + # and do a vocab-size sanity check. Same-family models (e.g. Qwen3-4B + # and Qwen3-30B) share the same vocabulary, so even though the + # name_or_path strings differ they should use the fast path. student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" if student_tok_name and teacher_tok_name and student_tok_name != teacher_tok_name: try: from transformers import AutoTokenizer - self._teacher_tokenizer = AutoTokenizer.from_pretrained( + loaded = AutoTokenizer.from_pretrained( teacher_tok_name, use_fast=True ) - self._same_tokenizer = False - logger.warning( - "TeacherDistillationEnv: cross-tokenizer mode active. " - "student=%s teacher=%s. " - "Token IDs will be decoded → re-tokenized → vocab-remapped.", - student_tok_name, - teacher_tok_name, - ) + student_vocab_size = getattr(self.tokenizer, "vocab_size", None) + teacher_vocab_size = getattr(loaded, "vocab_size", None) + if ( + student_vocab_size is not None + and teacher_vocab_size is not None + and student_vocab_size == teacher_vocab_size + ): + # Same vocab size — treat as same tokenizer (fast path). + # This covers same-family models (e.g. all Qwen3 variants). + self._same_tokenizer = True + logger.warning( + "TeacherDistillationEnv: names differ but vocab sizes match " + "(%d tokens). Using fast (same-tokenizer) path. " + "student=%s teacher=%s", + student_vocab_size, + student_tok_name, + teacher_tok_name, + ) + else: + self._teacher_tokenizer = loaded + self._same_tokenizer = False + logger.warning( + "TeacherDistillationEnv: cross-tokenizer mode active. " + "student=%s (%s tokens) teacher=%s (%s tokens). " + "Token IDs will be decoded → re-tokenized → vocab-remapped.", + student_tok_name, + student_vocab_size, + teacher_tok_name, + teacher_vocab_size, + ) except Exception as exc: logger.warning( "TeacherDistillationEnv: could not load teacher tokenizer '%s' " From 7ec622a098f614adda2a90e08221700eefeb98b7 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Thu, 12 Mar 2026 13:19:34 -0400 Subject: [PATCH 32/64] training ideas --- ...n_gsm8k_teacher_distill_single_terminal.sh | 81 ++++++++----------- 1 file changed, 35 insertions(+), 46 deletions(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 197599d0..5b100ce0 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -251,59 +251,48 @@ log " ${LOG_DIR}/student_vllm.log" log " ${LOG_DIR}/teacher_vllm.log" log " ${LOG_DIR}/env.log" -# 5) Trainer (foreground, primary output) +# 5) Trainer (background) +TRAINER_CMD=( + env + CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" + PYTHONUNBUFFERED=1 + "$PYTHON_BIN" + -u + -m + example_trainer.grpo + --model-name "$STUDENT_MODEL" + --weight-bridge-mode shared_vllm + --device cuda:0 + --save-path "$SAVE_DIR" + --atropos-url "http://localhost:${API_PORT}" + --vllm-port "$STUDENT_PORT" + --vllm-config-path "${BRIDGE_DIR}/vllm_bridge_config.json" + --training-steps "$TRAINING_STEPS" + --batch-size "$BATCH_SIZE" + --gradient-accumulation-steps "$GRAD_ACCUM" + --warmup-steps "$WARMUP_STEPS" + --lr "$LR" + --clip-eps "$CLIP_EPS" + --seq-len "$TRAINER_SEQ_LEN" + --distill-enabled + --distill-coef "$DISTILL_COEF" + --distill-temperature "$DISTILL_TEMPERATURE" + --use-wandb + --wandb-project "$WANDB_PROJECT" +) +if [[ -n "$WANDB_GROUP" ]]; then + TRAINER_CMD+=(--wandb-group "$WANDB_GROUP") +fi + if [[ "$DRY_RUN" == "1" ]]; then log "[DRY RUN] trainer command:" printf ' ' - printf '%q ' env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" PYTHONUNBUFFERED=1 \ - "$PYTHON_BIN" -u -m example_trainer.grpo \ - --model-name "$STUDENT_MODEL" \ - --weight-bridge-mode shared_vllm \ - --device cuda:0 \ - --save-path "$SAVE_DIR" \ - --atropos-url "http://localhost:${API_PORT}" \ - --vllm-port "$STUDENT_PORT" \ - --vllm-config-path "${BRIDGE_DIR}/vllm_bridge_config.json" \ - --training-steps "$TRAINING_STEPS" \ - --batch-size "$BATCH_SIZE" \ - --gradient-accumulation-steps "$GRAD_ACCUM" \ - --warmup-steps "$WARMUP_STEPS" \ - --lr "$LR" \ - --clip-eps "$CLIP_EPS" \ - --seq-len "$TRAINER_SEQ_LEN" \ - --distill-enabled \ - --distill-coef "$DISTILL_COEF" \ - --distill-temperature "$DISTILL_TEMPERATURE" \ - --use-wandb \ - --wandb-project "$WANDB_PROJECT" \ - ${WANDB_GROUP:+--wandb-group "$WANDB_GROUP"} + printf '%q ' "${TRAINER_CMD[@]}" printf '\n' exit 0 fi -start_process "trainer" "${LOG_DIR}/trainer.log" \ - env CUDA_VISIBLE_DEVICES="$TRAINER_GPUS" PYTHONUNBUFFERED=1 \ - "$PYTHON_BIN" -u -m example_trainer.grpo \ - --model-name "$STUDENT_MODEL" \ - --weight-bridge-mode shared_vllm \ - --device cuda:0 \ - --save-path "$SAVE_DIR" \ - --atropos-url "http://localhost:${API_PORT}" \ - --vllm-port "$STUDENT_PORT" \ - --vllm-config-path "${BRIDGE_DIR}/vllm_bridge_config.json" \ - --training-steps "$TRAINING_STEPS" \ - --batch-size "$BATCH_SIZE" \ - --gradient-accumulation-steps "$GRAD_ACCUM" \ - --warmup-steps "$WARMUP_STEPS" \ - --lr "$LR" \ - --clip-eps "$CLIP_EPS" \ - --seq-len "$TRAINER_SEQ_LEN" \ - --distill-enabled \ - --distill-coef "$DISTILL_COEF" \ - --distill-temperature "$DISTILL_TEMPERATURE" \ - --use-wandb \ - --wandb-project "$WANDB_PROJECT" \ - ${WANDB_GROUP:+--wandb-group "$WANDB_GROUP"} +start_process "trainer" "${LOG_DIR}/trainer.log" "${TRAINER_CMD[@]}" log "All processes running in background." log "" From a43b0b7e72f50ee86da2334a4c25f63e47ea2259 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Thu, 12 Mar 2026 14:51:28 -0400 Subject: [PATCH 33/64] training kernel --- atroposlib/api/server.py | 30 +++++++++++++++++++++++++++--- example_trainer/data.py | 29 +++++++++++++++++++++++++++++ example_trainer/trainers.py | 11 +++++++++++ 3 files changed, 67 insertions(+), 3 deletions(-) diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 3a0fb999..0ca0f02d 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -364,7 +364,15 @@ async def get_batch(): app.state.started = True if len(app.state.curr_batch) > 0: - return {"batch": app.state.curr_batch.pop()} + curr_batch = app.state.curr_batch.pop() + logger.warning( + "API /batch returning prebuilt batch: groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s", + len(curr_batch), + sum(len(x["tokens"]) for x in curr_batch), + len(app.state.curr_batch), + len(app.state.queue), + ) + return {"batch": curr_batch} else: new_batches = [] # Check if any envs have minimum allocations @@ -394,6 +402,17 @@ async def get_batch(): ) steps_to_take = len(new_batches) if steps_to_take == 0: + now = time.time() + last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0) + if now - last_empty_log > 30: + logger.warning( + "API /batch no full batch ready: queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s", + len(app.state.queue), + sum(len(x.get("tokens", [])) for x in app.state.queue), + len(app.state.curr_batch), + getattr(app.state, "batchsize", -1), + ) + app.state._last_empty_batch_log = now return {"batch": None} app.state.status_dict["step"] += steps_to_take # chunk it @@ -401,9 +420,14 @@ async def get_batch(): app.state.curr_batch.append(batch) curr_batch = app.state.curr_batch.pop() # check length before sending - logger.info( - "Sending batch of %s sequences", + logger.warning( + "API /batch built %s trainer batch(es); returning one with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s", + steps_to_take, + len(curr_batch), sum(len(x["tokens"]) for x in curr_batch), + len(app.state.curr_batch), + len(app.state.queue), + app.state.status_dict["step"], ) return {"batch": curr_batch} diff --git a/example_trainer/data.py b/example_trainer/data.py index 770d68fa..bf7f5b19 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -366,11 +366,23 @@ def get_data( """ batches = [] _logged_logprob_warning = False + empty_polls = 0 while True: data = get_batch(url=atropos_url) if data["batch"] is not None: + empty_polls = 0 + num_groups = len(data["batch"]) + num_sequences = sum(len(item["tokens"]) for item in data["batch"]) + max_seq_len = max( + max(len(seq) for seq in item["tokens"]) for item in data["batch"] + ) + print( + " [Data] received API batch: " + f"groups={num_groups} sequences={num_sequences} max_seq_len={max_seq_len}", + flush=True, + ) # DEBUG: Check if inference_logprobs exists in the data if not _logged_logprob_warning: has_logprobs = any( @@ -407,6 +419,7 @@ def get_data( _logged_logprob_warning = True # Process and accumulate batches (now includes batched inference logprobs) + print(" [Data] padding / batching API payload...", flush=True) ( token_batches, label_batches, @@ -416,6 +429,12 @@ def get_data( distill_token_id_batches, distill_logprob_batches, ) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) + batch_shapes = [tuple(tb.shape) for tb in token_batches] + print( + " [Data] pad_data_to_good_offset done: " + f"micro_batches={len(token_batches)} token_batch_shapes={batch_shapes}", + flush=True, + ) # Include inference logprob batches in the tuple batches.append( @@ -432,7 +451,17 @@ def get_data( elif len(batches) > 0: # Return accumulated batches when no more data + print( + f" [Data] returning {len(batches)} assembled trainer batch tuple(s)", + flush=True, + ) return batches, None else: # Wait for data + empty_polls += 1 + if empty_polls == 1 or empty_polls % 30 == 0: + print( + f" [Data] no batch ready yet (polls_without_data={empty_polls})", + flush=True, + ) time.sleep(1) diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index cc96cee5..bff1763f 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -317,12 +317,17 @@ def train_shared_vllm(config: TrainingConfig): # Fetch data (with inference logprobs for proper GRPO loss) data_fetch_start = time.time() if len(batches) == 0: + print(" [Trainer] requesting data from Atropos API...", flush=True) batches, _ = get_data( config.batch_size, config.seq_len, config.atropos_url, extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs ) + print( + f" [Trainer] get_data returned {len(batches)} trainer batch tuple(s)", + flush=True, + ) batch_data = batches.pop(0) token_batches, label_batches, advantage_batches, temperature_batches = ( batch_data[:4] @@ -330,6 +335,12 @@ def train_shared_vllm(config: TrainingConfig): inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None + token_shapes = [tuple(tb.shape) for tb in token_batches] + print( + " [Trainer] selected trainer batch: " + f"micro_batches={len(token_batches)} token_batch_shapes={token_shapes}", + flush=True, + ) data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) From 690e670e646902e5f8aa613647ded5ccc5f2b4b7 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Thu, 12 Mar 2026 16:11:06 -0400 Subject: [PATCH 34/64] investigating weird training issue --- atroposlib/api/server.py | 29 ++++++++++++++++++++++++----- example_trainer/api.py | 10 +++++++++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 0ca0f02d..ac134300 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -4,7 +4,7 @@ import time import uuid from typing import Any, Dict, List, Optional -from fastapi import FastAPI, status +from fastapi import FastAPI, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.responses import PlainTextResponse @@ -351,7 +351,7 @@ async def info(): @app.get("/batch") -async def get_batch(): +async def get_batch(request: Request): # Check if trainer has registered first if not hasattr(app.state, "started"): return { @@ -363,10 +363,21 @@ async def get_batch(): if not app.state.started: app.state.started = True + client = request.client + client_addr = ( + f"{client.host}:{client.port}" if client is not None else "unknown-client" + ) + client_tag = request.headers.get("x-atropos-client", "unknown") + client_pid = request.headers.get("x-atropos-pid", "unknown") + if len(app.state.curr_batch) > 0: curr_batch = app.state.curr_batch.pop() logger.warning( - "API /batch returning prebuilt batch: groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s", + "API /batch returning prebuilt batch to client=%s pid=%s addr=%s: " + "groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s", + client_tag, + client_pid, + client_addr, len(curr_batch), sum(len(x["tokens"]) for x in curr_batch), len(app.state.curr_batch), @@ -406,7 +417,11 @@ async def get_batch(): last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0) if now - last_empty_log > 30: logger.warning( - "API /batch no full batch ready: queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s", + "API /batch no full batch ready for client=%s pid=%s addr=%s: " + "queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s", + client_tag, + client_pid, + client_addr, len(app.state.queue), sum(len(x.get("tokens", [])) for x in app.state.queue), len(app.state.curr_batch), @@ -421,8 +436,12 @@ async def get_batch(): curr_batch = app.state.curr_batch.pop() # check length before sending logger.warning( - "API /batch built %s trainer batch(es); returning one with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s", + "API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s " + "with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s", steps_to_take, + client_tag, + client_pid, + client_addr, len(curr_batch), sum(len(x["tokens"]) for x in curr_batch), len(app.state.curr_batch), diff --git a/example_trainer/api.py b/example_trainer/api.py index 21c4288e..561e318f 100644 --- a/example_trainer/api.py +++ b/example_trainer/api.py @@ -7,6 +7,7 @@ Handles communication with the Atropos API server for: - Batch retrieval """ +import os import time as _time import requests @@ -99,7 +100,14 @@ def get_batch(url: str = "http://localhost:8000"): Raises: RuntimeError: If trainer is not registered or other API error """ - data = requests.get(f"{url}/batch", timeout=10).json() + data = requests.get( + f"{url}/batch", + headers={ + "X-Atropos-Client": "trainer", + "X-Atropos-Pid": str(os.getpid()), + }, + timeout=10, + ).json() # Check if there was an error (trainer not registered) if data.get("status") == "error": From 3df0e456591928bf9b06036785e4c74d728dca5d Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Thu, 12 Mar 2026 20:02:15 -0400 Subject: [PATCH 35/64] investigating weird training issue --- example_trainer/api.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/example_trainer/api.py b/example_trainer/api.py index 561e318f..fe0ac38a 100644 --- a/example_trainer/api.py +++ b/example_trainer/api.py @@ -100,14 +100,34 @@ def get_batch(url: str = "http://localhost:8000"): Raises: RuntimeError: If trainer is not registered or other API error """ - data = requests.get( - f"{url}/batch", - headers={ - "X-Atropos-Client": "trainer", - "X-Atropos-Pid": str(os.getpid()), - }, - timeout=10, - ).json() + try: + response = requests.get( + f"{url}/batch", + headers={ + "X-Atropos-Client": "trainer", + "X-Atropos-Pid": str(os.getpid()), + }, + timeout=10, + ) + print( + f" [Trainer/API] GET /batch status={response.status_code}", + flush=True, + ) + data = response.json() + batch = data.get("batch") + if batch is None: + print(" [Trainer/API] parsed batch=None", flush=True) + else: + num_groups = len(batch) + num_sequences = sum(len(item["tokens"]) for item in batch) + print( + " [Trainer/API] parsed batch payload: " + f"groups={num_groups} sequences={num_sequences}", + flush=True, + ) + except Exception as exc: + print(f" [Trainer/API] GET /batch failed: {exc!r}", flush=True) + raise # Check if there was an error (trainer not registered) if data.get("status") == "error": From d8857eb69faf8e34cad247d0b4a274f5710f13f4 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 09:07:11 -0400 Subject: [PATCH 36/64] investigating weird training issue --- .../run_gsm8k_teacher_distill_single_terminal.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 5b100ce0..418a87ea 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -49,9 +49,9 @@ API_PORT="${API_PORT:-8002}" STUDENT_PORT="${STUDENT_PORT:-9001}" TEACHER_PORT="${TEACHER_PORT:-9003}" -TRAINING_STEPS="${TRAINING_STEPS:-100}" -BATCH_SIZE="${BATCH_SIZE:-2}" -GRAD_ACCUM="${GRAD_ACCUM:-8}" +TRAINING_STEPS="${TRAINING_STEPS:-20}" +BATCH_SIZE="${BATCH_SIZE:-1}" +GRAD_ACCUM="${GRAD_ACCUM:-4}" LR="${LR:-1e-5}" WARMUP_STEPS="${WARMUP_STEPS:-0}" CLIP_EPS="${CLIP_EPS:-0.2}" @@ -68,8 +68,8 @@ TEACHER_TOP_K="${TEACHER_TOP_K:-8}" WANDB_PROJECT="${WANDB_PROJECT:-gsm8k-teacher-distill}" WANDB_GROUP="${WANDB_GROUP:-}" -STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.95}" -TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.95}" +STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.60}" +TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.85}" DTYPE="${DTYPE:-bfloat16}" SAVE_DIR="${SAVE_DIR:-${LAUNCH_DIR}/saves/gsm8k_teacher_distill}" LOG_DIR="${LOG_DIR:-${LAUNCH_DIR}/logs/gsm8k_teacher_distill}" @@ -77,7 +77,7 @@ BRIDGE_DIR="${BRIDGE_DIR:-${LOG_DIR}/bridge}" DRY_RUN="${DRY_RUN:-0}" ENV_GROUP_SIZE="${ENV_GROUP_SIZE:-4}" -ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-16}" +ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-8}" ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}" ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}" ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-1}" From d1b0dee8f75cc4dd472a54f7f28f15571e0b2bc6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:14:05 +0000 Subject: [PATCH 37/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../envs/server_handling/managed_server.py | 4 ++- atroposlib/envs/teacher_distillation_env.py | 10 ++++-- atroposlib/tests/test_server_logprobs.py | 4 ++- environments/gsm8k_server_teacher_distill.py | 1 + example_trainer/data.py | 32 +++++++++++++------ example_trainer/training.py | 9 ++++-- example_trainer/vllm_api_server.py | 12 +++++-- 7 files changed, 53 insertions(+), 19 deletions(-) diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index af74fa55..a8e97077 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -460,7 +460,9 @@ class ManagedServer: ) # Call the tokens and logprobs wrapper directly - logger.warning("managed_server chat_completion calling backend completion wrapper") + logger.warning( + "managed_server chat_completion calling backend completion wrapper" + ) ( prompt_tokens, output_tokens_list, diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 12c16079..1c88ab62 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -151,7 +151,11 @@ class TeacherDistillationEnv(BaseEnv, ABC): # and Qwen3-30B) share the same vocabulary, so even though the # name_or_path strings differ they should use the fast path. student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" - if student_tok_name and teacher_tok_name and student_tok_name != teacher_tok_name: + if ( + student_tok_name + and teacher_tok_name + and student_tok_name != teacher_tok_name + ): try: from transformers import AutoTokenizer @@ -367,7 +371,9 @@ class TeacherDistillationEnv(BaseEnv, ABC): teacher_topk_lps = payload["prompt_topk_logprobs"] alignment = self._build_student_teacher_alignment(text, token_ids, teacher_ids) - return self._align_and_remap(token_ids, teacher_topk_ids, teacher_topk_lps, alignment) + return self._align_and_remap( + token_ids, teacher_topk_ids, teacher_topk_lps, alignment + ) # ------------------------------------------------------------------ # Group enrichment diff --git a/atroposlib/tests/test_server_logprobs.py b/atroposlib/tests/test_server_logprobs.py index 2da50b42..8cbd84ad 100644 --- a/atroposlib/tests/test_server_logprobs.py +++ b/atroposlib/tests/test_server_logprobs.py @@ -41,7 +41,9 @@ class _FakeAPIServer(APIServer): class _FakeRoutedServer: - def __init__(self, name: str, train_slots: int, eval_slots: int, healthy: bool = True): + def __init__( + self, name: str, train_slots: int, eval_slots: int, healthy: bool = True + ): self.name = name self.server_healthy = healthy self.sem = AsyncSemWithAdaptiveWeight(4) diff --git a/environments/gsm8k_server_teacher_distill.py b/environments/gsm8k_server_teacher_distill.py index 159fa4d3..8276436b 100644 --- a/environments/gsm8k_server_teacher_distill.py +++ b/environments/gsm8k_server_teacher_distill.py @@ -47,5 +47,6 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): ) return env_config, server_config + if __name__ == "__main__": GSM8kTeacherDistillEnv.cli() diff --git a/example_trainer/data.py b/example_trainer/data.py index bf7f5b19..4823eb64 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -189,23 +189,27 @@ def pad_data_to_good_offset( rows = max(0, token_setup_len - 1) token_mat = np.full((rows, local_k), -1, dtype=np.int64) - logprob_mat = np.full( - (rows, local_k), -1e9, dtype=np.float32 - ) + logprob_mat = np.full((rows, local_k), -1e9, dtype=np.float32) # Shift by one to align with causal labels like inference_logprobs. copy_positions = min( - len(per_pos_token_ids), len(per_pos_logprobs), token_setup_len + len(per_pos_token_ids), + len(per_pos_logprobs), + token_setup_len, ) for pos in range(1, copy_positions): src_ids = per_pos_token_ids[pos] src_lps = per_pos_logprobs[pos] - if not isinstance(src_ids, list) or not isinstance(src_lps, list): + if not isinstance(src_ids, list) or not isinstance( + src_lps, list + ): continue topk = min(local_k, len(src_ids), len(src_lps)) if topk <= 0: continue - token_mat[pos - 1, :topk] = np.array(src_ids[:topk], dtype=np.int64) + token_mat[pos - 1, :topk] = np.array( + src_ids[:topk], dtype=np.int64 + ) logprob_mat[pos - 1, :topk] = np.array( src_lps[:topk], dtype=np.float32 ) @@ -222,14 +226,18 @@ def pad_data_to_good_offset( ) else: rows = max(0, token_setup_len - 1) - distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64)) + distill_token_ids_padded.append( + np.full((rows, 1), -1, dtype=np.int64) + ) distill_logprobs_padded.append( np.full((rows, 1), -1e9, dtype=np.float32) ) else: rows = max(0, token_setup_len - 1) distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64)) - distill_logprobs_padded.append(np.full((rows, 1), -1e9, dtype=np.float32)) + distill_logprobs_padded.append( + np.full((rows, 1), -1e9, dtype=np.float32) + ) # Extract temperature (priority: override > generation_params > group_overrides > 1.0) t = 1.0 @@ -310,10 +318,14 @@ def pad_data_to_good_offset( else None ) final_distill_token_id_batches = ( - distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None + distill_token_id_batches + if (has_any_distill and distill_token_id_batches) + else None ) final_distill_logprob_batches = ( - distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None + distill_logprob_batches + if (has_any_distill and distill_logprob_batches) + else None ) return ( diff --git a/example_trainer/training.py b/example_trainer/training.py index 673ed795..b7cab944 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -311,7 +311,10 @@ def compute_distillation_loss( if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3: return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 - if distill_token_ids.shape[:2] != labels.shape or distill_logprobs.shape != distill_token_ids.shape: + if ( + distill_token_ids.shape[:2] != labels.shape + or distill_logprobs.shape != distill_token_ids.shape + ): return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 temp = max(1e-6, float(temperature)) @@ -330,7 +333,9 @@ def compute_distillation_loss( # Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab]. scaled_logits = logits / temp log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1] - student_logp_topk = torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer + student_logp_topk = ( + torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer + ) masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9) teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1) diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 131861d5..b20704e2 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -360,10 +360,16 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: ret["prompt_token_ids"] = final_output.prompt_token_ids ret["token_ids"] = [x.token_ids for x in final_output.outputs] - if sampling_params.prompt_logprobs is not None and final_output.prompt_logprobs is not None: + if ( + sampling_params.prompt_logprobs is not None + and final_output.prompt_logprobs is not None + ): ret["prompt_logprobs"] = [ - {int(tok_id): lp.logprob for tok_id, lp in pos.items()} - if pos is not None else None + ( + {int(tok_id): lp.logprob for tok_id, lp in pos.items()} + if pos is not None + else None + ) for pos in final_output.prompt_logprobs ] From 600c54f5f88aa68e19920fe236fb09c227c4e703 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 12:09:08 -0400 Subject: [PATCH 38/64] clean log --- atroposlib/api/server.py | 53 ++------------- atroposlib/envs/base.py | 2 +- .../envs/server_handling/managed_server.py | 19 ------ .../envs/server_handling/server_manager.py | 17 ----- .../envs/server_handling/vllm_server.py | 28 -------- example_trainer/api.py | 37 +++-------- example_trainer/data.py | 65 ------------------- 7 files changed, 15 insertions(+), 206 deletions(-) diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index ac134300..3a0fb999 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -4,7 +4,7 @@ import time import uuid from typing import Any, Dict, List, Optional -from fastapi import FastAPI, Request, status +from fastapi import FastAPI, status from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.responses import PlainTextResponse @@ -351,7 +351,7 @@ async def info(): @app.get("/batch") -async def get_batch(request: Request): +async def get_batch(): # Check if trainer has registered first if not hasattr(app.state, "started"): return { @@ -363,27 +363,8 @@ async def get_batch(request: Request): if not app.state.started: app.state.started = True - client = request.client - client_addr = ( - f"{client.host}:{client.port}" if client is not None else "unknown-client" - ) - client_tag = request.headers.get("x-atropos-client", "unknown") - client_pid = request.headers.get("x-atropos-pid", "unknown") - if len(app.state.curr_batch) > 0: - curr_batch = app.state.curr_batch.pop() - logger.warning( - "API /batch returning prebuilt batch to client=%s pid=%s addr=%s: " - "groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s", - client_tag, - client_pid, - client_addr, - len(curr_batch), - sum(len(x["tokens"]) for x in curr_batch), - len(app.state.curr_batch), - len(app.state.queue), - ) - return {"batch": curr_batch} + return {"batch": app.state.curr_batch.pop()} else: new_batches = [] # Check if any envs have minimum allocations @@ -413,21 +394,6 @@ async def get_batch(request: Request): ) steps_to_take = len(new_batches) if steps_to_take == 0: - now = time.time() - last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0) - if now - last_empty_log > 30: - logger.warning( - "API /batch no full batch ready for client=%s pid=%s addr=%s: " - "queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s", - client_tag, - client_pid, - client_addr, - len(app.state.queue), - sum(len(x.get("tokens", [])) for x in app.state.queue), - len(app.state.curr_batch), - getattr(app.state, "batchsize", -1), - ) - app.state._last_empty_batch_log = now return {"batch": None} app.state.status_dict["step"] += steps_to_take # chunk it @@ -435,18 +401,9 @@ async def get_batch(request: Request): app.state.curr_batch.append(batch) curr_batch = app.state.curr_batch.pop() # check length before sending - logger.warning( - "API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s " - "with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s", - steps_to_take, - client_tag, - client_pid, - client_addr, - len(curr_batch), + logger.info( + "Sending batch of %s sequences", sum(len(x["tokens"]) for x in curr_batch), - len(app.state.curr_batch), - len(app.state.queue), - app.state.status_dict["step"], ) return {"batch": curr_batch} diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 7aa391ba..3d3b6c20 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -907,7 +907,7 @@ class BaseEnv(ABC): "ensure your trainer handles this appropriately." ) elif abort_on_any_max_length_exceeded and any( - [len(x) > self.max_token_len for x in group["tokens"]] + [len(x) >= self.max_token_len for x in group["tokens"]] ): logger.warning("Token length is too long in a group, skipping...") continue diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index a8e97077..9d46f265 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -447,33 +447,14 @@ class ManagedServer: if not self.track_tree and self.tokenizer is not None: input_ids = self._compute_input_ids(prompt, extending_node) completion_kwargs["input_ids"] = input_ids - logger.warning( - "managed_server chat_completion prepared input_ids=%s extending=%s", - len(input_ids), - extending_node is not None, - ) - else: - logger.warning( - "managed_server chat_completion using prompt passthrough track_tree=%s tokenizer=%s", - self.track_tree, - self.tokenizer is not None, - ) # Call the tokens and logprobs wrapper directly - logger.warning( - "managed_server chat_completion calling backend completion wrapper" - ) ( prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons, ) = await self.server.tokens_and_logprobs_completion(**completion_kwargs) - logger.warning( - "managed_server chat_completion backend returned prompt_tokens=%s outputs=%s", - len(prompt_tokens), - len(output_tokens_list), - ) # Track each completion and build choices n = len(output_tokens_list) diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index b24698a6..d34f69c9 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -106,13 +106,6 @@ class ServerManager: self.servers = [ServerHarness()] return if not isinstance(configs, list): - logger.warning( - "ServerManager: configs is NOT a list (type=%s). " - "Using auto-generated URLs (template mode). " - "Passed base_url=%s will be IGNORED.", - type(configs).__name__, - getattr(configs, "base_url", "N/A"), - ) urls = [] if os.environ.get("SLURM_JOB_NODELIST", None) is not None: nodelist = ( @@ -155,21 +148,11 @@ class ServerManager: server_class(config, reasoning_config=reasoning_config) for config in openai_configs ] - logger.warning( - "ServerManager: auto-generated %s server(s) at URLs: %s", - len(self.servers), - [c.base_url for c in openai_configs], - ) elif not slurm: self.servers = [ server_class(config, reasoning_config=reasoning_config) for config in configs ] - logger.warning( - "ServerManager: using %s explicit config(s) at URLs: %s", - len(self.servers), - [c.base_url for c in configs], - ) else: nodelist = ( os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}') diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 18b8333e..acc26830 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -193,14 +193,6 @@ class VLLMServer(APIServer): # Prepare request for VLLM native API request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0} request_data.update(kwargs) - logger.warning( - "vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s", - self.config.base_url, - len(prompt_tokens), - request_data.get("n"), - request_data.get("max_tokens"), - request_data.get("temperature"), - ) # Make async request to VLLM /generate endpoint async with aiohttp.ClientSession() as session: @@ -216,11 +208,6 @@ class VLLMServer(APIServer): ) as response: response.raise_for_status() results = await response.json() - logger.warning( - "vllm_server completion POST done outputs=%s finish_reasons=%s", - len(results.get("logprobs", [])), - len(results.get("finish_reasons", [])), - ) output_tokens_list = [] output_logprobs_list = [] finish_reasons_list = [] @@ -330,13 +317,6 @@ class VLLMServer(APIServer): request_data["temperature"] = 0.0 request_data["top_p"] = 1.0 request_data.setdefault("max_tokens", 1) - logger.warning( - "vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s", - self.config.base_url, - len(prompt_tokens), - top_k, - request_data.get("max_tokens"), - ) async with aiohttp.ClientSession() as session: async with session.post( @@ -351,10 +331,6 @@ class VLLMServer(APIServer): ) as response: response.raise_for_status() results = await response.json() - logger.warning( - "vllm_server get_logprobs POST done prompt_logprobs_present=%s", - results.get("prompt_logprobs") is not None, - ) raw_prompt_logprobs = results.get("prompt_logprobs") if raw_prompt_logprobs is None: @@ -451,10 +427,6 @@ def resolve_openai_configs( elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: - logger.warning( - f"Unexpected type for default_server_configs: {type(default_server_configs)}. " - f"Proceeding with single OpenAI server configuration based on merged settings." - ) server_configs = [final_openai_config] return server_configs diff --git a/example_trainer/api.py b/example_trainer/api.py index fe0ac38a..dc51af4f 100644 --- a/example_trainer/api.py +++ b/example_trainer/api.py @@ -100,34 +100,15 @@ def get_batch(url: str = "http://localhost:8000"): Raises: RuntimeError: If trainer is not registered or other API error """ - try: - response = requests.get( - f"{url}/batch", - headers={ - "X-Atropos-Client": "trainer", - "X-Atropos-Pid": str(os.getpid()), - }, - timeout=10, - ) - print( - f" [Trainer/API] GET /batch status={response.status_code}", - flush=True, - ) - data = response.json() - batch = data.get("batch") - if batch is None: - print(" [Trainer/API] parsed batch=None", flush=True) - else: - num_groups = len(batch) - num_sequences = sum(len(item["tokens"]) for item in batch) - print( - " [Trainer/API] parsed batch payload: " - f"groups={num_groups} sequences={num_sequences}", - flush=True, - ) - except Exception as exc: - print(f" [Trainer/API] GET /batch failed: {exc!r}", flush=True) - raise + response = requests.get( + f"{url}/batch", + headers={ + "X-Atropos-Client": "trainer", + "X-Atropos-Pid": str(os.getpid()), + }, + timeout=10, + ) + data = response.json() # Check if there was an error (trainer not registered) if data.get("status") == "error": diff --git a/example_trainer/data.py b/example_trainer/data.py index 4823eb64..0aa1a88a 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -377,61 +377,12 @@ def get_data( - inference_logprob_batches are aligned with labels for proper GRPO loss computation """ batches = [] - _logged_logprob_warning = False - empty_polls = 0 while True: data = get_batch(url=atropos_url) if data["batch"] is not None: - empty_polls = 0 - num_groups = len(data["batch"]) - num_sequences = sum(len(item["tokens"]) for item in data["batch"]) - max_seq_len = max( - max(len(seq) for seq in item["tokens"]) for item in data["batch"] - ) - print( - " [Data] received API batch: " - f"groups={num_groups} sequences={num_sequences} max_seq_len={max_seq_len}", - flush=True, - ) - # DEBUG: Check if inference_logprobs exists in the data - if not _logged_logprob_warning: - has_logprobs = any( - "inference_logprobs" in item for item in data["batch"] - ) - if has_logprobs: - # Check if they're non-empty - sample_item = next( - ( - item - for item in data["batch"] - if "inference_logprobs" in item - ), - None, - ) - if sample_item and sample_item.get("inference_logprobs"): - sample_lp = ( - sample_item["inference_logprobs"][0] - if sample_item["inference_logprobs"] - else [] - ) - print( - f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})" - ) - else: - print( - " [Data] ⚠ inference_logprobs key exists but is empty!" - ) - else: - print(" [Data] ⚠ NO inference_logprobs in batch data!") - print( - f" [Data] Keys in first item: {list(data['batch'][0].keys())}" - ) - _logged_logprob_warning = True - # Process and accumulate batches (now includes batched inference logprobs) - print(" [Data] padding / batching API payload...", flush=True) ( token_batches, label_batches, @@ -441,12 +392,6 @@ def get_data( distill_token_id_batches, distill_logprob_batches, ) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) - batch_shapes = [tuple(tb.shape) for tb in token_batches] - print( - " [Data] pad_data_to_good_offset done: " - f"micro_batches={len(token_batches)} token_batch_shapes={batch_shapes}", - flush=True, - ) # Include inference logprob batches in the tuple batches.append( @@ -463,17 +408,7 @@ def get_data( elif len(batches) > 0: # Return accumulated batches when no more data - print( - f" [Data] returning {len(batches)} assembled trainer batch tuple(s)", - flush=True, - ) return batches, None else: # Wait for data - empty_polls += 1 - if empty_polls == 1 or empty_polls % 30 == 0: - print( - f" [Data] no batch ready yet (polls_without_data={empty_polls})", - flush=True, - ) time.sleep(1) From 862cd3667d28a60434100ee01b6bbdc12cf00b0e Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 12:38:52 -0400 Subject: [PATCH 39/64] clean logging --- .../envs/server_handling/server_manager.py | 3 --- .../envs/server_handling/vllm_server.py | 3 --- atroposlib/tests/test_managed_server.py | 24 +++++++++++++++++++ 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index d34f69c9..b9c493f9 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -1,6 +1,5 @@ import asyncio import inspect -import logging import os import warnings from contextlib import asynccontextmanager @@ -26,8 +25,6 @@ from atroposlib.envs.server_handling.sglang_server import SGLangServer from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer from atroposlib.envs.server_handling.vllm_server import VLLMServer -logger = logging.getLogger(__name__) - class ServerManagerConfig(BaseModel): slurm: bool = Field( diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index acc26830..aaee28d7 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -2,7 +2,6 @@ # see example_trainer/vllm_api_server.py for an example import asyncio -import logging import warnings from typing import Any, Dict, List, Tuple @@ -20,8 +19,6 @@ from atroposlib.envs.server_handling.server_baseline import ( ReasoningConfig, ) -logger = logging.getLogger(__name__) - class VLLMServer(APIServer): """ diff --git a/atroposlib/tests/test_managed_server.py b/atroposlib/tests/test_managed_server.py index 6f18be08..1524aaf7 100644 --- a/atroposlib/tests/test_managed_server.py +++ b/atroposlib/tests/test_managed_server.py @@ -319,6 +319,30 @@ async def test_get_logprobs_messages_passthrough(mock_server): assert len(payload["prompt_topk_logprobs"]) == len(prompt_tokens) +@pytest.mark.asyncio +async def test_get_logprobs_input_ids_only_passthrough(mock_server): + """ManagedServer.get_logprobs supports input_ids-only without requiring prompt.""" + managed = ManagedServer(mock_server, tokenizer=mock_server.tokenizer) + input_ids = [10, 20, 30] + + async def _mock_get_logprobs(**kwargs): + assert "input_ids" in kwargs + assert kwargs["input_ids"] == input_ids + assert kwargs.get("prompt") is None + return { + "prompt_tokens": input_ids, + "prompt_topk_token_ids": [[t] for t in input_ids], + "prompt_topk_logprobs": [[-0.1] for _ in input_ids], + } + + mock_server.get_logprobs = _mock_get_logprobs + payload = await managed.get_logprobs(input_ids=input_ids, top_k=1) + + assert payload["prompt_tokens"] == input_ids + assert payload["prompt_topk_token_ids"] == [[10], [20], [30]] + assert payload["prompt_topk_logprobs"] == [[-0.1], [-0.1], [-0.1]] + + @pytest.mark.asyncio async def test_get_logprobs_strict_mode_requires_backend_impl(mock_server): """ManagedServer.get_logprobs requires backend get_logprobs in strict mode.""" From 148a4fd5eb6ad74ff689905c6afa3c2dd00f5cce Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 12:52:52 -0400 Subject: [PATCH 40/64] remove training code --- example_trainer/cli.py | 20 ----- example_trainer/config.py | 12 --- example_trainer/data.py | 168 ++++++++---------------------------- example_trainer/run.py | 3 - example_trainer/trainers.py | 27 ------ example_trainer/training.py | 137 +---------------------------- 6 files changed, 38 insertions(+), 329 deletions(-) diff --git a/example_trainer/cli.py b/example_trainer/cli.py index 93946d51..1e46bfc9 100644 --- a/example_trainer/cli.py +++ b/example_trainer/cli.py @@ -163,23 +163,6 @@ def add_grpo_args(parser: argparse.ArgumentParser) -> None: default=0.2, help="PPO-style clipping epsilon. Clips ratio to [1-eps, 1+eps].", ) - group.add_argument( - "--distill-enabled", - action="store_true", - help="Enable teacher distillation loss (requires distill payload in Atropos batch).", - ) - group.add_argument( - "--distill-coef", - type=float, - default=0.0, - help="Coefficient for distillation loss term.", - ) - group.add_argument( - "--distill-temperature", - type=float, - default=1.0, - help="Temperature for teacher top-k distribution in distillation loss.", - ) def add_vllm_args(parser: argparse.ArgumentParser) -> None: @@ -441,9 +424,6 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: checkpoint_interval=getattr(args, "checkpoint_interval", 3), # GRPO/PPO hyperparameters clip_eps=getattr(args, "clip_eps", 0.2), - distill_enabled=getattr(args, "distill_enabled", False), - distill_coef=getattr(args, "distill_coef", 0.0), - distill_temperature=getattr(args, "distill_temperature", 1.0), adafactor_scale_parameter=getattr(args, "adafactor_scale_parameter", False), adafactor_relative_step=getattr(args, "adafactor_relative_step", False), # vLLM settings diff --git a/example_trainer/config.py b/example_trainer/config.py index 03fd80a8..4ddeddb5 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -69,18 +69,6 @@ class TrainingConfig(BaseModel): "Prevents large policy updates that could destabilize training." ), ) - distill_enabled: bool = Field( - False, - description="Enable teacher distillation loss when distill tensors are present.", - ) - distill_coef: float = Field( - 0.0, - description="Weight for distillation loss in total loss.", - ) - distill_temperature: float = Field( - 1.0, - description="Temperature applied when converting teacher top-k logprobs.", - ) # === Device & Storage === device: str = Field( "cuda" if torch.cuda.is_available() else "cpu", description="Device to train on" diff --git a/example_trainer/data.py b/example_trainer/data.py index 0aa1a88a..16a38564 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -29,8 +29,6 @@ def pad_data_to_good_offset( List[torch.Tensor], # advantage_batches List[torch.Tensor], # temperature_batches Optional[List[torch.Tensor]], # inference_logprob_batches (aligned with labels) - Optional[List[torch.Tensor]], # distill_token_id_batches [batch, seq, k] - Optional[List[torch.Tensor]], # distill_logprob_batches [batch, seq, k] ]: """ Pad and batch data from the Atropos API. @@ -47,8 +45,7 @@ def pad_data_to_good_offset( extract_inference_logprobs: Whether to extract inference logprobs Returns: - Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, - inference_logprob_batches, distill_token_id_batches, distill_logprob_batches) + Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches) inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data Note: @@ -76,10 +73,6 @@ def pad_data_to_good_offset( temperatures = [] inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape has_any_logprobs = False - distill_token_ids_padded: List[np.ndarray] = [] - distill_logprobs_padded: List[np.ndarray] = [] - has_any_distill = False - max_distill_k = 1 for item in data["batch"]: # Normalize advantage scores @@ -160,85 +153,6 @@ def pad_data_to_good_offset( np.full(token_setup_len - 1, 1.0, dtype=np.float32) ) - # Extract teacher distillation top-k arrays if available. - # Expected shape in incoming payload: [sequence][position][k]. - if "distill_token_ids" in item and "distill_logprobs" in item: - seq_token_ids = item["distill_token_ids"] - seq_logprobs = item["distill_logprobs"] - if ( - isinstance(seq_token_ids, list) - and isinstance(seq_logprobs, list) - and i < len(seq_token_ids) - and i < len(seq_logprobs) - and seq_token_ids[i] is not None - and seq_logprobs[i] is not None - ): - per_pos_token_ids = seq_token_ids[i] - per_pos_logprobs = seq_logprobs[i] - if ( - isinstance(per_pos_token_ids, list) - and isinstance(per_pos_logprobs, list) - and len(per_pos_token_ids) == len(per_pos_logprobs) - ): - local_k = 1 - for row_ids in per_pos_token_ids: - if isinstance(row_ids, list): - local_k = max(local_k, len(row_ids)) - max_distill_k = max(max_distill_k, local_k) - has_any_distill = True - - rows = max(0, token_setup_len - 1) - token_mat = np.full((rows, local_k), -1, dtype=np.int64) - logprob_mat = np.full((rows, local_k), -1e9, dtype=np.float32) - - # Shift by one to align with causal labels like inference_logprobs. - copy_positions = min( - len(per_pos_token_ids), - len(per_pos_logprobs), - token_setup_len, - ) - for pos in range(1, copy_positions): - src_ids = per_pos_token_ids[pos] - src_lps = per_pos_logprobs[pos] - if not isinstance(src_ids, list) or not isinstance( - src_lps, list - ): - continue - topk = min(local_k, len(src_ids), len(src_lps)) - if topk <= 0: - continue - token_mat[pos - 1, :topk] = np.array( - src_ids[:topk], dtype=np.int64 - ) - logprob_mat[pos - 1, :topk] = np.array( - src_lps[:topk], dtype=np.float32 - ) - - distill_token_ids_padded.append(token_mat) - distill_logprobs_padded.append(logprob_mat) - else: - rows = max(0, token_setup_len - 1) - distill_token_ids_padded.append( - np.full((rows, 1), -1, dtype=np.int64) - ) - distill_logprobs_padded.append( - np.full((rows, 1), -1e9, dtype=np.float32) - ) - else: - rows = max(0, token_setup_len - 1) - distill_token_ids_padded.append( - np.full((rows, 1), -1, dtype=np.int64) - ) - distill_logprobs_padded.append( - np.full((rows, 1), -1e9, dtype=np.float32) - ) - else: - rows = max(0, token_setup_len - 1) - distill_token_ids_padded.append(np.full((rows, 1), -1, dtype=np.int64)) - distill_logprobs_padded.append( - np.full((rows, 1), -1e9, dtype=np.float32) - ) - # Extract temperature (priority: override > generation_params > group_overrides > 1.0) t = 1.0 if ( @@ -264,8 +178,6 @@ def pad_data_to_good_offset( advantage_batches = [] temperature_batches = [] inference_logprob_batches = [] - distill_token_id_batches = [] - distill_logprob_batches = [] for start in range(0, len(input_ids), batch_size): end = min(start + batch_size, len(input_ids)) @@ -287,46 +199,12 @@ def pad_data_to_good_offset( torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0)) ) - if distill_token_ids_padded and distill_logprobs_padded: - seq_slice_ids = distill_token_ids_padded[start:end] - seq_slice_lps = distill_logprobs_padded[start:end] - normalized_ids = [] - normalized_lps = [] - for ids_mat, lps_mat in zip(seq_slice_ids, seq_slice_lps): - if ids_mat.shape[1] < max_distill_k: - pad_cols = max_distill_k - ids_mat.shape[1] - ids_mat = np.pad( - ids_mat, ((0, 0), (0, pad_cols)), constant_values=-1 - ) - lps_mat = np.pad( - lps_mat, ((0, 0), (0, pad_cols)), constant_values=-1e9 - ) - normalized_ids.append(ids_mat) - normalized_lps.append(lps_mat) - - distill_token_id_batches.append( - torch.tensor(np.stack(normalized_ids, axis=0), dtype=torch.long) - ) - distill_logprob_batches.append( - torch.tensor(np.stack(normalized_lps, axis=0), dtype=torch.float32) - ) - # Return inference logprob batches if we have any real logprobs final_logprob_batches = ( inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None ) - final_distill_token_id_batches = ( - distill_token_id_batches - if (has_any_distill and distill_token_id_batches) - else None - ) - final_distill_logprob_batches = ( - distill_logprob_batches - if (has_any_distill and distill_logprob_batches) - else None - ) return ( token_batches, @@ -334,8 +212,6 @@ def pad_data_to_good_offset( advantage_batches, temperature_batches, final_logprob_batches, - final_distill_token_id_batches, - final_distill_logprob_batches, ) @@ -352,8 +228,6 @@ def get_data( List[torch.Tensor], # advantage_batches List[torch.Tensor], # temperature_batches Optional[List[torch.Tensor]], # inference_logprob_batches - Optional[List[torch.Tensor]], # distill_token_id_batches - Optional[List[torch.Tensor]], # distill_logprob_batches ] ], None, # Legacy return (no longer used) @@ -377,11 +251,47 @@ def get_data( - inference_logprob_batches are aligned with labels for proper GRPO loss computation """ batches = [] + _logged_logprob_warning = False while True: data = get_batch(url=atropos_url) if data["batch"] is not None: + # DEBUG: Check if inference_logprobs exists in the data + if not _logged_logprob_warning: + has_logprobs = any( + "inference_logprobs" in item for item in data["batch"] + ) + if has_logprobs: + # Check if they're non-empty + sample_item = next( + ( + item + for item in data["batch"] + if "inference_logprobs" in item + ), + None, + ) + if sample_item and sample_item.get("inference_logprobs"): + sample_lp = ( + sample_item["inference_logprobs"][0] + if sample_item["inference_logprobs"] + else [] + ) + print( + f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})" + ) + else: + print( + " [Data] ⚠ inference_logprobs key exists but is empty!" + ) + else: + print(" [Data] ⚠ NO inference_logprobs in batch data!") + print( + f" [Data] Keys in first item: {list(data['batch'][0].keys())}" + ) + _logged_logprob_warning = True + # Process and accumulate batches (now includes batched inference logprobs) ( token_batches, @@ -389,8 +299,6 @@ def get_data( adv_batches, temp_batches, inf_logprob_batches, - distill_token_id_batches, - distill_logprob_batches, ) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) # Include inference logprob batches in the tuple @@ -401,8 +309,6 @@ def get_data( adv_batches, temp_batches, inf_logprob_batches, - distill_token_id_batches, - distill_logprob_batches, ) ) diff --git a/example_trainer/run.py b/example_trainer/run.py index d1cf37b2..b9b5f88f 100644 --- a/example_trainer/run.py +++ b/example_trainer/run.py @@ -201,9 +201,6 @@ def main(): checkpoint_interval=args.checkpoint_interval, # GRPO hyperparameters clip_eps=args.clip_eps, - distill_enabled=getattr(args, "distill_enabled", False), - distill_coef=getattr(args, "distill_coef", 0.0), - distill_temperature=getattr(args, "distill_temperature", 1.0), # vLLM settings vllm_port=args.vllm_port, vllm_gpu_memory_utilization=args.gpu_memory_utilization, diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index bff1763f..4c9e2893 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -170,8 +170,6 @@ def train_legacy(config: TrainingConfig): batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None - distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None - distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -194,8 +192,6 @@ def train_legacy(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, - distill_token_id_batches=distill_token_id_batches, - distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) @@ -317,30 +313,17 @@ def train_shared_vllm(config: TrainingConfig): # Fetch data (with inference logprobs for proper GRPO loss) data_fetch_start = time.time() if len(batches) == 0: - print(" [Trainer] requesting data from Atropos API...", flush=True) batches, _ = get_data( config.batch_size, config.seq_len, config.atropos_url, extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs ) - print( - f" [Trainer] get_data returned {len(batches)} trainer batch tuple(s)", - flush=True, - ) batch_data = batches.pop(0) token_batches, label_batches, advantage_batches, temperature_batches = ( batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None - distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None - distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None - token_shapes = [tuple(tb.shape) for tb in token_batches] - print( - " [Trainer] selected trainer batch: " - f"micro_batches={len(token_batches)} token_batch_shapes={token_shapes}", - flush=True, - ) data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -356,8 +339,6 @@ def train_shared_vllm(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation - distill_token_id_batches=distill_token_id_batches, - distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) @@ -503,8 +484,6 @@ def train_lora(config: TrainingConfig): batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None - distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None - distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -520,8 +499,6 @@ def train_lora(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, - distill_token_id_batches=distill_token_id_batches, - distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) @@ -729,8 +706,6 @@ def train_lora_restart(config: TrainingConfig): batch_data[:4] ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None - distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None - distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -746,8 +721,6 @@ def train_lora_restart(config: TrainingConfig): config, step_idx=step, inference_logprob_batches=inference_logprob_batches, - distill_token_id_batches=distill_token_id_batches, - distill_logprob_batches=distill_logprob_batches, ) step_time = time.time() - step_start benchmark_stats["step_times"].append(step_time) diff --git a/example_trainer/training.py b/example_trainer/training.py index b7cab944..035d45c7 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -70,11 +70,6 @@ def compute_grpo_loss( gradient_accumulation_steps: int, inference_logprobs: Optional[torch.Tensor] = None, clip_eps: float = 0.2, - distill_token_ids: Optional[torch.Tensor] = None, - distill_logprobs: Optional[torch.Tensor] = None, - distill_enabled: bool = False, - distill_coef: float = 0.0, - distill_temperature: float = 1.0, ) -> Tuple[torch.Tensor, dict]: """ Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch. @@ -130,9 +125,6 @@ def compute_grpo_loss( logprob_diff_abs_mean = 0.0 logprob_diff_max = 0.0 - distill_loss_value = torch.tensor(0.0, device=logp_per_token.device) - distill_token_count = 0.0 - # === GRPO/PPO Loss Computation === if inference_logprobs is not None: # Move inference logprobs to correct device/dtype @@ -195,23 +187,7 @@ def compute_grpo_loss( # Average over tokens, then over batch policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean() - if ( - distill_enabled - and distill_coef > 0 - and distill_token_ids is not None - and distill_logprobs is not None - ): - distill_loss_value, distill_token_count = compute_distillation_loss( - logits=scaled_logits, - labels=labels, - distill_token_ids=distill_token_ids.to(logits.device), - distill_logprobs=distill_logprobs.to(logits.device, logits.dtype), - temperature=max(1e-6, float(distill_temperature)), - ) - - total_loss = (policy_loss + distill_coef * distill_loss_value) / ( - gradient_accumulation_steps - ) + total_loss = policy_loss / gradient_accumulation_steps # Compute metrics for logging with torch.no_grad(): @@ -277,77 +253,11 @@ def compute_grpo_loss( "logprob_diff_mean": logprob_diff_mean, "logprob_diff_abs_mean": logprob_diff_abs_mean, "logprob_diff_max": logprob_diff_max, - "distill_loss": ( - distill_loss_value.item() - if torch.is_tensor(distill_loss_value) - else float(distill_loss_value) - ), - "distill_token_count": distill_token_count, } return total_loss, metrics -def compute_distillation_loss( - logits: torch.Tensor, - labels: torch.Tensor, - distill_token_ids: torch.Tensor, - distill_logprobs: torch.Tensor, - temperature: float = 1.0, -) -> Tuple[torch.Tensor, float]: - """ - Compute token-level distillation loss from teacher top-k prompt logprobs. - - Args: - logits: Student logits [batch, seq_len, vocab] - labels: Labels [batch, seq_len], -100 for masked positions - distill_token_ids: Teacher top-k token IDs [batch, seq_len, k], -1 padded - distill_logprobs: Teacher top-k logprobs [batch, seq_len, k], very negative padded - temperature: Distillation temperature - - Returns: - Tuple of (distillation loss scalar, valid token count) - """ - if distill_token_ids.dim() != 3 or distill_logprobs.dim() != 3: - return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 - - if ( - distill_token_ids.shape[:2] != labels.shape - or distill_logprobs.shape != distill_token_ids.shape - ): - return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 - - temp = max(1e-6, float(temperature)) - - valid_ids = distill_token_ids >= 0 - label_mask = labels != -100 - valid_pos = label_mask & valid_ids.any(dim=-1) - if not valid_pos.any(): - return torch.tensor(0.0, device=logits.device, dtype=logits.dtype), 0.0 - - gather_ids = distill_token_ids.clamp_min(0).long() - - # Avoid materializing the full [batch, seq_len, vocab] log_softmax tensor - # (e.g. [2, 20480, 151936] = ~12.5 GB) which is the main cause of OOM/hangs. - # Instead: gather raw logits at top-k positions, then subtract logsumexp. - # Output tensors are [batch, seq_len, k] (tiny) not [batch, seq_len, vocab]. - scaled_logits = logits / temp - log_normalizer = torch.logsumexp(scaled_logits, dim=-1, keepdim=True) # [b, s, 1] - student_logp_topk = ( - torch.gather(scaled_logits, dim=-1, index=gather_ids) - log_normalizer - ) - - masked_teacher_logprobs = distill_logprobs.masked_fill(~valid_ids, -1e9) - teacher_probs = F.softmax(masked_teacher_logprobs / temp, dim=-1) - - per_token_loss = -(teacher_probs * student_logp_topk).sum(dim=-1) - per_token_loss = per_token_loss * valid_pos.to(per_token_loss.dtype) - - token_count = valid_pos.sum().item() - loss = per_token_loss.sum() / valid_pos.sum().clamp_min(1).to(per_token_loss.dtype) - return loss, float(token_count) - - def run_training_step( model: torch.nn.Module, optimizer: torch.optim.Optimizer, @@ -358,8 +268,6 @@ def run_training_step( config: TrainingConfig, step_idx: int, inference_logprob_batches: Optional[List[torch.Tensor]] = None, - distill_token_id_batches: Optional[List[torch.Tensor]] = None, - distill_logprob_batches: Optional[List[torch.Tensor]] = None, ) -> dict: """ Run a single training step with gradient accumulation. @@ -394,8 +302,6 @@ def run_training_step( total_logprob_diff_mean = 0.0 total_logprob_diff_abs_mean = 0.0 total_logprob_diff_max = 0.0 - total_distill_loss = 0.0 - total_distill_tokens = 0.0 grad_norm = 0.0 all_training_logprobs: List[torch.Tensor] = [] all_inference_logprobs: List[torch.Tensor] = [] @@ -419,13 +325,6 @@ def run_training_step( for batch_idx, (tokens, labels, advantages, temperatures) in enumerate( zip(token_batches, label_batches, advantage_batches, temperature_batches) ): - print( - f" [Step] micro-batch {batch_idx+1}/{num_batches} " - f"tokens={tokens.shape} " - f"gpu_mem={torch.cuda.memory_allocated()/1e9:.1f}GB " - f"gpu_reserved={torch.cuda.memory_reserved()/1e9:.1f}GB", - flush=True, - ) tokens = tokens.to(config.device) labels = labels.to(config.device) advantages = advantages.to(config.device) @@ -436,18 +335,7 @@ def run_training_step( inference_logprob_batches ): inf_logprobs = inference_logprob_batches[batch_idx] - distill_ids = None - if distill_token_id_batches is not None and batch_idx < len( - distill_token_id_batches - ): - distill_ids = distill_token_id_batches[batch_idx] - distill_lps = None - if distill_logprob_batches is not None and batch_idx < len( - distill_logprob_batches - ): - distill_lps = distill_logprob_batches[batch_idx] - print(f" [Step] micro-batch {batch_idx+1} forward pass...", flush=True) loss, metrics = compute_grpo_loss( model, tokens, @@ -457,20 +345,9 @@ def run_training_step( config.gradient_accumulation_steps, inference_logprobs=inf_logprobs, clip_eps=clip_eps, - distill_token_ids=distill_ids, - distill_logprobs=distill_lps, - distill_enabled=bool(getattr(config, "distill_enabled", False)), - distill_coef=float(getattr(config, "distill_coef", 0.0)), - distill_temperature=float(getattr(config, "distill_temperature", 1.0)), ) - print( - f" [Step] micro-batch {batch_idx+1} loss={loss.item():.4f} " - f"backward...", - flush=True, - ) loss.backward() - print(f" [Step] micro-batch {batch_idx+1} backward done", flush=True) total_loss += loss.item() total_pos_logp += metrics["pos_logp"] total_neg_logp += metrics["neg_logp"] @@ -487,8 +364,6 @@ def run_training_step( total_logprob_diff_max = max( total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0) ) - total_distill_loss += metrics.get("distill_loss", 0.0) - total_distill_tokens += metrics.get("distill_token_count", 0.0) # Collect logprobs for alignment monitoring if "training_logprobs" in metrics and metrics["training_logprobs"] is not None: @@ -524,8 +399,6 @@ def run_training_step( # GRPO-specific metrics (averaged over batches) "mean_ratio": total_mean_ratio / num_batches, "clipped_fraction": total_clipped_fraction / num_batches, - "distill_loss": total_distill_loss / num_batches, - "distill_token_count": total_distill_tokens, } # Compute logprob alignment stats for monitoring @@ -599,12 +472,6 @@ def log_metrics( clipped_frac = metrics.get("clipped_fraction", 0) print(f" GRPO: ratio={mean_ratio:.3f}, clipped={clipped_frac*100:.1f}%") - if metrics.get("distill_token_count", 0) > 0: - print( - " Distill: " - f"loss={metrics.get('distill_loss', 0.0):.4f}, " - f"tokens={int(metrics.get('distill_token_count', 0))}" - ) # Advantage distribution if "pos_count" in metrics or "neg_count" in metrics: @@ -627,8 +494,6 @@ def log_metrics( # GRPO-specific metrics "grpo/mean_ratio": mean_ratio, "grpo/clipped_fraction": clipped_frac, - "distill/loss": metrics.get("distill_loss", 0.0), - "distill/token_count": metrics.get("distill_token_count", 0.0), } # Add timing metrics if present for key in [ From a1b545c7344cb4953fad81257bb7c6aa8ebac712 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 13:19:28 -0400 Subject: [PATCH 41/64] remove cross tokenization and fix location of configs --- README.md | 41 +++ atroposlib/envs/teacher_distillation_env.py | 343 +++--------------- .../tests/test_teacher_distillation_env.py | 26 ++ environments/gsm8k_server_teacher_distill.py | 11 +- example_trainer/README.md | 23 ++ ...n_gsm8k_teacher_distill_single_terminal.sh | 5 +- 6 files changed, 147 insertions(+), 302 deletions(-) diff --git a/README.md b/README.md index 3b533a9b..6def2fdf 100644 --- a/README.md +++ b/README.md @@ -298,6 +298,47 @@ curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids! - Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.). - Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR. +### TeacherDistillationEnv follow-up + +The follow-up teacher environment uses a dedicated teacher server config and +attaches teacher prompt logprobs before the group is sent to the API. + +Teacher config shape: + +```python +TeacherDistillationConfig( + teacher_enabled=True, + teacher_server=APIServerConfig( + base_url="http://localhost:9003/v1", + model_name="Qwen/Qwen3-30B-A3B-Instruct-2507", + api_key="", + server_type="vllm", + ), + teacher_top_k=8, +) +``` + +CLI shape: + +```bash +--env.teacher_enabled true \ +--env.teacher_server.base_url "http://localhost:9003/v1" \ +--env.teacher_server.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \ +--env.teacher_server.server_type vllm \ +--env.teacher_top_k 8 +``` + +Tokenizer requirement: + +- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary. +- If the tokenizers do not match, `TeacherDistillationEnv` raises an error instead of attempting token conversion. + +Why same-tokenizer is required: + +- `distill_token_ids` are consumed as student-vocabulary IDs by the trainer. +- If the teacher uses a different vocabulary, the same integer token ID refers to different text on the teacher and student sides. +- A decode/re-tokenize/remap pipeline is not a safe drop-in fix because it changes both token positions and token identities, which breaks the exact per-position token supervision that the current distillation loss assumes. + --- ## Testing and Debugging Tools diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 1c88ab62..d8284335 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -4,42 +4,10 @@ Teacher distillation environment layer. This module adds teacher prompt-logprob fetching on top of BaseEnv without modifying BaseEnv transport behavior. -Cross-tokenizer distillation ----------------------------- -When student and teacher use the same tokenizer family (e.g. both Qwen3) the -student's raw token IDs can be forwarded directly to the teacher vLLM and the -returned top-k token IDs can be used as-is in the student logit lookup. - -When tokenizers differ (e.g. Llama student, Qwen teacher) two problems arise: - - 1. Token-ID aliasing: student token 42 = " the" in Llama, but 42 = "ท" in - Qwen. Sending student IDs to the teacher causes it to score garbage. - - 2. Vocab-space mismatch: the teacher's top-k IDs live in the teacher's - vocabulary. The student logit lookup at those IDs would access random - tokens in the student vocab. - -This module fixes both problems automatically: - - • Re-tokenization – student tokens are decoded to plain text and - re-tokenized with the teacher tokenizer before being sent to the teacher - server. The teacher therefore always scores the correct text. - - • Character-level position alignment – after re-tokenisation the teacher - has a different number of tokens than the student. A character-offset - map is built (requires a fast HuggingFace tokenizer) to project each - teacher logprob position back onto the student token it overlaps with. - - • Vocabulary remapping – teacher top-k token IDs (teacher vocab) are - decoded to text fragments and re-encoded with the student tokenizer so - that the final distill_token_ids live in the student vocabulary and can - be looked up directly in the student logit tensor. - -Same-tokenizer fast path ------------------------- -When teacher_tokenizer_name resolves to the same underlying vocabulary as the -student tokenizer the original fast path (no decode / re-tokenize / remap) is -taken automatically. +This implementation supports same-tokenizer distillation only. The teacher and +student must share the same tokenizer vocabulary so the student's token IDs can +be forwarded directly to the teacher and the returned teacher top-k token IDs +can be looked up directly in the student's logits. """ from __future__ import annotations @@ -47,7 +15,7 @@ from __future__ import annotations import asyncio import logging from abc import ABC -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from pydantic import Field @@ -63,29 +31,9 @@ class TeacherDistillationConfig(BaseEnvConfig): default=False, description="Whether to fetch teacher prompt logprobs for distillation.", ) - teacher_base_url: Optional[str] = Field( + teacher_server: Optional[APIServerConfig] = Field( default=None, - description="Teacher server base URL (OpenAI-compatible).", - ) - teacher_model_name: Optional[str] = Field( - default=None, - description="Teacher model name used in teacher server requests.", - ) - teacher_api_key: str = Field( - default="", - description="Teacher API key, if required by the teacher endpoint.", - ) - teacher_server_type: str = Field( - default="vllm", - description="Teacher server type (e.g. vllm, sglang, trl, openai).", - ) - teacher_tokenizer_name: str = Field( - default="none", - description=( - "Tokenizer name for teacher server. If 'none', teacher_model_name is used. " - "When this resolves to a different vocabulary than the student tokenizer, " - "cross-tokenizer alignment is applied automatically." - ), + description="Teacher inference server configuration.", ) teacher_top_k: int = Field( default=1, @@ -114,266 +62,71 @@ class TeacherDistillationEnv(BaseEnv, ABC): ): super().__init__(config, server_configs, slurm=slurm, testing=testing) self.teacher_server: Optional[ServerManager] = None - # Teacher tokenizer (only loaded when tokenizers may differ). - self._teacher_tokenizer = None - # True when student and teacher share the same vocabulary. - self._same_tokenizer: bool = True - # LRU-style cache: teacher_token_id -> student_token_id - self._vocab_remap_cache: Dict[int, int] = {} if config.teacher_enabled: - if not config.teacher_base_url or not config.teacher_model_name: + if config.teacher_server is None: raise ValueError( - "teacher_enabled=True requires teacher_base_url and teacher_model_name." + "teacher_enabled=True requires a teacher_server configuration." ) - teacher_tok_name = ( - config.teacher_model_name - if config.teacher_tokenizer_name in ("none", "") - else config.teacher_tokenizer_name - ) - teacher_cfg = APIServerConfig( - server_type=config.teacher_server_type, # type: ignore[arg-type] - base_url=config.teacher_base_url, - api_key=config.teacher_api_key, - model_name=config.teacher_model_name, - tokenizer_name=teacher_tok_name, - timeout=1200, + teacher_cfg = config.teacher_server.model_copy( + update={ + "tokenizer_name": ( + config.teacher_server.model_name + if config.teacher_server.tokenizer_name in ("", "none") + else config.teacher_server.tokenizer_name + ), + "timeout": 1200, + } ) self.teacher_server = ServerManager( [teacher_cfg], slurm=False, testing=False, ) - - # Detect vocabulary mismatch. - # Compare by name first; if names differ, load the teacher tokenizer - # and do a vocab-size sanity check. Same-family models (e.g. Qwen3-4B - # and Qwen3-30B) share the same vocabulary, so even though the - # name_or_path strings differ they should use the fast path. - student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" - if ( - student_tok_name - and teacher_tok_name - and student_tok_name != teacher_tok_name - ): - try: - from transformers import AutoTokenizer - - loaded = AutoTokenizer.from_pretrained( - teacher_tok_name, use_fast=True - ) - student_vocab_size = getattr(self.tokenizer, "vocab_size", None) - teacher_vocab_size = getattr(loaded, "vocab_size", None) - if ( - student_vocab_size is not None - and teacher_vocab_size is not None - and student_vocab_size == teacher_vocab_size - ): - # Same vocab size — treat as same tokenizer (fast path). - # This covers same-family models (e.g. all Qwen3 variants). - self._same_tokenizer = True - logger.warning( - "TeacherDistillationEnv: names differ but vocab sizes match " - "(%d tokens). Using fast (same-tokenizer) path. " - "student=%s teacher=%s", - student_vocab_size, - student_tok_name, - teacher_tok_name, - ) - else: - self._teacher_tokenizer = loaded - self._same_tokenizer = False - logger.warning( - "TeacherDistillationEnv: cross-tokenizer mode active. " - "student=%s (%s tokens) teacher=%s (%s tokens). " - "Token IDs will be decoded → re-tokenized → vocab-remapped.", - student_tok_name, - student_vocab_size, - teacher_tok_name, - teacher_vocab_size, - ) - except Exception as exc: - logger.warning( - "TeacherDistillationEnv: could not load teacher tokenizer '%s' " - "(%s). Falling back to same-tokenizer (fast) path — only safe if " - "student and teacher share the same vocabulary.", - teacher_tok_name, - exc, - ) - self._same_tokenizer = True - else: - self._same_tokenizer = True - - logger.warning( - "TeacherDistillationEnv: teacher server configured at %s " - "(model=%s, top_k=%s, same_tokenizer=%s)", - config.teacher_base_url, - config.teacher_model_name, - config.teacher_top_k, - self._same_tokenizer, - ) - - # ------------------------------------------------------------------ - # Cross-tokenizer helpers - # ------------------------------------------------------------------ - - def _build_student_teacher_alignment( - self, - text: str, - student_ids: List[int], - teacher_ids: List[int], - ) -> List[List[int]]: - """ - For each student token position return the list of teacher token positions - whose character spans overlap with the student token's character span. - - Requires fast (Rust-backed) HuggingFace tokenizers that support - return_offsets_mapping. Falls back to a proportional approximation - if offset mapping is unavailable. - """ - student_len = len(student_ids) - teacher_len = len(teacher_ids) - - try: - s_enc = self.tokenizer( - text, return_offsets_mapping=True, add_special_tokens=False - ) - t_enc = self._teacher_tokenizer( - text, return_offsets_mapping=True, add_special_tokens=False - ) - s_offsets: List[Tuple[int, int]] = s_enc["offset_mapping"][:student_len] - t_offsets: List[Tuple[int, int]] = t_enc["offset_mapping"][:teacher_len] - - alignment: List[List[int]] = [] - for s_start, s_end in s_offsets: - overlapping = [ - t_idx - for t_idx, (t_start, t_end) in enumerate(t_offsets) - if t_start < s_end and t_end > s_start and s_end > s_start - ] - alignment.append(overlapping) - return alignment - - except Exception as exc: - logger.warning( - "TeacherDistillationEnv: offset-mapping alignment failed (%s). " - "Using proportional fallback.", - exc, - ) - ratio = teacher_len / max(student_len, 1) - return [[int(i * ratio)] for i in range(student_len)] - - def _remap_teacher_token_to_student(self, teacher_token_id: int) -> int: - """ - Convert a teacher vocabulary token ID to the best-matching student - vocabulary token ID by decoding the teacher token to text then - re-encoding with the student tokenizer. - - Results are cached to avoid repeated tokenizer calls. - """ - if teacher_token_id in self._vocab_remap_cache: - return self._vocab_remap_cache[teacher_token_id] - - try: - text = self._teacher_tokenizer.decode( - [teacher_token_id], clean_up_tokenization_spaces=False - ) - student_ids = self.tokenizer.encode(text, add_special_tokens=False) - # Use the first student token as the representative. - sid = int(student_ids[0]) if student_ids else teacher_token_id - except Exception: - sid = teacher_token_id - - self._vocab_remap_cache[teacher_token_id] = sid - return sid - - def _align_and_remap( - self, - student_ids: List[int], - teacher_topk_ids: List[List[int]], - teacher_topk_lps: List[List[float]], - alignment: List[List[int]], - ) -> Tuple[List[List[int]], List[List[float]]]: - """ - Project teacher logprobs (teacher positions, teacher vocab) onto - student positions in student vocab. - - For each student token position: - 1. Collect all teacher top-k entries from overlapping teacher positions. - 2. Remap each teacher token ID to the student vocab. - 3. Merge duplicates by keeping the maximum logprob. - 4. Return the top-k entries sorted by descending logprob. - """ - k = max(1, len(teacher_topk_ids[0]) if teacher_topk_ids else 1) - result_ids: List[List[int]] = [] - result_lps: List[List[float]] = [] - - for s_idx in range(len(student_ids)): - t_positions = alignment[s_idx] if s_idx < len(alignment) else [] - if not t_positions: - result_ids.append([]) - result_lps.append([]) - continue - - # Merge all overlapping teacher positions, remap vocab. - merged: Dict[int, float] = {} - for t_idx in t_positions: - if t_idx >= len(teacher_topk_ids): - continue - for tid, tlp in zip(teacher_topk_ids[t_idx], teacher_topk_lps[t_idx]): - sid = self._remap_teacher_token_to_student(tid) - merged[sid] = max(merged.get(sid, -1e9), tlp) - - sorted_items = sorted(merged.items(), key=lambda x: -x[1]) - top = sorted_items[:k] - result_ids.append([int(sid) for sid, _ in top]) - result_lps.append([float(lp) for _, lp in top]) - - return result_ids, result_lps + self._validate_teacher_tokenizer_compatibility(teacher_cfg.tokenizer_name) # ------------------------------------------------------------------ # Core fetch # ------------------------------------------------------------------ + def _validate_teacher_tokenizer_compatibility(self, teacher_tokenizer_name: str) -> None: + student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" + if student_tok_name == teacher_tokenizer_name: + return + + try: + from transformers import AutoTokenizer + + teacher_tokenizer = AutoTokenizer.from_pretrained( + teacher_tokenizer_name, use_fast=True + ) + except Exception as exc: + raise ValueError( + "Cross-tokenizer distillation is not supported in this PR, and the " + f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to " + f"verify compatibility: {exc}" + ) from exc + + student_vocab = self.tokenizer.get_vocab() + teacher_vocab = teacher_tokenizer.get_vocab() + if student_vocab != teacher_vocab: + raise ValueError( + "Cross-tokenizer distillation is not supported in this PR. " + f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' " + f"and teacher tokenizer '{teacher_tokenizer_name}' do not match." + ) + async def _fetch_teacher_for_sequence( self, token_ids: List[int], top_k: int ) -> Tuple[List[List[int]], List[List[float]]]: assert self.teacher_server is not None - - if self._same_tokenizer or self._teacher_tokenizer is None: - # Fast path: same vocabulary — send student IDs directly. - payload = await self.teacher_server.get_logprobs( - input_ids=token_ids, - top_k=top_k, - max_tokens=1, - split="train", - ) - return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"] - - # Cross-tokenizer path: - # 1. Decode student tokens → plain text. - # 2. Re-tokenize with teacher tokenizer → teacher IDs. - # 3. Send teacher IDs to teacher vLLM. - # 4. Align teacher positions → student positions. - # 5. Remap teacher vocab IDs → student vocab IDs. - text = self.tokenizer.decode(token_ids, clean_up_tokenization_spaces=False) - teacher_ids: List[int] = self._teacher_tokenizer.encode( - text, add_special_tokens=False - ) - payload = await self.teacher_server.get_logprobs( - input_ids=teacher_ids, + input_ids=token_ids, top_k=top_k, max_tokens=1, split="train", ) - teacher_topk_ids = payload["prompt_topk_token_ids"] - teacher_topk_lps = payload["prompt_topk_logprobs"] - - alignment = self._build_student_teacher_alignment(text, token_ids, teacher_ids) - return self._align_and_remap( - token_ids, teacher_topk_ids, teacher_topk_lps, alignment - ) + return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"] # ------------------------------------------------------------------ # Group enrichment diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index 199f1453..7f5262e7 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -67,3 +67,29 @@ async def test_attach_teacher_distillation_failure_drops_payload(): out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) assert out["distill_token_ids"] is None assert out["distill_logprobs"] is None + + +def test_teacher_tokenizer_mismatch_raises(monkeypatch): + env = object.__new__(_ConcreteTeacherEnv) + + class _StudentTokenizer: + name_or_path = "student-model" + + def get_vocab(self): + return {"a": 1} + + class _TeacherTokenizer: + def get_vocab(self): + return {"b": 1} + + env.tokenizer = _StudentTokenizer() + monkeypatch.setattr( + "transformers.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: _TeacherTokenizer(), + ) + + with pytest.raises(ValueError, match="Cross-tokenizer distillation is not supported"): + TeacherDistillationEnv._validate_teacher_tokenizer_compatibility( + env, + teacher_tokenizer_name="teacher-model", + ) diff --git a/environments/gsm8k_server_teacher_distill.py b/environments/gsm8k_server_teacher_distill.py index 8276436b..49caabec 100644 --- a/environments/gsm8k_server_teacher_distill.py +++ b/environments/gsm8k_server_teacher_distill.py @@ -32,11 +32,12 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): max_token_length=2048, wandb_name="gsm8k_teacher_distill", teacher_enabled=True, - teacher_base_url="http://localhost:8003/v1", - teacher_model_name="mock-teacher", - teacher_api_key="", - teacher_server_type="vllm", - teacher_tokenizer_name="none", + teacher_server=APIServerConfig( + base_url="http://localhost:8003/v1", + model_name="mock-teacher", + api_key="", + server_type="vllm", + ), teacher_top_k=4, ) server_config = APIServerConfig( diff --git a/example_trainer/README.md b/example_trainer/README.md index ddb96b8a..b889f440 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -304,6 +304,29 @@ environment uses the `/generate` path and includes token-level 4. Trainer extracts and aligns logprobs with training labels 5. GRPO loss uses these rollout logprobs in importance-ratio terms +### 1b. Teacher distillation requires the same tokenizer + +When distillation data is attached to Atropos batches, the trainer treats +`distill_token_ids` as indices into the student's logit tensor. That only works +if the teacher and student share the same tokenizer vocabulary. + +What to configure on the environment side: + +```bash +--env.teacher_enabled true \ +--env.teacher_server.base_url "http://localhost:9003/v1" \ +--env.teacher_server.model_name "$TEACHER_MODEL" \ +--env.teacher_server.server_type vllm \ +--env.teacher_top_k 8 +``` + +Why cross-tokenizer conversion is not acceptable here: + +- Teacher token ID `1234` and student token ID `1234` can correspond to different text. +- Re-tokenizing teacher text changes token boundaries, so teacher position `i` may no longer correspond to student position `i`. +- Remapping teacher top-k tokens back into student vocab can collapse multiple teacher candidates into one student token or expand one teacher token into multiple student tokens. +- The current distillation loss expects exact per-position supervision in student token space, so an approximate remapping would silently produce misleading targets. + ### 2. Clipping ```bash diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 418a87ea..91cecf8a 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -234,8 +234,9 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \ --env.use_wandb true \ --env.wandb_name "gsm8k-teacher-distill" \ --env.teacher_enabled true \ - --env.teacher_base_url "http://localhost:${TEACHER_PORT}/v1" \ - --env.teacher_model_name "$TEACHER_MODEL" \ + --env.teacher_server.base_url "http://localhost:${TEACHER_PORT}/v1" \ + --env.teacher_server.model_name "$TEACHER_MODEL" \ + --env.teacher_server.server_type vllm \ --env.teacher_top_k "$TEACHER_TOP_K" \ --env.ensure_scores_are_not_same false \ --openai.api_key "dummy" \ From 994e9c287dcf3a249900ef3e9574ecc21a7ef154 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 17:20:56 +0000 Subject: [PATCH 42/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .secrets.baseline | 4 ++-- atroposlib/envs/teacher_distillation_env.py | 4 +++- atroposlib/tests/test_teacher_distillation_env.py | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index 7059d28b..e4785ede 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -133,7 +133,7 @@ "filename": "README.md", "hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5", "is_verified": false, - "line_number": 454 + "line_number": 495 } ], "SLURM.md": [ @@ -561,5 +561,5 @@ } ] }, - "generated_at": "2026-03-02T22:46:56Z" + "generated_at": "2026-03-13T17:20:46Z" } diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index d8284335..85e040c7 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -89,7 +89,9 @@ class TeacherDistillationEnv(BaseEnv, ABC): # Core fetch # ------------------------------------------------------------------ - def _validate_teacher_tokenizer_compatibility(self, teacher_tokenizer_name: str) -> None: + def _validate_teacher_tokenizer_compatibility( + self, teacher_tokenizer_name: str + ) -> None: student_tok_name = getattr(self.tokenizer, "name_or_path", None) or "" if student_tok_name == teacher_tokenizer_name: return diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index 7f5262e7..65262984 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -88,7 +88,9 @@ def test_teacher_tokenizer_mismatch_raises(monkeypatch): lambda *args, **kwargs: _TeacherTokenizer(), ) - with pytest.raises(ValueError, match="Cross-tokenizer distillation is not supported"): + with pytest.raises( + ValueError, match="Cross-tokenizer distillation is not supported" + ): TeacherDistillationEnv._validate_teacher_tokenizer_compatibility( env, teacher_tokenizer_name="teacher-model", From 322e7e66237a54185dc09011460f6065b2647ddc Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 13:29:47 -0400 Subject: [PATCH 43/64] remove comments --- .../envs/server_handling/openai_server.py | 22 +------------------ environments/gsm8k_server.py | 15 ------------- example_trainer/README.md | 2 +- example_trainer/api.py | 5 ----- example_trainer/vllm_api_server.py | 11 ---------- 5 files changed, 2 insertions(+), 53 deletions(-) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index 40a993fe..f99c14e2 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -230,28 +230,8 @@ def resolve_openai_configs( f"Merged Dict: {openai_config_dict}" ) from e - if isinstance(default_server_configs, APIServerConfig): - server_configs = [final_openai_config] - elif isinstance(default_server_configs, list): + if isinstance(default_server_configs, list): server_configs = [final_openai_config] else: - logger.warning( - f"Unexpected type for default_server_configs: {type(default_server_configs)}. " - f"Proceeding with single OpenAI server configuration based on merged settings." - ) server_configs = [final_openai_config] - - if isinstance(server_configs, list): - logger.warning( - "resolve_openai_configs: returning list of %s config(s), URLs: %s", - len(server_configs), - [c.base_url for c in server_configs], - ) - else: - logger.warning( - "resolve_openai_configs: returning single %s (base_url=%s) — " - "ServerManager will use template mode!", - type(server_configs).__name__, - getattr(server_configs, "base_url", "N/A"), - ) return server_configs diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 2697ef30..f8437f7b 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -235,30 +235,15 @@ class GSM8kEnv(BaseEnv): ) async with self.server.managed_server(tokenizer=self.tokenizer) as managed: - logger.warning( - "gsm8k collect_trajectories start group_size=%s max_tokens=%s question_chars=%s", - self.config.group_size, - self.config.max_token_length, - len(item["question"]), - ) - chat_completions = await managed.chat_completion( messages=[{"role": "system", "content": system_prompt}, user_message], n=self.config.group_size, max_tokens=self.config.max_token_length, temperature=1.0, ) - logger.warning( - "gsm8k collect_trajectories completion_received choices=%s", - len(chat_completions.choices), - ) state = managed.get_state() nodes = state["nodes"] - logger.warning( - "gsm8k collect_trajectories managed_state_nodes=%s", - len(nodes), - ) to_score = list() to_backlog = list() diff --git a/example_trainer/README.md b/example_trainer/README.md index b889f440..8023ab73 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -573,7 +573,7 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands ### WandB Logging - ```bash +```bash --use-wandb \ --wandb-project "my-grpo-training" \ --wandb-group "hermes-8b-gsm8k" diff --git a/example_trainer/api.py b/example_trainer/api.py index dc51af4f..1bc8a1bd 100644 --- a/example_trainer/api.py +++ b/example_trainer/api.py @@ -7,7 +7,6 @@ Handles communication with the Atropos API server for: - Batch retrieval """ -import os import time as _time import requests @@ -102,10 +101,6 @@ def get_batch(url: str = "http://localhost:8000"): """ response = requests.get( f"{url}/batch", - headers={ - "X-Atropos-Client": "trainer", - "X-Atropos-Pid": str(os.getpid()), - }, timeout=10, ) data = response.json() diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index b20704e2..24d40326 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -296,17 +296,6 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: if engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") - request_preview = { - "has_prompt": "prompt" in request_dict, - "n": request_dict.get("n"), - "max_tokens": request_dict.get("max_tokens"), - "temperature": request_dict.get("temperature"), - "top_p": request_dict.get("top_p"), - "logprobs": request_dict.get("logprobs"), - "prompt_logprobs": request_dict.get("prompt_logprobs"), - } - logger.info("POST /generate received %s", request_preview) - prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY From a8cdb53a4d90c1b0b2a5ee37c4c2dc1522b14893 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 16:12:05 -0400 Subject: [PATCH 44/64] address problems --- README.md | 4 + .../envs/server_handling/vllm_server.py | 4 +- atroposlib/envs/teacher_distillation_env.py | 74 +++++++++++++------ .../tests/test_teacher_distillation_env.py | 36 +++++++++ environments/gsm8k_server_teacher_distill.py | 1 + example_trainer/README.md | 4 + 6 files changed, 99 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 6def2fdf..076f886a 100644 --- a/README.md +++ b/README.md @@ -318,6 +318,10 @@ TeacherDistillationConfig( ) ``` +If `teacher_server.model_name` is a deployment alias rather than a tokenizer +identifier, set `teacher_server.tokenizer_name` explicitly so the env can +validate tokenizer compatibility. + CLI shape: ```bash diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index aaee28d7..cc5bf9a5 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -281,7 +281,7 @@ class VLLMServer(APIServer): ), "Prompt or input_ids is required for get_logprobs!" top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1))) - top_k = max(1, top_k) + top_k = max(0, top_k) # Use input_ids if provided (from ManagedServer), otherwise tokenize prompt from_prompt_text = False @@ -420,7 +420,7 @@ def resolve_openai_configs( ) from e if isinstance(default_server_configs, APIServerConfig): - server_configs = final_openai_config + server_configs = [final_openai_config] elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 85e040c7..4c8bfa75 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -33,12 +33,16 @@ class TeacherDistillationConfig(BaseEnvConfig): ) teacher_server: Optional[APIServerConfig] = Field( default=None, - description="Teacher inference server configuration.", + description="Fallback teacher server configuration when not provided at init.", ) teacher_top_k: int = Field( - default=1, - ge=1, - description="Top-k prompt logprobs to fetch per token position.", + default=0, + ge=-1, + description=( + "Number of extra prompt logprobs to fetch beyond the selected token. " + "Use 0 for selected-token-only prompt logprobs and <= -1 to disable " + "teacher fetching." + ), ) @@ -57,6 +61,9 @@ class TeacherDistillationEnv(BaseEnv, ABC): self, config: TeacherDistillationConfig, server_configs: Union[ServerBaseline, List[APIServerConfig]], + teacher_server_configs: Optional[ + Union[ServerBaseline, List[APIServerConfig]] + ] = None, slurm: bool = False, testing: bool = False, ): @@ -64,26 +71,42 @@ class TeacherDistillationEnv(BaseEnv, ABC): self.teacher_server: Optional[ServerManager] = None if config.teacher_enabled: - if config.teacher_server is None: + teacher_config_source = teacher_server_configs + if teacher_config_source is None and config.teacher_server is not None: + teacher_config_source = [ + config.teacher_server.model_copy( + update={ + "tokenizer_name": ( + config.teacher_server.model_name + if config.teacher_server.tokenizer_name in ("", "none") + else config.teacher_server.tokenizer_name + ), + "timeout": 1200, + } + ) + ] + + if teacher_config_source is None: raise ValueError( - "teacher_enabled=True requires a teacher_server configuration." + "teacher_enabled=True requires teacher_server_configs at init " + "or a fallback teacher_server config." ) - teacher_cfg = config.teacher_server.model_copy( - update={ - "tokenizer_name": ( - config.teacher_server.model_name - if config.teacher_server.tokenizer_name in ("", "none") - else config.teacher_server.tokenizer_name - ), - "timeout": 1200, - } - ) self.teacher_server = ServerManager( - [teacher_cfg], + teacher_config_source, slurm=False, testing=False, ) - self._validate_teacher_tokenizer_compatibility(teacher_cfg.tokenizer_name) + if isinstance(teacher_config_source, list): + teacher_cfg = teacher_config_source[0] + else: + teacher_cfg = teacher_config_source + + teacher_tokenizer_name = ( + teacher_cfg.model_name + if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none") + else teacher_cfg.tokenizer_name + ) + self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name) # ------------------------------------------------------------------ # Core fetch @@ -146,12 +169,19 @@ class TeacherDistillationEnv(BaseEnv, ABC): group["distill_logprobs"] = None return group + group_overrides = group.get("group_overrides") or {} + if group_overrides.get("skip_teacher_top_k", False): + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + top_k = int( - (group.get("group_overrides") or {}).get( - "teacher_top_k", self.config.teacher_top_k - ) + group_overrides.get("teacher_top_k", self.config.teacher_top_k) ) - top_k = max(1, top_k) + if top_k <= -1: + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs] results = await asyncio.gather(*tasks, return_exceptions=True) diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index 65262984..2c0ddf17 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -69,6 +69,42 @@ async def test_attach_teacher_distillation_failure_drops_payload(): assert out["distill_logprobs"] is None +@pytest.mark.asyncio +async def test_attach_teacher_distillation_negative_topk_skips_fetch(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=-1) + env.teacher_server = _FakeTeacherServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": None, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert env.teacher_server.calls == 0 + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None + + +@pytest.mark.asyncio +async def test_attach_teacher_distillation_group_override_can_skip_fetch(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2) + env.teacher_server = _FakeTeacherServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": {"skip_teacher_top_k": True}, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert env.teacher_server.calls == 0 + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None + + def test_teacher_tokenizer_mismatch_raises(monkeypatch): env = object.__new__(_ConcreteTeacherEnv) diff --git a/environments/gsm8k_server_teacher_distill.py b/environments/gsm8k_server_teacher_distill.py index 49caabec..5aa33a01 100644 --- a/environments/gsm8k_server_teacher_distill.py +++ b/environments/gsm8k_server_teacher_distill.py @@ -37,6 +37,7 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): model_name="mock-teacher", api_key="", server_type="vllm", + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", ), teacher_top_k=4, ) diff --git a/example_trainer/README.md b/example_trainer/README.md index 8023ab73..8596a849 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -320,6 +320,10 @@ What to configure on the environment side: --env.teacher_top_k 8 ``` +If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier, +also set `--env.teacher_server.tokenizer_name ...` so the env can validate +tokenizer compatibility. + Why cross-tokenizer conversion is not acceptable here: - Teacher token ID `1234` and student token ID `1234` can correspond to different text. From 82964b6e48c33eac30f0e6ea0d9971c9276e9351 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 20:13:30 +0000 Subject: [PATCH 45/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .secrets.baseline | 4 ++-- atroposlib/envs/teacher_distillation_env.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index e4785ede..2651783a 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -133,7 +133,7 @@ "filename": "README.md", "hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5", "is_verified": false, - "line_number": 495 + "line_number": 499 } ], "SLURM.md": [ @@ -561,5 +561,5 @@ } ] }, - "generated_at": "2026-03-13T17:20:46Z" + "generated_at": "2026-03-13T20:13:21Z" } diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 4c8bfa75..fdac54c5 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -175,9 +175,7 @@ class TeacherDistillationEnv(BaseEnv, ABC): group["distill_logprobs"] = None return group - top_k = int( - group_overrides.get("teacher_top_k", self.config.teacher_top_k) - ) + top_k = int(group_overrides.get("teacher_top_k", self.config.teacher_top_k)) if top_k <= -1: group["distill_token_ids"] = None group["distill_logprobs"] = None From 697c594c721158a5dfae5aac95bae37de50d8d18 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 16:57:46 -0400 Subject: [PATCH 46/64] changes --- atroposlib/envs/server_handling/openai_server.py | 12 ++++++++++++ atroposlib/envs/server_handling/vllm_server.py | 4 ++++ environments/gsm8k_server.py | 9 ++------- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index f99c14e2..54f03fb4 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -234,4 +234,16 @@ def resolve_openai_configs( server_configs = [final_openai_config] else: server_configs = [final_openai_config] + + if isinstance(server_configs, list): + logger.info( + "resolve_openai_configs returning %s config(s) with URLs: %s", + len(server_configs), + [getattr(c, "base_url", None) for c in server_configs], + ) + else: + logger.info( + "resolve_openai_configs returning %s", + type(server_configs).__name__, + ) return server_configs diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index cc5bf9a5..1c9cb24d 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -424,6 +424,10 @@ def resolve_openai_configs( elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: + logger.warning( + f"Unexpected type for default_server_configs: {type(default_server_configs)}. " + "Proceeding with single OpenAI server configuration based on merged settings." + ) server_configs = [final_openai_config] return server_configs diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index f8437f7b..2cb74795 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -1,4 +1,3 @@ -import logging import random import time from typing import Dict, List, Optional, Tuple, TypedDict, Union @@ -32,9 +31,6 @@ It is important that you provide your answer in the correct format. If you do not, you will not receive credit for your answer. So please end your answer with \\boxed{your answer here}""" -logger = logging.getLogger(__name__) - - class GSM8kRow(TypedDict): question: str answer: str @@ -353,9 +349,8 @@ class GSM8kEnv(BaseEnv): percentage_of_range = min(percentage_of_range, 1.0) # Apply linear penalty scaling from 1.0 down to 0.0 scores["scores"].append(1.0 - percentage_of_range) - # NOTE: identical-score filter disabled for testing. - # if all([scores["scores"][0] == score for score in scores["scores"]]): - # return None + if all([scores["scores"][0] == score for score in scores["scores"]]): + return None return scores else: # If the gold solution is not parseable, we return None From 6c564799bc4779829ba9d72518747c40524be589 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 21:02:04 +0000 Subject: [PATCH 47/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- environments/gsm8k_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 2cb74795..de13f8c9 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -31,6 +31,7 @@ It is important that you provide your answer in the correct format. If you do not, you will not receive credit for your answer. So please end your answer with \\boxed{your answer here}""" + class GSM8kRow(TypedDict): question: str answer: str From 1b8ff075c42600efa6e43eb844b42fafeabb6636 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 17:23:40 -0400 Subject: [PATCH 48/64] adding tests --- .../envs/server_handling/vllm_server.py | 17 ++++-- atroposlib/tests/test_server_logprobs.py | 52 +++++++++++++++++++ .../tests/test_teacher_distillation_env.py | 18 +++++++ 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 1c9cb24d..72e7140e 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -401,6 +401,19 @@ def resolve_openai_configs( raise FailedExecutionException( f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" ) from e + elif isinstance(default_server_configs, APIServerConfig): + # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) + try: + final_openai_config = APIServerConfig(**openai_config_dict) + except Exception as e: + raise FailedExecutionException( + f"Error creating final OpenAI configuration from merged settings: {e}\n" + f"Merged Dict: {openai_config_dict}" + ) from e + server_configs = [final_openai_config] elif isinstance(default_server_configs, ServerBaseline): logger.info("Using ServerBaseline configuration.") server_configs = default_server_configs @@ -419,9 +432,7 @@ def resolve_openai_configs( f"Merged Dict: {openai_config_dict}" ) from e - if isinstance(default_server_configs, APIServerConfig): - server_configs = [final_openai_config] - elif isinstance(default_server_configs, list): + if isinstance(default_server_configs, list): server_configs = [final_openai_config] else: logger.warning( diff --git a/atroposlib/tests/test_server_logprobs.py b/atroposlib/tests/test_server_logprobs.py index 8cbd84ad..531b1578 100644 --- a/atroposlib/tests/test_server_logprobs.py +++ b/atroposlib/tests/test_server_logprobs.py @@ -1,7 +1,13 @@ """Tests for get_logprobs wrappers and server-manager routing.""" +import logging + import pytest +from atroposlib.envs.server_handling.openai_server import resolve_openai_configs +from atroposlib.envs.server_handling.vllm_server import ( + resolve_openai_configs as resolve_vllm_configs, +) from atroposlib.envs.server_handling.server_baseline import ( APIServer, APIServerConfig, @@ -103,3 +109,49 @@ async def test_server_manager_get_logprobs_routes_to_most_available_server(): out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval") assert out_eval["server"] == "s1" assert s1.calls == 1 + + +def test_resolve_openai_configs_wraps_single_api_server_config_in_list(): + default_server_config = APIServerConfig( + model_name="test-model", + base_url="http://localhost:9001/v1", + api_key="x", + server_type="openai", + ) + merged_config = default_server_config.model_dump() + + server_configs = resolve_openai_configs( + default_server_configs=default_server_config, + openai_config_dict=merged_config, + yaml_config={}, + cli_passed_flags={}, + logger=logging.getLogger("test"), + ) + + assert isinstance(server_configs, list) + assert len(server_configs) == 1 + assert isinstance(server_configs[0], APIServerConfig) + assert server_configs[0].base_url == "http://localhost:9001/v1" + + +def test_resolve_vllm_configs_wraps_single_api_server_config_in_list(): + default_server_config = APIServerConfig( + model_name="test-model", + base_url="http://localhost:9001/v1", + api_key="x", + server_type="vllm", + ) + merged_config = default_server_config.model_dump() + + server_configs = resolve_vllm_configs( + default_server_configs=default_server_config, + openai_config_dict=merged_config, + yaml_config={}, + cli_passed_flags={}, + logger=logging.getLogger("test"), + ) + + assert isinstance(server_configs, list) + assert len(server_configs) == 1 + assert isinstance(server_configs[0], APIServerConfig) + assert server_configs[0].base_url == "http://localhost:9001/v1" diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index 2c0ddf17..7c8cb439 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -87,6 +87,24 @@ async def test_attach_teacher_distillation_negative_topk_skips_fetch(): assert out["distill_logprobs"] is None +@pytest.mark.asyncio +async def test_attach_teacher_distillation_zero_topk_passthrough(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0) + env.teacher_server = _FakeTeacherServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": None, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert env.teacher_server.calls == 1 + assert out["distill_token_ids"] is not None + assert out["distill_logprobs"] is not None + + @pytest.mark.asyncio async def test_attach_teacher_distillation_group_override_can_skip_fetch(): env = object.__new__(_ConcreteTeacherEnv) From 12ba3cc3bdc2e7ac2ad8efccabbc003a99e7524a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 21:25:23 +0000 Subject: [PATCH 49/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- atroposlib/tests/test_server_logprobs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/atroposlib/tests/test_server_logprobs.py b/atroposlib/tests/test_server_logprobs.py index 531b1578..7e545355 100644 --- a/atroposlib/tests/test_server_logprobs.py +++ b/atroposlib/tests/test_server_logprobs.py @@ -5,15 +5,15 @@ import logging import pytest from atroposlib.envs.server_handling.openai_server import resolve_openai_configs -from atroposlib.envs.server_handling.vllm_server import ( - resolve_openai_configs as resolve_vllm_configs, -) from atroposlib.envs.server_handling.server_baseline import ( APIServer, APIServerConfig, AsyncSemWithAdaptiveWeight, ) from atroposlib.envs.server_handling.server_manager import ServerManager +from atroposlib.envs.server_handling.vllm_server import ( + resolve_openai_configs as resolve_vllm_configs, +) class _FakeAPIServer(APIServer): From a171358f2e73b084baa713f9ff77ce0d7809d6b2 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 18:49:01 -0400 Subject: [PATCH 50/64] structural changes --- README.md | 37 ++- atroposlib/envs/teacher_distillation_env.py | 254 ++++++++++++++++-- .../tests/test_teacher_distillation_env.py | 149 ++++++++++ environments/gsm8k_server_teacher_distill.py | 18 +- example_trainer/README.md | 8 +- ...n_gsm8k_teacher_distill_single_terminal.sh | 6 +- 6 files changed, 422 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 076f886a..a0885cdd 100644 --- a/README.md +++ b/README.md @@ -308,30 +308,43 @@ Teacher config shape: ```python TeacherDistillationConfig( teacher_enabled=True, - teacher_server=APIServerConfig( - base_url="http://localhost:9003/v1", - model_name="Qwen/Qwen3-30B-A3B-Instruct-2507", - api_key="", - server_type="vllm", - ), teacher_top_k=8, ) ``` -If `teacher_server.model_name` is a deployment alias rather than a tokenizer -identifier, set `teacher_server.tokenizer_name` explicitly so the env can -validate tokenizer compatibility. +Teacher server configs are passed separately at init, just like the primary +`server_configs`: + +```python +env = MyTeacherEnv( + config=env_config, + server_configs=student_server_configs, + teacher_server_configs=[ + APIServerConfig( + base_url="http://localhost:9003/v1", + model_name="Qwen/Qwen3-30B-A3B-Instruct-2507", + api_key="", + server_type="vllm", + tokenizer_name="Qwen/Qwen3-30B-A3B-Instruct-2507", + ) + ], +) +``` CLI shape: ```bash --env.teacher_enabled true \ ---env.teacher_server.base_url "http://localhost:9003/v1" \ ---env.teacher_server.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \ ---env.teacher_server.server_type vllm \ +--teacher.base_url "http://localhost:9003/v1" \ +--teacher.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \ +--teacher.server_type vllm \ --env.teacher_top_k 8 ``` +If `--teacher.model_name` is a deployment alias rather than a tokenizer +identifier, also set `--teacher.tokenizer_name ...` so the env can validate +tokenizer compatibility. + Tokenizer requirement: - Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary. diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index fdac54c5..1b3cda8f 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -15,13 +15,24 @@ from __future__ import annotations import asyncio import logging from abc import ABC -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union +import yaml from pydantic import Field +from pydantic_cli import Cmd +from rich import print as rprint from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup +from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE +from .server_handling.openai_server import resolve_openai_configs from .server_handling.server_baseline import APIServerConfig, ServerBaseline -from .server_handling.server_manager import ServerManager +from .server_handling.server_manager import ServerManager, ServerManagerConfig +from ..utils.cli import ( + extract_namespace, + get_double_dash_flags, + get_prefixed_pydantic_model, + merge_dicts, +) logger = logging.getLogger(__name__) @@ -31,10 +42,6 @@ class TeacherDistillationConfig(BaseEnvConfig): default=False, description="Whether to fetch teacher prompt logprobs for distillation.", ) - teacher_server: Optional[APIServerConfig] = Field( - default=None, - description="Fallback teacher server configuration when not provided at init.", - ) teacher_top_k: int = Field( default=0, ge=-1, @@ -56,6 +63,220 @@ class TeacherDistillationEnv(BaseEnv, ABC): """ env_config_cls = TeacherDistillationConfig + teacher_namespace = "teacher" + + @classmethod + def teacher_config_init( + cls, + ) -> Optional[Union[ServerBaseline, List[APIServerConfig], APIServerConfig]]: + return None + + @classmethod + def _resolve_teacher_server_configs( + cls, + default_teacher_server_configs: Optional[ + Union[ServerBaseline, List[APIServerConfig], APIServerConfig] + ], + yaml_config: Dict[str, Any], + cli_passed_flags: Dict[str, Any], + ) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]: + teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}" + teacher_cli_passed_args = extract_namespace(cli_passed_flags, teacher_full_prefix) + yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {}) + + if ( + default_teacher_server_configs is None + and not teacher_cli_passed_args + and not yaml_teacher_config + ): + return None + + effective_teacher_server_configs = default_teacher_server_configs + if effective_teacher_server_configs is None: + effective_teacher_server_configs = APIServerConfig() + elif isinstance(effective_teacher_server_configs, ServerBaseline) and ( + teacher_cli_passed_args or yaml_teacher_config + ): + effective_teacher_server_configs = APIServerConfig( + **effective_teacher_server_configs.model_dump() + ) + + if ( + isinstance(effective_teacher_server_configs, list) + and len(effective_teacher_server_configs) == 1 + ): + default_teacher_config = effective_teacher_server_configs[0] + else: + default_teacher_config = effective_teacher_server_configs + + if isinstance(yaml_teacher_config, list) and len(yaml_teacher_config) == 1: + yaml_teacher_config = yaml_teacher_config[0] + + if isinstance(default_teacher_config, APIServerConfig) and isinstance( + yaml_teacher_config, dict + ): + teacher_config_dict = merge_dicts( + default_teacher_config.model_dump(), + yaml_teacher_config, + teacher_cli_passed_args, + ) + else: + teacher_config_dict = {} + + teacher_yaml_wrapped = {OPENAI_NAMESPACE: yaml_teacher_config} + teacher_cli_wrapped = { + f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}{key}": value + for key, value in teacher_cli_passed_args.items() + } + return resolve_openai_configs( + default_server_configs=effective_teacher_server_configs, + openai_config_dict=teacher_config_dict, + yaml_config=teacher_yaml_wrapped, + cli_passed_flags=teacher_cli_wrapped, + logger=logger, + ) + + @classmethod + def get_cli_serve_config_cls(cls) -> type: + default_env_config, default_server_configs = cls.config_init() + default_teacher_server_configs = cls.teacher_config_init() + + env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}" + openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}" + teacher_cli_base = get_prefixed_pydantic_model( + APIServerConfig, teacher_full_prefix + ) + + class CliServeConfig( + get_prefixed_pydantic_model(type(default_env_config), env_full_prefix), + get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix), + teacher_cli_base, + ServerManagerConfig, + Cmd, + ): + config: str | None = Field( + default=None, + description="Path to .yaml config file. CLI args override this.", + ) + + def run(self) -> None: + wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" + if ( + getattr(self, wandb_name_attr, None) is None + and cls.name is not None + ): + setattr(self, wandb_name_attr, cls.name) + + if self.config is not None: + with open(self.config, "r") as f: + yaml_config = yaml.safe_load(f) + logger.info("Loaded config from %s", self.config) + else: + yaml_config = {} + + cli_passed_flags = get_double_dash_flags() + + env_config_dict = merge_dicts( + default_env_config.model_dump(), + yaml_config.get(ENV_NAMESPACE, {}), + extract_namespace(cli_passed_flags, env_full_prefix), + ) + + oai_cli_passed_args = extract_namespace( + cli_passed_flags, openai_full_prefix + ) + yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) + + effective_server_configs = default_server_configs + if isinstance(effective_server_configs, ServerBaseline) and ( + oai_cli_passed_args or yaml_oai_config + ): + effective_server_configs = APIServerConfig( + **effective_server_configs.model_dump() + ) + + if ( + isinstance(effective_server_configs, list) + and len(effective_server_configs) == 1 + ): + default_openai_config_ = effective_server_configs[0] + else: + default_openai_config_ = effective_server_configs + + if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: + yaml_oai_config = yaml_oai_config[0] + + if isinstance(default_openai_config_, APIServerConfig) and isinstance( + yaml_oai_config, dict + ): + openai_config_dict = merge_dicts( + default_openai_config_.model_dump(), + yaml_oai_config, + oai_cli_passed_args, + ) + else: + openai_config_dict = {} + + server_manager_cli_passed_flags = {} + if "slurm" in cli_passed_flags: + server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"] + if "testing" in cli_passed_flags: + server_manager_cli_passed_flags["testing"] = cli_passed_flags[ + "testing" + ] + + server_manager_yaml_dict = {} + if "slurm" in yaml_config: + server_manager_yaml_dict["slurm"] = yaml_config["slurm"] + if "testing" in yaml_config: + server_manager_yaml_dict["testing"] = yaml_config["testing"] + + server_manager_config_dict = merge_dicts( + ServerManagerConfig().model_dump(), + server_manager_yaml_dict, + server_manager_cli_passed_flags, + ) + + env_config = type(default_env_config)(**env_config_dict) + server_manager_config = ServerManagerConfig( + **server_manager_config_dict + ) + openai_configs = resolve_openai_configs( + default_server_configs=effective_server_configs, + openai_config_dict=openai_config_dict, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + logger=logger, + ) + teacher_configs = cls._resolve_teacher_server_configs( + default_teacher_server_configs=default_teacher_server_configs, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + ) + + env_kwargs = { + "config": env_config, + "server_configs": openai_configs, + "slurm": server_manager_config.slurm, + "testing": server_manager_config.testing, + } + if teacher_configs is not None: + env_kwargs["teacher_server_configs"] = teacher_configs + env = cls(**env_kwargs) + rprint(env_config) + rprint(openai_configs) + if teacher_configs is not None: + rprint(teacher_configs) + + try: + loop = asyncio.get_running_loop() + task = loop.create_task(env.env_manager()) + loop.run_until_complete(task) + except RuntimeError: + asyncio.run(env.env_manager()) + + return CliServeConfig def __init__( self, @@ -71,26 +292,11 @@ class TeacherDistillationEnv(BaseEnv, ABC): self.teacher_server: Optional[ServerManager] = None if config.teacher_enabled: - teacher_config_source = teacher_server_configs - if teacher_config_source is None and config.teacher_server is not None: - teacher_config_source = [ - config.teacher_server.model_copy( - update={ - "tokenizer_name": ( - config.teacher_server.model_name - if config.teacher_server.tokenizer_name in ("", "none") - else config.teacher_server.tokenizer_name - ), - "timeout": 1200, - } - ) - ] - - if teacher_config_source is None: + if teacher_server_configs is None: raise ValueError( - "teacher_enabled=True requires teacher_server_configs at init " - "or a fallback teacher_server config." + "teacher_enabled=True requires teacher_server_configs at init." ) + teacher_config_source = teacher_server_configs self.teacher_server = ServerManager( teacher_config_source, slurm=False, diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index 7c8cb439..c789670d 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -4,6 +4,7 @@ from types import SimpleNamespace import pytest +from atroposlib.envs.server_handling.server_baseline import APIServerConfig from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv @@ -32,6 +33,20 @@ class _ConcreteTeacherEnv(TeacherDistillationEnv): return None +class _DummyTokenizer: + name_or_path = "student-model" + + def get_vocab(self): + return {"a": 1} + + +class _CapturingServerManager: + def __init__(self, configs, slurm=False, testing=False): + self.configs = configs + self.slurm = slurm + self.testing = testing + + @pytest.mark.asyncio async def test_attach_teacher_distillation_success(): env = object.__new__(_ConcreteTeacherEnv) @@ -105,6 +120,32 @@ async def test_attach_teacher_distillation_zero_topk_passthrough(): assert out["distill_logprobs"] is not None +@pytest.mark.asyncio +async def test_attach_teacher_distillation_group_override_topk_is_used(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0) + + seen_topks = [] + + async def _fake_fetch(seq, top_k): + seen_topks.append(top_k) + return [[tok] for tok in seq], [[-0.1] for _ in seq] + + env.teacher_server = object() + env._fetch_teacher_for_sequence = _fake_fetch + + group = { + "tokens": [[1, 2, 3], [4, 5]], + "group_overrides": {"teacher_top_k": 7}, + "masks": [[-100, 2, 3], [-100, 5]], + "scores": [1.0, 0.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert seen_topks == [7, 7] + assert out["distill_token_ids"] is not None + assert out["distill_logprobs"] is not None + + @pytest.mark.asyncio async def test_attach_teacher_distillation_group_override_can_skip_fetch(): env = object.__new__(_ConcreteTeacherEnv) @@ -149,3 +190,111 @@ def test_teacher_tokenizer_mismatch_raises(monkeypatch): env, teacher_tokenizer_name="teacher-model", ) + + +def test_init_requires_teacher_server_source(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + def _fake_base_init(self, config, server_configs, slurm=False, testing=False): + self.config = config + self.tokenizer = _DummyTokenizer() + + monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init) + + config = SimpleNamespace( + teacher_enabled=True, + teacher_top_k=0, + ) + with pytest.raises(ValueError, match="teacher_enabled=True requires"): + _ConcreteTeacherEnv( + config=config, + server_configs=[], + ) + + +def test_init_uses_explicit_teacher_server_configs(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + called = {} + + def _fake_base_init(self, config, server_configs, slurm=False, testing=False): + self.config = config + self.tokenizer = _DummyTokenizer() + + def _fake_validate(self, teacher_tokenizer_name): + called["teacher_tokenizer_name"] = teacher_tokenizer_name + + monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init) + monkeypatch.setattr(module, "ServerManager", _CapturingServerManager) + monkeypatch.setattr( + _ConcreteTeacherEnv, + "_validate_teacher_tokenizer_compatibility", + _fake_validate, + ) + + explicit_cfg = APIServerConfig( + model_name="explicit-model", + tokenizer_name="explicit-tokenizer", + base_url="http://explicit/v1", + api_key="x", + server_type="vllm", + ) + config = SimpleNamespace( + teacher_enabled=True, + teacher_top_k=0, + ) + + env = _ConcreteTeacherEnv( + config=config, + server_configs=[], + teacher_server_configs=[explicit_cfg], + ) + + assert isinstance(env.teacher_server, _CapturingServerManager) + assert env.teacher_server.configs == [explicit_cfg] + assert called["teacher_tokenizer_name"] == "explicit-tokenizer" + + +def test_resolve_teacher_server_configs_returns_none_when_unset(): + assert ( + _ConcreteTeacherEnv._resolve_teacher_server_configs( + default_teacher_server_configs=None, + yaml_config={}, + cli_passed_flags={}, + ) + is None + ) + + +def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + captured = {} + + def _fake_resolve(**kwargs): + captured.update(kwargs) + return ["resolved"] + + monkeypatch.setattr(module, "resolve_openai_configs", _fake_resolve) + + default_cfg = APIServerConfig( + model_name="teacher-model", + base_url="http://teacher/v1", + api_key="x", + server_type="vllm", + ) + + out = _ConcreteTeacherEnv._resolve_teacher_server_configs( + default_teacher_server_configs=default_cfg, + yaml_config={"teacher": {"tokenizer_name": "teacher-tokenizer"}}, + cli_passed_flags={"teacher.base_url": "http://override/v1"}, + ) + + assert out == ["resolved"] + assert captured["openai_config_dict"]["base_url"] == "http://override/v1" + assert captured["openai_config_dict"]["tokenizer_name"] == "teacher-tokenizer" + assert captured["yaml_config"] == { + "openai": {"tokenizer_name": "teacher-tokenizer"} + } + assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"} + diff --git a/environments/gsm8k_server_teacher_distill.py b/environments/gsm8k_server_teacher_distill.py index 5aa33a01..59106f10 100644 --- a/environments/gsm8k_server_teacher_distill.py +++ b/environments/gsm8k_server_teacher_distill.py @@ -32,13 +32,6 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): max_token_length=2048, wandb_name="gsm8k_teacher_distill", teacher_enabled=True, - teacher_server=APIServerConfig( - base_url="http://localhost:8003/v1", - model_name="mock-teacher", - api_key="", - server_type="vllm", - tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", - ), teacher_top_k=4, ) server_config = APIServerConfig( @@ -49,6 +42,17 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): ) return env_config, server_config + @classmethod + def teacher_config_init(cls) -> APIServerConfig: + return APIServerConfig( + base_url="http://localhost:9003/v1", + model_name="mock-teacher", + api_key="", + server_type="vllm", + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + timeout=1200, + ) + if __name__ == "__main__": GSM8kTeacherDistillEnv.cli() diff --git a/example_trainer/README.md b/example_trainer/README.md index 8596a849..de31b3eb 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -314,14 +314,14 @@ What to configure on the environment side: ```bash --env.teacher_enabled true \ ---env.teacher_server.base_url "http://localhost:9003/v1" \ ---env.teacher_server.model_name "$TEACHER_MODEL" \ ---env.teacher_server.server_type vllm \ +--teacher.base_url "http://localhost:9003/v1" \ +--teacher.model_name "$TEACHER_MODEL" \ +--teacher.server_type vllm \ --env.teacher_top_k 8 ``` If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier, -also set `--env.teacher_server.tokenizer_name ...` so the env can validate +also set `--teacher.tokenizer_name ...` so the env can validate tokenizer compatibility. Why cross-tokenizer conversion is not acceptable here: diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 91cecf8a..fead9ba6 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -234,9 +234,9 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \ --env.use_wandb true \ --env.wandb_name "gsm8k-teacher-distill" \ --env.teacher_enabled true \ - --env.teacher_server.base_url "http://localhost:${TEACHER_PORT}/v1" \ - --env.teacher_server.model_name "$TEACHER_MODEL" \ - --env.teacher_server.server_type vllm \ + --teacher.base_url "http://localhost:${TEACHER_PORT}/v1" \ + --teacher.model_name "$TEACHER_MODEL" \ + --teacher.server_type vllm \ --env.teacher_top_k "$TEACHER_TOP_K" \ --env.ensure_scores_are_not_same false \ --openai.api_key "dummy" \ From 3a85ede8ba2e4b547abd3134b7e0c6de2d415b2c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 22:51:54 +0000 Subject: [PATCH 51/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .secrets.baseline | 4 ++-- atroposlib/envs/teacher_distillation_env.py | 14 ++++++++------ atroposlib/tests/test_teacher_distillation_env.py | 1 - 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index 2651783a..31d5870f 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -133,7 +133,7 @@ "filename": "README.md", "hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5", "is_verified": false, - "line_number": 499 + "line_number": 512 } ], "SLURM.md": [ @@ -561,5 +561,5 @@ } ] }, - "generated_at": "2026-03-13T20:13:21Z" + "generated_at": "2026-03-13T22:51:44Z" } diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 1b3cda8f..892d424c 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -22,17 +22,17 @@ from pydantic import Field from pydantic_cli import Cmd from rich import print as rprint -from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup -from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE -from .server_handling.openai_server import resolve_openai_configs -from .server_handling.server_baseline import APIServerConfig, ServerBaseline -from .server_handling.server_manager import ServerManager, ServerManagerConfig from ..utils.cli import ( extract_namespace, get_double_dash_flags, get_prefixed_pydantic_model, merge_dicts, ) +from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup +from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE +from .server_handling.openai_server import resolve_openai_configs +from .server_handling.server_baseline import APIServerConfig, ServerBaseline +from .server_handling.server_manager import ServerManager, ServerManagerConfig logger = logging.getLogger(__name__) @@ -81,7 +81,9 @@ class TeacherDistillationEnv(BaseEnv, ABC): cli_passed_flags: Dict[str, Any], ) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]: teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}" - teacher_cli_passed_args = extract_namespace(cli_passed_flags, teacher_full_prefix) + teacher_cli_passed_args = extract_namespace( + cli_passed_flags, teacher_full_prefix + ) yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {}) if ( diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index c789670d..b348e65a 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -297,4 +297,3 @@ def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch): "openai": {"tokenizer_name": "teacher-tokenizer"} } assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"} - From 9bd299b3efc98ee3c18f8f9c8b3230984f567726 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 20:41:41 -0400 Subject: [PATCH 52/64] better logging for devex --- README.md | 18 ++++++++++++++++++ atroposlib/envs/teacher_distillation_env.py | 6 +++++- .../tests/test_teacher_distillation_env.py | 2 +- example_trainer/README.md | 5 +++++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a0885cdd..4d12f592 100644 --- a/README.md +++ b/README.md @@ -331,6 +331,17 @@ env = MyTeacherEnv( ) ``` +You can either: + +- build a teacher-enabled env by mixing `TeacherDistillationEnv` into an existing + `BaseEnv`-derived env such as `GSM8kEnv`, or +- subclass `TeacherDistillationEnv` directly and implement the usual environment + methods yourself. + +In both cases, `TeacherDistillationEnv` still assumes the normal `BaseEnv` +runtime contract: tokenized rollouts, `ScoredDataGroup` payloads, and the +standard `handle_send_to_api(...)` transport path. + CLI shape: ```bash @@ -345,6 +356,13 @@ If `--teacher.model_name` is a deployment alias rather than a tokenizer identifier, also set `--teacher.tokenizer_name ...` so the env can validate tokenizer compatibility. +Scope note: + +- The teacher-aware CLI wiring currently exists for `serve`. +- If `teacher_enabled=True`, the generic `process` and `evaluate` commands will + fail loudly at env construction time unless you instantiate the env yourself + and pass `teacher_server_configs=...`. + Tokenizer requirement: - Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary. diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 892d424c..64f62b14 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -296,7 +296,11 @@ class TeacherDistillationEnv(BaseEnv, ABC): if config.teacher_enabled: if teacher_server_configs is None: raise ValueError( - "teacher_enabled=True requires teacher_server_configs at init." + "teacher_enabled=True but no teacher server configuration was " + "provided. Pass teacher_server_configs=... when instantiating " + "the environment directly, or use the teacher-aware 'serve' CLI " + "path with --teacher.* flags. The generic BaseEnv 'process' and " + "'evaluate' commands do not currently wire teacher_server_configs." ) teacher_config_source = teacher_server_configs self.teacher_server = ServerManager( diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index b348e65a..b35c48ee 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -205,7 +205,7 @@ def test_init_requires_teacher_server_source(monkeypatch): teacher_enabled=True, teacher_top_k=0, ) - with pytest.raises(ValueError, match="teacher_enabled=True requires"): + with pytest.raises(ValueError, match="no teacher server configuration was provided"): _ConcreteTeacherEnv( config=config, server_configs=[], diff --git a/example_trainer/README.md b/example_trainer/README.md index de31b3eb..4563e12c 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -324,6 +324,11 @@ If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier, also set `--teacher.tokenizer_name ...` so the env can validate tokenizer compatibility. +The teacher-aware CLI path is currently wired for `serve`. If +`teacher_enabled=True`, the generic `process` and `evaluate` commands are not +teacher-aware and will fail loudly unless the environment is instantiated +manually with `teacher_server_configs=...`. + Why cross-tokenizer conversion is not acceptable here: - Teacher token ID `1234` and student token ID `1234` can correspond to different text. From f053c77a6249386520efd8075c9beb9c40024e14 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Mar 2026 00:43:19 +0000 Subject: [PATCH 53/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .secrets.baseline | 4 ++-- atroposlib/tests/test_teacher_distillation_env.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index 31d5870f..f93f324a 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -133,7 +133,7 @@ "filename": "README.md", "hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5", "is_verified": false, - "line_number": 512 + "line_number": 530 } ], "SLURM.md": [ @@ -561,5 +561,5 @@ } ] }, - "generated_at": "2026-03-13T22:51:44Z" + "generated_at": "2026-03-14T00:43:09Z" } diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index b35c48ee..11c586d9 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -205,7 +205,9 @@ def test_init_requires_teacher_server_source(monkeypatch): teacher_enabled=True, teacher_top_k=0, ) - with pytest.raises(ValueError, match="no teacher server configuration was provided"): + with pytest.raises( + ValueError, match="no teacher server configuration was provided" + ): _ConcreteTeacherEnv( config=config, server_configs=[], From 805a0c0eaca8aedec37dccb310d7ed0b56133e98 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 13 Mar 2026 20:52:40 -0400 Subject: [PATCH 54/64] revert to similar structure --- .../envs/server_handling/openai_server.py | 20 +++++++------------ .../envs/server_handling/vllm_server.py | 4 +++- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index 54f03fb4..0a0f3751 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -230,20 +230,14 @@ def resolve_openai_configs( f"Merged Dict: {openai_config_dict}" ) from e - if isinstance(default_server_configs, list): + if isinstance(default_server_configs, APIServerConfig): + server_configs = [final_openai_config] + elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: + logger.warning( + f"Unexpected type for default_server_configs: {type(default_server_configs)}. " + "Proceeding with single OpenAI server configuration based on merged settings." + ) server_configs = [final_openai_config] - - if isinstance(server_configs, list): - logger.info( - "resolve_openai_configs returning %s config(s) with URLs: %s", - len(server_configs), - [getattr(c, "base_url", None) for c in server_configs], - ) - else: - logger.info( - "resolve_openai_configs returning %s", - type(server_configs).__name__, - ) return server_configs diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 72e7140e..935d3816 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -432,7 +432,9 @@ def resolve_openai_configs( f"Merged Dict: {openai_config_dict}" ) from e - if isinstance(default_server_configs, list): + if isinstance(default_server_configs, APIServerConfig): + server_configs = [final_openai_config] + elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: logger.warning( From 7aba0d3fc865da9a1e34b266dd861f5962aaac56 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Sat, 14 Mar 2026 11:20:15 -0400 Subject: [PATCH 55/64] fresh eyes check --- atroposlib/envs/teacher_distillation_env.py | 7 ++- .../tests/test_teacher_distillation_env.py | 43 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 64f62b14..1f0e2110 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -285,7 +285,7 @@ class TeacherDistillationEnv(BaseEnv, ABC): config: TeacherDistillationConfig, server_configs: Union[ServerBaseline, List[APIServerConfig]], teacher_server_configs: Optional[ - Union[ServerBaseline, List[APIServerConfig]] + Union[ServerBaseline, APIServerConfig, List[APIServerConfig]] ] = None, slurm: bool = False, testing: bool = False, @@ -302,7 +302,10 @@ class TeacherDistillationEnv(BaseEnv, ABC): "path with --teacher.* flags. The generic BaseEnv 'process' and " "'evaluate' commands do not currently wire teacher_server_configs." ) - teacher_config_source = teacher_server_configs + if isinstance(teacher_server_configs, APIServerConfig): + teacher_config_source = [teacher_server_configs] + else: + teacher_config_source = teacher_server_configs self.teacher_server = ServerManager( teacher_config_source, slurm=False, diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index 11c586d9..c8825218 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -257,6 +257,49 @@ def test_init_uses_explicit_teacher_server_configs(monkeypatch): assert called["teacher_tokenizer_name"] == "explicit-tokenizer" +def test_init_wraps_bare_teacher_api_server_config(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + called = {} + + def _fake_base_init(self, config, server_configs, slurm=False, testing=False): + self.config = config + self.tokenizer = _DummyTokenizer() + + def _fake_validate(self, teacher_tokenizer_name): + called["teacher_tokenizer_name"] = teacher_tokenizer_name + + monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init) + monkeypatch.setattr(module, "ServerManager", _CapturingServerManager) + monkeypatch.setattr( + _ConcreteTeacherEnv, + "_validate_teacher_tokenizer_compatibility", + _fake_validate, + ) + + explicit_cfg = APIServerConfig( + model_name="explicit-model", + tokenizer_name="explicit-tokenizer", + base_url="http://explicit/v1", + api_key="x", + server_type="vllm", + ) + config = SimpleNamespace( + teacher_enabled=True, + teacher_top_k=0, + ) + + env = _ConcreteTeacherEnv( + config=config, + server_configs=[], + teacher_server_configs=explicit_cfg, + ) + + assert isinstance(env.teacher_server, _CapturingServerManager) + assert env.teacher_server.configs == [explicit_cfg] + assert called["teacher_tokenizer_name"] == "explicit-tokenizer" + + def test_resolve_teacher_server_configs_returns_none_when_unset(): assert ( _ConcreteTeacherEnv._resolve_teacher_server_configs( From 79baac1ea7c8eb1617ca7c2b1c29bb5022953b1b Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 17 Mar 2026 12:23:35 -0400 Subject: [PATCH 56/64] clean --- .../envs/server_handling/openai_server.py | 17 +++++------------ atroposlib/envs/server_handling/vllm_server.py | 17 +++++------------ 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index 0a0f3751..98f68682 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -199,16 +199,12 @@ def resolve_openai_configs( f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" ) from e elif isinstance(default_server_configs, APIServerConfig): - # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline - logger.info( - "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." - ) + logger.info("Using single OpenAI server configuration.") try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( - f"Error creating final OpenAI configuration from merged settings: {e}\n" - f"Merged Dict: {openai_config_dict}" + f"Error creating final OpenAI configuration: {e}" ) from e server_configs = [final_openai_config] elif isinstance(default_server_configs, ServerBaseline): @@ -219,15 +215,12 @@ def resolve_openai_configs( logger.info("Using default multi-server configuration (length >= 2).") server_configs = default_server_configs else: - logger.info( - "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." - ) + logger.info("Using single OpenAI server configuration.") try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( - f"Error creating final OpenAI configuration from merged settings: {e}\n" - f"Merged Dict: {openai_config_dict}" + f"Error creating final OpenAI configuration: {e}" ) from e if isinstance(default_server_configs, APIServerConfig): @@ -237,7 +230,7 @@ def resolve_openai_configs( else: logger.warning( f"Unexpected type for default_server_configs: {type(default_server_configs)}. " - "Proceeding with single OpenAI server configuration based on merged settings." + "Proceeding with single OpenAI server configuration." ) server_configs = [final_openai_config] return server_configs diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 935d3816..c8ce5e4b 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -402,16 +402,12 @@ def resolve_openai_configs( f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" ) from e elif isinstance(default_server_configs, APIServerConfig): - # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline - logger.info( - "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." - ) + logger.info("Using single OpenAI server configuration.") try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( - f"Error creating final OpenAI configuration from merged settings: {e}\n" - f"Merged Dict: {openai_config_dict}" + f"Error creating final OpenAI configuration: {e}" ) from e server_configs = [final_openai_config] elif isinstance(default_server_configs, ServerBaseline): @@ -421,15 +417,12 @@ def resolve_openai_configs( logger.info("Using default multi-server configuration (length >= 2).") server_configs = default_server_configs else: - logger.info( - "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." - ) + logger.info("Using single OpenAI server configuration.") try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( - f"Error creating final OpenAI configuration from merged settings: {e}\n" - f"Merged Dict: {openai_config_dict}" + f"Error creating final OpenAI configuration: {e}" ) from e if isinstance(default_server_configs, APIServerConfig): @@ -439,7 +432,7 @@ def resolve_openai_configs( else: logger.warning( f"Unexpected type for default_server_configs: {type(default_server_configs)}. " - "Proceeding with single OpenAI server configuration based on merged settings." + "Proceeding with single OpenAI server configuration." ) server_configs = [final_openai_config] From 41947e98d65ae72d3cef374fdbfff5fc4b9a1147 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 17 Mar 2026 12:25:38 -0400 Subject: [PATCH 57/64] clean --- atroposlib/envs/server_handling/vllm_server.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index c8ce5e4b..ba272e76 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -401,15 +401,6 @@ def resolve_openai_configs( raise FailedExecutionException( f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" ) from e - elif isinstance(default_server_configs, APIServerConfig): - logger.info("Using single OpenAI server configuration.") - try: - final_openai_config = APIServerConfig(**openai_config_dict) - except Exception as e: - raise FailedExecutionException( - f"Error creating final OpenAI configuration: {e}" - ) from e - server_configs = [final_openai_config] elif isinstance(default_server_configs, ServerBaseline): logger.info("Using ServerBaseline configuration.") server_configs = default_server_configs From 45f569f3afccf8d51439356adfe48860aa4eaaee Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Wed, 18 Mar 2026 09:20:08 -0400 Subject: [PATCH 58/64] clean --- .../envs/server_handling/openai_server.py | 21 ++++++++++++------- .../envs/server_handling/vllm_server.py | 13 +++++++----- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index 98f68682..a715054b 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -199,14 +199,18 @@ def resolve_openai_configs( f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" ) from e elif isinstance(default_server_configs, APIServerConfig): - logger.info("Using single OpenAI server configuration.") + # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( - f"Error creating final OpenAI configuration: {e}" + f"Error creating final OpenAI configuration from merged settings: {e}\n" + f"Merged Dict: {openai_config_dict}" ) from e - server_configs = [final_openai_config] + server_configs = final_openai_config elif isinstance(default_server_configs, ServerBaseline): # Pure ServerBaseline (not APIServerConfig) - no CLI overrides possible logger.info("Using ServerBaseline configuration.") @@ -215,22 +219,25 @@ def resolve_openai_configs( logger.info("Using default multi-server configuration (length >= 2).") server_configs = default_server_configs else: - logger.info("Using single OpenAI server configuration.") + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( - f"Error creating final OpenAI configuration: {e}" + f"Error creating final OpenAI configuration from merged settings: {e}\n" + f"Merged Dict: {openai_config_dict}" ) from e if isinstance(default_server_configs, APIServerConfig): - server_configs = [final_openai_config] + server_configs = final_openai_config elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: logger.warning( f"Unexpected type for default_server_configs: {type(default_server_configs)}. " - "Proceeding with single OpenAI server configuration." + f"Proceeding with single OpenAI server configuration based on merged settings." ) server_configs = [final_openai_config] return server_configs diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index ba272e76..3c35bebb 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -281,7 +281,7 @@ class VLLMServer(APIServer): ), "Prompt or input_ids is required for get_logprobs!" top_k = int(kwargs.pop("top_k", kwargs.pop("top_logprobs", 1))) - top_k = max(0, top_k) + top_k = max(1, top_k) # Use input_ids if provided (from ManagedServer), otherwise tokenize prompt from_prompt_text = False @@ -408,22 +408,25 @@ def resolve_openai_configs( logger.info("Using default multi-server configuration (length >= 2).") server_configs = default_server_configs else: - logger.info("Using single OpenAI server configuration.") + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: raise FailedExecutionException( - f"Error creating final OpenAI configuration: {e}" + f"Error creating final OpenAI configuration from merged settings: {e}\n" + f"Merged Dict: {openai_config_dict}" ) from e if isinstance(default_server_configs, APIServerConfig): - server_configs = [final_openai_config] + server_configs = final_openai_config elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: logger.warning( f"Unexpected type for default_server_configs: {type(default_server_configs)}. " - "Proceeding with single OpenAI server configuration." + f"Proceeding with single OpenAI server configuration based on merged settings." ) server_configs = [final_openai_config] From 79ff1642f8b4668516855faf8ca1118f22aa9fac Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 23 Mar 2026 11:18:14 -0700 Subject: [PATCH 59/64] revert gsm8k --- environments/gsm8k_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index de13f8c9..87823526 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -232,6 +232,7 @@ class GSM8kEnv(BaseEnv): ) async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + chat_completions = await managed.chat_completion( messages=[{"role": "system", "content": system_prompt}, user_message], n=self.config.group_size, @@ -351,7 +352,7 @@ class GSM8kEnv(BaseEnv): # Apply linear penalty scaling from 1.0 down to 0.0 scores["scores"].append(1.0 - percentage_of_range) if all([scores["scores"][0] == score for score in scores["scores"]]): - return None + return None # If all the same, we return None return scores else: # If the gold solution is not parseable, we return None From 8745f0533e7fd7e532930e4838540460458de38e Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 23 Mar 2026 11:23:47 -0700 Subject: [PATCH 60/64] revert teacher logprobs --- atroposlib/tests/test_server_logprobs.py | 52 ------------------------ 1 file changed, 52 deletions(-) diff --git a/atroposlib/tests/test_server_logprobs.py b/atroposlib/tests/test_server_logprobs.py index 7e545355..8cbd84ad 100644 --- a/atroposlib/tests/test_server_logprobs.py +++ b/atroposlib/tests/test_server_logprobs.py @@ -1,19 +1,13 @@ """Tests for get_logprobs wrappers and server-manager routing.""" -import logging - import pytest -from atroposlib.envs.server_handling.openai_server import resolve_openai_configs from atroposlib.envs.server_handling.server_baseline import ( APIServer, APIServerConfig, AsyncSemWithAdaptiveWeight, ) from atroposlib.envs.server_handling.server_manager import ServerManager -from atroposlib.envs.server_handling.vllm_server import ( - resolve_openai_configs as resolve_vllm_configs, -) class _FakeAPIServer(APIServer): @@ -109,49 +103,3 @@ async def test_server_manager_get_logprobs_routes_to_most_available_server(): out_eval = await ServerManager.get_logprobs(manager, prompt="x", split="eval") assert out_eval["server"] == "s1" assert s1.calls == 1 - - -def test_resolve_openai_configs_wraps_single_api_server_config_in_list(): - default_server_config = APIServerConfig( - model_name="test-model", - base_url="http://localhost:9001/v1", - api_key="x", - server_type="openai", - ) - merged_config = default_server_config.model_dump() - - server_configs = resolve_openai_configs( - default_server_configs=default_server_config, - openai_config_dict=merged_config, - yaml_config={}, - cli_passed_flags={}, - logger=logging.getLogger("test"), - ) - - assert isinstance(server_configs, list) - assert len(server_configs) == 1 - assert isinstance(server_configs[0], APIServerConfig) - assert server_configs[0].base_url == "http://localhost:9001/v1" - - -def test_resolve_vllm_configs_wraps_single_api_server_config_in_list(): - default_server_config = APIServerConfig( - model_name="test-model", - base_url="http://localhost:9001/v1", - api_key="x", - server_type="vllm", - ) - merged_config = default_server_config.model_dump() - - server_configs = resolve_vllm_configs( - default_server_configs=default_server_config, - openai_config_dict=merged_config, - yaml_config={}, - cli_passed_flags={}, - logger=logging.getLogger("test"), - ) - - assert isinstance(server_configs, list) - assert len(server_configs) == 1 - assert isinstance(server_configs[0], APIServerConfig) - assert server_configs[0].base_url == "http://localhost:9001/v1" From 295bb9c4464eb49b677567eb7d5909c1cc0286a0 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 23 Mar 2026 11:25:28 -0700 Subject: [PATCH 61/64] revert openai server --- atroposlib/envs/server_handling/openai_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index a715054b..dbf1c2d9 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -240,4 +240,5 @@ def resolve_openai_configs( f"Proceeding with single OpenAI server configuration based on merged settings." ) server_configs = [final_openai_config] + return server_configs From 75a032bf3e7fc327357c5d0276d5df1e5a3cb041 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 23 Mar 2026 11:26:05 -0700 Subject: [PATCH 62/64] revert openai server --- atroposlib/envs/server_handling/openai_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index dbf1c2d9..fecc5828 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -240,5 +240,5 @@ def resolve_openai_configs( f"Proceeding with single OpenAI server configuration based on merged settings." ) server_configs = [final_openai_config] - + return server_configs From fae87dcaaa0070a6d2fbc764797b0de9b97bc7c6 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 23 Mar 2026 11:28:40 -0700 Subject: [PATCH 63/64] clean vllm tonight --- example_trainer/api.py | 6 +----- example_trainer/vllm_api_server.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/example_trainer/api.py b/example_trainer/api.py index 1bc8a1bd..21c4288e 100644 --- a/example_trainer/api.py +++ b/example_trainer/api.py @@ -99,11 +99,7 @@ def get_batch(url: str = "http://localhost:8000"): Raises: RuntimeError: If trainer is not registered or other API error """ - response = requests.get( - f"{url}/batch", - timeout=10, - ) - data = response.json() + data = requests.get(f"{url}/batch", timeout=10).json() # Check if there was an error (trainer not registered) if data.get("status") == "error": diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 24d40326..10759466 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -325,7 +325,6 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: async for request_output in results_generator: final_output = request_output except asyncio.CancelledError: - logger.warning("POST /generate cancelled request_id=%s", request_id) return Response(status_code=499) assert final_output is not None From e1542ee731d4c9176fcb304cf16220c86f587920 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 23 Mar 2026 11:30:04 -0700 Subject: [PATCH 64/64] clean example trainer --- example_trainer/vllm_api_server.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 10759466..2846f14f 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -348,26 +348,6 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: ret["prompt_token_ids"] = final_output.prompt_token_ids ret["token_ids"] = [x.token_ids for x in final_output.outputs] - if ( - sampling_params.prompt_logprobs is not None - and final_output.prompt_logprobs is not None - ): - ret["prompt_logprobs"] = [ - ( - {int(tok_id): lp.logprob for tok_id, lp in pos.items()} - if pos is not None - else None - ) - for pos in final_output.prompt_logprobs - ] - - logger.info( - "POST /generate completed request_id=%s outputs=%s finish_reasons=%s", - request_id, - len(text_outputs), - finish_reasons, - ) - return JSONResponse(ret)