diff --git a/atroposlib/envs/student_distillation_env.py b/atroposlib/envs/student_distillation_env.py new file mode 100644 index 00000000..3f94d1d3 --- /dev/null +++ b/atroposlib/envs/student_distillation_env.py @@ -0,0 +1,205 @@ +""" +Student/self-distillation environment layer. + +This module adds prompt-logprob fetching from the student rollout server itself +before the scored group is sent to the API. + +By default, the student server scores the exact token IDs already present in the +group. Override-driven prompt/message scoring is supported only when the +resulting prompt tokenization matches the original token sequence exactly. +""" + +from __future__ import annotations + +import asyncio +import logging +from abc import ABC +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import Field + +from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup + +logger = logging.getLogger(__name__) + + +class StudentDistillationConfig(BaseEnvConfig): + student_distill_enabled: bool = Field( + default=False, + description="Whether to fetch prompt logprobs from the student server itself.", + ) + student_top_k: int = Field( + default=0, + ge=-1, + description=( + "Number of extra prompt logprobs to fetch beyond the selected token. " + "Use 0 for selected-token-only prompt logprobs and <= -1 to disable " + "student distillation fetching." + ), + ) + + +class StudentDistillationEnv(BaseEnv, ABC): + """ + BaseEnv subclass that enriches scored groups with self-distillation arrays. + + Distillation payload shape: + - distill_token_ids: [sequence][position][k] + - distill_logprobs: [sequence][position][k] + """ + + env_config_cls = StudentDistillationConfig + + def _get_student_logprob_overrides( + self, group: ScoredDataGroup, seq_idx: int + ) -> Dict[str, Any]: + merged: Dict[str, Any] = {} + + group_overrides = group.get("group_overrides") or {} + group_kwargs = group_overrides.get( + "student_logprob_kwargs", + group_overrides.get("student_distill_kwargs"), + ) + if isinstance(group_kwargs, dict): + merged.update(group_kwargs) + + overrides = group.get("overrides") or [] + if seq_idx < len(overrides): + seq_overrides = overrides[seq_idx] or {} + seq_kwargs = seq_overrides.get( + "student_logprob_kwargs", + seq_overrides.get("student_distill_kwargs"), + ) + if isinstance(seq_kwargs, dict): + merged.update(seq_kwargs) + + return merged + + async def _fetch_student_for_sequence( + self, + token_ids: List[int], + top_k: int, + extra_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[List[List[int]], List[List[float]]]: + request_kwargs: Dict[str, Any] = { + "input_ids": token_ids, + "top_k": top_k, + "max_tokens": 1, + "split": "train", + } + if extra_kwargs: + request_kwargs.update(extra_kwargs) + if extra_kwargs.get("messages") is not None or extra_kwargs.get( + "prompt" + ) is not None: + # Let message/prompt overrides drive tokenization instead of the + # original input_ids when callers explicitly request it. + request_kwargs.pop("input_ids", None) + + if request_kwargs.get("messages") is not None: + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + payload = await managed.get_logprobs(**request_kwargs) + else: + payload = await self.server.get_logprobs(**request_kwargs) + + if payload.get("prompt_tokens") != token_ids: + raise ValueError( + "Student distillation request did not align to the original token " + "sequence. Override-driven prompt/messages are only supported when " + "they reproduce the exact same prompt tokens." + ) + + return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"] + + async def _attach_student_distillation( + self, group: ScoredDataGroup + ) -> ScoredDataGroup: + if not self.config.student_distill_enabled: + return group + + seqs = group.get("tokens", []) + if not seqs: + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + + group_overrides = group.get("group_overrides") or {} + if group_overrides.get("skip_student_top_k", False): + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + + top_k = int(group_overrides.get("student_top_k", self.config.student_top_k)) + if top_k <= -1: + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + + tasks = [ + self._fetch_student_for_sequence( + seq, + top_k, + self._get_student_logprob_overrides(group, idx), + ) + for idx, seq in enumerate(seqs) + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + distill_token_ids: List[List[List[int]]] = [] + distill_logprobs: List[List[List[float]]] = [] + for idx, result in enumerate(results): + if isinstance(result, Exception): + logger.warning( + "Student logprob fetch failed for seq %s: %s. " + "Dropping distill payload for this group.", + idx, + result, + ) + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + token_ids_k, logprobs_k = result + if len(token_ids_k) != len(logprobs_k): + logger.warning( + "Student prompt-topk length mismatch for seq %s (%s != %s). " + "Dropping distill payload for this group.", + idx, + len(token_ids_k), + len(logprobs_k), + ) + group["distill_token_ids"] = None + group["distill_logprobs"] = None + return group + distill_token_ids.append(token_ids_k) + distill_logprobs.append(logprobs_k) + + group["distill_token_ids"] = distill_token_ids + group["distill_logprobs"] = distill_logprobs + return group + + async def handle_send_to_api( + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Any = None, + do_send_to_api: bool = True, + abort_on_any_max_length_exceeded: bool = True, + ): + groups = scored_data if isinstance(scored_data, list) else [scored_data] + enriched_groups: List[ScoredDataGroup] = [] + for group in groups: + if group is None: + continue + enriched_groups.append(await self._attach_student_distillation(group)) + + payload: Union[ScoredDataGroup, List[ScoredDataGroup]] + if isinstance(scored_data, list): + payload = enriched_groups + else: + payload = enriched_groups[0] if enriched_groups else scored_data + + return await super().handle_send_to_api( + payload, + item=item, + do_send_to_api=do_send_to_api, + abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded, + ) diff --git a/atroposlib/tests/test_student_distillation_env.py b/atroposlib/tests/test_student_distillation_env.py new file mode 100644 index 00000000..3d7f71d0 --- /dev/null +++ b/atroposlib/tests/test_student_distillation_env.py @@ -0,0 +1,234 @@ +"""Tests for StudentDistillationEnv distillation enrichment.""" + +from contextlib import asynccontextmanager +from types import SimpleNamespace + +import pytest + +from atroposlib.envs.student_distillation_env import StudentDistillationEnv + + +class _FakeManagedServer: + def __init__(self, prompt_tokens): + self.prompt_tokens = prompt_tokens + self.calls = 0 + self.kwargs = [] + + async def get_logprobs(self, **kwargs): + self.calls += 1 + self.kwargs.append(kwargs) + return { + "prompt_tokens": self.prompt_tokens, + "prompt_topk_token_ids": [[tok, tok + 1] for tok in self.prompt_tokens], + "prompt_topk_logprobs": [[-0.1, -0.2] for _ in self.prompt_tokens], + } + + +class _FakeStudentServer: + def __init__(self, fail_on_call: int = -1, managed_prompt_tokens=None): + self.calls = 0 + self.fail_on_call = fail_on_call + self.kwargs = [] + self.managed_calls = 0 + self.managed = _FakeManagedServer( + prompt_tokens=managed_prompt_tokens if managed_prompt_tokens is not None else [] + ) + + async def get_logprobs(self, **kwargs): + self.calls += 1 + self.kwargs.append(kwargs) + if self.calls == self.fail_on_call: + raise RuntimeError("student backend failure") + seq = kwargs["input_ids"] + return { + "prompt_tokens": seq, + "prompt_topk_token_ids": [[tok, tok + 1] for tok in seq], + "prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq], + } + + @asynccontextmanager + async def managed_server(self, tokenizer=None): + self.managed_calls += 1 + yield self.managed + + +class _ConcreteStudentEnv(StudentDistillationEnv): + async def get_next_item(self): + return None + + async def evaluate(self, *args, **kwargs): + return None + + +@pytest.mark.asyncio +async def test_attach_student_distillation_success(): + env = object.__new__(_ConcreteStudentEnv) + env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2) + env.server = _FakeStudentServer() + + group = { + "tokens": [[1, 2, 3], [4, 5]], + "group_overrides": None, + "overrides": None, + "masks": [[-100, 2, 3], [-100, 5]], + "scores": [1.0, 0.0], + } + out = await StudentDistillationEnv._attach_student_distillation(env, group) + assert out["distill_token_ids"] is not None + assert out["distill_logprobs"] is not None + assert len(out["distill_token_ids"]) == 2 + assert len(out["distill_token_ids"][0]) == 3 + assert len(out["distill_logprobs"][1]) == 2 + assert env.server.calls == 2 + + +@pytest.mark.asyncio +async def test_attach_student_distillation_failure_drops_payload(): + env = object.__new__(_ConcreteStudentEnv) + env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2) + env.server = _FakeStudentServer(fail_on_call=2) + + group = { + "tokens": [[1, 2, 3], [4, 5]], + "group_overrides": None, + "overrides": None, + "masks": [[-100, 2, 3], [-100, 5]], + "scores": [1.0, 0.0], + } + out = await StudentDistillationEnv._attach_student_distillation(env, group) + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None + + +@pytest.mark.asyncio +async def test_attach_student_distillation_negative_topk_skips_fetch(): + env = object.__new__(_ConcreteStudentEnv) + env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=-1) + env.server = _FakeStudentServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": None, + "overrides": None, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await StudentDistillationEnv._attach_student_distillation(env, group) + assert env.server.calls == 0 + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None + + +@pytest.mark.asyncio +async def test_attach_student_distillation_zero_topk_passthrough(): + env = object.__new__(_ConcreteStudentEnv) + env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=0) + env.server = _FakeStudentServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": None, + "overrides": None, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await StudentDistillationEnv._attach_student_distillation(env, group) + assert env.server.calls == 1 + assert out["distill_token_ids"] is not None + assert out["distill_logprobs"] is not None + + +@pytest.mark.asyncio +async def test_attach_student_distillation_group_override_topk_is_used(): + env = object.__new__(_ConcreteStudentEnv) + env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=0) + env.server = _FakeStudentServer() + + group = { + "tokens": [[1, 2, 3], [4, 5]], + "group_overrides": {"student_top_k": 7}, + "overrides": None, + "masks": [[-100, 2, 3], [-100, 5]], + "scores": [1.0, 0.0], + } + out = await StudentDistillationEnv._attach_student_distillation(env, group) + assert env.server.kwargs[0]["top_k"] == 7 + assert env.server.kwargs[1]["top_k"] == 7 + assert out["distill_token_ids"] is not None + assert out["distill_logprobs"] is not None + + +@pytest.mark.asyncio +async def test_attach_student_distillation_group_override_can_skip_fetch(): + env = object.__new__(_ConcreteStudentEnv) + env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2) + env.server = _FakeStudentServer() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": {"skip_student_top_k": True}, + "overrides": None, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + out = await StudentDistillationEnv._attach_student_distillation(env, group) + assert env.server.calls == 0 + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None + + +def test_get_student_logprob_overrides_merges_group_and_sequence(): + env = object.__new__(_ConcreteStudentEnv) + group = { + "group_overrides": { + "student_logprob_kwargs": {"temperature": 0.0, "prompt": "group"} + }, + "overrides": [ + {"student_logprob_kwargs": {"prompt": "seq", "top_p": 1.0}}, + ], + } + + out = StudentDistillationEnv._get_student_logprob_overrides(env, group, 0) + assert out == {"temperature": 0.0, "prompt": "seq", "top_p": 1.0} + + +@pytest.mark.asyncio +async def test_fetch_student_for_sequence_uses_managed_server_for_messages(): + env = object.__new__(_ConcreteStudentEnv) + env.server = _FakeStudentServer(managed_prompt_tokens=[1, 2, 3]) + env.tokenizer = object() + + out = await StudentDistillationEnv._fetch_student_for_sequence( + env, + token_ids=[1, 2, 3], + top_k=2, + extra_kwargs={"messages": [{"role": "user", "content": "hi"}]}, + ) + + assert env.server.calls == 0 + assert env.server.managed_calls == 1 + assert env.server.managed.calls == 1 + assert "input_ids" not in env.server.managed.kwargs[0] + assert out[0] == [[1, 2], [2, 3], [3, 4]] + + +@pytest.mark.asyncio +async def test_fetch_student_for_sequence_mismatch_drops_payload(): + env = object.__new__(_ConcreteStudentEnv) + env.config = SimpleNamespace(student_distill_enabled=True, student_top_k=2) + env.server = _FakeStudentServer(managed_prompt_tokens=[9, 9]) + env.tokenizer = object() + + group = { + "tokens": [[1, 2, 3]], + "group_overrides": { + "student_logprob_kwargs": {"messages": [{"role": "user", "content": "hi"}]} + }, + "overrides": None, + "masks": [[-100, 2, 3]], + "scores": [1.0], + } + + out = await StudentDistillationEnv._attach_student_distillation(env, group) + assert out["distill_token_ids"] is None + assert out["distill_logprobs"] is None diff --git a/environments/gsm8k_server_student_distill.py b/environments/gsm8k_server_student_distill.py new file mode 100644 index 00000000..68d0b3fc --- /dev/null +++ b/environments/gsm8k_server_student_distill.py @@ -0,0 +1,41 @@ +from typing import Tuple + +from atroposlib.envs.base import APIServerConfig, ServerBaseline +from atroposlib.envs.student_distillation_env import ( + StudentDistillationConfig, + StudentDistillationEnv, +) + +from environments.gsm8k_server import GSM8kEnv + + +class GSM8kStudentDistillEnv(GSM8kEnv, StudentDistillationEnv): + name = "gsm8k_student_distill" + env_config_cls = StudentDistillationConfig + + @classmethod + def config_init(cls) -> Tuple[StudentDistillationConfig, ServerBaseline]: + env_config = StudentDistillationConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, + wandb_name="gsm8k_student_distill", + student_distill_enabled=True, + student_top_k=4, + ) + server_config = APIServerConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, + ) + return env_config, server_config + + +if __name__ == "__main__": + GSM8kStudentDistillEnv.cli()