diff --git a/README.md b/README.md index 076f886a..a0885cdd 100644 --- a/README.md +++ b/README.md @@ -308,30 +308,43 @@ Teacher config shape: ```python TeacherDistillationConfig( teacher_enabled=True, - teacher_server=APIServerConfig( - base_url="http://localhost:9003/v1", - model_name="Qwen/Qwen3-30B-A3B-Instruct-2507", - api_key="", - server_type="vllm", - ), teacher_top_k=8, ) ``` -If `teacher_server.model_name` is a deployment alias rather than a tokenizer -identifier, set `teacher_server.tokenizer_name` explicitly so the env can -validate tokenizer compatibility. +Teacher server configs are passed separately at init, just like the primary +`server_configs`: + +```python +env = MyTeacherEnv( + config=env_config, + server_configs=student_server_configs, + teacher_server_configs=[ + APIServerConfig( + base_url="http://localhost:9003/v1", + model_name="Qwen/Qwen3-30B-A3B-Instruct-2507", + api_key="", + server_type="vllm", + tokenizer_name="Qwen/Qwen3-30B-A3B-Instruct-2507", + ) + ], +) +``` CLI shape: ```bash --env.teacher_enabled true \ ---env.teacher_server.base_url "http://localhost:9003/v1" \ ---env.teacher_server.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \ ---env.teacher_server.server_type vllm \ +--teacher.base_url "http://localhost:9003/v1" \ +--teacher.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \ +--teacher.server_type vllm \ --env.teacher_top_k 8 ``` +If `--teacher.model_name` is a deployment alias rather than a tokenizer +identifier, also set `--teacher.tokenizer_name ...` so the env can validate +tokenizer compatibility. + Tokenizer requirement: - Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary. diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index fdac54c5..1b3cda8f 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -15,13 +15,24 @@ from __future__ import annotations import asyncio import logging from abc import ABC -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union +import yaml from pydantic import Field +from pydantic_cli import Cmd +from rich import print as rprint from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup +from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE +from .server_handling.openai_server import resolve_openai_configs from .server_handling.server_baseline import APIServerConfig, ServerBaseline -from .server_handling.server_manager import ServerManager +from .server_handling.server_manager import ServerManager, ServerManagerConfig +from ..utils.cli import ( + extract_namespace, + get_double_dash_flags, + get_prefixed_pydantic_model, + merge_dicts, +) logger = logging.getLogger(__name__) @@ -31,10 +42,6 @@ class TeacherDistillationConfig(BaseEnvConfig): default=False, description="Whether to fetch teacher prompt logprobs for distillation.", ) - teacher_server: Optional[APIServerConfig] = Field( - default=None, - description="Fallback teacher server configuration when not provided at init.", - ) teacher_top_k: int = Field( default=0, ge=-1, @@ -56,6 +63,220 @@ class TeacherDistillationEnv(BaseEnv, ABC): """ env_config_cls = TeacherDistillationConfig + teacher_namespace = "teacher" + + @classmethod + def teacher_config_init( + cls, + ) -> Optional[Union[ServerBaseline, List[APIServerConfig], APIServerConfig]]: + return None + + @classmethod + def _resolve_teacher_server_configs( + cls, + default_teacher_server_configs: Optional[ + Union[ServerBaseline, List[APIServerConfig], APIServerConfig] + ], + yaml_config: Dict[str, Any], + cli_passed_flags: Dict[str, Any], + ) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]: + teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}" + teacher_cli_passed_args = extract_namespace(cli_passed_flags, teacher_full_prefix) + yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {}) + + if ( + default_teacher_server_configs is None + and not teacher_cli_passed_args + and not yaml_teacher_config + ): + return None + + effective_teacher_server_configs = default_teacher_server_configs + if effective_teacher_server_configs is None: + effective_teacher_server_configs = APIServerConfig() + elif isinstance(effective_teacher_server_configs, ServerBaseline) and ( + teacher_cli_passed_args or yaml_teacher_config + ): + effective_teacher_server_configs = APIServerConfig( + **effective_teacher_server_configs.model_dump() + ) + + if ( + isinstance(effective_teacher_server_configs, list) + and len(effective_teacher_server_configs) == 1 + ): + default_teacher_config = effective_teacher_server_configs[0] + else: + default_teacher_config = effective_teacher_server_configs + + if isinstance(yaml_teacher_config, list) and len(yaml_teacher_config) == 1: + yaml_teacher_config = yaml_teacher_config[0] + + if isinstance(default_teacher_config, APIServerConfig) and isinstance( + yaml_teacher_config, dict + ): + teacher_config_dict = merge_dicts( + default_teacher_config.model_dump(), + yaml_teacher_config, + teacher_cli_passed_args, + ) + else: + teacher_config_dict = {} + + teacher_yaml_wrapped = {OPENAI_NAMESPACE: yaml_teacher_config} + teacher_cli_wrapped = { + f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}{key}": value + for key, value in teacher_cli_passed_args.items() + } + return resolve_openai_configs( + default_server_configs=effective_teacher_server_configs, + openai_config_dict=teacher_config_dict, + yaml_config=teacher_yaml_wrapped, + cli_passed_flags=teacher_cli_wrapped, + logger=logger, + ) + + @classmethod + def get_cli_serve_config_cls(cls) -> type: + default_env_config, default_server_configs = cls.config_init() + default_teacher_server_configs = cls.teacher_config_init() + + env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}" + openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" + teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}" + teacher_cli_base = get_prefixed_pydantic_model( + APIServerConfig, teacher_full_prefix + ) + + class CliServeConfig( + get_prefixed_pydantic_model(type(default_env_config), env_full_prefix), + get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix), + teacher_cli_base, + ServerManagerConfig, + Cmd, + ): + config: str | None = Field( + default=None, + description="Path to .yaml config file. CLI args override this.", + ) + + def run(self) -> None: + wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name" + if ( + getattr(self, wandb_name_attr, None) is None + and cls.name is not None + ): + setattr(self, wandb_name_attr, cls.name) + + if self.config is not None: + with open(self.config, "r") as f: + yaml_config = yaml.safe_load(f) + logger.info("Loaded config from %s", self.config) + else: + yaml_config = {} + + cli_passed_flags = get_double_dash_flags() + + env_config_dict = merge_dicts( + default_env_config.model_dump(), + yaml_config.get(ENV_NAMESPACE, {}), + extract_namespace(cli_passed_flags, env_full_prefix), + ) + + oai_cli_passed_args = extract_namespace( + cli_passed_flags, openai_full_prefix + ) + yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) + + effective_server_configs = default_server_configs + if isinstance(effective_server_configs, ServerBaseline) and ( + oai_cli_passed_args or yaml_oai_config + ): + effective_server_configs = APIServerConfig( + **effective_server_configs.model_dump() + ) + + if ( + isinstance(effective_server_configs, list) + and len(effective_server_configs) == 1 + ): + default_openai_config_ = effective_server_configs[0] + else: + default_openai_config_ = effective_server_configs + + if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: + yaml_oai_config = yaml_oai_config[0] + + if isinstance(default_openai_config_, APIServerConfig) and isinstance( + yaml_oai_config, dict + ): + openai_config_dict = merge_dicts( + default_openai_config_.model_dump(), + yaml_oai_config, + oai_cli_passed_args, + ) + else: + openai_config_dict = {} + + server_manager_cli_passed_flags = {} + if "slurm" in cli_passed_flags: + server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"] + if "testing" in cli_passed_flags: + server_manager_cli_passed_flags["testing"] = cli_passed_flags[ + "testing" + ] + + server_manager_yaml_dict = {} + if "slurm" in yaml_config: + server_manager_yaml_dict["slurm"] = yaml_config["slurm"] + if "testing" in yaml_config: + server_manager_yaml_dict["testing"] = yaml_config["testing"] + + server_manager_config_dict = merge_dicts( + ServerManagerConfig().model_dump(), + server_manager_yaml_dict, + server_manager_cli_passed_flags, + ) + + env_config = type(default_env_config)(**env_config_dict) + server_manager_config = ServerManagerConfig( + **server_manager_config_dict + ) + openai_configs = resolve_openai_configs( + default_server_configs=effective_server_configs, + openai_config_dict=openai_config_dict, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + logger=logger, + ) + teacher_configs = cls._resolve_teacher_server_configs( + default_teacher_server_configs=default_teacher_server_configs, + yaml_config=yaml_config, + cli_passed_flags=cli_passed_flags, + ) + + env_kwargs = { + "config": env_config, + "server_configs": openai_configs, + "slurm": server_manager_config.slurm, + "testing": server_manager_config.testing, + } + if teacher_configs is not None: + env_kwargs["teacher_server_configs"] = teacher_configs + env = cls(**env_kwargs) + rprint(env_config) + rprint(openai_configs) + if teacher_configs is not None: + rprint(teacher_configs) + + try: + loop = asyncio.get_running_loop() + task = loop.create_task(env.env_manager()) + loop.run_until_complete(task) + except RuntimeError: + asyncio.run(env.env_manager()) + + return CliServeConfig def __init__( self, @@ -71,26 +292,11 @@ class TeacherDistillationEnv(BaseEnv, ABC): self.teacher_server: Optional[ServerManager] = None if config.teacher_enabled: - teacher_config_source = teacher_server_configs - if teacher_config_source is None and config.teacher_server is not None: - teacher_config_source = [ - config.teacher_server.model_copy( - update={ - "tokenizer_name": ( - config.teacher_server.model_name - if config.teacher_server.tokenizer_name in ("", "none") - else config.teacher_server.tokenizer_name - ), - "timeout": 1200, - } - ) - ] - - if teacher_config_source is None: + if teacher_server_configs is None: raise ValueError( - "teacher_enabled=True requires teacher_server_configs at init " - "or a fallback teacher_server config." + "teacher_enabled=True requires teacher_server_configs at init." ) + teacher_config_source = teacher_server_configs self.teacher_server = ServerManager( teacher_config_source, slurm=False, diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index 7c8cb439..c789670d 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -4,6 +4,7 @@ from types import SimpleNamespace import pytest +from atroposlib.envs.server_handling.server_baseline import APIServerConfig from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv @@ -32,6 +33,20 @@ class _ConcreteTeacherEnv(TeacherDistillationEnv): return None +class _DummyTokenizer: + name_or_path = "student-model" + + def get_vocab(self): + return {"a": 1} + + +class _CapturingServerManager: + def __init__(self, configs, slurm=False, testing=False): + self.configs = configs + self.slurm = slurm + self.testing = testing + + @pytest.mark.asyncio async def test_attach_teacher_distillation_success(): env = object.__new__(_ConcreteTeacherEnv) @@ -105,6 +120,32 @@ async def test_attach_teacher_distillation_zero_topk_passthrough(): assert out["distill_logprobs"] is not None +@pytest.mark.asyncio +async def test_attach_teacher_distillation_group_override_topk_is_used(): + env = object.__new__(_ConcreteTeacherEnv) + env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0) + + seen_topks = [] + + async def _fake_fetch(seq, top_k): + seen_topks.append(top_k) + return [[tok] for tok in seq], [[-0.1] for _ in seq] + + env.teacher_server = object() + env._fetch_teacher_for_sequence = _fake_fetch + + group = { + "tokens": [[1, 2, 3], [4, 5]], + "group_overrides": {"teacher_top_k": 7}, + "masks": [[-100, 2, 3], [-100, 5]], + "scores": [1.0, 0.0], + } + out = await TeacherDistillationEnv._attach_teacher_distillation(env, group) + assert seen_topks == [7, 7] + assert out["distill_token_ids"] is not None + assert out["distill_logprobs"] is not None + + @pytest.mark.asyncio async def test_attach_teacher_distillation_group_override_can_skip_fetch(): env = object.__new__(_ConcreteTeacherEnv) @@ -149,3 +190,111 @@ def test_teacher_tokenizer_mismatch_raises(monkeypatch): env, teacher_tokenizer_name="teacher-model", ) + + +def test_init_requires_teacher_server_source(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + def _fake_base_init(self, config, server_configs, slurm=False, testing=False): + self.config = config + self.tokenizer = _DummyTokenizer() + + monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init) + + config = SimpleNamespace( + teacher_enabled=True, + teacher_top_k=0, + ) + with pytest.raises(ValueError, match="teacher_enabled=True requires"): + _ConcreteTeacherEnv( + config=config, + server_configs=[], + ) + + +def test_init_uses_explicit_teacher_server_configs(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + called = {} + + def _fake_base_init(self, config, server_configs, slurm=False, testing=False): + self.config = config + self.tokenizer = _DummyTokenizer() + + def _fake_validate(self, teacher_tokenizer_name): + called["teacher_tokenizer_name"] = teacher_tokenizer_name + + monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init) + monkeypatch.setattr(module, "ServerManager", _CapturingServerManager) + monkeypatch.setattr( + _ConcreteTeacherEnv, + "_validate_teacher_tokenizer_compatibility", + _fake_validate, + ) + + explicit_cfg = APIServerConfig( + model_name="explicit-model", + tokenizer_name="explicit-tokenizer", + base_url="http://explicit/v1", + api_key="x", + server_type="vllm", + ) + config = SimpleNamespace( + teacher_enabled=True, + teacher_top_k=0, + ) + + env = _ConcreteTeacherEnv( + config=config, + server_configs=[], + teacher_server_configs=[explicit_cfg], + ) + + assert isinstance(env.teacher_server, _CapturingServerManager) + assert env.teacher_server.configs == [explicit_cfg] + assert called["teacher_tokenizer_name"] == "explicit-tokenizer" + + +def test_resolve_teacher_server_configs_returns_none_when_unset(): + assert ( + _ConcreteTeacherEnv._resolve_teacher_server_configs( + default_teacher_server_configs=None, + yaml_config={}, + cli_passed_flags={}, + ) + is None + ) + + +def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + captured = {} + + def _fake_resolve(**kwargs): + captured.update(kwargs) + return ["resolved"] + + monkeypatch.setattr(module, "resolve_openai_configs", _fake_resolve) + + default_cfg = APIServerConfig( + model_name="teacher-model", + base_url="http://teacher/v1", + api_key="x", + server_type="vllm", + ) + + out = _ConcreteTeacherEnv._resolve_teacher_server_configs( + default_teacher_server_configs=default_cfg, + yaml_config={"teacher": {"tokenizer_name": "teacher-tokenizer"}}, + cli_passed_flags={"teacher.base_url": "http://override/v1"}, + ) + + assert out == ["resolved"] + assert captured["openai_config_dict"]["base_url"] == "http://override/v1" + assert captured["openai_config_dict"]["tokenizer_name"] == "teacher-tokenizer" + assert captured["yaml_config"] == { + "openai": {"tokenizer_name": "teacher-tokenizer"} + } + assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"} + diff --git a/environments/gsm8k_server_teacher_distill.py b/environments/gsm8k_server_teacher_distill.py index 5aa33a01..59106f10 100644 --- a/environments/gsm8k_server_teacher_distill.py +++ b/environments/gsm8k_server_teacher_distill.py @@ -32,13 +32,6 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): max_token_length=2048, wandb_name="gsm8k_teacher_distill", teacher_enabled=True, - teacher_server=APIServerConfig( - base_url="http://localhost:8003/v1", - model_name="mock-teacher", - api_key="", - server_type="vllm", - tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", - ), teacher_top_k=4, ) server_config = APIServerConfig( @@ -49,6 +42,17 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv): ) return env_config, server_config + @classmethod + def teacher_config_init(cls) -> APIServerConfig: + return APIServerConfig( + base_url="http://localhost:9003/v1", + model_name="mock-teacher", + api_key="", + server_type="vllm", + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + timeout=1200, + ) + if __name__ == "__main__": GSM8kTeacherDistillEnv.cli() diff --git a/example_trainer/README.md b/example_trainer/README.md index 8596a849..de31b3eb 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -314,14 +314,14 @@ What to configure on the environment side: ```bash --env.teacher_enabled true \ ---env.teacher_server.base_url "http://localhost:9003/v1" \ ---env.teacher_server.model_name "$TEACHER_MODEL" \ ---env.teacher_server.server_type vllm \ +--teacher.base_url "http://localhost:9003/v1" \ +--teacher.model_name "$TEACHER_MODEL" \ +--teacher.server_type vllm \ --env.teacher_top_k 8 ``` If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier, -also set `--env.teacher_server.tokenizer_name ...` so the env can validate +also set `--teacher.tokenizer_name ...` so the env can validate tokenizer compatibility. Why cross-tokenizer conversion is not acceptable here: diff --git a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh index 91cecf8a..fead9ba6 100755 --- a/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh +++ b/example_trainer/run_gsm8k_teacher_distill_single_terminal.sh @@ -234,9 +234,9 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \ --env.use_wandb true \ --env.wandb_name "gsm8k-teacher-distill" \ --env.teacher_enabled true \ - --env.teacher_server.base_url "http://localhost:${TEACHER_PORT}/v1" \ - --env.teacher_server.model_name "$TEACHER_MODEL" \ - --env.teacher_server.server_type vllm \ + --teacher.base_url "http://localhost:${TEACHER_PORT}/v1" \ + --teacher.model_name "$TEACHER_MODEL" \ + --teacher.server_type vllm \ --env.teacher_top_k "$TEACHER_TOP_K" \ --env.ensure_scores_are_not_same false \ --openai.api_key "dummy" \