Merge branch 'NousResearch:main' into feat/trainer-inference-optimization

This commit is contained in:
Prakarsh Kaushik 2026-03-31 00:57:26 +05:30 committed by GitHub
commit 6825974cd0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1285 additions and 4 deletions

View file

@ -133,7 +133,7 @@
"filename": "README.md",
"hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5",
"is_verified": false,
"line_number": 454
"line_number": 530
}
],
"SLURM.md": [
@ -561,5 +561,5 @@
}
]
},
"generated_at": "2026-03-02T22:46:56Z"
"generated_at": "2026-03-14T00:43:09Z"
}

View file

@ -298,6 +298,82 @@ curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!
- Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.).
- Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR.
### TeacherDistillationEnv follow-up
The follow-up teacher environment uses a dedicated teacher server config and
attaches teacher prompt logprobs before the group is sent to the API.
Teacher config shape:
```python
TeacherDistillationConfig(
teacher_enabled=True,
teacher_top_k=8,
)
```
Teacher server configs are passed separately at init, just like the primary
`server_configs`:
```python
env = MyTeacherEnv(
config=env_config,
server_configs=student_server_configs,
teacher_server_configs=[
APIServerConfig(
base_url="http://localhost:9003/v1",
model_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
api_key="",
server_type="vllm",
tokenizer_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
)
],
)
```
You can either:
- build a teacher-enabled env by mixing `TeacherDistillationEnv` into an existing
`BaseEnv`-derived env such as `GSM8kEnv`, or
- subclass `TeacherDistillationEnv` directly and implement the usual environment
methods yourself.
In both cases, `TeacherDistillationEnv` still assumes the normal `BaseEnv`
runtime contract: tokenized rollouts, `ScoredDataGroup` payloads, and the
standard `handle_send_to_api(...)` transport path.
CLI shape:
```bash
--env.teacher_enabled true \
--teacher.base_url "http://localhost:9003/v1" \
--teacher.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \
--teacher.server_type vllm \
--env.teacher_top_k 8
```
If `--teacher.model_name` is a deployment alias rather than a tokenizer
identifier, also set `--teacher.tokenizer_name ...` so the env can validate
tokenizer compatibility.
Scope note:
- The teacher-aware CLI wiring currently exists for `serve`.
- If `teacher_enabled=True`, the generic `process` and `evaluate` commands will
fail loudly at env construction time unless you instantiate the env yourself
and pass `teacher_server_configs=...`.
Tokenizer requirement:
- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary.
- If the tokenizers do not match, `TeacherDistillationEnv` raises an error instead of attempting token conversion.
Why same-tokenizer is required:
- `distill_token_ids` are consumed as student-vocabulary IDs by the trainer.
- If the teacher uses a different vocabulary, the same integer token ID refers to different text on the teacher and student sides.
- A decode/re-tokenize/remap pipeline is not a safe drop-in fix because it changes both token positions and token identities, which breaks the exact per-position token supervision that the current distillation loss assumes.
---
## Testing and Debugging Tools

View file

