mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-02 17:45:50 +00:00
Merge branch 'NousResearch:main' into feat/reward-normalization
This commit is contained in:
commit
39fb1d6870
8 changed files with 1285 additions and 4 deletions
|
|
@ -133,7 +133,7 @@
|
||||||
"filename": "README.md",
|
"filename": "README.md",
|
||||||
"hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5",
|
"hashed_secret": "a8253456364f1bfc7da7ae4a1db5b45d106317a5",
|
||||||
"is_verified": false,
|
"is_verified": false,
|
||||||
"line_number": 454
|
"line_number": 530
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"SLURM.md": [
|
"SLURM.md": [
|
||||||
|
|
@ -561,5 +561,5 @@
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_at": "2026-03-02T22:46:56Z"
|
"generated_at": "2026-03-14T00:43:09Z"
|
||||||
}
|
}
|
||||||
|
|
|
||||||
76
README.md
76
README.md
|
|
@ -298,6 +298,82 @@ curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!
|
||||||
- Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.).
|
- Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.).
|
||||||
- Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR.
|
- Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR.
|
||||||
|
|
||||||
|
### TeacherDistillationEnv follow-up
|
||||||
|
|
||||||
|
The follow-up teacher environment uses a dedicated teacher server config and
|
||||||
|
attaches teacher prompt logprobs before the group is sent to the API.
|
||||||
|
|
||||||
|
Teacher config shape:
|
||||||
|
|
||||||
|
```python
|
||||||
|
TeacherDistillationConfig(
|
||||||
|
teacher_enabled=True,
|
||||||
|
teacher_top_k=8,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Teacher server configs are passed separately at init, just like the primary
|
||||||
|
`server_configs`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
env = MyTeacherEnv(
|
||||||
|
config=env_config,
|
||||||
|
server_configs=student_server_configs,
|
||||||
|
teacher_server_configs=[
|
||||||
|
APIServerConfig(
|
||||||
|
base_url="http://localhost:9003/v1",
|
||||||
|
model_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||||
|
api_key="",
|
||||||
|
server_type="vllm",
|
||||||
|
tokenizer_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can either:
|
||||||
|
|
||||||
|
- build a teacher-enabled env by mixing `TeacherDistillationEnv` into an existing
|
||||||
|
`BaseEnv`-derived env such as `GSM8kEnv`, or
|
||||||
|
- subclass `TeacherDistillationEnv` directly and implement the usual environment
|
||||||
|
methods yourself.
|
||||||
|
|
||||||
|
In both cases, `TeacherDistillationEnv` still assumes the normal `BaseEnv`
|
||||||
|
runtime contract: tokenized rollouts, `ScoredDataGroup` payloads, and the
|
||||||
|
standard `handle_send_to_api(...)` transport path.
|
||||||
|
|
||||||
|
CLI shape:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--env.teacher_enabled true \
|
||||||
|
--teacher.base_url "http://localhost:9003/v1" \
|
||||||
|
--teacher.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \
|
||||||
|
--teacher.server_type vllm \
|
||||||
|
--env.teacher_top_k 8
|
||||||
|
```
|
||||||
|
|
||||||
|
If `--teacher.model_name` is a deployment alias rather than a tokenizer
|
||||||
|
identifier, also set `--teacher.tokenizer_name ...` so the env can validate
|
||||||
|
tokenizer compatibility.
|
||||||
|
|
||||||
|
Scope note:
|
||||||
|
|
||||||
|
- The teacher-aware CLI wiring currently exists for `serve`.
|
||||||
|
- If `teacher_enabled=True`, the generic `process` and `evaluate` commands will
|
||||||
|
fail loudly at env construction time unless you instantiate the env yourself
|
||||||
|
and pass `teacher_server_configs=...`.
|
||||||
|
|
||||||
|
Tokenizer requirement:
|
||||||
|
|
||||||
|
- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary.
|
||||||
|
- If the tokenizers do not match, `TeacherDistillationEnv` raises an error instead of attempting token conversion.
|
||||||
|
|
||||||
|
Why same-tokenizer is required:
|
||||||
|
|
||||||
|
- `distill_token_ids` are consumed as student-vocabulary IDs by the trainer.
|
||||||
|
- If the teacher uses a different vocabulary, the same integer token ID refers to different text on the teacher and student sides.
|
||||||
|
- A decode/re-tokenize/remap pipeline is not a safe drop-in fix because it changes both token positions and token identities, which breaks the exact per-position token supervision that the current distillation loss assumes.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Testing and Debugging Tools
|
## Testing and Debugging Tools
|
||||||
|
|
|
||||||
459
atroposlib/envs/teacher_distillation_env.py
Normal file
459
atroposlib/envs/teacher_distillation_env.py
Normal file
|
|
@ -0,0 +1,459 @@
|
||||||
|
"""
|
||||||
|
Teacher distillation environment layer.
|
||||||
|
|
||||||
|
This module adds teacher prompt-logprob fetching on top of BaseEnv without
|
||||||
|
modifying BaseEnv transport behavior.
|
||||||
|
|
||||||
|
This implementation supports same-tokenizer distillation only. The teacher and
|
||||||
|
student must share the same tokenizer vocabulary so the student's token IDs can
|
||||||
|
be forwarded directly to the teacher and the returned teacher top-k token IDs
|
||||||
|
can be looked up directly in the student's logits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_cli import Cmd
|
||||||
|
from rich import print as rprint
|
||||||
|
|
||||||
|
from ..utils.cli import (
|
||||||
|
extract_namespace,
|
||||||
|
get_double_dash_flags,
|
||||||
|
get_prefixed_pydantic_model,
|
||||||
|
merge_dicts,
|
||||||
|
)
|
||||||
|
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
|
||||||
|
from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||||
|
from .server_handling.openai_server import resolve_openai_configs
|
||||||
|
from .server_handling.server_baseline import APIServerConfig, ServerBaseline
|
||||||
|
from .server_handling.server_manager import ServerManager, ServerManagerConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TeacherDistillationConfig(BaseEnvConfig):
|
||||||
|
teacher_enabled: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to fetch teacher prompt logprobs for distillation.",
|
||||||
|
)
|
||||||
|
teacher_top_k: int = Field(
|
||||||
|
default=0,
|
||||||
|
ge=-1,
|
||||||
|
description=(
|
||||||
|
"Number of extra prompt logprobs to fetch beyond the selected token. "
|
||||||
|
"Use 0 for selected-token-only prompt logprobs and <= -1 to disable "
|
||||||
|
"teacher fetching."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TeacherDistillationEnv(BaseEnv, ABC):
|
||||||
|
"""
|
||||||
|
BaseEnv subclass that enriches scored groups with teacher distillation arrays.
|
||||||
|
|
||||||
|
Distillation payload shape:
|
||||||
|
- distill_token_ids: [sequence][position][k] (student vocab IDs)
|
||||||
|
- distill_logprobs: [sequence][position][k]
|
||||||
|
"""
|
||||||
|
|
||||||
|
env_config_cls = TeacherDistillationConfig
|
||||||
|
teacher_namespace = "teacher"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teacher_config_init(
|
||||||
|
cls,
|
||||||
|
) -> Optional[Union[ServerBaseline, List[APIServerConfig], APIServerConfig]]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _resolve_teacher_server_configs(
|
||||||
|
cls,
|
||||||
|
default_teacher_server_configs: Optional[
|
||||||
|
Union[ServerBaseline, List[APIServerConfig], APIServerConfig]
|
||||||
|
],
|
||||||
|
yaml_config: Dict[str, Any],
|
||||||
|
cli_passed_flags: Dict[str, Any],
|
||||||
|
) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]:
|
||||||
|
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
|
||||||
|
teacher_cli_passed_args = extract_namespace(
|
||||||
|
cli_passed_flags, teacher_full_prefix
|
||||||
|
)
|
||||||
|
yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {})
|
||||||
|
|
||||||
|
if (
|
||||||
|
default_teacher_server_configs is None
|
||||||
|
and not teacher_cli_passed_args
|
||||||
|
and not yaml_teacher_config
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
effective_teacher_server_configs = default_teacher_server_configs
|
||||||
|
if effective_teacher_server_configs is None:
|
||||||
|
effective_teacher_server_configs = APIServerConfig()
|
||||||
|
elif isinstance(effective_teacher_server_configs, ServerBaseline) and (
|
||||||
|
teacher_cli_passed_args or yaml_teacher_config
|
||||||
|
):
|
||||||
|
effective_teacher_server_configs = APIServerConfig(
|
||||||
|
**effective_teacher_server_configs.model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(effective_teacher_server_configs, list)
|
||||||
|
and len(effective_teacher_server_configs) == 1
|
||||||
|
):
|
||||||
|
default_teacher_config = effective_teacher_server_configs[0]
|
||||||
|
else:
|
||||||
|
default_teacher_config = effective_teacher_server_configs
|
||||||
|
|
||||||
|
if isinstance(yaml_teacher_config, list) and len(yaml_teacher_config) == 1:
|
||||||
|
yaml_teacher_config = yaml_teacher_config[0]
|
||||||
|
|
||||||
|
if isinstance(default_teacher_config, APIServerConfig) and isinstance(
|
||||||
|
yaml_teacher_config, dict
|
||||||
|
):
|
||||||
|
teacher_config_dict = merge_dicts(
|
||||||
|
default_teacher_config.model_dump(),
|
||||||
|
yaml_teacher_config,
|
||||||
|
teacher_cli_passed_args,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
teacher_config_dict = {}
|
||||||
|
|
||||||
|
teacher_yaml_wrapped = {OPENAI_NAMESPACE: yaml_teacher_config}
|
||||||
|
teacher_cli_wrapped = {
|
||||||
|
f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}{key}": value
|
||||||
|
for key, value in teacher_cli_passed_args.items()
|
||||||
|
}
|
||||||
|
return resolve_openai_configs(
|
||||||
|
default_server_configs=effective_teacher_server_configs,
|
||||||
|
openai_config_dict=teacher_config_dict,
|
||||||
|
yaml_config=teacher_yaml_wrapped,
|
||||||
|
cli_passed_flags=teacher_cli_wrapped,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_serve_config_cls(cls) -> type:
|
||||||
|
default_env_config, default_server_configs = cls.config_init()
|
||||||
|
default_teacher_server_configs = cls.teacher_config_init()
|
||||||
|
|
||||||
|
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||||
|
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||||
|
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
|
||||||
|
teacher_cli_base = get_prefixed_pydantic_model(
|
||||||
|
APIServerConfig, teacher_full_prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
class CliServeConfig(
|
||||||
|
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
|
||||||
|
get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix),
|
||||||
|
teacher_cli_base,
|
||||||
|
ServerManagerConfig,
|
||||||
|
Cmd,
|
||||||
|
):
|
||||||
|
config: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Path to .yaml config file. CLI args override this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||||
|
if (
|
||||||
|
getattr(self, wandb_name_attr, None) is None
|
||||||
|
and cls.name is not None
|
||||||
|
):
|
||||||
|
setattr(self, wandb_name_attr, cls.name)
|
||||||
|
|
||||||
|
if self.config is not None:
|
||||||
|
with open(self.config, "r") as f:
|
||||||
|
yaml_config = yaml.safe_load(f)
|
||||||
|
logger.info("Loaded config from %s", self.config)
|
||||||
|
else:
|
||||||
|
yaml_config = {}
|
||||||
|
|
||||||
|
cli_passed_flags = get_double_dash_flags()
|
||||||
|
|
||||||
|
env_config_dict = merge_dicts(
|
||||||
|
default_env_config.model_dump(),
|
||||||
|
yaml_config.get(ENV_NAMESPACE, {}),
|
||||||
|
extract_namespace(cli_passed_flags, env_full_prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
oai_cli_passed_args = extract_namespace(
|
||||||
|
cli_passed_flags, openai_full_prefix
|
||||||
|
)
|
||||||
|
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||||
|
|
||||||
|
effective_server_configs = default_server_configs
|
||||||
|
if isinstance(effective_server_configs, ServerBaseline) and (
|
||||||
|
oai_cli_passed_args or yaml_oai_config
|
||||||
|
):
|
||||||
|
effective_server_configs = APIServerConfig(
|
||||||
|
**effective_server_configs.model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(effective_server_configs, list)
|
||||||
|
and len(effective_server_configs) == 1
|
||||||
|
):
|
||||||
|
default_openai_config_ = effective_server_configs[0]
|
||||||
|
else:
|
||||||
|
default_openai_config_ = effective_server_configs
|
||||||
|
|
||||||
|
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||||
|
yaml_oai_config = yaml_oai_config[0]
|
||||||
|
|
||||||
|
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||||
|
yaml_oai_config, dict
|
||||||
|
):
|
||||||
|
openai_config_dict = merge_dicts(
|
||||||
|
default_openai_config_.model_dump(),
|
||||||
|
yaml_oai_config,
|
||||||
|
oai_cli_passed_args,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
openai_config_dict = {}
|
||||||
|
|
||||||
|
server_manager_cli_passed_flags = {}
|
||||||
|
if "slurm" in cli_passed_flags:
|
||||||
|
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||||
|
if "testing" in cli_passed_flags:
|
||||||
|
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
|
||||||
|
"testing"
|
||||||
|
]
|
||||||
|
|
||||||
|
server_manager_yaml_dict = {}
|
||||||
|
if "slurm" in yaml_config:
|
||||||
|
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
|
||||||
|
if "testing" in yaml_config:
|
||||||
|
server_manager_yaml_dict["testing"] = yaml_config["testing"]
|
||||||
|
|
||||||
|
server_manager_config_dict = merge_dicts(
|
||||||
|
ServerManagerConfig().model_dump(),
|
||||||
|
server_manager_yaml_dict,
|
||||||
|
server_manager_cli_passed_flags,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_config = type(default_env_config)(**env_config_dict)
|
||||||
|
server_manager_config = ServerManagerConfig(
|
||||||
|
**server_manager_config_dict
|
||||||
|
)
|
||||||
|
openai_configs = resolve_openai_configs(
|
||||||
|
default_server_configs=effective_server_configs,
|
||||||
|
openai_config_dict=openai_config_dict,
|
||||||
|
yaml_config=yaml_config,
|
||||||
|
cli_passed_flags=cli_passed_flags,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
teacher_configs = cls._resolve_teacher_server_configs(
|
||||||
|
default_teacher_server_configs=default_teacher_server_configs,
|
||||||
|
yaml_config=yaml_config,
|
||||||
|
cli_passed_flags=cli_passed_flags,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_kwargs = {
|
||||||
|
"config": env_config,
|
||||||
|
"server_configs": openai_configs,
|
||||||
|
"slurm": server_manager_config.slurm,
|
||||||
|
"testing": server_manager_config.testing,
|
||||||
|
}
|
||||||
|
if teacher_configs is not None:
|
||||||
|
env_kwargs["teacher_server_configs"] = teacher_configs
|
||||||
|
env = cls(**env_kwargs)
|
||||||
|
rprint(env_config)
|
||||||
|
rprint(openai_configs)
|
||||||
|
if teacher_configs is not None:
|
||||||
|
rprint(teacher_configs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
task = loop.create_task(env.env_manager())
|
||||||
|
loop.run_until_complete(task)
|
||||||
|
except RuntimeError:
|
||||||
|
asyncio.run(env.env_manager())
|
||||||
|
|
||||||
|
return CliServeConfig
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: TeacherDistillationConfig,
|
||||||
|
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
||||||
|
teacher_server_configs: Optional[
|
||||||
|
Union[ServerBaseline, APIServerConfig, List[APIServerConfig]]
|
||||||
|
] = None,
|
||||||
|
slurm: bool = False,
|
||||||
|
testing: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(config, server_configs, slurm=slurm, testing=testing)
|
||||||
|
self.teacher_server: Optional[ServerManager] = None
|
||||||
|
|
||||||
|
if config.teacher_enabled:
|
||||||
|
if teacher_server_configs is None:
|
||||||
|
raise ValueError(
|
||||||
|
"teacher_enabled=True but no teacher server configuration was "
|
||||||
|
"provided. Pass teacher_server_configs=... when instantiating "
|
||||||
|
"the environment directly, or use the teacher-aware 'serve' CLI "
|
||||||
|
"path with --teacher.* flags. The generic BaseEnv 'process' and "
|
||||||
|
"'evaluate' commands do not currently wire teacher_server_configs."
|
||||||
|
)
|
||||||
|
if isinstance(teacher_server_configs, APIServerConfig):
|
||||||
|
teacher_config_source = [teacher_server_configs]
|
||||||
|
else:
|
||||||
|
teacher_config_source = teacher_server_configs
|
||||||
|
self.teacher_server = ServerManager(
|
||||||
|
teacher_config_source,
|
||||||
|
slurm=False,
|
||||||
|
testing=False,
|
||||||
|
)
|
||||||
|
if isinstance(teacher_config_source, list):
|
||||||
|
teacher_cfg = teacher_config_source[0]
|
||||||
|
else:
|
||||||
|
teacher_cfg = teacher_config_source
|
||||||
|
|
||||||
|
teacher_tokenizer_name = (
|
||||||
|
teacher_cfg.model_name
|
||||||
|
if getattr(teacher_cfg, "tokenizer_name", "none") in ("", "none")
|
||||||
|
else teacher_cfg.tokenizer_name
|
||||||
|
)
|
||||||
|
self._validate_teacher_tokenizer_compatibility(teacher_tokenizer_name)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Core fetch
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _validate_teacher_tokenizer_compatibility(
|
||||||
|
self, teacher_tokenizer_name: str
|
||||||
|
) -> None:
|
||||||
|
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
|
||||||
|
if student_tok_name == teacher_tokenizer_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
teacher_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
teacher_tokenizer_name, use_fast=True
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Cross-tokenizer distillation is not supported in this PR, and the "
|
||||||
|
f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to "
|
||||||
|
f"verify compatibility: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
student_vocab = self.tokenizer.get_vocab()
|
||||||
|
teacher_vocab = teacher_tokenizer.get_vocab()
|
||||||
|
if student_vocab != teacher_vocab:
|
||||||
|
raise ValueError(
|
||||||
|
"Cross-tokenizer distillation is not supported in this PR. "
|
||||||
|
f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' "
|
||||||
|
f"and teacher tokenizer '{teacher_tokenizer_name}' do not match."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fetch_teacher_for_sequence(
|
||||||
|
self, token_ids: List[int], top_k: int
|
||||||
|
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||||
|
assert self.teacher_server is not None
|
||||||
|
payload = await self.teacher_server.get_logprobs(
|
||||||
|
input_ids=token_ids,
|
||||||
|
top_k=top_k,
|
||||||
|
max_tokens=1,
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Group enrichment
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _attach_teacher_distillation(
|
||||||
|
self, group: ScoredDataGroup
|
||||||
|
) -> ScoredDataGroup:
|
||||||
|
if not self.config.teacher_enabled or self.teacher_server is None:
|
||||||
|
return group
|
||||||
|
|
||||||
|
seqs = group.get("tokens", [])
|
||||||
|
if not seqs:
|
||||||
|
group["distill_token_ids"] = None
|
||||||
|
group["distill_logprobs"] = None
|
||||||
|
return group
|
||||||
|
|
||||||
|
group_overrides = group.get("group_overrides") or {}
|
||||||
|
if group_overrides.get("skip_teacher_top_k", False):
|
||||||
|
group["distill_token_ids"] = None
|
||||||
|
group["distill_logprobs"] = None
|
||||||
|
return group
|
||||||
|
|
||||||
|
top_k = int(group_overrides.get("teacher_top_k", self.config.teacher_top_k))
|
||||||
|
if top_k <= -1:
|
||||||
|
group["distill_token_ids"] = None
|
||||||
|
group["distill_logprobs"] = None
|
||||||
|
return group
|
||||||
|
|
||||||
|
tasks = [self._fetch_teacher_for_sequence(seq, top_k) for seq in seqs]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
distill_token_ids: List[List[List[int]]] = []
|
||||||
|
distill_logprobs: List[List[List[float]]] = []
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.warning(
|
||||||
|
"Teacher logprob fetch failed for seq %s: %s. "
|
||||||
|
"Dropping distill payload for this group.",
|
||||||
|
idx,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
group["distill_token_ids"] = None
|
||||||
|
group["distill_logprobs"] = None
|
||||||
|
return group
|
||||||
|
token_ids_k, logprobs_k = result
|
||||||
|
if len(token_ids_k) != len(logprobs_k):
|
||||||
|
logger.warning(
|
||||||
|
"Teacher prompt-topk length mismatch for seq %s (%s != %s). "
|
||||||
|
"Dropping distill payload for this group.",
|
||||||
|
idx,
|
||||||
|
len(token_ids_k),
|
||||||
|
len(logprobs_k),
|
||||||
|
)
|
||||||
|
group["distill_token_ids"] = None
|
||||||
|
group["distill_logprobs"] = None
|
||||||
|
return group
|
||||||
|
distill_token_ids.append(token_ids_k)
|
||||||
|
distill_logprobs.append(logprobs_k)
|
||||||
|
|
||||||
|
group["distill_token_ids"] = distill_token_ids
|
||||||
|
group["distill_logprobs"] = distill_logprobs
|
||||||
|
return group
|
||||||
|
|
||||||
|
async def handle_send_to_api(
|
||||||
|
self,
|
||||||
|
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||||
|
item: Any = None,
|
||||||
|
do_send_to_api: bool = True,
|
||||||
|
abort_on_any_max_length_exceeded: bool = True,
|
||||||
|
):
|
||||||
|
groups = scored_data if isinstance(scored_data, list) else [scored_data]
|
||||||
|
enriched_groups: List[ScoredDataGroup] = []
|
||||||
|
for group in groups:
|
||||||
|
if group is None:
|
||||||
|
continue
|
||||||
|
enriched_groups.append(await self._attach_teacher_distillation(group))
|
||||||
|
|
||||||
|
payload: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||||
|
if isinstance(scored_data, list):
|
||||||
|
payload = enriched_groups
|
||||||
|
else:
|
||||||
|
payload = enriched_groups[0] if enriched_groups else scored_data
|
||||||
|
|
||||||
|
return await super().handle_send_to_api(
|
||||||
|
payload,
|
||||||
|
item=item,
|
||||||
|
do_send_to_api=do_send_to_api,
|
||||||
|
abort_on_any_max_length_exceeded=abort_on_any_max_length_exceeded,
|
||||||
|
)
|
||||||
344
atroposlib/tests/test_teacher_distillation_env.py
Normal file
344
atroposlib/tests/test_teacher_distillation_env.py
Normal file
|
|
@ -0,0 +1,344 @@
|
||||||
|
"""Tests for TeacherDistillationEnv distillation enrichment."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||||
|
from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTeacherServer:
|
||||||
|
def __init__(self, fail_on_call: int = -1):
|
||||||
|
self.calls = 0
|
||||||
|
self.fail_on_call = fail_on_call
|
||||||
|
|
||||||
|
async def get_logprobs(self, **kwargs):
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == self.fail_on_call:
|
||||||
|
raise RuntimeError("teacher backend failure")
|
||||||
|
seq = kwargs["input_ids"]
|
||||||
|
return {
|
||||||
|
"prompt_tokens": seq,
|
||||||
|
"prompt_topk_token_ids": [[tok, tok + 1] for tok in seq],
|
||||||
|
"prompt_topk_logprobs": [[-0.1, -0.2] for _ in seq],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _ConcreteTeacherEnv(TeacherDistillationEnv):
|
||||||
|
async def get_next_item(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def evaluate(self, *args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyTokenizer:
|
||||||
|
name_or_path = "student-model"
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
return {"a": 1}
|
||||||
|
|
||||||
|
|
||||||
|
class _CapturingServerManager:
|
||||||
|
def __init__(self, configs, slurm=False, testing=False):
|
||||||
|
self.configs = configs
|
||||||
|
self.slurm = slurm
|
||||||
|
self.testing = testing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_attach_teacher_distillation_success():
|
||||||
|
env = object.__new__(_ConcreteTeacherEnv)
|
||||||
|
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||||
|
env.teacher_server = _FakeTeacherServer()
|
||||||
|
|
||||||
|
group = {
|
||||||
|
"tokens": [[1, 2, 3], [4, 5]],
|
||||||
|
"group_overrides": None,
|
||||||
|
"masks": [[-100, 2, 3], [-100, 5]],
|
||||||
|
"scores": [1.0, 0.0],
|
||||||
|
}
|
||||||
|
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||||
|
assert out["distill_token_ids"] is not None
|
||||||
|
assert out["distill_logprobs"] is not None
|
||||||
|
assert len(out["distill_token_ids"]) == 2
|
||||||
|
assert len(out["distill_token_ids"][0]) == 3
|
||||||
|
assert len(out["distill_logprobs"][1]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_attach_teacher_distillation_failure_drops_payload():
|
||||||
|
env = object.__new__(_ConcreteTeacherEnv)
|
||||||
|
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||||
|
env.teacher_server = _FakeTeacherServer(fail_on_call=2)
|
||||||
|
|
||||||
|
group = {
|
||||||
|
"tokens": [[1, 2, 3], [4, 5]],
|
||||||
|
"group_overrides": None,
|
||||||
|
"masks": [[-100, 2, 3], [-100, 5]],
|
||||||
|
"scores": [1.0, 0.0],
|
||||||
|
}
|
||||||
|
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||||
|
assert out["distill_token_ids"] is None
|
||||||
|
assert out["distill_logprobs"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_attach_teacher_distillation_negative_topk_skips_fetch():
|
||||||
|
env = object.__new__(_ConcreteTeacherEnv)
|
||||||
|
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=-1)
|
||||||
|
env.teacher_server = _FakeTeacherServer()
|
||||||
|
|
||||||
|
group = {
|
||||||
|
"tokens": [[1, 2, 3]],
|
||||||
|
"group_overrides": None,
|
||||||
|
"masks": [[-100, 2, 3]],
|
||||||
|
"scores": [1.0],
|
||||||
|
}
|
||||||
|
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||||
|
assert env.teacher_server.calls == 0
|
||||||
|
assert out["distill_token_ids"] is None
|
||||||
|
assert out["distill_logprobs"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_attach_teacher_distillation_zero_topk_passthrough():
|
||||||
|
env = object.__new__(_ConcreteTeacherEnv)
|
||||||
|
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
|
||||||
|
env.teacher_server = _FakeTeacherServer()
|
||||||
|
|
||||||
|
group = {
|
||||||
|
"tokens": [[1, 2, 3]],
|
||||||
|
"group_overrides": None,
|
||||||
|
"masks": [[-100, 2, 3]],
|
||||||
|
"scores": [1.0],
|
||||||
|
}
|
||||||
|
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||||
|
assert env.teacher_server.calls == 1
|
||||||
|
assert out["distill_token_ids"] is not None
|
||||||
|
assert out["distill_logprobs"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_attach_teacher_distillation_group_override_topk_is_used():
|
||||||
|
env = object.__new__(_ConcreteTeacherEnv)
|
||||||
|
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
|
||||||
|
|
||||||
|
seen_topks = []
|
||||||
|
|
||||||
|
async def _fake_fetch(seq, top_k):
|
||||||
|
seen_topks.append(top_k)
|
||||||
|
return [[tok] for tok in seq], [[-0.1] for _ in seq]
|
||||||
|
|
||||||
|
env.teacher_server = object()
|
||||||
|
env._fetch_teacher_for_sequence = _fake_fetch
|
||||||
|
|
||||||
|
group = {
|
||||||
|
"tokens": [[1, 2, 3], [4, 5]],
|
||||||
|
"group_overrides": {"teacher_top_k": 7},
|
||||||
|
"masks": [[-100, 2, 3], [-100, 5]],
|
||||||
|
"scores": [1.0, 0.0],
|
||||||
|
}
|
||||||
|
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||||
|
assert seen_topks == [7, 7]
|
||||||
|
assert out["distill_token_ids"] is not None
|
||||||
|
assert out["distill_logprobs"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
|
||||||
|
env = object.__new__(_ConcreteTeacherEnv)
|
||||||
|
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=2)
|
||||||
|
env.teacher_server = _FakeTeacherServer()
|
||||||
|
|
||||||
|
group = {
|
||||||
|
"tokens": [[1, 2, 3]],
|
||||||
|
"group_overrides": {"skip_teacher_top_k": True},
|
||||||
|
"masks": [[-100, 2, 3]],
|
||||||
|
"scores": [1.0],
|
||||||
|
}
|
||||||
|
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||||
|
assert env.teacher_server.calls == 0
|
||||||
|
assert out["distill_token_ids"] is None
|
||||||
|
assert out["distill_logprobs"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_teacher_tokenizer_mismatch_raises(monkeypatch):
|
||||||
|
env = object.__new__(_ConcreteTeacherEnv)
|
||||||
|
|
||||||
|
class _StudentTokenizer:
|
||||||
|
name_or_path = "student-model"
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
return {"a": 1}
|
||||||
|
|
||||||
|
class _TeacherTokenizer:
|
||||||
|
def get_vocab(self):
|
||||||
|
return {"b": 1}
|
||||||
|
|
||||||
|
env.tokenizer = _StudentTokenizer()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"transformers.AutoTokenizer.from_pretrained",
|
||||||
|
lambda *args, **kwargs: _TeacherTokenizer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Cross-tokenizer distillation is not supported"
|
||||||
|
):
|
||||||
|
TeacherDistillationEnv._validate_teacher_tokenizer_compatibility(
|
||||||
|
env,
|
||||||
|
teacher_tokenizer_name="teacher-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_requires_teacher_server_source(monkeypatch):
|
||||||
|
from atroposlib.envs import teacher_distillation_env as module
|
||||||
|
|
||||||
|
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||||
|
self.config = config
|
||||||
|
self.tokenizer = _DummyTokenizer()
|
||||||
|
|
||||||
|
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||||
|
|
||||||
|
config = SimpleNamespace(
|
||||||
|
teacher_enabled=True,
|
||||||
|
teacher_top_k=0,
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="no teacher server configuration was provided"
|
||||||
|
):
|
||||||
|
_ConcreteTeacherEnv(
|
||||||
|
config=config,
|
||||||
|
server_configs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_uses_explicit_teacher_server_configs(monkeypatch):
|
||||||
|
from atroposlib.envs import teacher_distillation_env as module
|
||||||
|
|
||||||
|
called = {}
|
||||||
|
|
||||||
|
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||||
|
self.config = config
|
||||||
|
self.tokenizer = _DummyTokenizer()
|
||||||
|
|
||||||
|
def _fake_validate(self, teacher_tokenizer_name):
|
||||||
|
called["teacher_tokenizer_name"] = teacher_tokenizer_name
|
||||||
|
|
||||||
|
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||||
|
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
_ConcreteTeacherEnv,
|
||||||
|
"_validate_teacher_tokenizer_compatibility",
|
||||||
|
_fake_validate,
|
||||||
|
)
|
||||||
|
|
||||||
|
explicit_cfg = APIServerConfig(
|
||||||
|
model_name="explicit-model",
|
||||||
|
tokenizer_name="explicit-tokenizer",
|
||||||
|
base_url="http://explicit/v1",
|
||||||
|
api_key="x",
|
||||||
|
server_type="vllm",
|
||||||
|
)
|
||||||
|
config = SimpleNamespace(
|
||||||
|
teacher_enabled=True,
|
||||||
|
teacher_top_k=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
env = _ConcreteTeacherEnv(
|
||||||
|
config=config,
|
||||||
|
server_configs=[],
|
||||||
|
teacher_server_configs=[explicit_cfg],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(env.teacher_server, _CapturingServerManager)
|
||||||
|
assert env.teacher_server.configs == [explicit_cfg]
|
||||||
|
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_wraps_bare_teacher_api_server_config(monkeypatch):
|
||||||
|
from atroposlib.envs import teacher_distillation_env as module
|
||||||
|
|
||||||
|
called = {}
|
||||||
|
|
||||||
|
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||||
|
self.config = config
|
||||||
|
self.tokenizer = _DummyTokenizer()
|
||||||
|
|
||||||
|
def _fake_validate(self, teacher_tokenizer_name):
|
||||||
|
called["teacher_tokenizer_name"] = teacher_tokenizer_name
|
||||||
|
|
||||||
|
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||||
|
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
_ConcreteTeacherEnv,
|
||||||
|
"_validate_teacher_tokenizer_compatibility",
|
||||||
|
_fake_validate,
|
||||||
|
)
|
||||||
|
|
||||||
|
explicit_cfg = APIServerConfig(
|
||||||
|
model_name="explicit-model",
|
||||||
|
tokenizer_name="explicit-tokenizer",
|
||||||
|
base_url="http://explicit/v1",
|
||||||
|
api_key="x",
|
||||||
|
server_type="vllm",
|
||||||
|
)
|
||||||
|
config = SimpleNamespace(
|
||||||
|
teacher_enabled=True,
|
||||||
|
teacher_top_k=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
env = _ConcreteTeacherEnv(
|
||||||
|
config=config,
|
||||||
|
server_configs=[],
|
||||||
|
teacher_server_configs=explicit_cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(env.teacher_server, _CapturingServerManager)
|
||||||
|
assert env.teacher_server.configs == [explicit_cfg]
|
||||||
|
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_teacher_server_configs_returns_none_when_unset():
|
||||||
|
assert (
|
||||||
|
_ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||||
|
default_teacher_server_configs=None,
|
||||||
|
yaml_config={},
|
||||||
|
cli_passed_flags={},
|
||||||
|
)
|
||||||
|
is None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch):
|
||||||
|
from atroposlib.envs import teacher_distillation_env as module
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def _fake_resolve(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return ["resolved"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(module, "resolve_openai_configs", _fake_resolve)
|
||||||
|
|
||||||
|
default_cfg = APIServerConfig(
|
||||||
|
model_name="teacher-model",
|
||||||
|
base_url="http://teacher/v1",
|
||||||
|
api_key="x",
|
||||||
|
server_type="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
out = _ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||||
|
default_teacher_server_configs=default_cfg,
|
||||||
|
yaml_config={"teacher": {"tokenizer_name": "teacher-tokenizer"}},
|
||||||
|
cli_passed_flags={"teacher.base_url": "http://override/v1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert out == ["resolved"]
|
||||||
|
assert captured["openai_config_dict"]["base_url"] == "http://override/v1"
|
||||||
|
assert captured["openai_config_dict"]["tokenizer_name"] == "teacher-tokenizer"
|
||||||
|
assert captured["yaml_config"] == {
|
||||||
|
"openai": {"tokenizer_name": "teacher-tokenizer"}
|
||||||
|
}
|
||||||
|
assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"}
|
||||||
58
environments/gsm8k_server_teacher_distill.py
Normal file
58
environments/gsm8k_server_teacher_distill.py
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from atroposlib.envs.base import APIServerConfig, ServerBaseline
|
||||||
|
from atroposlib.envs.teacher_distillation_env import (
|
||||||
|
TeacherDistillationConfig,
|
||||||
|
TeacherDistillationEnv,
|
||||||
|
)
|
||||||
|
from environments.gsm8k_server import GSM8kEnv
|
||||||
|
|
||||||
|
|
||||||
|
class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv):
|
||||||
|
"""
|
||||||
|
GSM8K environment variant that enables TeacherDistillationEnv config fields.
|
||||||
|
|
||||||
|
This preserves the original `gsm8k_server.py` while providing a separate entrypoint
|
||||||
|
for teacher-distillation data collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "gsm8k_teacher_distill"
|
||||||
|
env_config_cls = TeacherDistillationConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_init(cls) -> Tuple[TeacherDistillationConfig, ServerBaseline]:
|
||||||
|
env_config = TeacherDistillationConfig(
|
||||||
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||||
|
group_size=8,
|
||||||
|
use_wandb=True,
|
||||||
|
rollout_server_url="http://localhost:8000",
|
||||||
|
total_steps=1000,
|
||||||
|
batch_size=12,
|
||||||
|
steps_per_eval=100,
|
||||||
|
max_token_length=2048,
|
||||||
|
wandb_name="gsm8k_teacher_distill",
|
||||||
|
teacher_enabled=True,
|
||||||
|
teacher_top_k=4,
|
||||||
|
)
|
||||||
|
server_config = APIServerConfig(
|
||||||
|
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||||
|
base_url="http://localhost:9001/v1",
|
||||||
|
api_key="x",
|
||||||
|
num_requests_for_eval=256,
|
||||||
|
)
|
||||||
|
return env_config, server_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teacher_config_init(cls) -> APIServerConfig:
|
||||||
|
return APIServerConfig(
|
||||||
|
base_url="http://localhost:9003/v1",
|
||||||
|
model_name="mock-teacher",
|
||||||
|
api_key="",
|
||||||
|
server_type="vllm",
|
||||||
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||||
|
timeout=1200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
GSM8kTeacherDistillEnv.cli()
|
||||||
|
|
@ -304,6 +304,38 @@ environment uses the `/generate` path and includes token-level
|
||||||
4. Trainer extracts and aligns logprobs with training labels
|
4. Trainer extracts and aligns logprobs with training labels
|
||||||
5. GRPO loss uses these rollout logprobs in importance-ratio terms
|
5. GRPO loss uses these rollout logprobs in importance-ratio terms
|
||||||
|
|
||||||
|
### 1b. Teacher distillation requires the same tokenizer
|
||||||
|
|
||||||
|
When distillation data is attached to Atropos batches, the trainer treats
|
||||||
|
`distill_token_ids` as indices into the student's logit tensor. That only works
|
||||||
|
if the teacher and student share the same tokenizer vocabulary.
|
||||||
|
|
||||||
|
What to configure on the environment side:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--env.teacher_enabled true \
|
||||||
|
--teacher.base_url "http://localhost:9003/v1" \
|
||||||
|
--teacher.model_name "$TEACHER_MODEL" \
|
||||||
|
--teacher.server_type vllm \
|
||||||
|
--env.teacher_top_k 8
|
||||||
|
```
|
||||||
|
|
||||||
|
If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier,
|
||||||
|
also set `--teacher.tokenizer_name ...` so the env can validate
|
||||||
|
tokenizer compatibility.
|
||||||
|
|
||||||
|
The teacher-aware CLI path is currently wired for `serve`. If
|
||||||
|
`teacher_enabled=True`, the generic `process` and `evaluate` commands are not
|
||||||
|
teacher-aware and will fail loudly unless the environment is instantiated
|
||||||
|
manually with `teacher_server_configs=...`.
|
||||||
|
|
||||||
|
Why cross-tokenizer conversion is not acceptable here:
|
||||||
|
|
||||||
|
- Teacher token ID `1234` and student token ID `1234` can correspond to different text.
|
||||||
|
- Re-tokenizing teacher text changes token boundaries, so teacher position `i` may no longer correspond to student position `i`.
|
||||||
|
- Remapping teacher top-k tokens back into student vocab can collapse multiple teacher candidates into one student token or expand one teacher token into multiple student tokens.
|
||||||
|
- The current distillation loss expects exact per-position supervision in student token space, so an approximate remapping would silently produce misleading targets.
|
||||||
|
|
||||||
### 2. Clipping
|
### 2. Clipping
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -248,8 +248,7 @@ run_shared_vllm() {
|
||||||
--port "$vllm_port" \
|
--port "$vllm_port" \
|
||||||
--gpu-memory-utilization "$SHARED_GPU_MEMORY_UTILIZATION" \
|
--gpu-memory-utilization "$SHARED_GPU_MEMORY_UTILIZATION" \
|
||||||
--max-model-len "$MAX_MODEL_LEN" \
|
--max-model-len "$MAX_MODEL_LEN" \
|
||||||
--dtype "$DTYPE" \
|
--dtype "$DTYPE"
|
||||||
--enforce-eager
|
|
||||||
if [[ "$DRY_RUN" == "1" ]]; then
|
if [[ "$DRY_RUN" == "1" ]]; then
|
||||||
log "[DRY RUN] wait for http://localhost:${vllm_port}/health"
|
log "[DRY RUN] wait for http://localhost:${vllm_port}/health"
|
||||||
else
|
else
|
||||||
|
|
|
||||||
313
example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
Executable file
313
example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
Executable file
|
|
@ -0,0 +1,313 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Single-terminal teacher-distillation runner.
|
||||||
|
# Starts everything in the background from ONE shell that has GPU access:
|
||||||
|
# 1) Atropos API
|
||||||
|
# 2) Student vLLM server
|
||||||
|
# 3) Teacher vLLM server
|
||||||
|
# 4) GSM8K teacher-distill environment
|
||||||
|
# 5) Example trainer (foreground)
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# chmod +x example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
|
||||||
|
# ./example_trainer/run_gsm8k_teacher_distill_single_terminal.sh
|
||||||
|
#
|
||||||
|
# Optional overrides:
|
||||||
|
# STUDENT_MODEL="Qwen/Qwen3-4B-Instruct-2507-FP8"
|
||||||
|
# TEACHER_MODEL="Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||||
|
# STUDENT_GPUS="0"
|
||||||
|
# TEACHER_GPUS="4,5,6,7"
|
||||||
|
# TRAINER_GPUS="0"
|
||||||
|
# STUDENT_TP=1
|
||||||
|
# TEACHER_TP=4
|
||||||
|
# API_PORT=8002
|
||||||
|
# STUDENT_PORT=9001
|
||||||
|
# TEACHER_PORT=9003
|
||||||
|
# TRAINING_STEPS=100
|
||||||
|
# DISTILL_COEF=0.2
|
||||||
|
# DISTILL_TEMPERATURE=1.0
|
||||||
|
# TEACHER_TOP_K=8
|
||||||
|
# DRY_RUN=1
|
||||||
|
|
||||||
|
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
LAUNCH_DIR="$PWD"
|
||||||
|
cd "$ROOT_DIR"
|
||||||
|
|
||||||
|
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
||||||
|
STUDENT_MODEL="${STUDENT_MODEL:-Qwen/Qwen3-4B}"
|
||||||
|
TEACHER_MODEL="${TEACHER_MODEL:-Qwen/Qwen3-30B-A3B-Instruct-2507}"
|
||||||
|
|
||||||
|
STUDENT_GPUS="${STUDENT_GPUS:-0}"
|
||||||
|
TEACHER_GPUS="${TEACHER_GPUS:-4,5,6,7}"
|
||||||
|
TRAINER_GPUS="${TRAINER_GPUS:-$STUDENT_GPUS}"
|
||||||
|
|
||||||
|
STUDENT_TP="${STUDENT_TP:-1}"
|
||||||
|
TEACHER_TP="${TEACHER_TP:-4}"
|
||||||
|
|
||||||
|
API_PORT="${API_PORT:-8002}"
|
||||||
|
STUDENT_PORT="${STUDENT_PORT:-9001}"
|
||||||
|
TEACHER_PORT="${TEACHER_PORT:-9003}"
|
||||||
|
|
||||||
|
TRAINING_STEPS="${TRAINING_STEPS:-20}"
|
||||||
|
BATCH_SIZE="${BATCH_SIZE:-1}"
|
||||||
|
GRAD_ACCUM="${GRAD_ACCUM:-4}"
|
||||||
|
LR="${LR:-1e-5}"
|
||||||
|
WARMUP_STEPS="${WARMUP_STEPS:-0}"
|
||||||
|
CLIP_EPS="${CLIP_EPS:-0.2}"
|
||||||
|
MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}"
|
||||||
|
TEACHER_MAX_MODEL_LEN="${TEACHER_MAX_MODEL_LEN:-32768}"
|
||||||
|
# Trainer seq_len must be larger than ENV_MAX_TOKEN_LENGTH to accommodate
|
||||||
|
# chat template overhead (~400-800 tokens for Qwen3 thinking format).
|
||||||
|
TRAINER_SEQ_LEN="${TRAINER_SEQ_LEN:-20480}"
|
||||||
|
ENV_MAX_TOKEN_LENGTH="${ENV_MAX_TOKEN_LENGTH:-16384}"
|
||||||
|
DISTILL_COEF="${DISTILL_COEF:-0.2}"
|
||||||
|
DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}"
|
||||||
|
TEACHER_TOP_K="${TEACHER_TOP_K:-8}"
|
||||||
|
|
||||||
|
WANDB_PROJECT="${WANDB_PROJECT:-gsm8k-teacher-distill}"
|
||||||
|
WANDB_GROUP="${WANDB_GROUP:-}"
|
||||||
|
|
||||||
|
STUDENT_GPU_MEMORY_UTILIZATION="${STUDENT_GPU_MEMORY_UTILIZATION:-0.60}"
|
||||||
|
TEACHER_GPU_MEMORY_UTILIZATION="${TEACHER_GPU_MEMORY_UTILIZATION:-0.85}"
|
||||||
|
DTYPE="${DTYPE:-bfloat16}"
|
||||||
|
SAVE_DIR="${SAVE_DIR:-${LAUNCH_DIR}/saves/gsm8k_teacher_distill}"
|
||||||
|
LOG_DIR="${LOG_DIR:-${LAUNCH_DIR}/logs/gsm8k_teacher_distill}"
|
||||||
|
BRIDGE_DIR="${BRIDGE_DIR:-${LOG_DIR}/bridge}"
|
||||||
|
DRY_RUN="${DRY_RUN:-0}"
|
||||||
|
|
||||||
|
ENV_GROUP_SIZE="${ENV_GROUP_SIZE:-4}"
|
||||||
|
ENV_BATCH_SIZE="${ENV_BATCH_SIZE:-8}"
|
||||||
|
ENV_TOTAL_STEPS="${ENV_TOTAL_STEPS:-200}"
|
||||||
|
ENV_STEPS_PER_EVAL="${ENV_STEPS_PER_EVAL:-50}"
|
||||||
|
ENV_MAX_WORKERS_PER_NODE="${ENV_MAX_WORKERS_PER_NODE:-1}"
|
||||||
|
ENV_WORKER_TIMEOUT="${ENV_WORKER_TIMEOUT:-1800}"
|
||||||
|
|
||||||
|
RUN_PIDS=()
|
||||||
|
RUN_PORTS=()
|
||||||
|
|
||||||
|
log() {
|
||||||
|
echo "[$(date '+%H:%M:%S')] $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
kill_port() {
|
||||||
|
local port="$1"
|
||||||
|
if [[ "$DRY_RUN" == "1" ]]; then
|
||||||
|
log "[DRY RUN] skip port cleanup for :${port}"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
if lsof -i ":${port}" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||||
|
lsof -ti ":${port}" | xargs -r kill -9 || true
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_http() {
|
||||||
|
local url="$1"
|
||||||
|
local timeout="${2:-240}"
|
||||||
|
local name="${3:-endpoint}"
|
||||||
|
local start
|
||||||
|
start="$(date +%s)"
|
||||||
|
while true; do
|
||||||
|
if curl -fsS "$url" >/dev/null 2>&1; then
|
||||||
|
log "Ready: ${name} (${url})"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
if (( "$(date +%s)" - start > timeout )); then
|
||||||
|
log "Timeout waiting for ${name}: ${url}"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
start_process() {
|
||||||
|
local name="$1"
|
||||||
|
local logfile="$2"
|
||||||
|
shift 2
|
||||||
|
if [[ "$DRY_RUN" == "1" ]]; then
|
||||||
|
log "[DRY RUN] start ${name} (log: ${logfile})"
|
||||||
|
printf ' '
|
||||||
|
printf '%q ' "$@"
|
||||||
|
printf '\n'
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
log "Starting ${name} (log: ${logfile})"
|
||||||
|
"$@" >"$logfile" 2>&1 &
|
||||||
|
local pid=$!
|
||||||
|
RUN_PIDS+=("$pid")
|
||||||
|
log "${name} PID=${pid}"
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup_all() {
|
||||||
|
log "Cleaning up processes..."
|
||||||
|
for pid in "${RUN_PIDS[@]:-}"; do
|
||||||
|
kill "$pid" >/dev/null 2>&1 || true
|
||||||
|
done
|
||||||
|
sleep 1
|
||||||
|
for pid in "${RUN_PIDS[@]:-}"; do
|
||||||
|
kill -9 "$pid" >/dev/null 2>&1 || true
|
||||||
|
done
|
||||||
|
for port in "${RUN_PORTS[@]:-}"; do
|
||||||
|
kill_port "$port"
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
trap cleanup_all EXIT INT TERM
|
||||||
|
|
||||||
|
mkdir -p "$LOG_DIR" "$SAVE_DIR" "$BRIDGE_DIR"
|
||||||
|
RUN_PORTS+=("$API_PORT" "$STUDENT_PORT" "$TEACHER_PORT")
|
||||||
|
kill_port "$API_PORT"
|
||||||
|
kill_port "$STUDENT_PORT"
|
||||||
|
kill_port "$TEACHER_PORT"
|
||||||
|
|
||||||
|
log "Config:"
|
||||||
|
log " student=${STUDENT_MODEL}"
|
||||||
|
log " teacher=${TEACHER_MODEL}"
|
||||||
|
log " gpus student=${STUDENT_GPUS}, teacher=${TEACHER_GPUS}, trainer=${TRAINER_GPUS}"
|
||||||
|
log " ports api=${API_PORT}, student=${STUDENT_PORT}, teacher=${TEACHER_PORT}"
|
||||||
|
log " logs=${LOG_DIR}"
|
||||||
|
log " saves=${SAVE_DIR}"
|
||||||
|
log " bridge=${BRIDGE_DIR}"
|
||||||
|
log " env max_token_length=${ENV_MAX_TOKEN_LENGTH}, env workers=${ENV_MAX_WORKERS_PER_NODE}, env worker_timeout=${ENV_WORKER_TIMEOUT}"
|
||||||
|
log " wandb project=${WANDB_PROJECT}${WANDB_GROUP:+, group=${WANDB_GROUP}}"
|
||||||
|
|
||||||
|
# Shared-vLLM attach path currently expects the student server to expose
|
||||||
|
# unsharded weights. Keep the student on TP=1 and the trainer on the same GPU set.
|
||||||
|
if [[ "$STUDENT_TP" != "1" ]]; then
|
||||||
|
log "ERROR: shared_vllm teacher-distill runner currently requires STUDENT_TP=1."
|
||||||
|
log " The current attach path does not support TP-sharded student bridge weights."
|
||||||
|
exit 2
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$TRAINER_GPUS" != "$STUDENT_GPUS" ]]; then
|
||||||
|
log "ERROR: TRAINER_GPUS must match STUDENT_GPUS for shared_vllm mode."
|
||||||
|
log " Got student=${STUDENT_GPUS}, trainer=${TRAINER_GPUS}"
|
||||||
|
exit 2
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 1) Atropos API
|
||||||
|
start_process "run_api" "${LOG_DIR}/run_api.log" \
|
||||||
|
run-api --port "$API_PORT"
|
||||||
|
if [[ "$DRY_RUN" == "0" ]]; then
|
||||||
|
wait_for_http "http://localhost:${API_PORT}/info" 180 "run-api"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 2) Student vLLM server
|
||||||
|
start_process "student_vllm" "${LOG_DIR}/student_vllm.log" \
|
||||||
|
env CUDA_VISIBLE_DEVICES="$STUDENT_GPUS" VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR="$BRIDGE_DIR" \
|
||||||
|
"$PYTHON_BIN" -m example_trainer.vllm_api_server \
|
||||||
|
--model "$STUDENT_MODEL" \
|
||||||
|
--port "$STUDENT_PORT" \
|
||||||
|
--tensor-parallel-size "$STUDENT_TP" \
|
||||||
|
--gpu-memory-utilization "$STUDENT_GPU_MEMORY_UTILIZATION" \
|
||||||
|
--max-model-len "$MAX_MODEL_LEN" \
|
||||||
|
--dtype "$DTYPE"
|
||||||
|
if [[ "$DRY_RUN" == "0" ]]; then
|
||||||
|
wait_for_http "http://localhost:${STUDENT_PORT}/health" 420 "student vLLM"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 3) Teacher vLLM server
|
||||||
|
start_process "teacher_vllm" "${LOG_DIR}/teacher_vllm.log" \
|
||||||
|
env CUDA_VISIBLE_DEVICES="$TEACHER_GPUS" \
|
||||||
|
"$PYTHON_BIN" -m example_trainer.vllm_api_server \
|
||||||
|
--model "$TEACHER_MODEL" \
|
||||||
|
--port "$TEACHER_PORT" \
|
||||||
|
--tensor-parallel-size "$TEACHER_TP" \
|
||||||
|
--gpu-memory-utilization "$TEACHER_GPU_MEMORY_UTILIZATION" \
|
||||||
|
--max-model-len "$TEACHER_MAX_MODEL_LEN" \
|
||||||
|
--dtype "$DTYPE"
|
||||||
|
if [[ "$DRY_RUN" == "0" ]]; then
|
||||||
|
wait_for_http "http://localhost:${TEACHER_PORT}/health" 1800 "teacher vLLM"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 4) Teacher-distill GSM8K env
|
||||||
|
start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \
|
||||||
|
"$PYTHON_BIN" environments/gsm8k_server_teacher_distill.py serve \
|
||||||
|
--env.group_size "$ENV_GROUP_SIZE" \
|
||||||
|
--env.batch_size "$ENV_BATCH_SIZE" \
|
||||||
|
--env.total_steps "$ENV_TOTAL_STEPS" \
|
||||||
|
--env.steps_per_eval "$ENV_STEPS_PER_EVAL" \
|
||||||
|
--env.max_num_workers_per_node "$ENV_MAX_WORKERS_PER_NODE" \
|
||||||
|
--env.max_token_length "$ENV_MAX_TOKEN_LENGTH" \
|
||||||
|
--env.worker_timeout "$ENV_WORKER_TIMEOUT" \
|
||||||
|
--env.rollout_server_url "http://localhost:${API_PORT}" \
|
||||||
|
--env.use_wandb true \
|
||||||
|
--env.wandb_name "gsm8k-teacher-distill" \
|
||||||
|
--env.teacher_enabled true \
|
||||||
|
--teacher.base_url "http://localhost:${TEACHER_PORT}/v1" \
|
||||||
|
--teacher.model_name "$TEACHER_MODEL" \
|
||||||
|
--teacher.server_type vllm \
|
||||||
|
--env.teacher_top_k "$TEACHER_TOP_K" \
|
||||||
|
--env.ensure_scores_are_not_same false \
|
||||||
|
--openai.api_key "dummy" \
|
||||||
|
--openai.base_url "http://localhost:${STUDENT_PORT}/v1" \
|
||||||
|
--openai.model_name "$STUDENT_MODEL" \
|
||||||
|
--openai.tokenizer_name "$STUDENT_MODEL" \
|
||||||
|
--openai.server_type vllm
|
||||||
|
|
||||||
|
log "All services launched."
|
||||||
|
log "Run logs:"
|
||||||
|
log " ${LOG_DIR}/run_api.log"
|
||||||
|
log " ${LOG_DIR}/student_vllm.log"
|
||||||
|
log " ${LOG_DIR}/teacher_vllm.log"
|
||||||
|
log " ${LOG_DIR}/env.log"
|
||||||
|
|
||||||
|
# 5) Trainer (background)
|
||||||
|
TRAINER_CMD=(
|
||||||
|
env
|
||||||
|
CUDA_VISIBLE_DEVICES="$TRAINER_GPUS"
|
||||||
|
PYTHONUNBUFFERED=1
|
||||||
|
"$PYTHON_BIN"
|
||||||
|
-u
|
||||||
|
-m
|
||||||
|
example_trainer.grpo
|
||||||
|
--model-name "$STUDENT_MODEL"
|
||||||
|
--weight-bridge-mode shared_vllm
|
||||||
|
--device cuda:0
|
||||||
|
--save-path "$SAVE_DIR"
|
||||||
|
--atropos-url "http://localhost:${API_PORT}"
|
||||||
|
--vllm-port "$STUDENT_PORT"
|
||||||
|
--vllm-config-path "${BRIDGE_DIR}/vllm_bridge_config.json"
|
||||||
|
--training-steps "$TRAINING_STEPS"
|
||||||
|
--batch-size "$BATCH_SIZE"
|
||||||
|
--gradient-accumulation-steps "$GRAD_ACCUM"
|
||||||
|
--warmup-steps "$WARMUP_STEPS"
|
||||||
|
--lr "$LR"
|
||||||
|
--clip-eps "$CLIP_EPS"
|
||||||
|
--seq-len "$TRAINER_SEQ_LEN"
|
||||||
|
--distill-enabled
|
||||||
|
--distill-coef "$DISTILL_COEF"
|
||||||
|
--distill-temperature "$DISTILL_TEMPERATURE"
|
||||||
|
--use-wandb
|
||||||
|
--wandb-project "$WANDB_PROJECT"
|
||||||
|
)
|
||||||
|
if [[ -n "$WANDB_GROUP" ]]; then
|
||||||
|
TRAINER_CMD+=(--wandb-group "$WANDB_GROUP")
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$DRY_RUN" == "1" ]]; then
|
||||||
|
log "[DRY RUN] trainer command:"
|
||||||
|
printf ' '
|
||||||
|
printf '%q ' "${TRAINER_CMD[@]}"
|
||||||
|
printf '\n'
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
start_process "trainer" "${LOG_DIR}/trainer.log" "${TRAINER_CMD[@]}"
|
||||||
|
|
||||||
|
log "All processes running in background."
|
||||||
|
log ""
|
||||||
|
log "Monitor with:"
|
||||||
|
log " tail -f ${LOG_DIR}/trainer.log"
|
||||||
|
log " tail -f ${LOG_DIR}/env.log"
|
||||||
|
log " tail -f ${LOG_DIR}/student_vllm.log"
|
||||||
|
log " tail -f ${LOG_DIR}/teacher_vllm.log"
|
||||||
|
log ""
|
||||||
|
log "Test endpoints:"
|
||||||
|
log " curl -s http://localhost:${STUDENT_PORT}/health"
|
||||||
|
log " curl -s http://localhost:${TEACHER_PORT}/health"
|
||||||
|
log " curl -s http://localhost:${STUDENT_PORT}/bridge/is_paused | jq ."
|
||||||
|
log ""
|
||||||
|
log "To stop all processes:"
|
||||||
|
log " kill ${RUN_PIDS[*]:-} 2>/dev/null; sleep 1; kill -9 ${RUN_PIDS[*]:-} 2>/dev/null"
|
||||||
|
trap - EXIT INT TERM
|
||||||
Loading…
Add table
Add a link
Reference in a new issue