mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
narrow down scope
This commit is contained in:
parent
e8d0e74877
commit
f343b24a6a
3 changed files with 22 additions and 869 deletions
|
|
@ -49,7 +49,6 @@ from .server_handling.server_manager import (
|
|||
ServerManager,
|
||||
ServerManagerConfig,
|
||||
)
|
||||
from .server_handling.teacher_client import TeacherClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
|
@ -212,51 +211,6 @@ class BaseEnvConfig(BaseModel):
|
|||
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
|
||||
"eval_helpers for the standard Hermes reasoning prompt.",
|
||||
)
|
||||
# On-policy distillation settings
|
||||
distillation_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Enable on-policy distillation. When True, automatically fetches teacher logprobs "
|
||||
"after scoring and includes them in data sent to trainer.",
|
||||
)
|
||||
teacher_base_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API "
|
||||
"(vLLM, OpenAI, Together, etc.). Examples: 'http://localhost:8001/v1', 'https://api.openai.com/v1'",
|
||||
)
|
||||
teacher_model_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Model name for teacher API calls (e.g., 'gpt-4o', 'meta-llama/Llama-3-70b'). "
|
||||
"If None, uses 'default' which works for single-model vLLM servers.",
|
||||
)
|
||||
teacher_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.",
|
||||
)
|
||||
teacher_top_k: int = Field(
|
||||
default=20,
|
||||
description="Number of top logprobs to fetch from teacher model per position.",
|
||||
)
|
||||
teacher_prefix_text: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional text prefix prepended to teacher scoring prompt. "
|
||||
"Useful for behavior steering. Prefix token positions are trimmed out "
|
||||
"before sending distillation arrays to the trainer, preserving alignment.",
|
||||
)
|
||||
teacher_system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional teacher system prompt. For completion-style teacher APIs, "
|
||||
"this is converted to a textual prefix. For chat fallback, this is injected "
|
||||
"as a leading system message.",
|
||||
)
|
||||
teacher_prompt_template: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional template-first teacher prompt renderer. "
|
||||
"Uses Python format-style variables from runtime context/overrides "
|
||||
"(e.g., {question}, {answer}, {episodes}). If set, this is preferred over "
|
||||
"mode-specific prompt building.",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
name: Optional[str] = None
|
||||
env_config_cls: BaseEnvConfig = BaseEnvConfig
|
||||
|
|
@ -305,9 +259,6 @@ class BaseEnv(ABC):
|
|||
self.curr_step = 0
|
||||
self.max_token_len = -1
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
|
||||
self.teacher_client = TeacherClient(
|
||||
config=self.config, tokenizer=self.tokenizer, logger=logger
|
||||
)
|
||||
self.completion_lengths = []
|
||||
self.max_num_workers = config.max_num_workers
|
||||
if self.max_num_workers == -1:
|
||||
|
|
@ -363,46 +314,6 @@ class BaseEnv(ABC):
|
|||
# Calculate derived batch size
|
||||
return int(self.config.batch_size * effective_fraction)
|
||||
|
||||
async def get_teacher_logprobs(
|
||||
self,
|
||||
token_sequences: List[List[int]],
|
||||
messages_list: Optional[List[List[Dict]]] = None,
|
||||
seq_overrides: Optional[List[Dict[str, Any]]] = None,
|
||||
group_overrides: Optional[Dict[str, Any]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||
return await self.teacher_client.get_teacher_logprobs(
|
||||
token_sequences=token_sequences,
|
||||
messages_list=messages_list,
|
||||
seq_overrides=seq_overrides,
|
||||
group_overrides=group_overrides,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
def _align_teacher_topk_to_tokens(
|
||||
self,
|
||||
seq_token_ids: List[List[int]],
|
||||
seq_logprobs: List[List[float]],
|
||||
target_token_len: int,
|
||||
prefix_token_len: int = 0,
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
return self.teacher_client._align_teacher_topk_to_tokens(
|
||||
seq_token_ids=seq_token_ids,
|
||||
seq_logprobs=seq_logprobs,
|
||||
target_token_len=target_token_len,
|
||||
prefix_token_len=prefix_token_len,
|
||||
)
|
||||
|
||||
def _parse_completion_logprobs(
|
||||
self, data: Dict, top_k: int
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
return self.teacher_client._parse_completion_logprobs(data=data, top_k=top_k)
|
||||
|
||||
def _parse_chat_logprobs(
|
||||
self, data: Dict, top_k: int
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
return self.teacher_client._parse_chat_logprobs(data=data, top_k=top_k)
|
||||
|
||||
@classmethod
|
||||
def config_init(
|
||||
cls,
|
||||
|
|
@ -1019,52 +930,6 @@ class BaseEnv(ABC):
|
|||
valid_groups.append(group)
|
||||
|
||||
if valid_groups and do_send_to_api:
|
||||
# On-policy distillation: fetch teacher logprobs if enabled
|
||||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
|
||||
for group in valid_groups:
|
||||
seq_overrides = group.get("overrides") or []
|
||||
group_overrides = (
|
||||
group.get("group_overrides")
|
||||
if isinstance(group.get("group_overrides"), dict)
|
||||
else {}
|
||||
)
|
||||
has_new_format = (
|
||||
group.get("distill_token_ids") is not None
|
||||
and group.get("distill_logprobs") is not None
|
||||
)
|
||||
if not has_new_format:
|
||||
try:
|
||||
teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
messages_list=group.get("messages"),
|
||||
seq_overrides=seq_overrides,
|
||||
group_overrides=group_overrides,
|
||||
)
|
||||
if teacher_token_ids and teacher_logprobs:
|
||||
group["distill_token_ids"] = teacher_token_ids
|
||||
group["distill_logprobs"] = teacher_logprobs
|
||||
logger.info(
|
||||
f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences"
|
||||
)
|
||||
else:
|
||||
logger.warning("[DISTILL] get_teacher_logprobs returned empty")
|
||||
except Exception as e:
|
||||
logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
self.teacher_client.assert_distill_arrays_aligned(
|
||||
token_sequences=group["tokens"],
|
||||
distill_token_ids=group.get("distill_token_ids"),
|
||||
distill_logprobs=group.get("distill_logprobs"),
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"[DISTILL] Skipped - enabled=%s, url=%s",
|
||||
self.config.distillation_enabled,
|
||||
self.config.teacher_base_url,
|
||||
)
|
||||
|
||||
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||
# send single or list of scored data groups
|
||||
if not original_was_list and len(valid_groups) == 1:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue