fresh eyes check

This commit is contained in:
Jai Suphavadeeprasit 2026-03-14 11:20:15 -04:00
parent 805a0c0eac
commit 7aba0d3fc8
2 changed files with 48 additions and 2 deletions

View file

@ -285,7 +285,7 @@ class TeacherDistillationEnv(BaseEnv, ABC):
config: TeacherDistillationConfig,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
teacher_server_configs: Optional[
Union[ServerBaseline, List[APIServerConfig]]
Union[ServerBaseline, APIServerConfig, List[APIServerConfig]]
] = None,
slurm: bool = False,
testing: bool = False,
@ -302,7 +302,10 @@ class TeacherDistillationEnv(BaseEnv, ABC):
"path with --teacher.* flags. The generic BaseEnv 'process' and "
"'evaluate' commands do not currently wire teacher_server_configs."
)
teacher_config_source = teacher_server_configs
if isinstance(teacher_server_configs, APIServerConfig):
teacher_config_source = [teacher_server_configs]
else:
teacher_config_source = teacher_server_configs
self.teacher_server = ServerManager(
teacher_config_source,
slurm=False,

View file

@ -257,6 +257,49 @@ def test_init_uses_explicit_teacher_server_configs(monkeypatch):
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
def test_init_wraps_bare_teacher_api_server_config(monkeypatch):
from atroposlib.envs import teacher_distillation_env as module
called = {}
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
self.config = config
self.tokenizer = _DummyTokenizer()
def _fake_validate(self, teacher_tokenizer_name):
called["teacher_tokenizer_name"] = teacher_tokenizer_name
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
monkeypatch.setattr(
_ConcreteTeacherEnv,
"_validate_teacher_tokenizer_compatibility",
_fake_validate,
)
explicit_cfg = APIServerConfig(
model_name="explicit-model",
tokenizer_name="explicit-tokenizer",
base_url="http://explicit/v1",
api_key="x",
server_type="vllm",
)
config = SimpleNamespace(
teacher_enabled=True,
teacher_top_k=0,
)
env = _ConcreteTeacherEnv(
config=config,
server_configs=[],
teacher_server_configs=explicit_cfg,
)
assert isinstance(env.teacher_server, _CapturingServerManager)
assert env.teacher_server.configs == [explicit_cfg]
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
def test_resolve_teacher_server_configs_returns_none_when_unset():
assert (
_ConcreteTeacherEnv._resolve_teacher_server_configs(