atropos/atroposlib/envs/teacher_distillation_env.py
2026-03-13 22:51:58 +00:00

452 lines
17 KiB
Python

"""
Teacher distillation environment layer.
This module adds teacher prompt-logprob fetching on top of BaseEnv without
modifying BaseEnv transport behavior.
This implementation supports same-tokenizer distillation only. The teacher and
student must share the same tokenizer vocabulary so the student's token IDs can
be forwarded directly to the teacher and the returned teacher top-k token IDs
can be looked up directly in the student's logits.
"""
from __future__ import annotations
import asyncio
import logging
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Union
import yaml
from pydantic import Field
from pydantic_cli import Cmd
from rich import print as rprint
from ..utils.cli import (
extract_namespace,
get_double_dash_flags,
get_prefixed_pydantic_model,
merge_dicts,
)
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
from .server_handling.openai_server import resolve_openai_configs
from .server_handling.server_baseline import APIServerConfig, ServerBaseline
from .server_handling.server_manager import ServerManager, ServerManagerConfig
logger = logging.getLogger(__name__)
class TeacherDistillationConfig(BaseEnvConfig):
teacher_enabled: bool = Field(
default=False,
description="Whether to fetch teacher prompt logprobs for distillation.",
)
teacher_top_k: int = Field(
default=0,
ge=-1,
description=(
"Number of extra prompt logprobs to fetch beyond the selected token. "
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
"teacher fetching."
),
)
class TeacherDistillationEnv(BaseEnv, ABC):
"""
BaseEnv subclass that enriches scored groups with teacher distillation arrays.
Distillation payload shape:
- distill_token_ids: [sequence][position][k] (student vocab IDs)
- distill_logprobs: [sequence][position][k]
"""
env_config_cls = TeacherDistillationConfig
teacher_namespace = "teacher"
@classmethod
def teacher_config_init(
cls,
) -> Optional[Union[ServerBaseline, List[APIServerConfig], APIServerConfig]]:
return None
@classmethod
def _resolve_teacher_server_configs(
cls,
default_teacher_server_configs: Optional[
Union[ServerBaseline, List[APIServerConfig], APIServerConfig]
],
yaml_config: Dict[str, Any],
cli_passed_flags: Dict[str, Any],
) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]:
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
teacher_cli_passed_args = extract_namespace(
cli_passed_flags, teacher_full_prefix
)
yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {})
if (
default_teacher_server_configs is None
and not teacher_cli_passed_args
and not yaml_teacher_config
):
return None
effective_teacher_server_configs = default_teacher_server_configs
if effective_teacher_server_configs is None:
effective_teacher_server_configs = APIServerConfig()
elif isinstance(effective_teacher_server_configs, ServerBaseline) and (
teacher_cli_passed_args or yaml_teacher_config
):
effective_teacher_server_configs = APIServerConfig(
**effective_teacher_server_configs.model_dump()
)
if (
isinstance(effective_teacher_server_configs, list)
and len(effective_teacher_server_configs) == 1
):
default_teacher_config = effective_teacher_server_configs[0]
else:
default_teacher_config = effective_teacher_server_configs
if isinstance(yaml_teacher_config, list) and len(yaml_teacher_config) == 1:
yaml_teacher_config = yaml_teacher_config[0]
if isinstance(default_teacher_config, APIServerConfig) and isinstance(
yaml_teacher_config, dict
):
teacher_config_dict = merge_dicts(
default_teacher_config.model_dump(),
yaml_teacher_config,
teacher_cli_passed_args,
)
else:
teacher_config_dict = {}
teacher_yaml_wrapped = {OPENAI_NAMESPACE: yaml_teacher_config}
teacher_cli_wrapped = {
f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}{key}": value
for key, value in teacher_cli_passed_args.items()
}
return resolve_openai_configs(
default_server_configs=effective_teacher_server_configs,
openai_config_dict=teacher_config_dict,
yaml_config=teacher_yaml_wrapped,
cli_passed_flags=teacher_cli_wrapped,
logger=logger,
)
@classmethod
def get_cli_serve_config_cls(cls) -> type:
default_env_config, default_server_configs = cls.config_init()
default_teacher_server_configs = cls.teacher_config_init()
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
teacher_cli_base = get_prefixed_pydantic_model(
APIServerConfig, teacher_full_prefix
)
class CliServeConfig(
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix),
teacher_cli_base,
ServerManagerConfig,
Cmd,
):
config: str | None = Field(
default=None,
description="Path to .yaml config file. CLI args override this.",
)
def run(self) -> None:
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if (
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
):
setattr(self, wandb_name_attr, cls.name)
if self.config is not None:
with open(self.config, "r") as f:
yaml_config = yaml.safe_load(f)
logger.info("Loaded config from %s", self.config)
else:
yaml_config = {}
cli_passed_flags = get_double_dash_flags()
env_config_dict = merge_dicts(
default_env_config.model_dump(),
yaml_config.get(ENV_NAMESPACE, {}),
extract_namespace(cli_passed_flags, env_full_prefix),
)
oai_cli_passed_args = extract_namespace(
cli_passed_flags, openai_full_prefix
)
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
effective_server_configs = default_server_configs
if isinstance(effective_server_configs, ServerBaseline) and (
oai_cli_passed_args or yaml_oai_config
):
effective_server_configs = APIServerConfig(
**effective_server_configs.model_dump()
)
if (
isinstance(effective_server_configs, list)
and len(effective_server_configs) == 1
):
default_openai_config_ = effective_server_configs[0]
else:
default_openai_config_ = effective_server_configs
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(),
yaml_oai_config,
oai_cli_passed_args,
)
else:
openai_config_dict = {}
server_manager_cli_passed_flags = {}
if "slurm" in cli_passed_flags:
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
if "testing" in cli_passed_flags:
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
"testing"
]
server_manager_yaml_dict = {}
if "slurm" in yaml_config:
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
if "testing" in yaml_config:
server_manager_yaml_dict["testing"] = yaml_config["testing"]
server_manager_config_dict = merge_dicts(
ServerManagerConfig().model_dump(),
server_manager_yaml_dict,
server_manager_cli_passed_flags,
)
env_config = type(default_env_config)(**env_config_dict)
server_manager_config = ServerManagerConfig(
**server_manager_config_dict
)
openai_configs = resolve_openai_configs(
default_server_configs=effective_server_configs,
openai_config_dict=openai_config_dict,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
logger=logger,
)
teacher_configs = cls._resolve_teacher_server_configs(
default_teacher_server_configs=default_teacher_server_configs,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
)
env_kwargs = {
"config": env_config,
"server_configs": openai_configs,
"slurm": server_manager_config.slurm,
"testing": server_manager_config.testing,
}
if teacher_configs is not None:
env_kwargs["teacher_server_configs"] = teacher_configs
env = cls(**env_kwargs)
rprint(env_config)
rprint(openai_configs)
if teacher_configs is not None:
rprint(teacher_configs)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(env.env_manager())
loop.run_until_complete(task)
except RuntimeError:
asyncio.run(env.env_manager())
return CliServeConfig
def __init__(
self,
config: TeacherDistillationConfig,
server_configs: Union[ServerBaseline, List[APIServerConfig]],
teacher_server_configs: Optional[
Union[ServerBaseline, List[APIServerConfig]]
] = None,
slurm: bool = False,
testing: bool = False,
):
super().__init__(config, server_configs, slurm=slurm, testing=testing)
self.teacher_server: Optional[ServerManager] = None
if config.teacher_enabled:
if teacher_server_configs is None:
raise ValueError(
"teacher_enabled=True requires teacher_server_configs at init."
)
teacher_config_source = teacher_server_configs
self.teacher_server = ServerManager(
teacher_config_source,
slurm=False,
testing=False,
)
if isinstance(teacher_config_source, list):
teacher_cfg = teacher_config_source[0]
else:
teacher_cfg = teacher_config_source
teacher_tokenizer_name = (
teacher_cfg.model_name
if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none")
else teacher_cfg.tokenizer_name
)
self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name)
# ------------------------------------------------------------------
# Core fetch
# ------------------------------------------------------------------
def _validate_teacher_tokenizer_compatibility(
self, teacher_tokenizer_name: str
) -> None:
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
if student_tok_name == teacher_tokenizer_name:
return
try:
from transformers import AutoTokenizer
teacher_tokenizer = AutoTokenizer.from_pretrained(
teacher_tokenizer_name, use_fast=True
)
except Exception as exc:
raise ValueError(
"Cross-tokenizer distillation is not supported in this PR, and the "
f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to "
f"verify compatibility: {exc}"
) from exc
student_vocab = self.tokenizer.get_vocab()
teacher_vocab = teacher_tokenizer.get_vocab()
if student_vocab != teacher_vocab:
raise ValueError(
"Cross-tokenizer distillation is not supported in this PR. "
f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' "
f"and teacher tokenizer '{teacher_tokenizer_name}' do not match."
)
async def _fetch_teacher_for_sequence(
self, token_ids: List[int], top_k: int
) -> Tuple[List[List[int]], List[List[float]]]:
assert self.teacher_server is not None
payload = await self.teacher_server.get_logprobs(
input_ids=token_ids,
top_k=top_k,
max_tokens=1,
split="train",
)
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
# ------------------------------------------------------------------
# Group enrichment
# ------------------------------------------------------------------
async def _attach_teacher_distillation(
self, group: ScoredDataGroup
) -> ScoredDataGroup:
if not self.config.teacher_enabled or self.teacher_server is None:
return group
seqs = group.get("tokens", [])
if not seqs:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
group_overrides = group.get("group_overrides") or {}
if group_overrides.get("skip_teacher_top_k", False):
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
top_k = int(group_overrides.get("teacher_top_k", self.config.teacher_top_k))
if top_k <= -1:
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs]
results = await asyncio.gather(*tasks, return_exceptions=True)
distill_token_ids: List[List[List[int]]] = []
distill_logprobs: List[List[List[float]]] = []
for idx, result in enumerate(results):
if isinstance(result, Exception):
logger.warning(
"Teacher logprob fetch failed for seq %s: %s. "
"Dropping distill payload for this group.",
idx,
result,
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
token_ids_k, logprobs_k = result
if len(token_ids_k) != len(logprobs_k):
logger.warning(
"Teacher prompt-topk length mismatch for seq %s (%s != %s). "
"Dropping distill payload for this group.",
idx,
len(token_ids_k),
len(logprobs_k),
)
group["distill_token_ids"] = None
group["distill_logprobs"] = None
return group
distill_token_ids.append(token_ids_k)
distill_logprobs.append(logprobs_k)
group["distill_token_ids"] = distill_token_ids
group["distill_logprobs"] = distill_logprobs
return group
async def handle_send_to_api(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Any = None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
groups = scored_data if isinstance(scored_data, list) else [scored_data]
enriched_groups: List[ScoredDataGroup] = []
for group in groups:
if group is None:
continue
enriched_groups.append(await self._attach_teacher_distillation(group))
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
if isinstance(scored_data, list):
payload = enriched_groups
else:
payload = enriched_groups[0] if enriched_groups else scored_data
return await super().handle_send_to_api(
payload,
item=item,
do_send_to_api=do_send_to_api,
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
)