mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
structural changes
This commit is contained in:
parent
12ba3cc3bd
commit
a171358f2e
6 changed files with 422 additions and 50 deletions
37
README.md
37
README.md
|
|
@ -308,30 +308,43 @@ Teacher config shape:
|
|||
```python
|
||||
TeacherDistillationConfig(
|
||||
teacher_enabled=True,
|
||||
teacher_server=APIServerConfig(
|
||||
base_url="http://localhost:9003/v1",
|
||||
model_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||
api_key="",
|
||||
server_type="vllm",
|
||||
),
|
||||
teacher_top_k=8,
|
||||
)
|
||||
```
|
||||
|
||||
If `teacher_server.model_name` is a deployment alias rather than a tokenizer
|
||||
identifier, set `teacher_server.tokenizer_name` explicitly so the env can
|
||||
validate tokenizer compatibility.
|
||||
Teacher server configs are passed separately at init, just like the primary
|
||||
`server_configs`:
|
||||
|
||||
```python
|
||||
env = MyTeacherEnv(
|
||||
config=env_config,
|
||||
server_configs=student_server_configs,
|
||||
teacher_server_configs=[
|
||||
APIServerConfig(
|
||||
base_url="http://localhost:9003/v1",
|
||||
model_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||
api_key="",
|
||||
server_type="vllm",
|
||||
tokenizer_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||
)
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
CLI shape:
|
||||
|
||||
```bash
|
||||
--env.teacher_enabled true \
|
||||
--env.teacher_server.base_url "http://localhost:9003/v1" \
|
||||
--env.teacher_server.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \
|
||||
--env.teacher_server.server_type vllm \
|
||||
--teacher.base_url "http://localhost:9003/v1" \
|
||||
--teacher.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \
|
||||
--teacher.server_type vllm \
|
||||
--env.teacher_top_k 8
|
||||
```
|
||||
|
||||
If `--teacher.model_name` is a deployment alias rather than a tokenizer
|
||||
identifier, also set `--teacher.tokenizer_name ...` so the env can validate
|
||||
tokenizer compatibility.
|
||||
|
||||
Tokenizer requirement:
|
||||
|
||||
- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary.
|
||||
|
|
|
|||
|
|
@ -15,13 +15,24 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
from abc import ABC
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
from pydantic_cli import Cmd
|
||||
from rich import print as rprint
|
||||
|
||||
from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
|
||||
from .constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from .server_handling.openai_server import resolve_openai_configs
|
||||
from .server_handling.server_baseline import APIServerConfig, ServerBaseline
|
||||
from .server_handling.server_manager import ServerManager
|
||||
from .server_handling.server_manager import ServerManager, ServerManagerConfig
|
||||
from ..utils.cli import (
|
||||
extract_namespace,
|
||||
get_double_dash_flags,
|
||||
get_prefixed_pydantic_model,
|
||||
merge_dicts,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -31,10 +42,6 @@ class TeacherDistillationConfig(BaseEnvConfig):
|
|||
default=False,
|
||||
description="Whether to fetch teacher prompt logprobs for distillation.",
|
||||
)
|
||||
teacher_server: Optional[APIServerConfig] = Field(
|
||||
default=None,
|
||||
description="Fallback teacher server configuration when not provided at init.",
|
||||
)
|
||||
teacher_top_k: int = Field(
|
||||
default=0,
|
||||
ge=-1,
|
||||
|
|
@ -56,6 +63,220 @@ class TeacherDistillationEnv(BaseEnv, ABC):
|
|||
"""
|
||||
|
||||
env_config_cls = TeacherDistillationConfig
|
||||
teacher_namespace = "teacher"
|
||||
|
||||
@classmethod
|
||||
def teacher_config_init(
|
||||
cls,
|
||||
) -> Optional[Union[ServerBaseline, List[APIServerConfig], APIServerConfig]]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _resolve_teacher_server_configs(
|
||||
cls,
|
||||
default_teacher_server_configs: Optional[
|
||||
Union[ServerBaseline, List[APIServerConfig], APIServerConfig]
|
||||
],
|
||||
yaml_config: Dict[str, Any],
|
||||
cli_passed_flags: Dict[str, Any],
|
||||
) -> Optional[Union[ServerBaseline, List[APIServerConfig]]]:
|
||||
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
|
||||
teacher_cli_passed_args = extract_namespace(cli_passed_flags, teacher_full_prefix)
|
||||
yaml_teacher_config = yaml_config.get(cls.teacher_namespace, {})
|
||||
|
||||
if (
|
||||
default_teacher_server_configs is None
|
||||
and not teacher_cli_passed_args
|
||||
and not yaml_teacher_config
|
||||
):
|
||||
return None
|
||||
|
||||
effective_teacher_server_configs = default_teacher_server_configs
|
||||
if effective_teacher_server_configs is None:
|
||||
effective_teacher_server_configs = APIServerConfig()
|
||||
elif isinstance(effective_teacher_server_configs, ServerBaseline) and (
|
||||
teacher_cli_passed_args or yaml_teacher_config
|
||||
):
|
||||
effective_teacher_server_configs = APIServerConfig(
|
||||
**effective_teacher_server_configs.model_dump()
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(effective_teacher_server_configs, list)
|
||||
and len(effective_teacher_server_configs) == 1
|
||||
):
|
||||
default_teacher_config = effective_teacher_server_configs[0]
|
||||
else:
|
||||
default_teacher_config = effective_teacher_server_configs
|
||||
|
||||
if isinstance(yaml_teacher_config, list) and len(yaml_teacher_config) == 1:
|
||||
yaml_teacher_config = yaml_teacher_config[0]
|
||||
|
||||
if isinstance(default_teacher_config, APIServerConfig) and isinstance(
|
||||
yaml_teacher_config, dict
|
||||
):
|
||||
teacher_config_dict = merge_dicts(
|
||||
default_teacher_config.model_dump(),
|
||||
yaml_teacher_config,
|
||||
teacher_cli_passed_args,
|
||||
)
|
||||
else:
|
||||
teacher_config_dict = {}
|
||||
|
||||
teacher_yaml_wrapped = {OPENAI_NAMESPACE: yaml_teacher_config}
|
||||
teacher_cli_wrapped = {
|
||||
f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}{key}": value
|
||||
for key, value in teacher_cli_passed_args.items()
|
||||
}
|
||||
return resolve_openai_configs(
|
||||
default_server_configs=effective_teacher_server_configs,
|
||||
openai_config_dict=teacher_config_dict,
|
||||
yaml_config=teacher_yaml_wrapped,
|
||||
cli_passed_flags=teacher_cli_wrapped,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cli_serve_config_cls(cls) -> type:
|
||||
default_env_config, default_server_configs = cls.config_init()
|
||||
default_teacher_server_configs = cls.teacher_config_init()
|
||||
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
teacher_full_prefix = f"{cls.teacher_namespace}{NAMESPACE_SEP}"
|
||||
teacher_cli_base = get_prefixed_pydantic_model(
|
||||
APIServerConfig, teacher_full_prefix
|
||||
)
|
||||
|
||||
class CliServeConfig(
|
||||
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix),
|
||||
teacher_cli_base,
|
||||
ServerManagerConfig,
|
||||
Cmd,
|
||||
):
|
||||
config: str | None = Field(
|
||||
default=None,
|
||||
description="Path to .yaml config file. CLI args override this.",
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
|
||||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
yaml_config = yaml.safe_load(f)
|
||||
logger.info("Loaded config from %s", self.config)
|
||||
else:
|
||||
yaml_config = {}
|
||||
|
||||
cli_passed_flags = get_double_dash_flags()
|
||||
|
||||
env_config_dict = merge_dicts(
|
||||
default_env_config.model_dump(),
|
||||
yaml_config.get(ENV_NAMESPACE, {}),
|
||||
extract_namespace(cli_passed_flags, env_full_prefix),
|
||||
)
|
||||
|
||||
oai_cli_passed_args = extract_namespace(
|
||||
cli_passed_flags, openai_full_prefix
|
||||
)
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
|
||||
effective_server_configs = default_server_configs
|
||||
if isinstance(effective_server_configs, ServerBaseline) and (
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
effective_server_configs = APIServerConfig(
|
||||
**effective_server_configs.model_dump()
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(effective_server_configs, list)
|
||||
and len(effective_server_configs) == 1
|
||||
):
|
||||
default_openai_config_ = effective_server_configs[0]
|
||||
else:
|
||||
default_openai_config_ = effective_server_configs
|
||||
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(),
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
else:
|
||||
openai_config_dict = {}
|
||||
|
||||
server_manager_cli_passed_flags = {}
|
||||
if "slurm" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||
if "testing" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
|
||||
"testing"
|
||||
]
|
||||
|
||||
server_manager_yaml_dict = {}
|
||||
if "slurm" in yaml_config:
|
||||
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
|
||||
if "testing" in yaml_config:
|
||||
server_manager_yaml_dict["testing"] = yaml_config["testing"]
|
||||
|
||||
server_manager_config_dict = merge_dicts(
|
||||
ServerManagerConfig().model_dump(),
|
||||
server_manager_yaml_dict,
|
||||
server_manager_cli_passed_flags,
|
||||
)
|
||||
|
||||
env_config = type(default_env_config)(**env_config_dict)
|
||||
server_manager_config = ServerManagerConfig(
|
||||
**server_manager_config_dict
|
||||
)
|
||||
openai_configs = resolve_openai_configs(
|
||||
default_server_configs=effective_server_configs,
|
||||
openai_config_dict=openai_config_dict,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
logger=logger,
|
||||
)
|
||||
teacher_configs = cls._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=default_teacher_server_configs,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
)
|
||||
|
||||
env_kwargs = {
|
||||
"config": env_config,
|
||||
"server_configs": openai_configs,
|
||||
"slurm": server_manager_config.slurm,
|
||||
"testing": server_manager_config.testing,
|
||||
}
|
||||
if teacher_configs is not None:
|
||||
env_kwargs["teacher_server_configs"] = teacher_configs
|
||||
env = cls(**env_kwargs)
|
||||
rprint(env_config)
|
||||
rprint(openai_configs)
|
||||
if teacher_configs is not None:
|
||||
rprint(teacher_configs)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = loop.create_task(env.env_manager())
|
||||
loop.run_until_complete(task)
|
||||
except RuntimeError:
|
||||
asyncio.run(env.env_manager())
|
||||
|
||||
return CliServeConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -71,26 +292,11 @@ class TeacherDistillationEnv(BaseEnv, ABC):
|
|||
self.teacher_server: Optional[ServerManager] = None
|
||||
|
||||
if config.teacher_enabled:
|
||||
teacher_config_source = teacher_server_configs
|
||||
if teacher_config_source is None and config.teacher_server is not None:
|
||||
teacher_config_source = [
|
||||
config.teacher_server.model_copy(
|
||||
update={
|
||||
"tokenizer_name": (
|
||||
config.teacher_server.model_name
|
||||
if config.teacher_server.tokenizer_name in ("", "none")
|
||||
else config.teacher_server.tokenizer_name
|
||||
),
|
||||
"timeout": 1200,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
if teacher_config_source is None:
|
||||
if teacher_server_configs is None:
|
||||
raise ValueError(
|
||||
"teacher_enabled=True requires teacher_server_configs at init "
|
||||
"or a fallback teacher_server config."
|
||||
"teacher_enabled=True requires teacher_server_configs at init."
|
||||
)
|
||||
teacher_config_source = teacher_server_configs
|
||||
self.teacher_server = ServerManager(
|
||||
teacher_config_source,
|
||||
slurm=False,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from types import SimpleNamespace
|
|||
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
|
||||
from atroposlib.envs.teacher_distillation_env import TeacherDistillationEnv
|
||||
|
||||
|
||||
|
|
@ -32,6 +33,20 @@ class _ConcreteTeacherEnv(TeacherDistillationEnv):
|
|||
return None
|
||||
|
||||
|
||||
class _DummyTokenizer:
|
||||
name_or_path = "student-model"
|
||||
|
||||
def get_vocab(self):
|
||||
return {"a": 1}
|
||||
|
||||
|
||||
class _CapturingServerManager:
|
||||
def __init__(self, configs, slurm=False, testing=False):
|
||||
self.configs = configs
|
||||
self.slurm = slurm
|
||||
self.testing = testing
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_success():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
|
|
@ -105,6 +120,32 @@ async def test_attach_teacher_distillation_zero_topk_passthrough():
|
|||
assert out["distill_logprobs"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_group_override_topk_is_used():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
env.config = SimpleNamespace(teacher_enabled=True, teacher_top_k=0)
|
||||
|
||||
seen_topks = []
|
||||
|
||||
async def _fake_fetch(seq, top_k):
|
||||
seen_topks.append(top_k)
|
||||
return [[tok] for tok in seq], [[-0.1] for _ in seq]
|
||||
|
||||
env.teacher_server = object()
|
||||
env._fetch_teacher_for_sequence = _fake_fetch
|
||||
|
||||
group = {
|
||||
"tokens": [[1, 2, 3], [4, 5]],
|
||||
"group_overrides": {"teacher_top_k": 7},
|
||||
"masks": [[-100, 2, 3], [-100, 5]],
|
||||
"scores": [1.0, 0.0],
|
||||
}
|
||||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert seen_topks == [7, 7]
|
||||
assert out["distill_token_ids"] is not None
|
||||
assert out["distill_logprobs"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_teacher_distillation_group_override_can_skip_fetch():
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
|
|
@ -149,3 +190,111 @@ def test_teacher_tokenizer_mismatch_raises(monkeypatch):
|
|||
env,
|
||||
teacher_tokenizer_name="teacher-model",
|
||||
)
|
||||
|
||||
|
||||
def test_init_requires_teacher_server_source(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
with pytest.raises(ValueError, match="teacher_enabled=True requires"):
|
||||
_ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
)
|
||||
|
||||
|
||||
def test_init_uses_explicit_teacher_server_configs(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
called = {}
|
||||
|
||||
def _fake_base_init(self, config, server_configs, slurm=False, testing=False):
|
||||
self.config = config
|
||||
self.tokenizer = _DummyTokenizer()
|
||||
|
||||
def _fake_validate(self, teacher_tokenizer_name):
|
||||
called["teacher_tokenizer_name"] = teacher_tokenizer_name
|
||||
|
||||
monkeypatch.setattr(module.BaseEnv, "__init__", _fake_base_init)
|
||||
monkeypatch.setattr(module, "ServerManager", _CapturingServerManager)
|
||||
monkeypatch.setattr(
|
||||
_ConcreteTeacherEnv,
|
||||
"_validate_teacher_tokenizer_compatibility",
|
||||
_fake_validate,
|
||||
)
|
||||
|
||||
explicit_cfg = APIServerConfig(
|
||||
model_name="explicit-model",
|
||||
tokenizer_name="explicit-tokenizer",
|
||||
base_url="http://explicit/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
teacher_enabled=True,
|
||||
teacher_top_k=0,
|
||||
)
|
||||
|
||||
env = _ConcreteTeacherEnv(
|
||||
config=config,
|
||||
server_configs=[],
|
||||
teacher_server_configs=[explicit_cfg],
|
||||
)
|
||||
|
||||
assert isinstance(env.teacher_server, _CapturingServerManager)
|
||||
assert env.teacher_server.configs == [explicit_cfg]
|
||||
assert called["teacher_tokenizer_name"] == "explicit-tokenizer"
|
||||
|
||||
|
||||
def test_resolve_teacher_server_configs_returns_none_when_unset():
|
||||
assert (
|
||||
_ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=None,
|
||||
yaml_config={},
|
||||
cli_passed_flags={},
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_teacher_server_configs_uses_teacher_namespace(monkeypatch):
|
||||
from atroposlib.envs import teacher_distillation_env as module
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_resolve(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return ["resolved"]
|
||||
|
||||
monkeypatch.setattr(module, "resolve_openai_configs", _fake_resolve)
|
||||
|
||||
default_cfg = APIServerConfig(
|
||||
model_name="teacher-model",
|
||||
base_url="http://teacher/v1",
|
||||
api_key="x",
|
||||
server_type="vllm",
|
||||
)
|
||||
|
||||
out = _ConcreteTeacherEnv._resolve_teacher_server_configs(
|
||||
default_teacher_server_configs=default_cfg,
|
||||
yaml_config={"teacher": {"tokenizer_name": "teacher-tokenizer"}},
|
||||
cli_passed_flags={"teacher.base_url": "http://override/v1"},
|
||||
)
|
||||
|
||||
assert out == ["resolved"]
|
||||
assert captured["openai_config_dict"]["base_url"] == "http://override/v1"
|
||||
assert captured["openai_config_dict"]["tokenizer_name"] == "teacher-tokenizer"
|
||||
assert captured["yaml_config"] == {
|
||||
"openai": {"tokenizer_name": "teacher-tokenizer"}
|
||||
}
|
||||
assert captured["cli_passed_flags"] == {"openai.base_url": "http://override/v1"}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,13 +32,6 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv):
|
|||
max_token_length=2048,
|
||||
wandb_name="gsm8k_teacher_distill",
|
||||
teacher_enabled=True,
|
||||
teacher_server=APIServerConfig(
|
||||
base_url="http://localhost:8003/v1",
|
||||
model_name="mock-teacher",
|
||||
api_key="",
|
||||
server_type="vllm",
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
),
|
||||
teacher_top_k=4,
|
||||
)
|
||||
server_config = APIServerConfig(
|
||||
|
|
@ -49,6 +42,17 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv):
|
|||
)
|
||||
return env_config, server_config
|
||||
|
||||
@classmethod
|
||||
def teacher_config_init(cls) -> APIServerConfig:
|
||||
return APIServerConfig(
|
||||
base_url="http://localhost:9003/v1",
|
||||
model_name="mock-teacher",
|
||||
api_key="",
|
||||
server_type="vllm",
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
timeout=1200,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kTeacherDistillEnv.cli()
|
||||
|
|
|
|||
|
|
@ -314,14 +314,14 @@ What to configure on the environment side:
|
|||
|
||||
```bash
|
||||
--env.teacher_enabled true \
|
||||
--env.teacher_server.base_url "http://localhost:9003/v1" \
|
||||
--env.teacher_server.model_name "$TEACHER_MODEL" \
|
||||
--env.teacher_server.server_type vllm \
|
||||
--teacher.base_url "http://localhost:9003/v1" \
|
||||
--teacher.model_name "$TEACHER_MODEL" \
|
||||
--teacher.server_type vllm \
|
||||
--env.teacher_top_k 8
|
||||
```
|
||||
|
||||
If `$TEACHER_MODEL` is a deployment alias instead of a tokenizer identifier,
|
||||
also set `--env.teacher_server.tokenizer_name ...` so the env can validate
|
||||
also set `--teacher.tokenizer_name ...` so the env can validate
|
||||
tokenizer compatibility.
|
||||
|
||||
Why cross-tokenizer conversion is not acceptable here:
|
||||
|
|
|
|||
|
|
@ -234,9 +234,9 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \
|
|||
--env.use_wandb true \
|
||||
--env.wandb_name "gsm8k-teacher-distill" \
|
||||
--env.teacher_enabled true \
|
||||
--env.teacher_server.base_url "http://localhost:${TEACHER_PORT}/v1" \
|
||||
--env.teacher_server.model_name "$TEACHER_MODEL" \
|
||||
--env.teacher_server.server_type vllm \
|
||||
--teacher.base_url "http://localhost:${TEACHER_PORT}/v1" \
|
||||
--teacher.model_name "$TEACHER_MODEL" \
|
||||
--teacher.server_type vllm \
|
||||
--env.teacher_top_k "$TEACHER_TOP_K" \
|
||||
--env.ensure_scores_are_not_same false \
|
||||
--openai.api_key "dummy" \
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue