[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-20 04:58:43 +00:00
parent ccdd5a1ca6
commit 60fb6cae11
11 changed files with 221 additions and 136 deletions

View file

@ -360,47 +360,51 @@ class BaseEnv(ABC):
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""
Fetch top-K logprobs from teacher model for given sequences.
Supports any OpenAI-compatible API (vLLM, OpenAI, Together, etc.).
Args:
token_sequences: List of token ID sequences to get logprobs for
messages_list: Optional list of message histories (for chat APIs).
If provided, uses chat/completions with logprobs.
top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k)
Returns:
Tuple of (distill_token_ids, distill_logprobs), both shaped as:
[batch][position][top_k].
Returns ([], []) if teacher_base_url is not configured.
"""
logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences")
logger.info(
f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences"
)
logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}")
if not self.config.teacher_base_url:
logger.warning("[TEACHER] No teacher_base_url configured, returning empty")
return [], []
if top_k is None:
top_k = self.config.teacher_top_k
# Get API key from config or environment
api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "")
model_name = self.config.teacher_model_name or "default"
logger.info(f"[TEACHER] Using model={model_name}, top_k={top_k}")
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
token_id_results: List[List[List[int]]] = []
logprob_results: List[List[List[float]]] = []
try:
async with aiohttp.ClientSession() as session:
for i, tokens in enumerate(token_sequences):
logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens")
logger.info(
f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens"
)
# Decode original sequence and optionally prepend teacher steering text.
base_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
steering_prefix = ""
@ -413,11 +417,15 @@ class BaseEnv(ABC):
steering_prefix += self.config.teacher_prefix_text
full_text = steering_prefix + base_text
prefix_token_len = (
len(self.tokenizer.encode(steering_prefix, add_special_tokens=False))
len(
self.tokenizer.encode(
steering_prefix, add_special_tokens=False
)
)
if steering_prefix
else 0
)
# Try vLLM-style completions first (supports prompt_logprobs)
# This is most efficient as it doesn't generate new tokens
request_data = {
@ -428,7 +436,7 @@ class BaseEnv(ABC):
"logprobs": top_k,
"echo": True, # Include prompt in response with logprobs
}
try:
async with session.post(
f"{self.config.teacher_base_url}/completions",
@ -438,22 +446,24 @@ class BaseEnv(ABC):
) as response:
if response.status == 200:
data = await response.json()
seq_token_ids, seq_logprobs = self._parse_completion_logprobs(
data, top_k
seq_token_ids, seq_logprobs = (
self._parse_completion_logprobs(data, top_k)
)
if seq_token_ids and seq_logprobs:
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens(
seq_token_ids,
seq_logprobs,
target_token_len=len(tokens),
prefix_token_len=prefix_token_len,
aligned_ids, aligned_lps = (
self._align_teacher_topk_to_tokens(
seq_token_ids,
seq_logprobs,
target_token_len=len(tokens),
prefix_token_len=prefix_token_len,
)
)
token_id_results.append(aligned_ids)
logprob_results.append(aligned_lps)
continue
except Exception:
pass # Fall through to chat completions
# Fallback: Use chat/completions with logprobs (OpenAI style)
# This requires messages format
if messages_list and i < len(messages_list):
@ -476,7 +486,7 @@ class BaseEnv(ABC):
}
)
messages.append({"role": "user", "content": full_text})
chat_request = {
"model": model_name,
"messages": messages,
@ -485,7 +495,7 @@ class BaseEnv(ABC):
"logprobs": True,
"top_logprobs": top_k,
}
try:
async with session.post(
f"{self.config.teacher_base_url}/chat/completions",
@ -501,11 +511,13 @@ class BaseEnv(ABC):
# Chat fallback logprobs are for generated tokens, not prompt tokens.
# To keep alignment correct for distillation, return empty per-position rows.
if seq_token_ids and len(seq_token_ids) >= len(tokens):
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens(
seq_token_ids,
seq_logprobs,
target_token_len=len(tokens),
prefix_token_len=0,
aligned_ids, aligned_lps = (
self._align_teacher_topk_to_tokens(
seq_token_ids,
seq_logprobs,
target_token_len=len(tokens),
prefix_token_len=0,
)
)
else:
aligned_ids = [[] for _ in range(len(tokens))]
@ -513,16 +525,20 @@ class BaseEnv(ABC):
token_id_results.append(aligned_ids)
logprob_results.append(aligned_lps)
else:
logger.warning(f"Teacher API returned {response.status}")
token_id_results.append([[] for _ in range(len(tokens))])
logger.warning(
f"Teacher API returned {response.status}"
)
token_id_results.append(
[[] for _ in range(len(tokens))]
)
logprob_results.append([[] for _ in range(len(tokens))])
except Exception as e:
logger.warning(f"Teacher chat request failed: {e}")
token_id_results.append([[] for _ in range(len(tokens))])
logprob_results.append([[] for _ in range(len(tokens))])
return token_id_results, logprob_results
except Exception as e:
logger.error(f"Error fetching teacher logprobs: {e}")
return [], []
@ -556,7 +572,7 @@ class BaseEnv(ABC):
aligned_lps.extend([[] for _ in range(pad_count)])
return aligned_ids, aligned_lps
def _parse_completion_logprobs(
self, data: Dict, top_k: int
) -> Tuple[List[List[int]], List[List[float]]]:
@ -564,10 +580,10 @@ class BaseEnv(ABC):
try:
choice = data.get("choices", [{}])[0]
logprobs_data = choice.get("logprobs", {})
# vLLM returns top_logprobs as list of dicts
top_logprobs = logprobs_data.get("top_logprobs", [])
if not top_logprobs:
return [], []
@ -580,15 +596,15 @@ class BaseEnv(ABC):
elif isinstance(pos_logprobs, dict):
# Format: {token_str: logprob, ...}
sorted_items = sorted(
pos_logprobs.items(),
key=lambda x: x[1],
reverse=True
pos_logprobs.items(), key=lambda x: x[1], reverse=True
)[:top_k]
pos_ids: List[int] = []
pos_lps: List[float] = []
for token_str, logprob in sorted_items:
# Convert token string to ID
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
token_ids = self.tokenizer.encode(
token_str, add_special_tokens=False
)
if token_ids:
pos_ids.append(int(token_ids[0]))
pos_lps.append(float(logprob))
@ -602,7 +618,7 @@ class BaseEnv(ABC):
except Exception as e:
logger.warning(f"Error parsing completion logprobs: {e}")
return [], []
def _parse_chat_logprobs(
self, data: Dict, top_k: int
) -> Tuple[List[List[int]], List[List[float]]]:
@ -610,14 +626,14 @@ class BaseEnv(ABC):
try:
choice = data.get("choices", [{}])[0]
logprobs_data = choice.get("logprobs", {})
if not logprobs_data:
return [], []
content = logprobs_data.get("content", [])
seq_token_ids: List[List[int]] = []
seq_logprobs: List[List[float]] = []
for token_data in content:
top_logprobs = token_data.get("top_logprobs", [])
pos_ids: List[int] = []
@ -626,7 +642,9 @@ class BaseEnv(ABC):
token_str = item.get("token", "")
logprob = item.get("logprob", 0.0)
# Convert token string to ID
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
token_ids = self.tokenizer.encode(
token_str, add_special_tokens=False
)
if token_ids:
pos_ids.append(int(token_ids[0]))
pos_lps.append(float(logprob))
@ -1251,7 +1269,9 @@ class BaseEnv(ABC):
if valid_groups and do_send_to_api:
# On-policy distillation: fetch teacher logprobs if enabled
if self.config.distillation_enabled and self.config.teacher_base_url:
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
logger.info(
f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups"
)
for group in valid_groups:
has_new_format = (
group.get("distill_token_ids") is not None
@ -1259,9 +1279,11 @@ class BaseEnv(ABC):
)
if not has_new_format:
try:
teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs(
token_sequences=group["tokens"],
messages_list=group.get("messages"),
teacher_token_ids, teacher_logprobs = (
await self.get_teacher_logprobs(
token_sequences=group["tokens"],
messages_list=group.get("messages"),
)
)
if teacher_token_ids and teacher_logprobs:
group["distill_token_ids"] = teacher_token_ids
@ -1270,10 +1292,15 @@ class BaseEnv(ABC):
f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences"
)
else:
logger.warning("[DISTILL] get_teacher_logprobs returned empty")
logger.warning(
"[DISTILL] get_teacher_logprobs returned empty"
)
except Exception as e:
logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}")
logger.error(
f"[DISTILL] Failed to fetch teacher logprobs: {e}"
)
import traceback
logger.error(traceback.format_exc())
else:
logger.debug(
@ -1788,13 +1815,13 @@ class BaseEnv(ABC):
cli_passed_flags, openai_full_prefix
) # CLI args
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
# Debug logging for CLI args
print(f"[CLI DEBUG] cli_passed_flags = {cli_passed_flags}")
print(f"[CLI DEBUG] openai_full_prefix = {openai_full_prefix}")
print(f"[CLI DEBUG] oai_cli_passed_args = {oai_cli_passed_args}")
print(f"[CLI DEBUG] yaml_oai_config = {yaml_oai_config}")
# Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided
# This allows any environment to use --openai.* CLI args without modifying config_init
# Use a new variable to avoid UnboundLocalError from closure scoping
@ -1808,7 +1835,7 @@ class BaseEnv(ABC):
logger.info(
"Auto-converted ServerBaseline to APIServerConfig for CLI/YAML overrides"
)
if (
isinstance(effective_server_configs, list)
and len(effective_server_configs) == 1
@ -1822,13 +1849,17 @@ class BaseEnv(ABC):
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict
):
print(f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}")
print(
f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}"
)
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
yaml_oai_config,
oai_cli_passed_args,
)
print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}")
print(
f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}"
)
else:
print(
"[CLI DEBUG] Not merging: default_openai_config_ "

View file

@ -165,7 +165,9 @@ def resolve_openai_configs(
"""
from atroposlib.envs.server_handling.server_manager import ServerBaseline
print(f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}")
print(
f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}"
)
print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}")
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
@ -216,7 +218,9 @@ def resolve_openai_configs(
elif isinstance(default_server_configs, APIServerConfig):
# Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline
print("[RESOLVE DEBUG] Taking APIServerConfig merged path")
logger.info("Using single OpenAI server configuration based on merged settings (default/YAML/CLI).")
logger.info(
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
)
try:
final_openai_config = APIServerConfig(**openai_config_dict)
except Exception as e:

View file

@ -193,7 +193,9 @@ class VLLMServer(APIServer):
debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
if debug_requests:
base = self.config.base_url.replace("/v1", "")
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n")
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace(
"\n", "\\n"
)
print(
f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate "
f"prompt_token_len={len(prompt_tokens)}",
@ -211,7 +213,7 @@ class VLLMServer(APIServer):
)
print(
f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate "
'-H "Content-Type: application/json" -d \'<JSON_PAYLOAD>\'',
"-H \"Content-Type: application/json\" -d '<JSON_PAYLOAD>'",
flush=True,
)