mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fresh eyes check
This commit is contained in:
parent
805a0c0eac
commit
7aba0d3fc8
2 changed files with 48 additions and 2 deletions
|
|
@ -285,7 +285,7 @@ class TeacherDistillationEnv(BaseEnv, ABC):
|
|||
config: TeacherDistillationConfig,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
teacher_server_configs: Optional[
|
||||
Union[ServerBaseline, List[APIServerConfig]]
|
||||
Union[ServerBaseline, APIServerConfig, List[APIServerConfig]]
|
||||
] = None,
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
|
|
@ -302,6 +302,9 @@ class TeacherDistillationEnv(BaseEnv, ABC):
|
|||
"path with --teacher.* flags. The generic BaseEnv 'process' and "
|
||||
"'evaluate' commands do not currently wire teacher_server_configs."
|
||||
)
|
||||
if isinstance(teacher_server_configs, APIServerConfig):
|
||||
teacher_config_source = [teacher_server_configs]
|
||||
else:
|
||||
teacher_config_source = teacher_server_configs
|
||||
self.teacher_server = ServerManager(
|
||||
teacher_config_source,
|
||||
|
|
|
|||
|
|
@ -257,6 +257,49 @@ def test_init_uses_explicit_teacher_server_configs(monkeypatch):
|
|||
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
|
||||
|
||||
|
||||
def test_init_wraps_bare_teacher_api_server_config(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
called = {}
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
def _fake_validate(self, teacher_tokenizer_name):
|
||||
called["teacher_tokenizer_name"] = teacher_tokenizer_name
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
|
||||
monkeypatch.setattr(
|
||||
_ConcreteTeacherEnv,
|
||||
"_validate_teacher_tokenizer_compatibility",
|
||||
_fake_validate,
|
||||
)
|
||||
|
||||
explicit_cfg = APIServerConfig(
|
||||
model_name="explicit-model",
|
||||
tokenizer_name="explicit-tokenizer",
|
||||
base_url="http://explicit/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
|
||||
env = _ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
teacher_server_configs=explicit_cfg,
|
||||
)
|
||||
|
||||
assert isinstance(env.teacher_server, _CapturingServerManager)
|
||||
assert env.teacher_server.configs == [explicit_cfg]
|
||||
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
|
||||
|
||||
|
||||
def test_resolve_teacher_server_configs_returns_none_when_unset():
|
||||
assert (
|
||||
_ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue