narrow down scope

This commit is contained in:
Jai Suphavadeeprasit 2026-02-27 11:14:42 -05:00
parent e8d0e74877
commit f343b24a6a
3 changed files with 22 additions and 869 deletions

View file

@ -49,7 +49,6 @@ from .server_handling.server_manager import (
ServerManager,
ServerManagerConfig,
)
from .server_handling.teacher_client import TeacherClient
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ -212,51 +211,6 @@ class BaseEnvConfig(BaseModel):
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
"eval_helpers for the standard Hermes reasoning prompt.",
)
# On-policy distillation settings
distillation_enabled: bool = Field(
default=False,
description="Enable on-policy distillation. When True, automatically fetches teacher logprobs "
"after scoring and includes them in data sent to trainer.",
)
teacher_base_url: Optional[str] = Field(
default=None,
description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API "
"(vLLM, OpenAI, Together, etc.). Examples: 'http://localhost:8001/v1', 'https://api.openai.com/v1'",
)
teacher_model_name: Optional[str] = Field(
default=None,
description="Model name for teacher API calls (e.g., 'gpt-4o', 'meta-llama/Llama-3-70b'). "
"If None, uses 'default' which works for single-model vLLM servers.",
)
teacher_api_key: Optional[str] = Field(
default=None,
description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.",
)
teacher_top_k: int = Field(
default=20,
description="Number of top logprobs to fetch from teacher model per position.",
)
teacher_prefix_text: Optional[str] = Field(
default=None,
description="Optional text prefix prepended to teacher scoring prompt. "
"Useful for behavior steering. Prefix token positions are trimmed out "
"before sending distillation arrays to the trainer, preserving alignment.",
)
teacher_system_prompt: Optional[str] = Field(
default=None,
description="Optional teacher system prompt. For completion-style teacher APIs, "
"this is converted to a textual prefix. For chat fallback, this is injected "
"as a leading system message.",
)
teacher_prompt_template: Optional[str] = Field(
default=None,
description="Optional template-first teacher prompt renderer. "
"Uses Python format-style variables from runtime context/overrides "
"(e.g., {question}, {answer}, {episodes}). If set, this is preferred over "
"mode-specific prompt building.",
)
class BaseEnv(ABC):
name: Optional[str] = None
env_config_cls: BaseEnvConfig = BaseEnvConfig
@ -305,9 +259,6 @@ class BaseEnv(ABC):
self.curr_step = 0
self.max_token_len = -1
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
self.teacher_client = TeacherClient(
config=self.config, tokenizer=self.tokenizer, logger=logger
)
self.completion_lengths = []
self.max_num_workers = config.max_num_workers
if self.max_num_workers == -1:
@ -363,46 +314,6 @@ class BaseEnv(ABC):
# Calculate derived batch size
return int(self.config.batch_size * effective_fraction)
async def get_teacher_logprobs(
self,
token_sequences: List[List[int]],
messages_list: Optional[List[List[Dict]]] = None,
seq_overrides: Optional[List[Dict[str, Any]]] = None,
group_overrides: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
return await self.teacher_client.get_teacher_logprobs(
token_sequences=token_sequences,
messages_list=messages_list,
seq_overrides=seq_overrides,
group_overrides=group_overrides,
top_k=top_k,
)
def _align_teacher_topk_to_tokens(
self,
seq_token_ids: List[List[int]],
seq_logprobs: List[List[float]],
target_token_len: int,
prefix_token_len: int = 0,
) -> Tuple[List[List[int]], List[List[float]]]:
return self.teacher_client._align_teacher_topk_to_tokens(
seq_token_ids=seq_token_ids,
seq_logprobs=seq_logprobs,
target_token_len=target_token_len,
prefix_token_len=prefix_token_len,
)
def _parse_completion_logprobs(
self, data: Dict, top_k: int
) -> Tuple[List[List[int]], List[List[float]]]:
return self.teacher_client._parse_completion_logprobs(data=data, top_k=top_k)
def _parse_chat_logprobs(
self, data: Dict, top_k: int
) -> Tuple[List[List[int]], List[List[float]]]:
return self.teacher_client._parse_chat_logprobs(data=data, top_k=top_k)
@classmethod
def config_init(
cls,
@ -1019,52 +930,6 @@ class BaseEnv(ABC):
valid_groups.append(group)
if valid_groups and do_send_to_api:
# On-policy distillation: fetch teacher logprobs if enabled
if self.config.distillation_enabled and self.config.teacher_base_url:
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
for group in valid_groups:
seq_overrides = group.get("overrides") or []
group_overrides = (
group.get("group_overrides")
if isinstance(group.get("group_overrides"), dict)
else {}
)
has_new_format = (
group.get("distill_token_ids") is not None
and group.get("distill_logprobs") is not None
)
if not has_new_format:
try:
teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs(
token_sequences=group["tokens"],
messages_list=group.get("messages"),
seq_overrides=seq_overrides,
group_overrides=group_overrides,
)
if teacher_token_ids and teacher_logprobs:
group["distill_token_ids"] = teacher_token_ids
group["distill_logprobs"] = teacher_logprobs
logger.info(
f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences"
)
else:
logger.warning("[DISTILL] get_teacher_logprobs returned empty")
except Exception as e:
logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}")
import traceback
logger.error(traceback.format_exc())
self.teacher_client.assert_distill_arrays_aligned(
token_sequences=group["tokens"],
distill_token_ids=group.get("distill_token_ids"),
distill_logprobs=group.get("distill_logprobs"),
)
else:
logger.debug(
"[DISTILL] Skipped - enabled=%s, url=%s",
self.config.distillation_enabled,
self.config.teacher_base_url,
)
data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]]
# send single or list of scored data groups
if not original_was_list and len(valid_groups) == 1: