mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge branch 'NousResearch:main' into feat/reward-normalization
This commit is contained in:
commit
39fb1d6870
8 changed files with 1285 additions and 4 deletions
|
|
@ -133,7 +133,7 @@
|
|||
"filename": "README.md",
|
||||
"hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5",
|
||||
"is_verified": false,
|
||||
"line_number": 454
|
||||
"line_number": 530
|
||||
}
|
||||
],
|
||||
"SLURM.md": [
|
||||
|
|
@ -561,5 +561,5 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
"generated_at": "2026-03-02T22:46:56Z"
|
||||
"generated_at": "2026-03-14T00:43:09Z"
|
||||
}
|
||||
|
|
|
|||
76
README.md
76
README.md
|
|
@ -298,6 +298,82 @@ curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!
|
|||
- Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.).
|
||||
- Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR.
|
||||
|
||||
### TeacherDistillationEnv follow-up
|
||||
|
||||
The follow-up teacher environment uses a dedicated teacher server config and
|
||||
attaches teacher prompt logprobs before the group is sent to the API.
|
||||
|
||||
Teacher config shape:
|
||||
|
||||
```python
|
||||
TeacherDistillationConfig(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=8,
|
||||
)
|
||||
```
|
||||
|
||||
Teacher server configs are passed separately at init, just like the primary
|
||||
`server_configs`:
|
||||
|
||||
```python
|
||||
env = MyTeacherEnv(
|
||||
config=env_config,
|
||||
server_configs=student_server_configs,
|
||||
teacher_server_configs=[
|
||||
APIServerConfig(
|
||||
base_url="http://localhost:9003/v1",
|
||||
model_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||
api_key="",
|
||||
server_type="vllm",
|
||||
tokenizer_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||
)
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
You can either:
|
||||
|
||||
- build a teacher-enabled env by mixing `TeacherDistillationEnv` into an existing
|
||||
`BaseEnv`-derived env such as `GSM8kEnv`, or
|
||||
- subclass `TeacherDistillationEnv` directly and implement the usual environment
|
||||
methods yourself.
|
||||
|
||||
In both cases, `TeacherDistillationEnv` still assumes the normal `BaseEnv`
|
||||
runtime contract: tokenized rollouts, `ScoredDataGroup` payloads, and the
|
||||
standard `handle_send_to_api(...)` transport path.
|
||||
|
||||
CLI shape:
|
||||
|
||||
```bash
|
||||
--env.teacher_enabled true \
|
||||
--teacher.base_url "http://localhost:9003/v1" \
|
||||
--teacher.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \
|
||||
--teacher.server_type vllm \
|
||||
--env.teacher_top_k 8
|
||||
```
|
||||
|
||||
If `--teacher.model_name` is a deployment alias rather than a tokenizer
|
||||
identifier, also set `--teacher.tokenizer_name ...` so the env can validate
|
||||
tokenizer compatibility.
|
||||
|
||||
Scope note:
|
||||
|
||||
- The teacher-aware CLI wiring currently exists for `serve`.
|
||||
- If `teacher_enabled=True`, the generic `process` and `evaluate` commands will
|
||||
fail loudly at env construction time unless you instantiate the env yourself
|
||||
and pass `teacher_server_configs=...`.
|
||||
|
||||
Tokenizer requirement:
|
||||
|
||||
- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary.
|
||||
- If the tokenizers do not match, `TeacherDistillationEnv` raises an error instead of attempting token conversion.
|
||||
|
||||
Why same-tokenizer is required:
|
||||
|
||||
- `distill_token_ids` are consumed as student-vocabulary IDs by the trainer.
|
||||
- If the teacher uses a different vocabulary, the same integer token ID refers to different text on the teacher and student sides.
|
||||
- A decode/re-tokenize/remap pipeline is not a safe drop-in fix because it changes both token positions and token identities, which breaks the exact per-position token supervision that the current distillation loss assumes.
|
||||
|
||||
---
|
||||
|
||||
## Testing and Debugging Tools
|
||||
|
|
|
|||
459
atroposlib/envs/teacher_distillation_env.py
Normal file
459
atroposlib/envs/teacher_distillation_env.py
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
"""
|
||||
Teacher distillation environment layer.
|
||||
|
||||
This module adds teacher prompt-logprob fetching on top of BaseEnv without
|
||||
modifying BaseEnv transport behavior.
|
||||
|
||||
This implementation supports same-tokenizer distillation only. The teacher and
|
||||
student must share the same tokenizer vocabulary so the student's token IDs can
|
||||
be forwarded directly to the teacher and the returned teacher top-k token IDs
|
||||
can be looked up directly in the student's logits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
from pydantic_cli import Cmd
|
||||
from rich import print as rprint
|
||||
|
||||
from ..utils.cli import (
|
||||
extract_namespace,
|
||||
get_double_dash_flags,
|
||||
get_prefixed_pydantic_model,
|
||||
merge_dicts,
|
||||
)
|
||||
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
|
||||
from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from .server_handling.openai_server import resolve_openai_configs
|
||||
from .server_handling.server_baseline import APIServerConfig, ServerBaseline
|
||||
from .server_handling.server_manager import ServerManager, ServerManagerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeacherDistillationConfig(BaseEnvConfig):
|
||||
teacher_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to fetch teacher prompt logprobs for distillation.",
|
||||
)
|
||||
teacher_top_k: int = Field(
|
||||
default=0,
|
||||
ge=-1,
|
||||
description=(
|
||||
"Number of extra prompt logprobs to fetch beyond the selected token. "
|
||||
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
|
||||
"teacher fetching."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TeacherDistillationEnv(BaseEnv, ABC):
|
||||
"""
|
||||
BaseEnv subclass that enriches scored groups with teacher distillation arrays.
|
||||
|
||||
Distillation payload shape:
|
||||
- distill_token_ids: [sequence][position][k] (student vocab IDs)
|
||||
- distill_logprobs: [sequence][position][k]
|
||||
"""
|
||||
|
||||
env_config_cls = TeacherDistillationConfig
|
||||
teacher_namespace = "teacher"
|
||||
|
||||
@classmethod
|
||||
def teacher_config_init(
|
||||
cls,
|
||||
) -> Optional[Union[ServerBaseline, List[APIServerConfig], APIServerConfig]]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _resolve_teacher_server_configs(
|
||||
cls,
|
||||
default_teacher_server_configs: Optional[
|
||||
Union[ServerBaseline, List[APIServerConfig], APIServerConfig]
|
||||
],
|
||||
yaml_config: Dict[str, Any],
|
||||
cli_passed_flags: Dict[str, Any],
|
||||
) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]:
|
||||
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
|
||||
teacher_cli_passed_args = extract_namespace(
|
||||
cli_passed_flags, teacher_full_prefix
|
||||
)
|
||||
yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {})
|
||||
|
||||
if (
|
||||
default_teacher_server_configs is None
|
||||
and not teacher_cli_passed_args
|
||||
and not yaml_teacher_config
|
||||
):
|
||||
return None
|
||||
|
||||
effective_teacher_server_configs = default_teacher_server_configs
|
||||
if effective_teacher_server_configs is None:
|
||||
effective_teacher_server_configs = APIServerConfig()
|
||||
elif isinstance(effective_teacher_server_configs, ServerBaseline) and (
|
||||
teacher_cli_passed_args or yaml_teacher_config
|
||||
):
|
||||
effective_teacher_server_configs = APIServerConfig(
|
||||
**effective_teacher_server_configs.model_dump()
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(effective_teacher_server_configs, list)
|
||||
and len(effective_teacher_server_configs) == 1
|
||||
):
|
||||
default_teacher_config = effective_teacher_server_configs[0]
|
||||
else:
|
||||
default_teacher_config = effective_teacher_server_configs
|
||||
|
||||
if isinstance(yaml_teacher_config, list) and len(yaml_teacher_config) == 1:
|
||||
yaml_teacher_config = yaml_teacher_config[0]
|
||||
|
||||
if isinstance(default_teacher_config, APIServerConfig) and isinstance(
|
||||
yaml_teacher_config, dict
|
||||
):
|
||||
teacher_config_dict = merge_dicts(
|
||||
default_teacher_config.model_dump(),
|
||||
yaml_teacher_config,
|
||||
teacher_cli_passed_args,
|
||||
)
|
||||
else:
|
||||
teacher_config_dict = {}
|
||||
|
||||
teacher_yaml_wrapped = {OPENAI_NAMESPACE: yaml_teacher_config}
|
||||
teacher_cli_wrapped = {
|
||||
f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}{key}": value
|
||||
for key, value in teacher_cli_passed_args.items()
|
||||
}
|
||||
return resolve_openai_configs(
|
||||
default_server_configs=effective_teacher_server_configs,
|
||||
openai_config_dict=teacher_config_dict,
|
||||
yaml_config=teacher_yaml_wrapped,
|
||||
cli_passed_flags=teacher_cli_wrapped,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cli_serve_config_cls(cls) -> type:
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
default_teacher_server_configs = cls.teacher_config_init()
|
||||
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
|
||||
teacher_cli_base = get_prefixed_pydantic_model(
|
||||
APIServerConfig, teacher_full_prefix
|
||||
)
|
||||
|
||||
class CliServeConfig(
|
||||
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix),
|
||||
teacher_cli_base,
|
||||
ServerManagerConfig,
|
||||
Cmd,
|
||||
):
|
||||
config: str | None = Field(
|
||||
default=None,
|
||||
description="Path to .yaml config file. CLI args override this.",
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
|
||||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
yaml_config = yaml.safe_load(f)
|
||||
logger.info("Loaded config from %s", self.config)
|
||||
else:
|
||||
yaml_config = {}
|
||||
|
||||
cli_passed_flags = get_double_dash_flags()
|
||||
|
||||
env_config_dict = merge_dicts(
|
||||
default_env_config.model_dump(),
|
||||
yaml_config.get(ENV_NAMESPACE, {}),
|
||||
extract_namespace(cli_passed_flags, env_full_prefix),
|
||||
)
|
||||
|
||||
oai_cli_passed_args = extract_namespace(
|
||||
cli_passed_flags, openai_full_prefix
|
||||
)
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
|
||||
effective_server_configs = default_server_configs
|
||||
if isinstance(effective_server_configs, ServerBaseline) and (
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
effective_server_configs = APIServerConfig(
|
||||
**effective_server_configs.model_dump()
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(effective_server_configs, list)
|
||||
and len(effective_server_configs) == 1
|
||||
):
|
||||
default_openai_config_ = effective_server_configs[0]
|
||||
else:
|
||||
default_openai_config_ = effective_server_configs
|
||||
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(),
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
else:
|
||||
openai_config_dict = {}
|
||||
|
||||
server_manager_cli_passed_flags = {}
|
||||
if "slurm" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||
if "testing" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
|
||||
"testing"
|
||||
]
|
||||
|
||||
server_manager_yaml_dict = {}
|
||||
if "slurm" in yaml_config:
|
||||
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
|
||||
if "testing" in yaml_config:
|
||||
server_manager_yaml_dict["testing"] = yaml_config["testing"]
|
||||
|
||||
server_manager_config_dict = merge_dicts(
|
||||
ServerManagerConfig().model_dump(),
|
||||
server_manager_yaml_dict,
|
||||
server_manager_cli_passed_flags,
|
||||
)
|
||||
|
||||
env_config = type(default_env_config)(**env_config_dict)
|
||||
server_manager_config = ServerManagerConfig(
|
||||
**server_manager_config_dict
|
||||
)
|
||||
openai_configs = resolve_openai_configs(
|
||||
default_server_configs=effective_server_configs,
|
||||
openai_config_dict=openai_config_dict,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
logger=logger,
|
||||
)
|
||||
teacher_configs = cls._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=default_teacher_server_configs,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
)
|
||||
|
||||
env_kwargs = {
|
||||
"config": env_config,
|
||||
"server_configs": openai_configs,
|
||||
"slurm": server_manager_config.slurm,
|
||||
"testing": server_manager_config.testing,
|
||||
}
|
||||
if teacher_configs is not None:
|
||||
env_kwargs["teacher_server_configs"] = teacher_configs
|
||||
env = cls(**env_kwargs)
|
||||
rprint(env_config)
|
||||
rprint(openai_configs)
|
||||
if teacher_configs is not None:
|
||||
rprint(teacher_configs)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = loop.create_task(env.env_manager())
|
||||
loop.run_until_complete(task)
|
||||
except RuntimeError:
|
||||
asyncio.run(env.env_manager())
|
||||
|
||||
return CliServeConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TeacherDistillationConfig,
|
||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||
teacher_server_configs: Optional[
|
||||
Union[ServerBaseline, APIServerConfig, List[APIServerConfig]]
|
||||
] = None,
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm=slurm, testing=testing)
|
||||
self.teacher_server: Optional[ServerManager] = None
|
||||
|
||||
if config.teacher_enabled:
|
||||
if teacher_server_configs is None:
|
||||
raise ValueError(
|
||||
"teacher_enabled=True but no teacher server configuration was "
|
||||
"provided. Pass teacher_server_configs=... when instantiating "
|
||||
"the environment directly, or use the teacher-aware 'serve' CLI "
|
||||
"path with --teacher.* flags. The generic BaseEnv 'process' and "
|
||||
"'evaluate' commands do not currently wire teacher_server_configs."
|
||||
)
|
||||
if isinstance(teacher_server_configs, APIServerConfig):
|
||||
teacher_config_source = [teacher_server_configs]
|
||||
else:
|
||||
teacher_config_source = teacher_server_configs
|
||||
self.teacher_server = ServerManager(
|
||||
teacher_config_source,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
)
|
||||
if isinstance(teacher_config_source, list):
|
||||
teacher_cfg = teacher_config_source[0]
|
||||
else:
|
||||
teacher_cfg = teacher_config_source
|
||||
|
||||
teacher_tokenizer_name = (
|
||||
teacher_cfg.model_name
|
||||
if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none")
|
||||
else teacher_cfg.tokenizer_name
|
||||
)
|
||||
self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core fetch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_teacher_tokenizer_compatibility(
|
||||
self, teacher_tokenizer_name: str
|
||||
) -> None:
|
||||
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
|
||||
if student_tok_name == teacher_tokenizer_name:
|
||||
return
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
teacher_tokenizer = AutoTokenizer.from_pretrained(
|
||||
teacher_tokenizer_name, use_fast=True
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Cross-tokenizer distillation is not supported in this PR, and the "
|
||||
f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to "
|
||||
f"verify compatibility: {exc}"
|
||||
) from exc
|
||||
|
||||
student_vocab = self.tokenizer.get_vocab()
|
||||
teacher_vocab = teacher_tokenizer.get_vocab()
|
||||
if student_vocab != teacher_vocab:
|
||||
raise ValueError(
|
||||
"Cross-tokenizer distillation is not supported in this PR. "
|
||||
f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' "
|
||||
f"and teacher tokenizer '{teacher_tokenizer_name}' do not match."
|
||||
)
|
||||
|
||||
async def _fetch_teacher_for_sequence(
|
||||
self, token_ids: List[int], top_k: int
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
assert self.teacher_server is not None
|
||||
payload = await self.teacher_server.get_logprobs(
|
||||
input_ids=token_ids,
|
||||
top_k=top_k,
|
||||
max_tokens=1,
|
||||
split="train",
|
||||
)
|
||||
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group enrichment
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _attach_teacher_distillation(
|
||||
self, group: ScoredDataGroup
|
||||
) -> ScoredDataGroup:
|
||||
if not self.config.teacher_enabled or self.teacher_server is None:
|
||||
return group
|
||||
|
||||
seqs = group.get("tokens", [])
|
||||
if not seqs:
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
group_overrides = group.get("group_overrides") or {}
|
||||
if group_overrides.get("skip_teacher_top_k", False):
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
top_k = int(group_overrides.get("teacher_top_k", self.config.teacher_top_k))
|
||||
if top_k <= -1:
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
|
||||
tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
distill_token_ids: List[List[List[int]]] = []
|
||||
distill_logprobs: List[List[List[float]]] = []
|
||||
for idx, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(
|
||||
"Teacher logprob fetch failed for seq %s: %s. "
|
||||
"Dropping distill payload for this group.",
|
||||
idx,
|
||||
result,
|
||||
)
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
token_ids_k, logprobs_k = result
|
||||
if len(token_ids_k) != len(logprobs_k):
|
||||
logger.warning(
|
||||
"Teacher prompt-topk length mismatch for seq %s (%s != %s). "
|
||||
"Dropping distill payload for this group.",
|
||||
idx,
|
||||
len(token_ids_k),
|
||||
len(logprobs_k),
|
||||
)
|
||||
group["distill_token_ids"] = None
|
||||
group["distill_logprobs"] = None
|
||||
return group
|
||||
distill_token_ids.append(token_ids_k)
|
||||
distill_logprobs.append(logprobs_k)
|
||||
|
||||
group["distill_token_ids"] = distill_token_ids
|
||||
group["distill_logprobs"] = distill_logprobs
|
||||
return group
|
||||
|
||||
async def handle_send_to_api(
|
||||
self,
|
||||
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||
item: Any = None,
|
||||
do_send_to_api: bool = True,
|
||||
abort_on_any_max_length_exceeded: bool = True,
|
||||
):
|
||||
groups = scored_data if isinstance(scored_data, list) else [scored_data]
|
||||
enriched_groups: List[ScoredDataGroup] = []
|
||||
for group in groups:
|
||||
if group is None:
|
||||
continue
|
||||
enriched_groups.append(await self._attach_teacher_distillation(group))
|
||||
|
||||
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||
if isinstance(scored_data, list):
|
||||
payload = enriched_groups
|
||||
else:
|
||||
payload = enriched_groups[0] if enriched_groups else scored_data
|
||||
|
||||
return await super().handle_send_to_api(
|
||||
payload,
|
||||
item=item,
|
||||
do_send_to_api=do_send_to_api,
|
||||
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
|
||||
)
|
||||
344
atroposlib/tests/test_teacher_distillation_env.py
Normal file
344
atroposlib/tests/test_teacher_distillation_env.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
"""Tests for TeacherDistillationEnv distillation enrichment."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv
|
||||
|
||||
|
||||
class _FakeTeacherServer:
|
||||
def __init__(self, fail_on_call: int = -1):
|
||||
self.calls = 0
|
||||
self.fail_on_call = fail_on_call
|
||||
|
||||
async def get_logprobs(self, **kwargs):
|
||||
self.calls += 1
|
||||
if self.calls == self.fail_on_call:
|
||||
raise RuntimeError("teacher backend failure")
|
||||
seq = kwargs["input_ids"]
|
||||
return {
|
||||
"prompt_tokens": seq,
|
||||
"prompt_topk_token_ids": [[tok, tok + 1] for tok in seq],
|
||||
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq],
|
||||
}
|
||||
|
||||
|
||||
class _ConcreteTeacherEnv(TeacherDistillationEnv):
|
||||
async def get_next_item(self):
|
||||
return None
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class _DummyTokenizer:
|
||||
name_or_path = "student-model"
|
||||
|
||||
def get_vocab(self):
|
||||
return {"a": 1}
|
||||
|
||||
|
||||
class _CapturingServerManager:
|
||||
def __init__(self, configs, slurm=False, testing=False):
|
||||
self.configs = configs
|
||||
self.slurm = slurm
|
||||
self.testing = testing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_success():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
assert len(out["distill_token_ids"]) == 2
|
||||
assert len(out["distill_token_ids"][0]) == 3
|
||||
assert len(out["distill_logprobs"][1]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_failure_drops_payload():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||
env.teacher_server = _FakeTeacherServer(fail_on_call=2)
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_negative_topk_skips_fetch():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=-1)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert env.teacher_server.calls == 0
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_zero_topk_passthrough():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": None,
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert env.teacher_server.calls == 1
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_group_override_topk_is_used():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
|
||||
|
||||
seen_topks = []
|
||||
|
||||
async def _fake_fetch(seq, top_k):
|
||||
seen_topks.append(top_k)
|
||||
return [[tok] for tok in seq], [[-0.1] for _ in seq]
|
||||
|
||||
env.teacher_server = object()
|
||||
env._fetch_teacher_for_sequence = _fake_fetch
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": {"teacher_top_k": 7},
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert seen_topks == [7, 7]
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||
env.teacher_server = _FakeTeacherServer()
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3]],
|
||||
"group_overrides": {"skip_teacher_top_k": True},
|
||||
"masks": [[-100, 2, 3]],
|
||||
"scores": [1.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert env.teacher_server.calls == 0
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
def test_teacher_tokenizer_mismatch_raises(monkeypatch):
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
|
||||
class _StudentTokenizer:
|
||||
name_or_path = "student-model"
|
||||
|
||||
def get_vocab(self):
|
||||
return {"a": 1}
|
||||
|
||||
class _TeacherTokenizer:
|
||||
def get_vocab(self):
|
||||
return {"b": 1}
|
||||
|
||||
env.tokenizer = _StudentTokenizer()
|
||||
monkeypatch.setattr(
|
||||
"transformers.AutoTokenizer.from_pretrained",
|
||||
lambda *args, **kwargs: _TeacherTokenizer(),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Cross-tokenizer distillation is not supported"
|
||||
):
|
||||
TeacherDistillationEnv._validate_teacher_tokenizer_compatibility(
|
||||
env,
|
||||
teacher_tokenizer_name="teacher-model",
|
||||
)
|
||||
|
||||
|
||||
def test_init_requires_teacher_server_source(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="no teacher server configuration was provided"
|
||||
):
|
||||
_ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
)
|
||||
|
||||
|
||||
def test_init_uses_explicit_teacher_server_configs(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
called = {}
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
def _fake_validate(self, teacher_tokenizer_name):
|
||||
called["teacher_tokenizer_name"] = teacher_tokenizer_name
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
|
||||
monkeypatch.setattr(
|
||||
_ConcreteTeacherEnv,
|
||||
"_validate_teacher_tokenizer_compatibility",
|
||||
_fake_validate,
|
||||
)
|
||||
|
||||
explicit_cfg = APIServerConfig(
|
||||
model_name="explicit-model",
|
||||
tokenizer_name="explicit-tokenizer",
|
||||
base_url="http://explicit/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
|
||||
env = _ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
teacher_server_configs=[explicit_cfg],
|
||||
)
|
||||
|
||||
assert isinstance(env.teacher_server, _CapturingServerManager)
|
||||
assert env.teacher_server.configs == [explicit_cfg]
|
||||
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
|
||||
|
||||
|
||||
def test_init_wraps_bare_teacher_api_server_config(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
called = {}
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
def _fake_validate(self, teacher_tokenizer_name):
|
||||
called["teacher_tokenizer_name"] = teacher_tokenizer_name
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
|
||||
monkeypatch.setattr(
|
||||
_ConcreteTeacherEnv,
|
||||
"_validate_teacher_tokenizer_compatibility",
|
||||
_fake_validate,
|
||||
)
|
||||
|
||||
explicit_cfg = APIServerConfig(
|
||||
model_name="explicit-model",
|
||||
tokenizer_name="explicit-tokenizer",
|
||||
base_url="http://explicit/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
|
||||
env = _ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
teacher_server_configs=explicit_cfg,
|
||||
)
|
||||
|
||||
assert isinstance(env.teacher_server, _CapturingServerManager)
|
||||
assert env.teacher_server.configs == [explicit_cfg]
|
||||
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
|
||||
|
||||
|
||||
def test_resolve_teacher_server_configs_returns_none_when_unset():
|
||||
assert (
|
||||
_ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=None,
|
||||
yaml_config={},
|
||||
cli_passed_flags={},
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_resolve(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return ["resolved"]
|
||||
|
||||
monkeypatch.setattr(module, "resolve_openai_configs", _fake_resolve)
|
||||
|
||||
default_cfg = APIServerConfig(
|
||||
model_name="teacher-model",
|
||||
base_url="http://teacher/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
|
||||
out = _ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=default_cfg,
|
||||
yaml_config={"teacher": {"tokenizer_name": "teacher-tokenizer"}},
|
||||
cli_passed_flags={"teacher.base_url": "http://override/v1"},
|
||||
)
|
||||
|
||||
assert out == ["resolved"]
|
||||
assert captured["openai_config_dict"]["base_url"] == "http://override/v1"
|
||||
assert captured["openai_config_dict"]["tokenizer_name"] == "teacher-tokenizer"
|
||||
assert captured["yaml_config"] == {
|
||||
"openai": {"tokenizer_name": "teacher-tokenizer"}
|
||||
}
|
||||
assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"}
|
||||
58
environments/gsm8k_server_teacher_distill.py
Normal file
58
environments/gsm8k_server_teacher_distill.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from typing import Tuple
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, ServerBaseline
|
||||
from atroposlib.envs.teacher_distillation_env import (
|
||||
TeacherDistillationConfig,
|
||||
TeacherDistillationEnv,
|
||||
)
|
||||
from environments.gsm8k_server import GSM8kEnv
|
||||
|
||||
|
||||
class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv):
|
||||
"""
|
||||
GSM8K environment variant that enables TeacherDistillationEnv config fields.
|
||||
|
||||
This preserves the original `gsm8k_server.py` while providing a separate entrypoint
|
||||
for teacher-distillation data collection.
|
||||
"""
|
||||
|
||||
name = "gsm8k_teacher_distill"
|
||||
env_config_cls = TeacherDistillationConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TeacherDistillationConfig, ServerBaseline]:
|
||||
env_config = TeacherDistillationConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="gsm8k_teacher_distill",
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=4,
|
||||
)
|
||||
server_config = APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
)
|
||||
return env_config, server_config
|
||||
|
||||
@classmethod
|
||||
def teacher_config_init(cls) -> APIServerConfig:
|
||||
return APIServerConfig(
|
||||
base_url="http://localhost:9003/v1",
|
||||
model_name="mock-teacher",
|
||||
api_key="",
|
||||
server_type="vllm",
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
timeout=1200,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kTeacherDistillEnv.cli()
|
||||
|
|
@ -304,6 +304,38 @@ environment uses the `/generate` path and includes token-level
|
|||
4. Trainer extracts and aligns logprobs with training labels
|
||||
5. GRPO loss uses these rollout logprobs in importance-ratio terms
|
||||
|
||||
### 1b. Teacher distillation requires the same tokenizer
|
||||
|
||||
When distillation data is attached to Atropos batches, the trainer treats
|
||||
`distill_token_ids` as indices into the student's logit tensor. That only works
|
||||
if the teacher and student share the same tokenizer vocabulary.
|
||||
|
||||
What to configure on the environment side:
|
||||
|
||||
```bash
|
||||
--env.teacher_enabled true \
|
||||
--teacher.base_url "http://localhost:9003/v1" \
|
||||
--teacher.model_name "$TEACHER_MODEL" \
|
||||
--teacher.server_type vllm \
|
||||
--env.teacher_top_k 8
|
||||
```
|
||||
|
||||
If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier,
|
||||
also set `--teacher.tokenizer_name ...` so the env can validate
|
||||
tokenizer compatibility.
|
||||
|
||||
The teacher-aware CLI path is currently wired for `serve`. If
|
||||
`teacher_enabled=True`, the generic `process` and `evaluate` commands are not
|
||||
teacher-aware and will fail loudly unless the environment is instantiated
|
||||
manually with `teacher_server_configs=...`.
|
||||
|
||||
Why cross-tokenizer conversion is not acceptable here:
|
||||
|
||||
- Teacher token ID `1234` and student token ID `1234` can correspond to different text.
|
||||
- Re-tokenizing teacher text changes token boundaries, so teacher position `i` may no longer correspond to student position `i`.
|
||||
- Remapping teacher top-k tokens back into student vocab can collapse multiple teacher candidates into one student token or expand one teacher token into multiple student tokens.
|
||||
- The current distillation loss expects exact per-position supervision in student token space, so an approximate remapping would silently produce misleading targets.
|
||||
|
||||
### 2. Clipping
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -248,8 +248,7 @@ run_shared_vllm() {
|
|||
--port "$vllm_port" \
|
||||
--gpu-memory-utilization "$SHARED_GPU_MEMORY_UTILIZATION" \
|
||||
--max-model-len "$MAX_MODEL_LEN" \
|
||||
--dtype "$DTYPE" \
|
||||
--enforce-eager
|
||||
--dtype "$DTYPE"
|
||||
if [[ "$DRY_RUN" == "1" ]]; then
|
||||
log "[DRY RUN] wait for http://localhost:${vllm_port}/health"
|
||||
else
|
||||
|
|
|
|||
313
example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
Executable file
313
example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
Executable file
|
|
@ -0,0 +1,313 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Single-terminal teacher-distillation runner.
|
||||
# Starts everything in the background from ONE shell that has GPU access:
|
||||
# 1) Atropos API
|
||||
# 2) Student vLLM server
|
||||
# 3) Teacher vLLM server
|
||||
# 4) GSM8K teacher-distill environment
|
||||
# 5) Example trainer (foreground)
|
||||
#
|
||||
# Usage:
|
||||
# chmod +x example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
|
||||
# ./example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
|
||||
#
|
||||
# Optional overrides:
|
||||
# STUDENT_MODEL="Qwen/Qwen3-4B-Instruct-2507-FP8"
|
||||
# TEACHER_MODEL="Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||
# STUDENT_GPUS="0"
|
||||
# TEACHER_GPUS="4,5,6,7"
|
||||
# TRAINER_GPUS="0"
|
||||
# STUDENT_TP=1
|
||||
# TEACHER_TP=4
|
||||
# API_PORT=8002
|
||||
# STUDENT_PORT=9001
|
||||
# TEACHER_PORT=9003
|
||||
# TRAINING_STEPS=100
|
||||
# DISTILL_COEF=0.2
|
||||
# DISTILL_TEMPERATURE=1.0
|
||||
# TEACHER_TOP_K=8
|
||||
# DRY_RUN=1
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
LAUNCH_DIR="$PWD"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
||||
STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B}"
|
||||
TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}"
|
||||
|
||||
STUDENT_GPUS="${STUDENT_GPUS:-0}"
|
||||
TEACHER_GPUS="${TEACHER_GPUS:-4,5,6,7}"
|
||||
TRAINER_GPUS="${TRAINER_GPUS:-$STUDENT_GPUS}"
|
||||
|
||||
STUDENT_TP="${STUDENT_TP:-1}"
|
||||
TEACHER_TP="${TEACHER_TP:-4}"
|
||||
|
||||
API_PORT="${API_PORT:-8002}"
|
||||
STUDENT_PORT="${STUDENT_PORT:-9001}"
|
||||
TEACHER_PORT="${TEACHER_PORT:-9003}"
|
||||
|
||||
TRAINING_STEPS="${TRAINING_STEPS:-20}"
|
||||
BATCH_SIZE="${BATCH_SIZE:-1}"
|
||||
GRAD_ACCUM="${GRAD_ACCUM:-4}"
|
||||
LR="${LR:-1e-5}"
|
||||
WARMUP_STEPS="${WARMUP_STEPS:-0}"
|
||||
CLIP_EPS="${CLIP_EPS:-0.2}"
|
||||
MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}"
|
||||
TEACHER_MAX_MODEL_LEN="${TEACHER_MAX_MODEL_LEN:-32768}"
|
||||
# Trainer seq_len must be larger than ENV_MAX_TOKEN_LENGTH to accommodate
|
||||
# chat template overhead (~400-800 tokens for Qwen3 thinking format).
|
||||
TRAINER_SEQ_LEN="${TRAINER_SEQ_LEN:-20480}"
|
||||
ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-16384}"
|
||||
DISTILL_COEF="${DISTILL_COEF:-0.2}"
|
||||
DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}"
|
||||
TEACHER_TOP_K="${TEACHER_TOP_K:-8}"
|
||||
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-gsm8k-teacher-distill}"
|
||||
WANDB_GROUP="${WANDB_GROUP:-}"
|
||||
|
||||
STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.60}"
|
||||
TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.85}"
|
||||
DTYPE="${DTYPE:-bfloat16}"
|
||||
SAVE_DIR="${SAVE_DIR:-${LAUNCH_DIR}/saves/gsm8k_teacher_distill}"
|
||||
LOG_DIR="${LOG_DIR:-${LAUNCH_DIR}/logs/gsm8k_teacher_distill}"
|
||||
BRIDGE_DIR="${BRIDGE_DIR:-${LOG_DIR}/bridge}"
|
||||
DRY_RUN="${DRY_RUN:-0}"
|
||||
|
||||
ENV_GROUP_SIZE="${ENV_GROUP_SIZE:-4}"
|
||||
ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-8}"
|
||||
ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}"
|
||||
ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}"
|
||||
ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-1}"
|
||||
ENV_WORKER_TIMEOUT="${ENV_WORKER_TIMEOUT:-1800}"
|
||||
|
||||
RUN_PIDS=()
|
||||
RUN_PORTS=()
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%H:%M:%S')] $*"
|
||||
}
|
||||
|
||||
kill_port() {
|
||||
local port="$1"
|
||||
if [[ "$DRY_RUN" == "1" ]]; then
|
||||
log "[DRY RUN] skip port cleanup for :${port}"
|
||||
return 0
|
||||
fi
|
||||
if lsof -i ":${port}" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
lsof -ti ":${port}" | xargs -r kill -9 || true
|
||||
fi
|
||||
}
|
||||
|
||||
wait_for_http() {
|
||||
local url="$1"
|
||||
local timeout="${2:-240}"
|
||||
local name="${3:-endpoint}"
|
||||
local start
|
||||
start="$(date +%s)"
|
||||
while true; do
|
||||
if curl -fsS "$url" >/dev/null 2>&1; then
|
||||
log "Ready: ${name} (${url})"
|
||||
return 0
|
||||
fi
|
||||
if (( "$(date +%s)" - start > timeout )); then
|
||||
log "Timeout waiting for ${name}: ${url}"
|
||||
return 1
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
}
|
||||
|
||||
start_process() {
|
||||
local name="$1"
|
||||
local logfile="$2"
|
||||
shift 2
|
||||
if [[ "$DRY_RUN" == "1" ]]; then
|
||||
log "[DRY RUN] start ${name} (log: ${logfile})"
|
||||
printf ' '
|
||||
printf '%q ' "$@"
|
||||
printf '\n'
|
||||
return 0
|
||||
fi
|
||||
log "Starting ${name} (log: ${logfile})"
|
||||
"$@" >"$logfile" 2>&1 &
|
||||
local pid=$!
|
||||
RUN_PIDS+=("$pid")
|
||||
log "${name} PID=${pid}"
|
||||
}
|
||||
|
||||
cleanup_all() {
|
||||
log "Cleaning up processes..."
|
||||
for pid in "${RUN_PIDS[@]:-}"; do
|
||||
kill "$pid" >/dev/null 2>&1 || true
|
||||
done
|
||||
sleep 1
|
||||
for pid in "${RUN_PIDS[@]:-}"; do
|
||||
kill -9 "$pid" >/dev/null 2>&1 || true
|
||||
done
|
||||
for port in "${RUN_PORTS[@]:-}"; do
|
||||
kill_port "$port"
|
||||
done
|
||||
}
|
||||
|
||||
trap cleanup_all EXIT INT TERM
|
||||
|
||||
mkdir -p "$LOG_DIR" "$SAVE_DIR" "$BRIDGE_DIR"
|
||||
RUN_PORTS+=("$API_PORT" "$STUDENT_PORT" "$TEACHER_PORT")
|
||||
kill_port "$API_PORT"
|
||||
kill_port "$STUDENT_PORT"
|
||||
kill_port "$TEACHER_PORT"
|
||||
|
||||
log "Config:"
|
||||
log " student=${STUDENT_MODEL}"
|
||||
log " teacher=${TEACHER_MODEL}"
|
||||
log " gpus student=${STUDENT_GPUS}, teacher=${TEACHER_GPUS}, trainer=${TRAINER_GPUS}"
|
||||
log " ports api=${API_PORT}, student=${STUDENT_PORT}, teacher=${TEACHER_PORT}"
|
||||
log " logs=${LOG_DIR}"
|
||||
log " saves=${SAVE_DIR}"
|
||||
log " bridge=${BRIDGE_DIR}"
|
||||
log " env max_token_length=${ENV_MAX_TOKEN_LENGTH}, env workers=${ENV_MAX_WORKERS_PER_NODE}, env worker_timeout=${ENV_WORKER_TIMEOUT}"
|
||||
log " wandb project=${WANDB_PROJECT}${WANDB_GROUP:+, group=${WANDB_GROUP}}"
|
||||
|
||||
# Shared-vLLM attach path currently expects the student server to expose
|
||||
# unsharded weights. Keep the student on TP=1 and the trainer on the same GPU set.
|
||||
if [[ "$STUDENT_TP" != "1" ]]; then
|
||||
log "ERROR: shared_vllm teacher-distill runner currently requires STUDENT_TP=1."
|
||||
log " The current attach path does not support TP-sharded student bridge weights."
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if [[ "$TRAINER_GPUS" != "$STUDENT_GPUS" ]]; then
|
||||
log "ERROR: TRAINER_GPUS must match STUDENT_GPUS for shared_vllm mode."
|
||||
log " Got student=${STUDENT_GPUS}, trainer=${TRAINER_GPUS}"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
# 1) Atropos API
|
||||
start_process "run_api" "${LOG_DIR}/run_api.log" \
|
||||
run-api --port "$API_PORT"
|
||||
if [[ "$DRY_RUN" == "0" ]]; then
|
||||
wait_for_http "http://localhost:${API_PORT}/info" 180 "run-api"
|
||||
fi
|
||||
|
||||
# 2) Student vLLM server
|
||||
start_process "student_vllm" "${LOG_DIR}/student_vllm.log" \
|
||||
env CUDA_VISIBLE_DEVICES="$STUDENT_GPUS" VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR="$BRIDGE_DIR" \
|
||||
"$PYTHON_BIN" -m example_trainer.vllm_api_server \
|
||||
--model "$STUDENT_MODEL" \
|
||||
--port "$STUDENT_PORT" \
|
||||
--tensor-parallel-size "$STUDENT_TP" \
|
||||
--gpu-memory-utilization "$STUDENT_GPU_MEMORY_UTILIZATION" \
|
||||
--max-model-len "$MAX_MODEL_LEN" \
|
||||
--dtype "$DTYPE"
|
||||
if [[ "$DRY_RUN" == "0" ]]; then
|
||||
wait_for_http "http://localhost:${STUDENT_PORT}/health" 420 "student vLLM"
|
||||
fi
|
||||
|
||||
# 3) Teacher vLLM server
|
||||
start_process "teacher_vllm" "${LOG_DIR}/teacher_vllm.log" \
|
||||
env CUDA_VISIBLE_DEVICES="$TEACHER_GPUS" \
|
||||
"$PYTHON_BIN" -m example_trainer.vllm_api_server \
|
||||
--model "$TEACHER_MODEL" \
|
||||
--port "$TEACHER_PORT" \
|
||||
--tensor-parallel-size "$TEACHER_TP" \
|
||||
--gpu-memory-utilization "$TEACHER_GPU_MEMORY_UTILIZATION" \
|
||||
--max-model-len "$TEACHER_MAX_MODEL_LEN" \
|
||||
--dtype "$DTYPE"
|
||||
if [[ "$DRY_RUN" == "0" ]]; then
|
||||
wait_for_http "http://localhost:${TEACHER_PORT}/health" 1800 "teacher vLLM"
|
||||
fi
|
||||
|
||||
# 4) Teacher-distill GSM8K env
|
||||
start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \
|
||||
"$PYTHON_BIN" environments/gsm8k_server_teacher_distill.py serve \
|
||||
--env.group_size "$ENV_GROUP_SIZE" \
|
||||
--env.batch_size "$ENV_BATCH_SIZE" \
|
||||
--env.total_steps "$ENV_TOTAL_STEPS" \
|
||||
--env.steps_per_eval "$ENV_STEPS_PER_EVAL" \
|
||||
--env.max_num_workers_per_node "$ENV_MAX_WORKERS_PER_NODE" \
|
||||
--env.max_token_length "$ENV_MAX_TOKEN_LENGTH" \
|
||||
--env.worker_timeout "$ENV_WORKER_TIMEOUT" \
|
||||
--env.rollout_server_url "http://localhost:${API_PORT}" \
|
||||
--env.use_wandb true \
|
||||
--env.wandb_name "gsm8k-teacher-distill" \
|
||||
--env.teacher_enabled true \
|
||||
--teacher.base_url "http://localhost:${TEACHER_PORT}/v1" \
|
||||
--teacher.model_name "$TEACHER_MODEL" \
|
||||
--teacher.server_type vllm \
|
||||
--env.teacher_top_k "$TEACHER_TOP_K" \
|
||||
--env.ensure_scores_are_not_same false \
|
||||
--openai.api_key "dummy" \
|
||||
--openai.base_url "http://localhost:${STUDENT_PORT}/v1" \
|
||||
--openai.model_name "$STUDENT_MODEL" \
|
||||
--openai.tokenizer_name "$STUDENT_MODEL" \
|
||||
--openai.server_type vllm
|
||||
|
||||
log "All services launched."
|
||||
log "Run logs:"
|
||||
log " ${LOG_DIR}/run_api.log"
|
||||
log " ${LOG_DIR}/student_vllm.log"
|
||||
log " ${LOG_DIR}/teacher_vllm.log"
|
||||
log " ${LOG_DIR}/env.log"
|
||||
|
||||
# 5) Trainer (background)
|
||||
TRAINER_CMD=(
|
||||
env
|
||||
CUDA_VISIBLE_DEVICES="$TRAINER_GPUS"
|
||||
PYTHONUNBUFFERED=1
|
||||
"$PYTHON_BIN"
|
||||
-u
|
||||
-m
|
||||
example_trainer.grpo
|
||||
--model-name "$STUDENT_MODEL"
|
||||
--weight-bridge-mode shared_vllm
|
||||
--device cuda:0
|
||||
--save-path "$SAVE_DIR"
|
||||
--atropos-url "http://localhost:${API_PORT}"
|
||||
--vllm-port "$STUDENT_PORT"
|
||||
--vllm-config-path "${BRIDGE_DIR}/vllm_bridge_config.json"
|
||||
--training-steps "$TRAINING_STEPS"
|
||||
--batch-size "$BATCH_SIZE"
|
||||
--gradient-accumulation-steps "$GRAD_ACCUM"
|
||||
--warmup-steps "$WARMUP_STEPS"
|
||||
--lr "$LR"
|
||||
--clip-eps "$CLIP_EPS"
|
||||
--seq-len "$TRAINER_SEQ_LEN"
|
||||
--distill-enabled
|
||||
--distill-coef "$DISTILL_COEF"
|
||||
--distill-temperature "$DISTILL_TEMPERATURE"
|
||||
--use-wandb
|
||||
--wandb-project "$WANDB_PROJECT"
|
||||
)
|
||||
if [[ -n "$WANDB_GROUP" ]]; then
|
||||
TRAINER_CMD+=(--wandb-group "$WANDB_GROUP")
|
||||
fi
|
||||
|
||||
if [[ "$DRY_RUN" == "1" ]]; then
|
||||
log "[DRY RUN] trainer command:"
|
||||
printf ' '
|
||||
printf '%q ' "${TRAINER_CMD[@]}"
|
||||
printf '\n'
|
||||
exit 0
|
||||
fi
|
||||
|
||||
start_process "trainer" "${LOG_DIR}/trainer.log" "${TRAINER_CMD[@]}"
|
||||
|
||||
log "All processes running in background."
|
||||
log ""
|
||||
log "Monitor with:"
|
||||
log " tail -f ${LOG_DIR}/trainer.log"
|
||||
log " tail -f ${LOG_DIR}/env.log"
|
||||
log " tail -f ${LOG_DIR}/student_vllm.log"
|
||||
log " tail -f ${LOG_DIR}/teacher_vllm.log"
|
||||
log ""
|
||||
log "Test endpoints:"
|
||||
log " curl -s http://localhost:${STUDENT_PORT}/health"
|
||||
log " curl -s http://localhost:${TEACHER_PORT}/health"
|
||||
log " curl -s http://localhost:${STUDENT_PORT}/bridge/is_paused | jq ."
|
||||
log ""
|
||||
log "To stop all processes:"
|
||||
log " kill ${RUN_PIDS[*]:-} 2>/dev/null; sleep 1; kill -9 ${RUN_PIDS[*]:-} 2>/dev/null"
|
||||
trap - EXIT INT TERM
|
||||
Loading…
Add table
Add a link
Reference in a new issue