@ -0,0 +1,459 @@
"""
Teacher distillation environment layer.
This module adds teacher prompt-logprob fetching on top of BaseEnv without
modifying BaseEnv transport behavior.
This implementation supports same-tokenizer distillation only. The teacher and
student must share the same tokenizer vocabulary so the student's token IDs can
be forwarded directly to the teacher and the returned teacher top-k token IDs
can be looked up directly in the student's logits.
"""
from __future__ import annotations
import asyncio
import logging
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union
import yaml
from pydantic import Field
from pydantic_cli import Cmd
from rich import print as rprint
from ..utils.cli import (
extract_namespace,
get_double_dash_flags,
get_prefixed_pydantic_model,
merge_dicts,
)
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
from .server_handling.openai_server import resolve_openai_configs
from .server_handling.server_baseline import APIServerConfig, ServerBaseline
from .server_handling.server_manager import ServerManager, ServerManagerConfig
logger = logging.getLogger(__name__)
class TeacherDistillationConfig(BaseEnvConfig):
teacher_enabled: bool = Field(
default=False,
description="Whether to fetch teacher prompt logprobs for distillation.",
)
teacher_top_k: int = Field(
default=0,
ge=-1,
description=(
"Number of extra prompt logprobs to fetch beyond the selected token. "
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
"teacher fetching."
),
)
class TeacherDistillationEnv(BaseEnv, ABC):
"""
BaseEnv subclass that enriches scored groups with teacher distillation arrays.
Distillation payload shape:
- distill_token_ids: [sequence][position][k] (student vocab IDs)
- distill_logprobs: [sequence][position][k]
"""
env_config_cls = TeacherDistillationConfig
teacher_namespace = "teacher"
@classmethod
def teacher_config_init(
cls,
) -> Optional[Union[ServerBaseline, List[APIServerConfig], APIServerConfig]]:
return None
@classmethod
def _resolve_teacher_server_configs(
cls,
default_teacher_server_configs: Optional[
Union[ServerBaseline, List[APIServerConfig], APIServerConfig]
],
yaml_config: Dict[str, Any],
cli_passed_flags: Dict[str, Any],
) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]:
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
teacher_cli_passed_args = extract_namespace(
cli_passed_flags, teacher_full_prefix
)
yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {})
if (
default_teacher_server_configs is None
and not teacher_cli_passed_args
and not yaml_teacher_config
):
return None
effective_teacher_server_configs = default_teacher_server_configs
if effective_teacher_server_configs is None:
effective_teacher_server_configs = APIServerConfig()
elif isinstance(effective_teacher_server_configs, ServerBaseline) and (
teacher_cli_passed_args or yaml_teacher_config
):
effective_teacher_server_configs = APIServerConfig(
**effective_teacher_server_configs.model_dump()
)
if (
isinstance(effective_teacher_server_configs, list)
and len(effective_teacher_server_configs) == 1
):
default_teacher_config = effective_teacher_server_configs[0]
else:
default_teacher_config = effective_teacher_server_configs
if isinstance(yaml_teacher_config, list) and len(yaml_teacher_config) == 1:
yaml_teacher_config = yaml_teacher_config[0]
if isinstance(default_teacher_config, APIServerConfig) and isinstance(
yaml_teacher_config, dict
):
teacher_config_dict = merge_dicts(
default_teacher_config.model_dump(),
yaml_teacher_config,
teacher_cli_passed_args,
)
else:
teacher_config_dict = {}
teacher_yaml_wrapped = {OPENAI_NAMESPACE: yaml_teacher_config}
teacher_cli_wrapped = {
f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}{key}": value
for key, value in teacher_cli_passed_args.items()
}
return resolve_openai_configs(
default_server_configs=effective_teacher_server_configs,
openai_config_dict=teacher_config_dict,
yaml_config=teacher_yaml_wrapped,
cli_passed_flags=teacher_cli_wrapped,
logger=logger,
)
@classmethod
def get_cli_serve_config_cls(cls) -> type:
default_env_config, default_server_configs = cls.config_init()
default_teacher_server_configs = cls.teacher_config_init()
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
teacher_cli_base = get_prefixed_pydantic_model(
APIServerConfig, teacher_full_prefix
)
class CliServeConfig(
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix),
teacher_cli_base,
ServerManagerConfig,
Cmd,
):
config: str | None = Field(
default=None,
description="Path to .yaml config file. CLI args override this.",
)
def run(self) -> None:
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if (
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
):
setattr(self, wandb_name_attr, cls.name)
if self.config is not None:
with open(self.config, "r") as f:
yaml_config = yaml.safe_load(f)
logger.info("Loaded config from %s", self.config)
else:
yaml_config = {}
cli_passed_flags = get_double_dash_flags()
env_config_dict = merge_dicts(
default_env_config.model_dump(),
yaml_config.get(ENV_NAMESPACE, {}),
extract_namespace(cli_passed_flags, env_full_prefix),
)
oai_cli_passed_args = extract_namespace(
cli_passed_flags, openai_full_prefix
)
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
effective_server_configs = default_server_configs
if isinstance(effective_server_configs, ServerBaseline) and (
oai_cli_passed_args or yaml_oai_config
):
effective_server_configs = APIServerConfig(
**effective_server_configs.model_dump()
)
if (
isinstance(effective_server_configs, list)
and len(effective_server_configs) == 1
):
default_openai_config_ = effective_server_configs[0]
else:
default_openai_config_ = effective_server_configs
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(),
yaml_oai_config,
oai_cli_passed_args,
)
else:
openai_config_dict = {}
server_manager_cli_passed_flags = {}
if "slurm" in cli_passed_flags:
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
if "testing" in cli_passed_flags:
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
"testing"
]
server_manager_yaml_dict = {}
if "slurm" in yaml_config:
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
if "testing" in yaml_config:
server_manager_yaml_dict["testing"] = yaml_config["testing"]
server_manager_config_dict = merge_dicts(
ServerManagerConfig().model_dump(),
server_manager_yaml_dict,
server_manager_cli_passed_flags,
)
env_config = type(default_env_config)(**env_config_dict)
server_manager_config = ServerManagerConfig(
**server_manager_config_dict
)
openai_configs = resolve_openai_configs(
default_server_configs=effective_server_configs,
openai_config_dict=openai_config_dict,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
logger=logger,
)
teacher_configs = cls._resolve_teacher_server_configs(
default_teacher_server_configs=default_teacher_server_configs,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
)
env_kwargs = {
"config": env_config,
"server_configs": openai_configs,
"slurm": server_manager_config.slurm,
"testing": server_manager_config.testing,
}
if teacher_configs is not None:
env_kwargs["teacher_server_configs"] = teacher_configs
env = cls(**env_kwargs)
rprint(env_config)
rprint(openai_configs)
if teacher_configs is not None:
rprint(teacher_configs)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(env.env_manager())
loop.run_until_complete(task)
except RuntimeError:
asyncio.run(env.env_manager())
return CliServeConfig
def __init__(
self,
config: TeacherDistillationConfig,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
teacher_server_configs: Optional[
Union[ServerBaseline, APIServerConfig, List[APIServerConfig]]
] = None,
slurm: bool = False,
testing: bool = False,
):
super().__init__(config, server_configs, slurm=slurm, testing=testing)
self.teacher_server: Optional[ServerManager] = None
if config.teacher_enabled:
if teacher_server_configs is None:
raise ValueError(
"teacher_enabled=True but no teacher server configuration was "
"provided. Pass teacher_server_configs=... when instantiating "
"the environment directly, or use the teacher-aware 'serve' CLI "
"path with --teacher.* flags. The generic BaseEnv 'process' and "
"'evaluate' commands do not currently wire teacher_server_configs."
)
if isinstance(teacher_server_configs, APIServerConfig):
teacher_config_source = [teacher_server_configs]
else:
teacher_config_source = teacher_server_configs
self.teacher_server = ServerManager(
teacher_config_source,
slurm=False,
testing=False,
)
if isinstance(teacher_config_source, list):
teacher_cfg = teacher_config_source[0]
else:
teacher_cfg = teacher_config_source
teacher_tokenizer_name = (
teacher_cfg.model_name
if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none")
else teacher_cfg.tokenizer_name
)
self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name)
# ------------------------------------------------------------------
# Core fetch
# ------------------------------------------------------------------
def _validate_teacher_tokenizer_compatibility(
self, teacher_tokenizer_name: str
) -> None:
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
if student_tok_name == teacher_tokenizer_name:
return
try:
from transformers import AutoTokenizer
teacher_tokenizer = AutoTokenizer.from_pretrained(
teacher_tokenizer_name, use_fast=True
)
except Exception as exc:
raise ValueError(
"Cross-tokenizer distillation is not supported in this PR, and the "
f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to "
f"verify compatibility: {exc}"
) from exc
student_vocab = self.tokenizer.get_vocab()
teacher_vocab = teacher_tokenizer.get_vocab()
if student_vocab != teacher_vocab:
raise ValueError(
"Cross-tokenizer distillation is not supported in this PR. "
f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' "
f"and teacher tokenizer '{teacher_tokenizer_name}' do not match."
)
async def _fetch_teacher_for_sequence(
self, token_ids: List[int], top_k: int
) -> Tuple[List[List[int]], List[List[float]]]:
assert self.teacher_server is not None
payload = await self.teacher_server.get_logprobs(
input_ids=token_ids,
top_k=top_k,
max_tokens=1,
split="train",
)
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
# ------------------------------------------------------------------
# Group enrichment
# ------------------------------------------------------------------
async def _attach_teacher_distillation(
self, group: ScoredDataGroup
) -> ScoredDataGroup:
if not self.config.teacher_enabled or self.teacher_server is None:
return group
seqs = group.get("tokens", [])
if not seqs:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
group_overrides = group.get("group_overrides") or {}
if group_overrides.get("skip_teacher_top_k", False):
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
top_k = int(group_overrides.get("teacher_top_k", self.config.teacher_top_k))
if top_k <= -1:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs]
results = await asyncio.gather(*tasks, return_exceptions=True)
distill_token_ids: List[List[List[int]]] = []
distill_logprobs: List[List[List[float]]] = []
for idx, result in enumerate(results):
if isinstance(result, Exception):
logger.warning(
"Teacher logprob fetch failed for seq %s: %s. "
"Dropping distill payload for this group.",
idx,
result,
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
token_ids_k, logprobs_k = result
if len(token_ids_k) != len(logprobs_k):
logger.warning(
"Teacher prompt-topk length mismatch for seq %s (%s != %s). "
"Dropping distill payload for this group.",
idx,
len(token_ids_k),
len(logprobs_k),
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
distill_token_ids.append(token_ids_k)
distill_logprobs.append(logprobs_k)
group["distill_token_ids"] = distill_token_ids
group["distill_logprobs"] = distill_logprobs
return group
async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Any = None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
groups = scored_data if isinstance(scored_data, list) else [scored_data]
enriched_groups: List[ScoredDataGroup] = []
for group in groups:
if group is None:
continue
enriched_groups.append(await self._attach_teacher_distillation(group))
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
if isinstance(scored_data, list):
payload = enriched_groups
else:
payload = enriched_groups[0] if enriched_groups else scored_data
return await super().handle_send_to_api(
payload,
item=item,
do_send_to_api=do_send_to_api,
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
)

View file

@ -0,0 +1,344 @@
"""Tests for TeacherDistillationEnv distillation enrichment."""
from types import SimpleNamespace
import pytest
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv
class _FakeTeacherServer:
def __init__(self, fail_on_call: int = -1):
self.calls = 0
self.fail_on_call = fail_on_call
async def get_logprobs(self, **kwargs):
self.calls += 1
if self.calls == self.fail_on_call:
raise RuntimeError("teacher backend failure")
seq = kwargs["input_ids"]
return {
"prompt_tokens": seq,
"prompt_topk_token_ids": [[tok, tok + 1] for tok in seq],
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq],
}
class _ConcreteTeacherEnv(TeacherDistillationEnv):
async def get_next_item(self):
return None
async def evaluate(self, *args, **kwargs):
return None
class _DummyTokenizer:
name_or_path = "student-model"
def get_vocab(self):
return {"a": 1}
class _CapturingServerManager:
def __init__(self, configs, slurm=False, testing=False):
self.configs = configs
self.slurm = slurm
self.testing = testing
@pytest.mark.asyncio
async def test_attach_teacher_distillation_success():
env = object.__new__(_ConcreteTeacherEnv)
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
env.teacher_server = _FakeTeacherServer()
group = {
"tokens": [[1, 2, 3], [4, 5]],
"group_overrides": None,
"masks": [[-100, 2, 3], [-100, 5]],
"scores": [1.0, 0.0],
}
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
assert out["distill_token_ids"] is not None
assert out["distill_logprobs"] is not None
assert len(out["distill_token_ids"]) == 2
assert len(out["distill_token_ids"][0]) == 3
assert len(out["distill_logprobs"][1]) == 2
@pytest.mark.asyncio
async def test_attach_teacher_distillation_failure_drops_payload():
env = object.__new__(_ConcreteTeacherEnv)
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
env.teacher_server = _FakeTeacherServer(fail_on_call=2)
group = {
"tokens": [[1, 2, 3], [4, 5]],
"group_overrides": None,
"masks": [[-100, 2, 3], [-100, 5]],
"scores": [1.0, 0.0],
}
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
assert out["distill_token_ids"] is None
assert out["distill_logprobs"] is None
@pytest.mark.asyncio
async def test_attach_teacher_distillation_negative_topk_skips_fetch():
env = object.__new__(_ConcreteTeacherEnv)
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=-1)
env.teacher_server = _FakeTeacherServer()
group = {
"tokens": [[1, 2, 3]],
"group_overrides": None,
"masks": [[-100, 2, 3]],
"scores": [1.0],
}
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
assert env.teacher_server.calls == 0
assert out["distill_token_ids"] is None
assert out["distill_logprobs"] is None
@pytest.mark.asyncio
async def test_attach_teacher_distillation_zero_topk_passthrough():
env = object.__new__(_ConcreteTeacherEnv)
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
env.teacher_server = _FakeTeacherServer()
group = {
"tokens": [[1, 2, 3]],
"group_overrides": None,
"masks": [[-100, 2, 3]],
"scores": [1.0],
}
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
assert env.teacher_server.calls == 1
assert out["distill_token_ids"] is not None
assert out["distill_logprobs"] is not None
@pytest.mark.asyncio
async def test_attach_teacher_distillation_group_override_topk_is_used():
env = object.__new__(_ConcreteTeacherEnv)
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
seen_topks = []
async def _fake_fetch(seq, top_k):
seen_topks.append(top_k)
return [[tok] for tok in seq], [[-0.1] for _ in seq]
env.teacher_server = object()
env._fetch_teacher_for_sequence = _fake_fetch
group = {
"tokens": [[1, 2, 3], [4, 5]],
"group_overrides": {"teacher_top_k": 7},
"masks": [[-100, 2, 3], [-100, 5]],
"scores": [1.0, 0.0],
}
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
assert seen_topks == [7, 7]
assert out["distill_token_ids"] is not None
assert out["distill_logprobs"] is not None
@pytest.mark.asyncio
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
env = object.__new__(_ConcreteTeacherEnv)
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
env.teacher_server = _FakeTeacherServer()
group = {
"tokens": [[1, 2, 3]],
"group_overrides": {"skip_teacher_top_k": True},
"masks": [[-100, 2, 3]],
"scores": [1.0],
}
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
assert env.teacher_server.calls == 0
assert out["distill_token_ids"] is None
assert out["distill_logprobs"] is None
def test_teacher_tokenizer_mismatch_raises(monkeypatch):
env = object.__new__(_ConcreteTeacherEnv)
class _StudentTokenizer:
name_or_path = "student-model"
def get_vocab(self):
return {"a": 1}
class _TeacherTokenizer:
def get_vocab(self):
return {"b": 1}
env.tokenizer = _StudentTokenizer()
monkeypatch.setattr(
"transformers.AutoTokenizer.from_pretrained",
lambda *args, **kwargs: _TeacherTokenizer(),
)
with pytest.raises(
ValueError, match="Cross-tokenizer distillation is not supported"
):
TeacherDistillationEnv._validate_teacher_tokenizer_compatibility(
env,
teacher_tokenizer_name="teacher-model",
)
def test_init_requires_teacher_server_source(monkeypatch):
from atroposlib.envs import teacher_distillation_env as module
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
self.config = config
self.tokenizer = _DummyTokenizer()
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
config = SimpleNamespace(
teacher_enabled=True,
teacher_top_k=0,
)
with pytest.raises(
ValueError, match="no teacher server configuration was provided"
):
_ConcreteTeacherEnv(
config=config,
server_configs=[],
)
def test_init_uses_explicit_teacher_server_configs(monkeypatch):
from atroposlib.envs import teacher_distillation_env as module
called = {}
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
self.config = config
self.tokenizer = _DummyTokenizer()
def _fake_validate(self, teacher_tokenizer_name):
called["teacher_tokenizer_name"] = teacher_tokenizer_name
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
monkeypatch.setattr(
_ConcreteTeacherEnv,
"_validate_teacher_tokenizer_compatibility",
_fake_validate,
)
explicit_cfg = APIServerConfig(
model_name="explicit-model",
tokenizer_name="explicit-tokenizer",
base_url="http://explicit/v1",
api_key="x",
server_type="vllm",
)
config = SimpleNamespace(
teacher_enabled=True,
teacher_top_k=0,
)
env = _ConcreteTeacherEnv(
config=config,
server_configs=[],
teacher_server_configs=[explicit_cfg],
)
assert isinstance(env.teacher_server, _CapturingServerManager)
assert env.teacher_server.configs == [explicit_cfg]
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
def test_init_wraps_bare_teacher_api_server_config(monkeypatch):
from atroposlib.envs import teacher_distillation_env as module
called = {}
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
self.config = config
self.tokenizer = _DummyTokenizer()
def _fake_validate(self, teacher_tokenizer_name):
called["teacher_tokenizer_name"] = teacher_tokenizer_name
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
monkeypatch.setattr(
_ConcreteTeacherEnv,
"_validate_teacher_tokenizer_compatibility",
_fake_validate,
)
explicit_cfg = APIServerConfig(
model_name="explicit-model",
tokenizer_name="explicit-tokenizer",
base_url="http://explicit/v1",
api_key="x",
server_type="vllm",
)
config = SimpleNamespace(
teacher_enabled=True,
teacher_top_k=0,
)
env = _ConcreteTeacherEnv(
config=config,
server_configs=[],
teacher_server_configs=explicit_cfg,
)
assert isinstance(env.teacher_server, _CapturingServerManager)
assert env.teacher_server.configs == [explicit_cfg]
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
def test_resolve_teacher_server_configs_returns_none_when_unset():
assert (
_ConcreteTeacherEnv._resolve_teacher_server_configs(
default_teacher_server_configs=None,
yaml_config={},
cli_passed_flags={},
)
is None
)
def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch):
from atroposlib.envs import teacher_distillation_env as module
captured = {}
def _fake_resolve(**kwargs):
captured.update(kwargs)
return ["resolved"]
monkeypatch.setattr(module, "resolve_openai_configs", _fake_resolve)
default_cfg = APIServerConfig(
model_name="teacher-model",
base_url="http://teacher/v1",
api_key="x",
server_type="vllm",
)
out = _ConcreteTeacherEnv._resolve_teacher_server_configs(
default_teacher_server_configs=default_cfg,
yaml_config={"teacher": {"tokenizer_name": "teacher-tokenizer"}},
cli_passed_flags={"teacher.base_url": "http://override/v1"},
)
assert out == ["resolved"]
assert captured["openai_config_dict"]["base_url"] == "http://override/v1"
assert captured["openai_config_dict"]["tokenizer_name"] == "teacher-tokenizer"
assert captured["yaml_config"] == {
"openai": {"tokenizer_name": "teacher-tokenizer"}
}
assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"}

View file

@ -0,0 +1,58 @@
from typing import Tuple
from atroposlib.envs.base import APIServerConfig, ServerBaseline
from atroposlib.envs.teacher_distillation_env import (
TeacherDistillationConfig,
TeacherDistillationEnv,
)
from environments.gsm8k_server import GSM8kEnv
class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv):
"""
GSM8K environment variant that enables TeacherDistillationEnv config fields.
This preserves the original `gsm8k_server.py` while providing a separate entrypoint
for teacher-distillation data collection.
"""
name = "gsm8k_teacher_distill"
env_config_cls = TeacherDistillationConfig
@classmethod
def config_init(cls) -> Tuple[TeacherDistillationConfig, ServerBaseline]:
env_config = TeacherDistillationConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
wandb_name="gsm8k_teacher_distill",
teacher_enabled=True,
teacher_top_k=4,
)
server_config = APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
)
return env_config, server_config
@classmethod
def teacher_config_init(cls) -> APIServerConfig:
return APIServerConfig(
base_url="http://localhost:9003/v1",
model_name="mock-teacher",
api_key="",
server_type="vllm",
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
timeout=1200,
)
if __name__ == "__main__":
GSM8kTeacherDistillEnv.cli()

View file

@ -304,6 +304,38 @@ environment uses the `/generate` path and includes token-level
4. Trainer extracts and aligns logprobs with training labels
5. GRPO loss uses these rollout logprobs in importance-ratio terms
### 1b. Teacher distillation requires the same tokenizer
When distillation data is attached to Atropos batches, the trainer treats
`distill_token_ids` as indices into the student's logit tensor. That only works
if the teacher and student share the same tokenizer vocabulary.
What to configure on the environment side:
```bash
--env.teacher_enabled true \
--teacher.base_url "http://localhost:9003/v1" \
--teacher.model_name "$TEACHER_MODEL" \
--teacher.server_type vllm \
--env.teacher_top_k 8
```
If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier,
also set `--teacher.tokenizer_name ...` so the env can validate
tokenizer compatibility.
The teacher-aware CLI path is currently wired for `serve`. If
`teacher_enabled=True`, the generic `process` and `evaluate` commands are not
teacher-aware and will fail loudly unless the environment is instantiated
manually with `teacher_server_configs=...`.
Why cross-tokenizer conversion is not acceptable here:
- Teacher token ID `1234` and student token ID `1234` can correspond to different text.
- Re-tokenizing teacher text changes token boundaries, so teacher position `i` may no longer correspond to student position `i`.
- Remapping teacher top-k tokens back into student vocab can collapse multiple teacher candidates into one student token or expand one teacher token into multiple student tokens.
- The current distillation loss expects exact per-position supervision in student token space, so an approximate remapping would silently produce misleading targets.
### 2. Clipping
```bash

View file

@ -248,8 +248,7 @@ run_shared_vllm() {
--port "$vllm_port" \
--gpu-memory-utilization "$SHARED_GPU_MEMORY_UTILIZATION" \
--max-model-len "$MAX_MODEL_LEN" \
--dtype "$DTYPE" \
--enforce-eager
--dtype "$DTYPE"
if [[ "$DRY_RUN" == "1" ]]; then
log "[DRY RUN] wait for http://localhost:${vllm_port}/health"
else

View file

@ -0,0 +1,313 @@
#!/usr/bin/env bash
set -euo pipefail
# Single-terminal teacher-distillation runner.
# Starts everything in the background from ONE shell that has GPU access:
# 1) Atropos API
# 2) Student vLLM server
# 3) Teacher vLLM server
# 4) GSM8K teacher-distill environment
# 5) Example trainer (foreground)
#
# Usage:
# chmod +x example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
# ./example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
#
# Optional overrides:
# STUDENT_MODEL="Qwen/Qwen3-4B-Instruct-2507-FP8"
# TEACHER_MODEL="Qwen/Qwen3-30B-A3B-Instruct-2507"
# STUDENT_GPUS="0"
# TEACHER_GPUS="4,5,6,7"
# TRAINER_GPUS="0"
# STUDENT_TP=1
# TEACHER_TP=4
# API_PORT=8002
# STUDENT_PORT=9001
# TEACHER_PORT=9003
# TRAINING_STEPS=100
# DISTILL_COEF=0.2
# DISTILL_TEMPERATURE=1.0
# TEACHER_TOP_K=8
# DRY_RUN=1
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
LAUNCH_DIR="$PWD"
cd "$ROOT_DIR"
PYTHON_BIN="${PYTHON_BIN:-python3}"
STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B}"
TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}"
STUDENT_GPUS="${STUDENT_GPUS:-0}"
TEACHER_GPUS="${TEACHER_GPUS:-4,5,6,7}"
TRAINER_GPUS="${TRAINER_GPUS:-$STUDENT_GPUS}"
STUDENT_TP="${STUDENT_TP:-1}"
TEACHER_TP="${TEACHER_TP:-4}"
API_PORT="${API_PORT:-8002}"
STUDENT_PORT="${STUDENT_PORT:-9001}"
TEACHER_PORT="${TEACHER_PORT:-9003}"
TRAINING_STEPS="${TRAINING_STEPS:-20}"
BATCH_SIZE="${BATCH_SIZE:-1}"
GRAD_ACCUM="${GRAD_ACCUM:-4}"
LR="${LR:-1e-5}"
WARMUP_STEPS="${WARMUP_STEPS:-0}"
CLIP_EPS="${CLIP_EPS:-0.2}"
MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}"
TEACHER_MAX_MODEL_LEN="${TEACHER_MAX_MODEL_LEN:-32768}"
# Trainer seq_len must be larger than ENV_MAX_TOKEN_LENGTH to accommodate
# chat template overhead (~400-800 tokens for Qwen3 thinking format).
TRAINER_SEQ_LEN="${TRAINER_SEQ_LEN:-20480}"
ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-16384}"
DISTILL_COEF="${DISTILL_COEF:-0.2}"
DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}"
TEACHER_TOP_K="${TEACHER_TOP_K:-8}"
WANDB_PROJECT="${WANDB_PROJECT:-gsm8k-teacher-distill}"
WANDB_GROUP="${WANDB_GROUP:-}"
STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.60}"
TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.85}"
DTYPE="${DTYPE:-bfloat16}"
SAVE_DIR="${SAVE_DIR:-${LAUNCH_DIR}/saves/gsm8k_teacher_distill}"
LOG_DIR="${LOG_DIR:-${LAUNCH_DIR}/logs/gsm8k_teacher_distill}"
BRIDGE_DIR="${BRIDGE_DIR:-${LOG_DIR}/bridge}"
DRY_RUN="${DRY_RUN:-0}"
ENV_GROUP_SIZE="${ENV_GROUP_SIZE:-4}"
ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-8}"
ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}"
ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}"
ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-1}"
ENV_WORKER_TIMEOUT="${ENV_WORKER_TIMEOUT:-1800}"
RUN_PIDS=()
RUN_PORTS=()
log() {
echo "[$(date '+%H:%M:%S')] $*"
}
kill_port() {
local port="$1"
if [[ "$DRY_RUN" == "1" ]]; then
log "[DRY RUN] skip port cleanup for :${port}"
return 0
fi
if lsof -i ":${port}" -sTCP:LISTEN >/dev/null 2>&1; then
lsof -ti ":${port}" | xargs -r kill -9 || true
fi
}
wait_for_http() {
local url="$1"
local timeout="${2:-240}"
local name="${3:-endpoint}"
local start
start="$(date +%s)"
while true; do
if curl -fsS "$url" >/dev/null 2>&1; then
log "Ready: ${name} (${url})"
return 0
fi
if (( "$(date +%s)" - start > timeout )); then
log "Timeout waiting for ${name}: ${url}"
return 1
fi
sleep 2
done
}
start_process() {
local name="$1"
local logfile="$2"
shift 2
if [[ "$DRY_RUN" == "1" ]]; then
log "[DRY RUN] start ${name} (log: ${logfile})"
printf ' '
printf '%q ' "$@"
printf '\n'
return 0
fi
log "Starting ${name} (log: ${logfile})"
"$@" >"$logfile" 2>&1 &
local pid=$!
RUN_PIDS+=("$pid")
log "${name} PID=${pid}"
}
cleanup_all() {
log "Cleaning up processes..."
for pid in "${RUN_PIDS[@]:-}"; do
kill "$pid" >/dev/null 2>&1 || true
done
sleep 1
for pid in "${RUN_PIDS[@]:-}"; do
kill -9 "$pid" >/dev/null 2>&1 || true
done
for port in "${RUN_PORTS[@]:-}"; do
kill_port "$port"
done
}
trap cleanup_all EXIT INT TERM
mkdir -p "$LOG_DIR" "$SAVE_DIR" "$BRIDGE_DIR"
RUN_PORTS+=("$API_PORT" "$STUDENT_PORT" "$TEACHER_PORT")
kill_port "$API_PORT"
kill_port "$STUDENT_PORT"
kill_port "$TEACHER_PORT"
log "Config:"
log " student=${STUDENT_MODEL}"
log " teacher=${TEACHER_MODEL}"
log " gpus student=${STUDENT_GPUS}, teacher=${TEACHER_GPUS}, trainer=${TRAINER_GPUS}"
log " ports api=${API_PORT}, student=${STUDENT_PORT}, teacher=${TEACHER_PORT}"
log " logs=${LOG_DIR}"
log " saves=${SAVE_DIR}"
log " bridge=${BRIDGE_DIR}"
log " env max_token_length=${ENV_MAX_TOKEN_LENGTH}, env workers=${ENV_MAX_WORKERS_PER_NODE}, env worker_timeout=${ENV_WORKER_TIMEOUT}"
log " wandb project=${WANDB_PROJECT}${WANDB_GROUP:+, group=${WANDB_GROUP}}"
# Shared-vLLM attach path currently expects the student server to expose
# unsharded weights. Keep the student on TP=1 and the trainer on the same GPU set.
if [[ "$STUDENT_TP" != "1" ]]; then
log "ERROR: shared_vllm teacher-distill runner currently requires STUDENT_TP=1."
log " The current attach path does not support TP-sharded student bridge weights."
exit 2
fi
if [[ "$TRAINER_GPUS" != "$STUDENT_GPUS" ]]; then
log "ERROR: TRAINER_GPUS must match STUDENT_GPUS for shared_vllm mode."
log " Got student=${STUDENT_GPUS}, trainer=${TRAINER_GPUS}"
exit 2
fi
# 1) Atropos API
start_process "run_api" "${LOG_DIR}/run_api.log" \
run-api --port "$API_PORT"
if [[ "$DRY_RUN" == "0" ]]; then
wait_for_http "http://localhost:${API_PORT}/info" 180 "run-api"
fi
# 2) Student vLLM server
start_process "student_vllm" "${LOG_DIR}/student_vllm.log" \
env CUDA_VISIBLE_DEVICES="$STUDENT_GPUS" VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR="$BRIDGE_DIR" \
"$PYTHON_BIN" -m example_trainer.vllm_api_server \
--model "$STUDENT_MODEL" \
--port "$STUDENT_PORT" \
--tensor-parallel-size "$STUDENT_TP" \
--gpu-memory-utilization "$STUDENT_GPU_MEMORY_UTILIZATION" \
--max-model-len "$MAX_MODEL_LEN" \
--dtype "$DTYPE"
if [[ "$DRY_RUN" == "0" ]]; then
wait_for_http "http://localhost:${STUDENT_PORT}/health" 420 "student vLLM"
fi
# 3) Teacher vLLM server
start_process "teacher_vllm" "${LOG_DIR}/teacher_vllm.log" \
env CUDA_VISIBLE_DEVICES="$TEACHER_GPUS" \
"$PYTHON_BIN" -m example_trainer.vllm_api_server \
--model "$TEACHER_MODEL" \
--port "$TEACHER_PORT" \
--tensor-parallel-size "$TEACHER_TP" \
--gpu-memory-utilization "$TEACHER_GPU_MEMORY_UTILIZATION" \
--max-model-len "$TEACHER_MAX_MODEL_LEN" \
--dtype "$DTYPE"
if [[ "$DRY_RUN" == "0" ]]; then
wait_for_http "http://localhost:${TEACHER_PORT}/health" 1800 "teacher vLLM"
fi
# 4) Teacher-distill GSM8K env
start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \
"$PYTHON_BIN" environments/gsm8k_server_teacher_distill.py serve \
--env.group_size "$ENV_GROUP_SIZE" \
--env.batch_size "$ENV_BATCH_SIZE" \
--env.total_steps "$ENV_TOTAL_STEPS" \
--env.steps_per_eval "$ENV_STEPS_PER_EVAL" \
--env.max_num_workers_per_node "$ENV_MAX_WORKERS_PER_NODE" \
--env.max_token_length "$ENV_MAX_TOKEN_LENGTH" \
--env.worker_timeout "$ENV_WORKER_TIMEOUT" \
--env.rollout_server_url "http://localhost:${API_PORT}" \
--env.use_wandb true \
--env.wandb_name "gsm8k-teacher-distill" \
--env.teacher_enabled true \
--teacher.base_url "http://localhost:${TEACHER_PORT}/v1" \
--teacher.model_name "$TEACHER_MODEL" \
--teacher.server_type vllm \
--env.teacher_top_k "$TEACHER_TOP_K" \
--env.ensure_scores_are_not_same false \
--openai.api_key "dummy" \
--openai.base_url "http://localhost:${STUDENT_PORT}/v1" \
--openai.model_name "$STUDENT_MODEL" \
--openai.tokenizer_name "$STUDENT_MODEL" \
--openai.server_type vllm
log "All services launched."
log "Run logs:"
log " ${LOG_DIR}/run_api.log"
log " ${LOG_DIR}/student_vllm.log"
log " ${LOG_DIR}/teacher_vllm.log"
log " ${LOG_DIR}/env.log"
# 5) Trainer (background)
TRAINER_CMD=(
env
CUDA_VISIBLE_DEVICES="$TRAINER_GPUS"
PYTHONUNBUFFERED=1
"$PYTHON_BIN"
-u
-m
example_trainer.grpo
--model-name "$STUDENT_MODEL"
--weight-bridge-mode shared_vllm
--device cuda:0
--save-path "$SAVE_DIR"
--atropos-url "http://localhost:${API_PORT}"
--vllm-port "$STUDENT_PORT"
--vllm-config-path "${BRIDGE_DIR}/vllm_bridge_config.json"
--training-steps "$TRAINING_STEPS"
--batch-size "$BATCH_SIZE"
--gradient-accumulation-steps "$GRAD_ACCUM"
--warmup-steps "$WARMUP_STEPS"
--lr "$LR"
--clip-eps "$CLIP_EPS"
--seq-len "$TRAINER_SEQ_LEN"
--distill-enabled
--distill-coef "$DISTILL_COEF"
--distill-temperature "$DISTILL_TEMPERATURE"
--use-wandb
--wandb-project "$WANDB_PROJECT"
)
if [[ -n "$WANDB_GROUP" ]]; then
TRAINER_CMD+=(--wandb-group "$WANDB_GROUP")
fi
if [[ "$DRY_RUN" == "1" ]]; then
log "[DRY RUN] trainer command:"
printf ' '
printf '%q ' "${TRAINER_CMD[@]}"
printf '\n'
exit 0
fi
start_process "trainer" "${LOG_DIR}/trainer.log" "${TRAINER_CMD[@]}"
log "All processes running in background."
log ""
log "Monitor with:"
log " tail -f ${LOG_DIR}/trainer.log"
log " tail -f ${LOG_DIR}/env.log"
log " tail -f ${LOG_DIR}/student_vllm.log"
log " tail -f ${LOG_DIR}/teacher_vllm.log"
log ""
log "Test endpoints:"
log " curl -s http://localhost:${STUDENT_PORT}/health"
log " curl -s http://localhost:${TEACHER_PORT}/health"
log " curl -s http://localhost:${STUDENT_PORT}/bridge/is_paused | jq ."
log ""
log "To stop all processes:"
log " kill ${RUN_PIDS[*]:-} 2>/dev/null; sleep 1; kill -9 ${RUN_PIDS[*]:-} 2>/dev/null"
trap - EXIT INT TERM