diff --git a/atroposlib/envs/teacher_distillation_env.py b/atroposlib/envs/teacher_distillation_env.py index 64f62b14..1f0e2110 100644 --- a/atroposlib/envs/teacher_distillation_env.py +++ b/atroposlib/envs/teacher_distillation_env.py @@ -285,7 +285,7 @@ class TeacherDistillationEnv(BaseEnv, ABC): config: TeacherDistillationConfig, server_configs: Union[ServerBaseline, List[APIServerConfig]], teacher_server_configs: Optional[ - Union[ServerBaseline, List[APIServerConfig]] + Union[ServerBaseline, APIServerConfig, List[APIServerConfig]] ] = None, slurm: bool = False, testing: bool = False, @@ -302,7 +302,10 @@ class TeacherDistillationEnv(BaseEnv, ABC): "path with --teacher.* flags. The generic BaseEnv 'process' and " "'evaluate' commands do not currently wire teacher_server_configs." ) - teacher_config_source = teacher_server_configs + if isinstance(teacher_server_configs, APIServerConfig): + teacher_config_source = [teacher_server_configs] + else: + teacher_config_source = teacher_server_configs self.teacher_server = ServerManager( teacher_config_source, slurm=False, diff --git a/atroposlib/tests/test_teacher_distillation_env.py b/atroposlib/tests/test_teacher_distillation_env.py index 11c586d9..c8825218 100644 --- a/atroposlib/tests/test_teacher_distillation_env.py +++ b/atroposlib/tests/test_teacher_distillation_env.py @@ -257,6 +257,49 @@ def test_init_uses_explicit_teacher_server_configs(monkeypatch): assert called["teacher_tokenizer_name"] == "explicit-tokenizer" +def test_init_wraps_bare_teacher_api_server_config(monkeypatch): + from atroposlib.envs import teacher_distillation_env as module + + called = {} + + def _fake_base_init(self, config, server_configs, slurm=False, testing=False): + self.config = config + self.tokenizer = _DummyTokenizer() + + def _fake_validate(self, teacher_tokenizer_name): + called["teacher_tokenizer_name"] = teacher_tokenizer_name + + monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init) + monkeypatch.setattr(module, "ServerManager", _CapturingServerManager) + monkeypatch.setattr( + _ConcreteTeacherEnv, + "_validate_teacher_tokenizer_compatibility", + _fake_validate, + ) + + explicit_cfg = APIServerConfig( + model_name="explicit-model", + tokenizer_name="explicit-tokenizer", + base_url="http://explicit/v1", + api_key="x", + server_type="vllm", + ) + config = SimpleNamespace( + teacher_enabled=True, + teacher_top_k=0, + ) + + env = _ConcreteTeacherEnv( + config=config, + server_configs=[], + teacher_server_configs=explicit_cfg, + ) + + assert isinstance(env.teacher_server, _CapturingServerManager) + assert env.teacher_server.configs == [explicit_cfg] + assert called["teacher_tokenizer_name"] == "explicit-tokenizer" + + def test_resolve_teacher_server_configs_returns_none_when_unset(): assert ( _ConcreteTeacherEnv._resolve_teacher_server_configs(