mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
remove cross tokenization and fix location of configs
This commit is contained in:
parent
148a4fd5eb
commit
a1b545c734
6 changed files with 147 additions and 302 deletions
41
README.md
41
README.md
|
|
@ -298,6 +298,47 @@ curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!
|
|||
- Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.).
|
||||
- Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR.
|
||||
|
||||
### TeacherDistillationEnv follow-up
|
||||
|
||||
The follow-up teacher environment uses a dedicated teacher server config and
|
||||
attaches teacher prompt logprobs before the group is sent to the API.
|
||||
|
||||
Teacher config shape:
|
||||
|
||||
```python
|
||||
TeacherDistillationConfig(
|
||||
teacher_enabled=True,
|
||||
teacher_server=APIServerConfig(
|
||||
base_url="http://localhost:9003/v1",
|
||||
model_name="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||
api_key="",
|
||||
server_type="vllm",
|
||||
),
|
||||
teacher_top_k=8,
|
||||
)
|
||||
```
|
||||
|
||||
CLI shape:
|
||||
|
||||
```bash
|
||||
--env.teacher_enabled true \
|
||||
--env.teacher_server.base_url "http://localhost:9003/v1" \
|
||||
--env.teacher_server.model_name "Qwen/Qwen3-30B-A3B-Instruct-2507" \
|
||||
--env.teacher_server.server_type vllm \
|
||||
--env.teacher_top_k 8
|
||||
```
|
||||
|
||||
Tokenizer requirement:
|
||||
|
||||
- Teacher distillation currently requires the teacher and student to use the same tokenizer vocabulary.
|
||||
- If the tokenizers do not match, `TeacherDistillationEnv` raises an error instead of attempting token conversion.
|
||||
|
||||
Why same-tokenizer is required:
|
||||
|
||||
- `distill_token_ids` are consumed as student-vocabulary IDs by the trainer.
|
||||
- If the teacher uses a different vocabulary, the same integer token ID refers to different text on the teacher and student sides.
|
||||
- A decode/re-tokenize/remap pipeline is not a safe drop-in fix because it changes both token positions and token identities, which breaks the exact per-position token supervision that the current distillation loss assumes.
|
||||
|
||||
---
|
||||
|
||||
## Testing and Debugging Tools
|
||||
|
|
|
|||
|
|
@ -4,42 +4,10 @@ Teacher distillation environment layer.
|
|||
This module adds teacher prompt-logprob fetching on top of BaseEnv without
|
||||
modifying BaseEnv transport behavior.
|
||||
|
||||
Cross-tokenizer distillation
|
||||
----------------------------
|
||||
When student and teacher use the same tokenizer family (e.g. both Qwen3) the
|
||||
student's raw token IDs can be forwarded directly to the teacher vLLM and the
|
||||
returned top-k token IDs can be used as-is in the student logit lookup.
|
||||
|
||||
When tokenizers differ (e.g. Llama student, Qwen teacher) two problems arise:
|
||||
|
||||
1. Token-ID aliasing: student token 42 = " the" in Llama, but 42 = "ท" in
|
||||
Qwen. Sending student IDs to the teacher causes it to score garbage.
|
||||
|
||||
2. Vocab-space mismatch: the teacher's top-k IDs live in the teacher's
|
||||
vocabulary. The student logit lookup at those IDs would access random
|
||||
tokens in the student vocab.
|
||||
|
||||
This module fixes both problems automatically:
|
||||
|
||||
• Re-tokenization – student tokens are decoded to plain text and
|
||||
re-tokenized with the teacher tokenizer before being sent to the teacher
|
||||
server. The teacher therefore always scores the correct text.
|
||||
|
||||
• Character-level position alignment – after re-tokenisation the teacher
|
||||
has a different number of tokens than the student. A character-offset
|
||||
map is built (requires a fast HuggingFace tokenizer) to project each
|
||||
teacher logprob position back onto the student token it overlaps with.
|
||||
|
||||
• Vocabulary remapping – teacher top-k token IDs (teacher vocab) are
|
||||
decoded to text fragments and re-encoded with the student tokenizer so
|
||||
that the final distill_token_ids live in the student vocabulary and can
|
||||
be looked up directly in the student logit tensor.
|
||||
|
||||
Same-tokenizer fast path
|
||||
------------------------
|
||||
When teacher_tokenizer_name resolves to the same underlying vocabulary as the
|
||||
student tokenizer the original fast path (no decode / re-tokenize / remap) is
|
||||
taken automatically.
|
||||
This implementation supports same-tokenizer distillation only. The teacher and
|
||||
student must share the same tokenizer vocabulary so the student's token IDs can
|
||||
be forwarded directly to the teacher and the returned teacher top-k token IDs
|
||||
can be looked up directly in the student's logits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -47,7 +15,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
|
@ -63,29 +31,9 @@ class TeacherDistillationConfig(BaseEnvConfig):
|
|||
default=False,
|
||||
description="Whether to fetch teacher prompt logprobs for distillation.",
|
||||
)
|
||||
teacher_base_url: Optional[str] = Field(
|
||||
teacher_server: Optional[APIServerConfig] = Field(
|
||||
default=None,
|
||||
description="Teacher server base URL (OpenAI-compatible).",
|
||||
)
|
||||
teacher_model_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Teacher model name used in teacher server requests.",
|
||||
)
|
||||
teacher_api_key: str = Field(
|
||||
default="",
|
||||
description="Teacher API key, if required by the teacher endpoint.",
|
||||
)
|
||||
teacher_server_type: str = Field(
|
||||
default="vllm",
|
||||
description="Teacher server type (e.g. vllm, sglang, trl, openai).",
|
||||
)
|
||||
teacher_tokenizer_name: str = Field(
|
||||
default="none",
|
||||
description=(
|
||||
"Tokenizer name for teacher server. If 'none', teacher_model_name is used. "
|
||||
"When this resolves to a different vocabulary than the student tokenizer, "
|
||||
"cross-tokenizer alignment is applied automatically."
|
||||
),
|
||||
description="Teacher inference server configuration.",
|
||||
)
|
||||
teacher_top_k: int = Field(
|
||||
default=1,
|
||||
|
|
@ -114,266 +62,71 @@ class TeacherDistillationEnv(BaseEnv, ABC):
|
|||
):
|
||||
super().__init__(config, server_configs, slurm=slurm, testing=testing)
|
||||
self.teacher_server: Optional[ServerManager] = None
|
||||
# Teacher tokenizer (only loaded when tokenizers may differ).
|
||||
self._teacher_tokenizer = None
|
||||
# True when student and teacher share the same vocabulary.
|
||||
self._same_tokenizer: bool = True
|
||||
# LRU-style cache: teacher_token_id -> student_token_id
|
||||
self._vocab_remap_cache: Dict[int, int] = {}
|
||||
|
||||
if config.teacher_enabled:
|
||||
if not config.teacher_base_url or not config.teacher_model_name:
|
||||
if config.teacher_server is None:
|
||||
raise ValueError(
|
||||
"teacher_enabled=True requires teacher_base_url and teacher_model_name."
|
||||
"teacher_enabled=True requires a teacher_server configuration."
|
||||
)
|
||||
teacher_tok_name = (
|
||||
config.teacher_model_name
|
||||
if config.teacher_tokenizer_name in ("none", "")
|
||||
else config.teacher_tokenizer_name
|
||||
)
|
||||
teacher_cfg = APIServerConfig(
|
||||
server_type=config.teacher_server_type, # type: ignore[arg-type]
|
||||
base_url=config.teacher_base_url,
|
||||
api_key=config.teacher_api_key,
|
||||
model_name=config.teacher_model_name,
|
||||
tokenizer_name=teacher_tok_name,
|
||||
timeout=1200,
|
||||
teacher_cfg = config.teacher_server.model_copy(
|
||||
update={
|
||||
"tokenizer_name": (
|
||||
config.teacher_server.model_name
|
||||
if config.teacher_server.tokenizer_name in ("", "none")
|
||||
else config.teacher_server.tokenizer_name
|
||||
),
|
||||
"timeout": 1200,
|
||||
}
|
||||
)
|
||||
self.teacher_server = ServerManager(
|
||||
[teacher_cfg],
|
||||
slurm=False,
|
||||
testing=False,
|
||||
)
|
||||
|
||||
# Detect vocabulary mismatch.
|
||||
# Compare by name first; if names differ, load the teacher tokenizer
|
||||
# and do a vocab-size sanity check. Same-family models (e.g. Qwen3-4B
|
||||
# and Qwen3-30B) share the same vocabulary, so even though the
|
||||
# name_or_path strings differ they should use the fast path.
|
||||
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
|
||||
if (
|
||||
student_tok_name
|
||||
and teacher_tok_name
|
||||
and student_tok_name != teacher_tok_name
|
||||
):
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
loaded = AutoTokenizer.from_pretrained(
|
||||
teacher_tok_name, use_fast=True
|
||||
)
|
||||
student_vocab_size = getattr(self.tokenizer, "vocab_size", None)
|
||||
teacher_vocab_size = getattr(loaded, "vocab_size", None)
|
||||
if (
|
||||
student_vocab_size is not None
|
||||
and teacher_vocab_size is not None
|
||||
and student_vocab_size == teacher_vocab_size
|
||||
):
|
||||
# Same vocab size — treat as same tokenizer (fast path).
|
||||
# This covers same-family models (e.g. all Qwen3 variants).
|
||||
self._same_tokenizer = True
|
||||
logger.warning(
|
||||
"TeacherDistillationEnv: names differ but vocab sizes match "
|
||||
"(%d tokens). Using fast (same-tokenizer) path. "
|
||||
"student=%s teacher=%s",
|
||||
student_vocab_size,
|
||||
student_tok_name,
|
||||
teacher_tok_name,
|
||||
)
|
||||
else:
|
||||
self._teacher_tokenizer = loaded
|
||||
self._same_tokenizer = False
|
||||
logger.warning(
|
||||
"TeacherDistillationEnv: cross-tokenizer mode active. "
|
||||
"student=%s (%s tokens) teacher=%s (%s tokens). "
|
||||
"Token IDs will be decoded → re-tokenized → vocab-remapped.",
|
||||
student_tok_name,
|
||||
student_vocab_size,
|
||||
teacher_tok_name,
|
||||
teacher_vocab_size,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"TeacherDistillationEnv: could not load teacher tokenizer '%s' "
|
||||
"(%s). Falling back to same-tokenizer (fast) path — only safe if "
|
||||
"student and teacher share the same vocabulary.",
|
||||
teacher_tok_name,
|
||||
exc,
|
||||
)
|
||||
self._same_tokenizer = True
|
||||
else:
|
||||
self._same_tokenizer = True
|
||||
|
||||
logger.warning(
|
||||
"TeacherDistillationEnv: teacher server configured at %s "
|
||||
"(model=%s, top_k=%s, same_tokenizer=%s)",
|
||||
config.teacher_base_url,
|
||||
config.teacher_model_name,
|
||||
config.teacher_top_k,
|
||||
self._same_tokenizer,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cross-tokenizer helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_student_teacher_alignment(
|
||||
self,
|
||||
text: str,
|
||||
student_ids: List[int],
|
||||
teacher_ids: List[int],
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
For each student token position return the list of teacher token positions
|
||||
whose character spans overlap with the student token's character span.
|
||||
|
||||
Requires fast (Rust-backed) HuggingFace tokenizers that support
|
||||
return_offsets_mapping. Falls back to a proportional approximation
|
||||
if offset mapping is unavailable.
|
||||
"""
|
||||
student_len = len(student_ids)
|
||||
teacher_len = len(teacher_ids)
|
||||
|
||||
try:
|
||||
s_enc = self.tokenizer(
|
||||
text, return_offsets_mapping=True, add_special_tokens=False
|
||||
)
|
||||
t_enc = self._teacher_tokenizer(
|
||||
text, return_offsets_mapping=True, add_special_tokens=False
|
||||
)
|
||||
s_offsets: List[Tuple[int, int]] = s_enc["offset_mapping"][:student_len]
|
||||
t_offsets: List[Tuple[int, int]] = t_enc["offset_mapping"][:teacher_len]
|
||||
|
||||
alignment: List[List[int]] = []
|
||||
for s_start, s_end in s_offsets:
|
||||
overlapping = [
|
||||
t_idx
|
||||
for t_idx, (t_start, t_end) in enumerate(t_offsets)
|
||||
if t_start < s_end and t_end > s_start and s_end > s_start
|
||||
]
|
||||
alignment.append(overlapping)
|
||||
return alignment
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"TeacherDistillationEnv: offset-mapping alignment failed (%s). "
|
||||
"Using proportional fallback.",
|
||||
exc,
|
||||
)
|
||||
ratio = teacher_len / max(student_len, 1)
|
||||
return [[int(i * ratio)] for i in range(student_len)]
|
||||
|
||||
def _remap_teacher_token_to_student(self, teacher_token_id: int) -> int:
|
||||
"""
|
||||
Convert a teacher vocabulary token ID to the best-matching student
|
||||
vocabulary token ID by decoding the teacher token to text then
|
||||
re-encoding with the student tokenizer.
|
||||
|
||||
Results are cached to avoid repeated tokenizer calls.
|
||||
"""
|
||||
if teacher_token_id in self._vocab_remap_cache:
|
||||
return self._vocab_remap_cache[teacher_token_id]
|
||||
|
||||
try:
|
||||
text = self._teacher_tokenizer.decode(
|
||||
[teacher_token_id], clean_up_tokenization_spaces=False
|
||||
)
|
||||
student_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
# Use the first student token as the representative.
|
||||
sid = int(student_ids[0]) if student_ids else teacher_token_id
|
||||
except Exception:
|
||||
sid = teacher_token_id
|
||||
|
||||
self._vocab_remap_cache[teacher_token_id] = sid
|
||||
return sid
|
||||
|
||||
def _align_and_remap(
|
||||
self,
|
||||
student_ids: List[int],
|
||||
teacher_topk_ids: List[List[int]],
|
||||
teacher_topk_lps: List[List[float]],
|
||||
alignment: List[List[int]],
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
"""
|
||||
Project teacher logprobs (teacher positions, teacher vocab) onto
|
||||
student positions in student vocab.
|
||||
|
||||
For each student token position:
|
||||
1. Collect all teacher top-k entries from overlapping teacher positions.
|
||||
2. Remap each teacher token ID to the student vocab.
|
||||
3. Merge duplicates by keeping the maximum logprob.
|
||||
4. Return the top-k entries sorted by descending logprob.
|
||||
"""
|
||||
k = max(1, len(teacher_topk_ids[0]) if teacher_topk_ids else 1)
|
||||
result_ids: List[List[int]] = []
|
||||
result_lps: List[List[float]] = []
|
||||
|
||||
for s_idx in range(len(student_ids)):
|
||||
t_positions = alignment[s_idx] if s_idx < len(alignment) else []
|
||||
if not t_positions:
|
||||
result_ids.append([])
|
||||
result_lps.append([])
|
||||
continue
|
||||
|
||||
# Merge all overlapping teacher positions, remap vocab.
|
||||
merged: Dict[int, float] = {}
|
||||
for t_idx in t_positions:
|
||||
if t_idx >= len(teacher_topk_ids):
|
||||
continue
|
||||
for tid, tlp in zip(teacher_topk_ids[t_idx], teacher_topk_lps[t_idx]):
|
||||
sid = self._remap_teacher_token_to_student(tid)
|
||||
merged[sid] = max(merged.get(sid, -1e9), tlp)
|
||||
|
||||
sorted_items = sorted(merged.items(), key=lambda x: -x[1])
|
||||
top = sorted_items[:k]
|
||||
result_ids.append([int(sid) for sid, _ in top])
|
||||
result_lps.append([float(lp) for _, lp in top])
|
||||
|
||||
return result_ids, result_lps
|
||||
self._validate_teacher_tokenizer_compatibility(teacher_cfg.tokenizer_name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core fetch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_teacher_tokenizer_compatibility(self, teacher_tokenizer_name: str) -> None:
|
||||
student_tok_name = getattr(self.tokenizer, "name_or_path", None) or ""
|
||||
if student_tok_name == teacher_tokenizer_name:
|
||||
return
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
teacher_tokenizer = AutoTokenizer.from_pretrained(
|
||||
teacher_tokenizer_name, use_fast=True
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Cross-tokenizer distillation is not supported in this PR, and the "
|
||||
f"teacher tokenizer for '{teacher_tokenizer_name}' could not be loaded to "
|
||||
f"verify compatibility: {exc}"
|
||||
) from exc
|
||||
|
||||
student_vocab = self.tokenizer.get_vocab()
|
||||
teacher_vocab = teacher_tokenizer.get_vocab()
|
||||
if student_vocab != teacher_vocab:
|
||||
raise ValueError(
|
||||
"Cross-tokenizer distillation is not supported in this PR. "
|
||||
f"Student tokenizer '{student_tok_name or type(self.tokenizer).__name__}' "
|
||||
f"and teacher tokenizer '{teacher_tokenizer_name}' do not match."
|
||||
)
|
||||
|
||||
async def _fetch_teacher_for_sequence(
|
||||
self, token_ids: List[int], top_k: int
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
assert self.teacher_server is not None
|
||||
|
||||
if self._same_tokenizer or self._teacher_tokenizer is None:
|
||||
# Fast path: same vocabulary — send student IDs directly.
|
||||
payload = await self.teacher_server.get_logprobs(
|
||||
input_ids=token_ids,
|
||||
top_k=top_k,
|
||||
max_tokens=1,
|
||||
split="train",
|
||||
)
|
||||
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
|
||||
|
||||
# Cross-tokenizer path:
|
||||
# 1. Decode student tokens → plain text.
|
||||
# 2. Re-tokenize with teacher tokenizer → teacher IDs.
|
||||
# 3. Send teacher IDs to teacher vLLM.
|
||||
# 4. Align teacher positions → student positions.
|
||||
# 5. Remap teacher vocab IDs → student vocab IDs.
|
||||
text = self.tokenizer.decode(token_ids, clean_up_tokenization_spaces=False)
|
||||
teacher_ids: List[int] = self._teacher_tokenizer.encode(
|
||||
text, add_special_tokens=False
|
||||
)
|
||||
|
||||
payload = await self.teacher_server.get_logprobs(
|
||||
input_ids=teacher_ids,
|
||||
input_ids=token_ids,
|
||||
top_k=top_k,
|
||||
max_tokens=1,
|
||||
split="train",
|
||||
)
|
||||
teacher_topk_ids = payload["prompt_topk_token_ids"]
|
||||
teacher_topk_lps = payload["prompt_topk_logprobs"]
|
||||
|
||||
alignment = self._build_student_teacher_alignment(text, token_ids, teacher_ids)
|
||||
return self._align_and_remap(
|
||||
token_ids, teacher_topk_ids, teacher_topk_lps, alignment
|
||||
)
|
||||
return payload["prompt_topk_token_ids"], payload["prompt_topk_logprobs"]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group enrichment
|
||||
|
|
|
|||
|
|
@ -67,3 +67,29 @@ async def test_attach_teacher_distillation_failure_drops_payload():
|
|||
out = await TeacherDistillationEnv._attach_teacher_distillation(env, group)
|
||||
assert out["distill_token_ids"] is None
|
||||
assert out["distill_logprobs"] is None
|
||||
|
||||
|
||||
def test_teacher_tokenizer_mismatch_raises(monkeypatch):
|
||||
env = object.__new__(_ConcreteTeacherEnv)
|
||||
|
||||
class _StudentTokenizer:
|
||||
name_or_path = "student-model"
|
||||
|
||||
def get_vocab(self):
|
||||
return {"a": 1}
|
||||
|
||||
class _TeacherTokenizer:
|
||||
def get_vocab(self):
|
||||
return {"b": 1}
|
||||
|
||||
env.tokenizer = _StudentTokenizer()
|
||||
monkeypatch.setattr(
|
||||
"transformers.AutoTokenizer.from_pretrained",
|
||||
lambda *args, **kwargs: _TeacherTokenizer(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Cross-tokenizer distillation is not supported"):
|
||||
TeacherDistillationEnv._validate_teacher_tokenizer_compatibility(
|
||||
env,
|
||||
teacher_tokenizer_name="teacher-model",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,11 +32,12 @@ class GSM8kTeacherDistillEnv(GSM8kEnv, TeacherDistillationEnv):
|
|||
max_token_length=2048,
|
||||
wandb_name="gsm8k_teacher_distill",
|
||||
teacher_enabled=True,
|
||||
teacher_base_url="http://localhost:8003/v1",
|
||||
teacher_model_name="mock-teacher",
|
||||
teacher_api_key="",
|
||||
teacher_server_type="vllm",
|
||||
teacher_tokenizer_name="none",
|
||||
teacher_server=APIServerConfig(
|
||||
base_url="http://localhost:8003/v1",
|
||||
model_name="mock-teacher",
|
||||
api_key="",
|
||||
server_type="vllm",
|
||||
),
|
||||
teacher_top_k=4,
|
||||
)
|
||||
server_config = APIServerConfig(
|
||||
|
|
|
|||
|
|
@ -304,6 +304,29 @@ environment uses the `/generate` path and includes token-level
|
|||
4. Trainer extracts and aligns logprobs with training labels
|
||||
5. GRPO loss uses these rollout logprobs in importance-ratio terms
|
||||
|
||||
### 1b. Teacher distillation requires the same tokenizer
|
||||
|
||||
When distillation data is attached to Atropos batches, the trainer treats
|
||||
`distill_token_ids` as indices into the student's logit tensor. That only works
|
||||
if the teacher and student share the same tokenizer vocabulary.
|
||||
|
||||
What to configure on the environment side:
|
||||
|
||||
```bash
|
||||
--env.teacher_enabled true \
|
||||
--env.teacher_server.base_url "http://localhost:9003/v1" \
|
||||
--env.teacher_server.model_name "$TEACHER_MODEL" \
|
||||
--env.teacher_server.server_type vllm \
|
||||
--env.teacher_top_k 8
|
||||
```
|
||||
|
||||
Why cross-tokenizer conversion is not acceptable here:
|
||||
|
||||
- Teacher token ID `1234` and student token ID `1234` can correspond to different text.
|
||||
- Re-tokenizing teacher text changes token boundaries, so teacher position `i` may no longer correspond to student position `i`.
|
||||
- Remapping teacher top-k tokens back into student vocab can collapse multiple teacher candidates into one student token or expand one teacher token into multiple student tokens.
|
||||
- The current distillation loss expects exact per-position supervision in student token space, so an approximate remapping would silently produce misleading targets.
|
||||
|
||||
### 2. Clipping
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -234,8 +234,9 @@ start_process "gsm8k_teacher_env" "${LOG_DIR}/env.log" \
|
|||
--env.use_wandb true \
|
||||
--env.wandb_name "gsm8k-teacher-distill" \
|
||||
--env.teacher_enabled true \
|
||||
--env.teacher_base_url "http://localhost:${TEACHER_PORT}/v1" \
|
||||
--env.teacher_model_name "$TEACHER_MODEL" \
|
||||
--env.teacher_server.base_url "http://localhost:${TEACHER_PORT}/v1" \
|
||||
--env.teacher_server.model_name "$TEACHER_MODEL" \
|
||||
--env.teacher_server.server_type vllm \
|
||||
--env.teacher_top_k "$TEACHER_TOP_K" \
|
||||
--env.ensure_scores_are_not_same false \
|
||||
--openai.api_key "dummy" \
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue