mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
Merge branch 'NousResearch:main' into feat/reward-normalization
This commit is contained in:
commit
39fb1d6870
8 changed files with 1285 additions and 4 deletions
459
atroposlib/envs/teacher_distillation_env.py
Normal file
459
atroposlib/envs/teacher_distillation_env.py
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
"""
|
||||
Teacher distillation environment layer.
|
||||
|
||||
This module adds teacher prompt-logprob fetching on top of BaseEnv without
|
||||
modifying BaseEnv transport behavior.
|
||||
|
||||
This implementation supports same-tokenizer distillation only. The teacher and
|
||||
student must share the same tokenizer vocabulary so the student's token IDs can
|
||||
be forwarded directly to the teacher and the returned teacher top-k token IDs
|
||||
can be looked up directly in the student's logits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
from pydantic_cli import Cmd
|
||||
from rich import print as rprint
|
||||
|
||||
from ..utils.cli import (
|
||||
extract_namespace,
|
||||
get_double_dash_flags,
|
||||
get_prefixed_pydantic_model,
|
||||
merge_dicts,
|
||||
)
|
||||
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
|
||||
from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from .server_handling.openai_server import resolve_openai_configs
|
||||
from .server_handling.server_baseline import APIServerConfig, ServerBaseline
|
||||
from .server_handling.server_manager import ServerManager, ServerManagerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeacherDistillationConfig(BaseEnvConfig):
|
||||
teacher_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to fetch teacher prompt logprobs for distillation.",
|
||||
)
|
||||
teacher_top_k: int = Field(
|
||||
default=0,
|
||||
ge=-1,
|
||||
description=(
|
||||
"Number of extra prompt logprobs to fetch beyond the selected token. "
|
||||
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
|
||||
"teacher fetching."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TeacherDistillationEnv(BaseEnv, ABC):
|
||||
"""
|
||||
BaseEnv subclass that enriches scored groups with teacher distillation arrays.
|
||||
|
||||
Distillation payload shape:
|
||||
- distill_token_ids: [sequence][position][k] (student vocab IDs)
|
||||
- distill_logprobs: [sequence][position][k]
|
||||
"""
|
||||
|
||||
env_config_cls = TeacherDistillationConfig
|
||||
teacher_namespace = "teacher"
|
||||
|
||||
@classmethod
|
||||
def teacher_config_init(
|
||||
cls,
|
||||
) -> Optional[Union[ServerBaseline, List[APIServerConfig], APIServerConfig]]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _resolve_teacher_server_configs(
|
||||
cls,
|
||||
default_teacher_server_configs: Optional[
|
||||
Union[ServerBaseline, List[APIServerConfig], APIServerConfig]
|
||||
],
|
||||
yaml_config: Dict[str, Any],
|
||||
cli_passed_flags: Dict[str, Any],
|
||||
) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]:
|
||||
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
|
||||
teacher_cli_passed_args = extract_namespace(
|
||||
cli_passed_flags, teacher_full_prefix
|
||||
)
|
||||
yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {})
|
||||
|
||||
if (
|
||||
default_teacher_server_configs is None
|
||||
and not teacher_cli_passed_args
|
||||
and not yaml_teacher_config
|
||||
):
|
||||
return None
|
||||
|
||||
effective_teacher_server_configs = default_teacher_server_configs
|
||||
if effective_teacher_server_configs is None:
|
||||
effective_teacher_server_configs = APIServerConfig()
|
||||
elif isinstance(effective_teacher_server_configs, ServerBaseline) and (
|
||||
teacher_cli_passed_args or yaml_teacher_config
|
||||
):
|
||||
effective_teacher_server_configs = APIServerConfig(
|
||||
**effective_teacher_server_configs.model_dump()
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(effective_teacher_server_configs, list)
|
||||
and len(effective_teacher_server_configs) == 1
|
||||
):
|
||||
default_teacher_config = effective_teacher_server_configs[0]
|
||||
else:
|
||||
default_teacher_config = effective_teacher_server_configs
|
||||
|
||||
if isinstance(yaml_teacher_config, list) and len(yaml_teacher_config) == 1:
|
||||
yaml_teacher_config = yaml_teacher_config[0]
|
||||
|
||||
if isinstance(default_teacher_config, APIServerConfig) and isinstance(
|
||||
yaml_teacher_config, dict
|
||||
):
|
||||
teacher_config_dict = merge_dicts(
|
||||
default_teacher_config.model_dump(),
|
||||
yaml_teacher_config,
|
||||
teacher_cli_passed_args,
|
||||
)
|
||||
else:
|
||||
teacher_config_dict = {}
|
||||
|
||||
teacher_yaml_wrapped = {OPENAI_NAMESPACE: yaml_teacher_config}
|
||||
teacher_cli_wrapped = {
|
||||
f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}{key}": value
|
||||
for key, value in teacher_cli_passed_args.items()
|
||||
}
|
||||
return resolve_openai_configs(
|
||||
default_server_configs=effective_teacher_server_configs,
|
||||
openai_config_dict=teacher_config_dict,
|
||||
yaml_config=teacher_yaml_wrapped,
|
||||
cli_passed_flags=teacher_cli_wrapped,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cli_serve_config_cls(cls) -> type:
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
default_teacher_server_configs = cls.teacher_config_init()
|
||||
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
|
||||
teacher_cli_base = get_prefixed_pydantic_model(
|
||||
APIServerConfig, teacher_full_prefix
|
||||
)
|
||||
|
||||
class CliServeConfig(
|
||||
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix),
|
||||
teacher_cli_base,
|
||||
ServerManagerConfig,
|
||||
Cmd,
|
||||
):
|
||||
config: str | None = Field(
|
||||
default=None,
|
||||
description="Path to .yaml config file. CLI args override this.",
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
|
||||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
yaml_config = yaml.safe_load(f)
|
||||
logger.info("Loaded config from %s", self.config)
|
||||
else:
|
||||
yaml_config = {}
|
||||
|
||||
cli_passed_flags = get_double_dash_flags()
|
||||
|
||||
env_config_dict = merge_dicts(
|
||||
default_env_config.model_dump(),
|
||||
yaml_config.get(ENV_NAMESPACE, {}),
|
||||
extract_namespace(cli_passed_flags, env_full_prefix),
|
||||
)
|
||||
|
||||
oai_cli_passed_args = extract_namespace(
|
||||
cli_passed_flags, openai_full_prefix
|
||||
)
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
|
||||
effective_server_configs = default_server_configs
|
||||
if isinstance(effective_server_configs, ServerBaseline) and (
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
effective_server_configs = APIServerConfig(
|
||||
**effective_server_configs.model_dump()
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(effective_server_configs, list)
|
||||
and len(effective_server_configs) == 1
|
||||
):
|
||||
default_openai_config_ = effective_server_configs[0]
|
||||
else:
|
||||
default_openai_config_ = effective_server_configs
|
||||
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(),
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
else:
|
||||
openai_config_dict = {}
|
||||
|
||||
server_manager_cli_passed_flags = {}
|
||||
if "slurm" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||
if "testing" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
|
||||
"testing"
|
||||
]
|
||||
|
||||
server_manager_yaml_dict = {}
|
||||
if "slurm" in yaml_config:
|
||||
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
|
||||
if "testing" in yaml_config:
|
||||
server_manager_yaml_dict["testing"] = yaml_config["testing"]
|
||||
|
||||
server_manager_config_dict = merge_dicts(
|
||||
ServerManagerConfig().model_dump(),
|
||||
server_manager_yaml_dict,
|
||||
server_manager_cli_passed_flags,
|
||||
)
|
||||
|
||||
env_config = type(default_env_config)(**env_config_dict)
|
||||
server_manager_config = ServerManagerConfig(
|
||||
**server_manager_config_dict
|
||||
)
|
||||
openai_configs = resolve_openai_configs(
|
||||
default_server_configs=effective_server_configs,
|
||||
openai_config_dict=openai_config_dict,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
logger=logger,
|
||||
)
|
||||
teacher_configs = cls._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=default_teacher_server_configs,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
)
|
||||
|
||||
env_kwargs = {
|
||||
"config": env_config,
|
||||
"server_configs": openai_configs,
|
||||
"slurm": server_manager_config.slurm,
|
||||
"testing": server_manager_config.testing,
|
||||
}
|
||||
if teacher_configs is not None:
|
||||
env_kwargs["teacher_server_configs"] = teacher_configs
|
||||
env = cls(**env_kwargs)
|
||||
rprint(env_config)
|
||||
rprint(openai_configs)
|
||||
if teacher_configs is not None:
|
||||
rprint(teacher_configs)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = loop.create_task(env.env_manager())
|
||||
loop.run_until_complete(task)
|
||||
except RuntimeError:
|
||||
asyncio.run(env.env_manager())
|
||||
|
||||
return CliServeConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TeacherDistillationConfig,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
teacher_server_configs: Optional[
|
||||
Union[ServerBaseline, APIServerConfig, List[APIServerConfig]]
|
||||
] = None,
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm=slurm, testing=testing)
|
||||
self.teacher_server: Optional[ServerManager] = None
|
||||
|
||||
if config.teacher_enabled:
|
||||
if teacher_server_configs is None:
|
||||
raise ValueError(
|
||||
"teacher_enabled=True but no teacher server configuration was "
|
||||
"provided. Pass teacher_server_configs=... when instantiating "
|
||||
"the environment directly, or use the teacher-aware 'serve' CLI "
|
||||
"path with --teacher.* flags. The generic BaseEnv 'process' and "
|
||||
"'evaluate' commands do not currently wire teacher_server_configs."
|
||||
)
|
||||
if isinstance(teacher_server_configs, APIServerConfig):
|
||||
teacher_config_source = [teacher_server_configs]
|
||||
else:
|
||||
teacher_config_source = teacher_server_configs
|
||||
self.teacher_server = ServerManager(
|
||||
teacher_config_source,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
)
|
||||
if isinstance(teacher_config_source, list):
|
||||
teacher_cfg = teacher_config_source[0]
|
||||
else:
|
||||
teacher_cfg = teacher_config_source
|
||||
|
||||
teacher_tokenizer_name = (
|
||||
teacher_cfg.model_name
|
||||
if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none")
|
||||
else teacher_cfg.tokenizer_name
|
||||
)
|
||||
self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core fetch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_teacher_tokenizer_compatibility(
|
||||
self, teacher_tokenizer_name: str
|
||||
) -> None:
|
||||
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
|
||||
if student_tok_name == teacher_tokenizer_name:
|
||||
return
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
teacher_tokenizer = AutoTokenizer.from_pretrained(
|
||||
teacher_tokenizer_name, use_fast=True
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Cross-tokenizer distillation is not supported in this PR, and the "
|
||||
f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to "
|
||||
f"verify compatibility: {exc}"
|
||||
) from exc
|
||||
|
||||
student_vocab = self.tokenizer.get_vocab()
|
||||
teacher_vocab = teacher_tokenizer.get_vocab()
|
||||
if student_vocab != teacher_vocab:
|
||||
raise ValueError(
|
||||
"Cross-tokenizer distillation is not supported in this PR. "
|
||||
f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' "
|
||||
f"and teacher tokenizer '{teacher_tokenizer_name}' do not match."
|
||||
)
|
||||
|
||||
async def _fetch_teacher_for_sequence(
|
||||
self, token_ids: List[int], top_k: int
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
assert self.teacher_server is not None
|
||||
payload = await self.teacher_server.get_logprobs(
|
||||
input_ids=token_ids,
|
||||
top_k=top_k,
|
||||
max_tokens=1,
|
||||
split="train",
|
||||
)
|
||||
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group enrichment
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _attach_teacher_distillation(
|
||||
self, group: ScoredDataGroup
|
||||
) -> ScoredDataGroup:
|
||||
if not self.config.teacher_enabled or self.teacher_server is None:
|
||||
return group
|
||||
|
||||
seqs = group.get("tokens", [])
|
||||
if not seqs:
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
group_overrides = group.get("group_overrides") or {}
|
||||
if group_overrides.get("skip_teacher_top_k", False):
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
top_k = int(group_overrides.get("teacher_top_k", self.config.teacher_top_k))
|
||||
if top_k <= -1:
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
distill_token_ids: List[List[List[int]]] = []
|
||||
distill_logprobs: List[List[List[float]]] = []
|
||||
for idx, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(
|
||||
"Teacher logprob fetch failed for seq %s: %s. "
|
||||
"Dropping distill payload for this group.",
|
||||
idx,
|
||||
result,
|
||||
)
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
token_ids_k, logprobs_k = result
|
||||
if len(token_ids_k) != len(logprobs_k):
|
||||
logger.warning(
|
||||
"Teacher prompt-topk length mismatch for seq %s (%s != %s). "
|
||||
"Dropping distill payload for this group.",
|
||||
idx,
|
||||
len(token_ids_k),
|
||||
len(logprobs_k),
|
||||
)
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
distill_token_ids.append(token_ids_k)
|
||||
distill_logprobs.append(logprobs_k)
|
||||
|
||||
group["distill_token_ids"] = distill_token_ids
|
||||
group["distill_logprobs"] = distill_logprobs
|
||||
return group
|
||||
|
||||
async def handle_send_to_api(
|
||||
self,
|
||||
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||
item: Any = None,
|
||||
do_send_to_api: bool = True,
|
||||
abort_on_any_max_length_exceeded: bool = True,
|
||||
):
|
||||
groups = scored_data if isinstance(scored_data, list) else [scored_data]
|
||||
enriched_groups: List[ScoredDataGroup] = []
|
||||
for group in groups:
|
||||
if group is None:
|
||||
continue
|
||||
enriched_groups.append(await self._attach_teacher_distillation(group))
|
||||
|
||||
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||
if isinstance(scored_data, list):
|
||||
payload = enriched_groups
|
||||
else:
|
||||
payload = enriched_groups[0] if enriched_groups else scored_data
|
||||
|
||||
return await super().handle_send_to_api(
|
||||
payload,
|
||||
item=item,
|
||||
do_send_to_api=do_send_to_api,
|
||||
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
|
||||
)
|
||||
344
atroposlib/tests/test_teacher_distillation_env.py
Normal file
344
atroposlib/tests/test_teacher_distillation_env.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
"""Tests for TeacherDistillationEnv distillation enrichment."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv
|
||||
|
||||
|
||||
class _FakeTeacherServer:
|
||||
def __init__(self, fail_on_call: int = -1):
|
||||
self.calls = 0
|
||||
self.fail_on_call = fail_on_call
|
||||
|
||||
async def get_logprobs(self, **kwargs):
|
||||
self.calls += 1
|
||||
if self.calls == self.fail_on_call:
|
||||
raise RuntimeError("teacher backend failure")
|
||||
seq = kwargs["input_ids"]
|
||||
return {
|
||||
"prompt_tokens": seq,
|
||||
"prompt_topk_token_ids": [[tok, tok + 1] for tok in seq],
|
||||
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq],
|
||||
}
|
||||
|
||||
|
||||
class _ConcreteTeacherEnv(TeacherDistillationEnv):
|
||||
async def get_next_item(self):
|
||||
return None
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class _DummyTokenizer:
|
||||
name_or_path = "student-model"
|
||||
|
||||
def get_vocab(self):
|
||||
return {"a": 1}
|
||||
|
||||
|
||||
class _CapturingServerManager:
|
||||
def __init__(self, configs, slurm=False, testing=False):
|
||||
self.configs = configs
|
||||
self.slurm = slurm
|
||||
self.testing = testing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_success():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
assert len(out["distill_token_ids"]) == 2
|
||||
assert len(out["distill_token_ids"][0]) == 3
|
||||
assert len(out["distill_logprobs"][1]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_failure_drops_payload():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||
env.teacher_server = _FakeTeacherServer(fail_on_call=2)
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_negative_topk_skips_fetch():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=-1)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert env.teacher_server.calls == 0
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_zero_topk_passthrough():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert env.teacher_server.calls == 1
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_group_override_topk_is_used():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
|
||||
|
||||
seen_topks = []
|
||||
|
||||
async def _fake_fetch(seq, top_k):
|
||||
seen_topks.append(top_k)
|
||||
return [[tok] for tok in seq], [[-0.1] for _ in seq]
|
||||
|
||||
env.teacher_server = object()
|
||||
env._fetch_teacher_for_sequence = _fake_fetch
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": {"teacher_top_k": 7},
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert seen_topks == [7, 7]
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": {"skip_teacher_top_k": True},
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert env.teacher_server.calls == 0
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
def test_teacher_tokenizer_mismatch_raises(monkeypatch):
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
|
||||
class _StudentTokenizer:
|
||||
name_or_path = "student-model"
|
||||
|
||||
def get_vocab(self):
|
||||
return {"a": 1}
|
||||
|
||||
class _TeacherTokenizer:
|
||||
def get_vocab(self):
|
||||
return {"b": 1}
|
||||
|
||||
env.tokenizer = _StudentTokenizer()
|
||||
monkeypatch.setattr(
|
||||
"transformers.AutoTokenizer.from_pretrained",
|
||||
lambda *args, **kwargs: _TeacherTokenizer(),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Cross-tokenizer distillation is not supported"
|
||||
):
|
||||
TeacherDistillationEnv._validate_teacher_tokenizer_compatibility(
|
||||
env,
|
||||
teacher_tokenizer_name="teacher-model",
|
||||
)
|
||||
|
||||
|
||||
def test_init_requires_teacher_server_source(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="no teacher server configuration was provided"
|
||||
):
|
||||
_ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
)
|
||||
|
||||
|
||||
def test_init_uses_explicit_teacher_server_configs(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
called = {}
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
def _fake_validate(self, teacher_tokenizer_name):
|
||||
called["teacher_tokenizer_name"] = teacher_tokenizer_name
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
|
||||
monkeypatch.setattr(
|
||||
_ConcreteTeacherEnv,
|
||||
"_validate_teacher_tokenizer_compatibility",
|
||||
_fake_validate,
|
||||
)
|
||||
|
||||
explicit_cfg = APIServerConfig(
|
||||
model_name="explicit-model",
|
||||
tokenizer_name="explicit-tokenizer",
|
||||
base_url="http://explicit/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
|
||||
env = _ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
teacher_server_configs=[explicit_cfg],
|
||||
)
|
||||
|
||||
assert isinstance(env.teacher_server, _CapturingServerManager)
|
||||
assert env.teacher_server.configs == [explicit_cfg]
|
||||
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
|
||||
|
||||
|
||||
def test_init_wraps_bare_teacher_api_server_config(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
called = {}
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
def _fake_validate(self, teacher_tokenizer_name):
|
||||
called["teacher_tokenizer_name"] = teacher_tokenizer_name
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
|
||||
monkeypatch.setattr(
|
||||
_ConcreteTeacherEnv,
|
||||
"_validate_teacher_tokenizer_compatibility",
|
||||
_fake_validate,
|
||||
)
|
||||
|
||||
explicit_cfg = APIServerConfig(
|
||||
model_name="explicit-model",
|
||||
tokenizer_name="explicit-tokenizer",
|
||||
base_url="http://explicit/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
|
||||
env = _ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
teacher_server_configs=explicit_cfg,
|
||||
)
|
||||
|
||||
assert isinstance(env.teacher_server, _CapturingServerManager)
|
||||
assert env.teacher_server.configs == [explicit_cfg]
|
||||
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
|
||||
|
||||
|
||||
def test_resolve_teacher_server_configs_returns_none_when_unset():
|
||||
assert (
|
||||
_ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=None,
|
||||
yaml_config={},
|
||||
cli_passed_flags={},
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_resolve(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return ["resolved"]
|
||||
|
||||
monkeypatch.setattr(module, "resolve_openai_configs", _fake_resolve)
|
||||
|
||||
default_cfg = APIServerConfig(
|
||||
model_name="teacher-model",
|
||||
base_url="http://teacher/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
|
||||
out = _ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=default_cfg,
|
||||
yaml_config={"teacher": {"tokenizer_name": "teacher-tokenizer"}},
|
||||
cli_passed_flags={"teacher.base_url": "http://override/v1"},
|
||||
)
|
||||
|
||||
assert out == ["resolved"]
|
||||
assert captured["openai_config_dict"]["base_url"] == "http://override/v1"
|
||||
assert captured["openai_config_dict"]["tokenizer_name"] == "teacher-tokenizer"
|
||||
assert captured["yaml_config"] == {
|
||||
"openai": {"tokenizer_name": "teacher-tokenizer"}
|
||||
}
|
||||
assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"}
|
||||
Loading…
Add table
Add a link
Reference in a new issue