structural changes

This commit is contained in:
Jai Suphavadeeprasit 2026-03-13 18:49:01 -04:00
parent 12ba3cc3bd
commit a171358f2e
6 changed files with 422 additions and 50 deletions

View file

@ -308,30 +308,43 @@ Teacher config shape:
```python
TeacherDistillationConfig(
teacher_enabled=True,
teacher_server=APIServerConfig(
base_url="http://localhost:9003/v1",
model_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
api_key="",
server_type="vllm",
),
teacher_top_k=8,
)
```
If `teacher_server.model_name` is a deployment alias rather than a tokenizer
identifier, set `teacher_server.tokenizer_name` explicitly so the env can
validate tokenizer compatibility.
Teacher server configs are passed separately at init, just like the primary
`server_configs`:
```python
env = MyTeacherEnv(
config=env_config,
server_configs=student_server_configs,
teacher_server_configs=[
APIServerConfig(
base_url="http://localhost:9003/v1",
model_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
api_key="",
server_type="vllm",
tokenizer_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
)
],
)
```
CLI shape:
```bash
--env.teacher_enabled true \
--env.teacher_server.base_url "http://localhost:9003/v1" \
--env.teacher_server.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \
--env.teacher_server.server_type vllm \
--teacher.base_url "http://localhost:9003/v1" \
--teacher.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \
--teacher.server_type vllm \
--env.teacher_top_k 8
```
If `--teacher.model_name` is a deployment alias rather than a tokenizer
identifier, also set `--teacher.tokenizer_name ...` so the env can validate
tokenizer compatibility.
Tokenizer requirement:
- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary.

View file

@ -15,13 +15,24 @@ from __future__ import annotations
import asyncio
import logging
from abc import ABC
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import yaml
from pydantic import Field
from pydantic_cli import Cmd
from rich import print as rprint
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
from .server_handling.openai_server import resolve_openai_configs
from .server_handling.server_baseline import APIServerConfig, ServerBaseline
from .server_handling.server_manager import ServerManager
from .server_handling.server_manager import ServerManager, ServerManagerConfig
from ..utils.cli import (
extract_namespace,
get_double_dash_flags,
get_prefixed_pydantic_model,
merge_dicts,
)
logger = logging.getLogger(__name__)
@ -31,10 +42,6 @@ class TeacherDistillationConfig(BaseEnvConfig):
default=False,
description="Whether to fetch teacher prompt logprobs for distillation.",
)
teacher_server: Optional[APIServerConfig] = Field(
default=None,
description="Fallback teacher server configuration when not provided at init.",
)
teacher_top_k: int = Field(
default=0,
ge=-1,
@ -56,6 +63,220 @@ class TeacherDistillationEnv(BaseEnv, ABC):
"""
env_config_cls = TeacherDistillationConfig
teacher_namespace = "teacher"
@classmethod
def teacher_config_init(
cls,
) -> Optional[Union[ServerBaseline, List[APIServerConfig], APIServerConfig]]:
return None
@classmethod
def _resolve_teacher_server_configs(
cls,
default_teacher_server_configs: Optional[
Union[ServerBaseline, List[APIServerConfig], APIServerConfig]
],
yaml_config: Dict[str, Any],
cli_passed_flags: Dict[str, Any],
) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]:
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
teacher_cli_passed_args = extract_namespace(cli_passed_flags, teacher_full_prefix)
yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {})
if (
default_teacher_server_configs is None
and not teacher_cli_passed_args
and not yaml_teacher_config
):
return None
effective_teacher_server_configs = default_teacher_server_configs
if effective_teacher_server_configs is None:
effective_teacher_server_configs = APIServerConfig()
elif isinstance(effective_teacher_server_configs, ServerBaseline) and (
teacher_cli_passed_args or yaml_teacher_config
):
effective_teacher_server_configs = APIServerConfig(
**effective_teacher_server_configs.model_dump()
)
if (
isinstance(effective_teacher_server_configs, list)
and len(effective_teacher_server_configs) == 1
):
default_teacher_config = effective_teacher_server_configs[0]
else:
default_teacher_config = effective_teacher_server_configs
if isinstance(yaml_teacher_config, list) and len(yaml_teacher_config) == 1:
yaml_teacher_config = yaml_teacher_config[0]
if isinstance(default_teacher_config, APIServerConfig) and isinstance(
yaml_teacher_config, dict
):
teacher_config_dict = merge_dicts(
default_teacher_config.model_dump(),
yaml_teacher_config,
teacher_cli_passed_args,
)
else:
teacher_config_dict = {}
teacher_yaml_wrapped = {OPENAI_NAMESPACE: yaml_teacher_config}
teacher_cli_wrapped = {
f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}{key}": value
for key, value in teacher_cli_passed_args.items()
}
return resolve_openai_configs(
default_server_configs=effective_teacher_server_configs,
openai_config_dict=teacher_config_dict,
yaml_config=teacher_yaml_wrapped,
cli_passed_flags=teacher_cli_wrapped,
logger=logger,
)
@classmethod
def get_cli_serve_config_cls(cls) -> type:
default_env_config, default_server_configs = cls.config_init()
default_teacher_server_configs = cls.teacher_config_init()
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
teacher_cli_base = get_prefixed_pydantic_model(
APIServerConfig, teacher_full_prefix
)
class CliServeConfig(
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix),
teacher_cli_base,
ServerManagerConfig,
Cmd,
):
config: str | None = Field(
default=None,
description="Path to .yaml config file. CLI args override this.",
)
def run(self) -> None:
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if (
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
):
setattr(self, wandb_name_attr, cls.name)
if self.config is not None:
with open(self.config, "r") as f:
yaml_config = yaml.safe_load(f)
logger.info("Loaded config from %s", self.config)
else:
yaml_config = {}
cli_passed_flags = get_double_dash_flags()
env_config_dict = merge_dicts(
default_env_config.model_dump(),
yaml_config.get(ENV_NAMESPACE, {}),
extract_namespace(cli_passed_flags, env_full_prefix),
)
oai_cli_passed_args = extract_namespace(
cli_passed_flags, openai_full_prefix
)
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
effective_server_configs = default_server_configs
if isinstance(effective_server_configs, ServerBaseline) and (
oai_cli_passed_args or yaml_oai_config
):
effective_server_configs = APIServerConfig(
**effective_server_configs.model_dump()
)
if (
isinstance(effective_server_configs, list)
and len(effective_server_configs) == 1
):
default_openai_config_ = effective_server_configs[0]
else:
default_openai_config_ = effective_server_configs
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(),
yaml_oai_config,
oai_cli_passed_args,
)
else:
openai_config_dict = {}
server_manager_cli_passed_flags = {}
if "slurm" in cli_passed_flags:
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
if "testing" in cli_passed_flags:
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
"testing"
]
server_manager_yaml_dict = {}
if "slurm" in yaml_config:
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
if "testing" in yaml_config:
server_manager_yaml_dict["testing"] = yaml_config["testing"]
server_manager_config_dict = merge_dicts(
ServerManagerConfig().model_dump(),
server_manager_yaml_dict,
server_manager_cli_passed_flags,
)
env_config = type(default_env_config)(**env_config_dict)
server_manager_config = ServerManagerConfig(
**server_manager_config_dict
)
openai_configs = resolve_openai_configs(
default_server_configs=effective_server_configs,
openai_config_dict=openai_config_dict,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
logger=logger,
)
teacher_configs = cls._resolve_teacher_server_configs(
default_teacher_server_configs=default_teacher_server_configs,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
)
env_kwargs = {
"config": env_config,
"server_configs": openai_configs,
"slurm": server_manager_config.slurm,
"testing": server_manager_config.testing,
}
if teacher_configs is not None:
env_kwargs["teacher_server_configs"] = teacher_configs
env = cls(**env_kwargs)
rprint(env_config)
rprint(openai_configs)
if teacher_configs is not None:
rprint(teacher_configs)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(env.env_manager())
loop.run_until_complete(task)
except RuntimeError:
asyncio.run(env.env_manager())
return CliServeConfig
def __init__(
self,
@ -71,26 +292,11 @@ class TeacherDistillationEnv(BaseEnv, ABC):
self.teacher_server: Optional[ServerManager] = None
if config.teacher_enabled:
teacher_config_source = teacher_server_configs
if teacher_config_source is None and config.teacher_server is not None:
teacher_config_source = [
config.teacher_server.model_copy(
update={
"tokenizer_name": (
config.teacher_server.model_name
if config.teacher_server.tokenizer_name in ("", "none")
else config.teacher_server.tokenizer_name
),
"timeout": 1200,
}
)
]
if teacher_config_source is None:
if teacher_server_configs is None:
raise ValueError(
"teacher_enabled=True requires teacher_server_configs at init "
"or a fallback teacher_server config."
"teacher_enabled=True requires teacher_server_configs at init."
)
teacher_config_source = teacher_server_configs
self.teacher_server = ServerManager(
teacher_config_source,
slurm=False,

View file

@ -4,6 +4,7 @@ from types import SimpleNamespace
import pytest
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv
@ -32,6 +33,20 @@ class _ConcreteTeacherEnv(TeacherDistillationEnv):
return None
class _DummyTokenizer:
name_or_path = "student-model"
def get_vocab(self):
return {"a": 1}
class _CapturingServerManager:
def __init__(self, configs, slurm=False, testing=False):
self.configs = configs
self.slurm = slurm
self.testing = testing
@pytest.mark.asyncio
async def test_attach_teacher_distillation_success():
env = object.__new__(_ConcreteTeacherEnv)
@ -105,6 +120,32 @@ async def test_attach_teacher_distillation_zero_topk_passthrough():
assert out["distill_logprobs"] is not None
@pytest.mark.asyncio
async def test_attach_teacher_distillation_group_override_topk_is_used():
env = object.__new__(_ConcreteTeacherEnv)
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
seen_topks = []
async def _fake_fetch(seq, top_k):
seen_topks.append(top_k)
return [[tok] for tok in seq], [[-0.1] for _ in seq]
env.teacher_server = object()
env._fetch_teacher_for_sequence = _fake_fetch
group = {
"tokens": [[1, 2, 3], [4, 5]],
"group_overrides": {"teacher_top_k": 7},
"masks": [[-100, 2, 3], [-100, 5]],
"scores": [1.0, 0.0],
}
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
assert seen_topks == [7, 7]
assert out["distill_token_ids"] is not None
assert out["distill_logprobs"] is not None
@pytest.mark.asyncio
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
env = object.__new__(_ConcreteTeacherEnv)
@ -149,3 +190,111 @@ def test_teacher_tokenizer_mismatch_raises(monkeypatch):
env,
teacher_tokenizer_name="teacher-model",
)
def test_init_requires_teacher_server_source(monkeypatch):
from atroposlib.envs import teacher_distillation_env as module
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
self.config = config
self.tokenizer = _DummyTokenizer()
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
config = SimpleNamespace(
teacher_enabled=True,
teacher_top_k=0,
)
with pytest.raises(ValueError, match="teacher_enabled=True requires"):
_ConcreteTeacherEnv(
config=config,
server_configs=[],
)
def test_init_uses_explicit_teacher_server_configs(monkeypatch):
from atroposlib.envs import teacher_distillation_env as module
called = {}
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
self.config = config
self.tokenizer = _DummyTokenizer()
def _fake_validate(self, teacher_tokenizer_name):
called["teacher_tokenizer_name"] = teacher_tokenizer_name
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
monkeypatch.setattr(
_ConcreteTeacherEnv,
"_validate_teacher_tokenizer_compatibility",
_fake_validate,
)
explicit_cfg = APIServerConfig(
model_name="explicit-model",
tokenizer_name="explicit-tokenizer",
base_url="http://explicit/v1",
api_key="x",
server_type="vllm",
)
config = SimpleNamespace(
teacher_enabled=True,
teacher_top_k=0,
)
env = _ConcreteTeacherEnv(
config=config,
server_configs=[],
teacher_server_configs=[explicit_cfg],
)
assert isinstance(env.teacher_server, _CapturingServerManager)
assert env.teacher_server.configs == [explicit_cfg]
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
def test_resolve_teacher_server_configs_returns_none_when_unset():
assert (
_ConcreteTeacherEnv._resolve_teacher_server_configs(
default_teacher_server_configs=None,
yaml_config={},
cli_passed_flags={},
)
is None
)
def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch):
from atroposlib.envs import teacher_distillation_env as module
captured = {}
def _fake_resolve(**kwargs):
captured.update(kwargs)
return ["resolved"]
monkeypatch.setattr(module, "resolve_openai_configs", _fake_resolve)
default_cfg = APIServerConfig(
model_name="teacher-model",
base_url="http://teacher/v1",
api_key="x",
server_type="vllm",
)
out = _ConcreteTeacherEnv._resolve_teacher_server_configs(
default_teacher_server_configs=default_cfg,
yaml_config={"teacher": {"tokenizer_name": "teacher-tokenizer"}},
cli_passed_flags={"teacher.base_url": "http://override/v1"},
)
assert out == ["resolved"]
assert captured["openai_config_dict"]["base_url"] == "http://override/v1"
assert captured["openai_config_dict"]["tokenizer_name"] == "teacher-tokenizer"
assert captured["yaml_config"] == {
"openai": {"tokenizer_name": "teacher-tokenizer"}
}
assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"}

View file

@ -32,13 +32,6 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv):
max_token_length=2048,
wandb_name="gsm8k_teacher_distill",
teacher_enabled=True,
teacher_server=APIServerConfig(
base_url="http://localhost:8003/v1",
model_name="mock-teacher",
api_key="",
server_type="vllm",
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
),
teacher_top_k=4,
)
server_config = APIServerConfig(
@ -49,6 +42,17 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv):
)
return env_config, server_config
@classmethod
def teacher_config_init(cls) -> APIServerConfig:
return APIServerConfig(
base_url="http://localhost:9003/v1",
model_name="mock-teacher",
api_key="",
server_type="vllm",
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
timeout=1200,
)
if __name__ == "__main__":
GSM8kTeacherDistillEnv.cli()

View file

@ -314,14 +314,14 @@ What to configure on the environment side:
```bash
--env.teacher_enabled true \
--env.teacher_server.base_url "http://localhost:9003/v1" \
--env.teacher_server.model_name "$TEACHER_MODEL" \
--env.teacher_server.server_type vllm \
--teacher.base_url "http://localhost:9003/v1" \
--teacher.model_name "$TEACHER_MODEL" \
--teacher.server_type vllm \
--env.teacher_top_k 8
```
If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier,
also set `--env.teacher_server.tokenizer_name ...` so the env can validate
also set `--teacher.tokenizer_name ...` so the env can validate
tokenizer compatibility.
Why cross-tokenizer conversion is not acceptable here:

View file

@ -234,9 +234,9 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \
--env.use_wandb true \
--env.wandb_name "gsm8k-teacher-distill" \
--env.teacher_enabled true \
--env.teacher_server.base_url "http://localhost:${TEACHER_PORT}/v1" \
--env.teacher_server.model_name "$TEACHER_MODEL" \
--env.teacher_server.server_type vllm \
--teacher.base_url "http://localhost:${TEACHER_PORT}/v1" \
--teacher.model_name "$TEACHER_MODEL" \
--teacher.server_type vllm \
--env.teacher_top_k "$TEACHER_TOP_K" \
--env.ensure_scores_are_not_same false \
--openai.api_key "dummy" \