mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ccdd5a1ca6
commit
60fb6cae11
11 changed files with 221 additions and 136 deletions
|
|
@ -360,47 +360,51 @@ class BaseEnv(ABC):
|
|||
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||
"""
|
||||
Fetch top-K logprobs from teacher model for given sequences.
|
||||
|
||||
|
||||
Supports any OpenAI-compatible API (vLLM, OpenAI, Together, etc.).
|
||||
|
||||
|
||||
Args:
|
||||
token_sequences: List of token ID sequences to get logprobs for
|
||||
messages_list: Optional list of message histories (for chat APIs).
|
||||
If provided, uses chat/completions with logprobs.
|
||||
top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k)
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (distill_token_ids, distill_logprobs), both shaped as:
|
||||
[batch][position][top_k].
|
||||
Returns ([], []) if teacher_base_url is not configured.
|
||||
"""
|
||||
logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences")
|
||||
logger.info(
|
||||
f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences"
|
||||
)
|
||||
logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}")
|
||||
|
||||
|
||||
if not self.config.teacher_base_url:
|
||||
logger.warning("[TEACHER] No teacher_base_url configured, returning empty")
|
||||
return [], []
|
||||
|
||||
|
||||
if top_k is None:
|
||||
top_k = self.config.teacher_top_k
|
||||
|
||||
|
||||
# Get API key from config or environment
|
||||
api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "")
|
||||
model_name = self.config.teacher_model_name or "default"
|
||||
|
||||
|
||||
logger.info(f"[TEACHER] Using model={model_name}, top_k={top_k}")
|
||||
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
|
||||
token_id_results: List[List[List[int]]] = []
|
||||
logprob_results: List[List[List[float]]] = []
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i, tokens in enumerate(token_sequences):
|
||||
logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens")
|
||||
logger.info(
|
||||
f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens"
|
||||
)
|
||||
# Decode original sequence and optionally prepend teacher steering text.
|
||||
base_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
|
||||
steering_prefix = ""
|
||||
|
|
@ -413,11 +417,15 @@ class BaseEnv(ABC):
|
|||
steering_prefix += self.config.teacher_prefix_text
|
||||
full_text = steering_prefix + base_text
|
||||
prefix_token_len = (
|
||||
len(self.tokenizer.encode(steering_prefix, add_special_tokens=False))
|
||||
len(
|
||||
self.tokenizer.encode(
|
||||
steering_prefix, add_special_tokens=False
|
||||
)
|
||||
)
|
||||
if steering_prefix
|
||||
else 0
|
||||
)
|
||||
|
||||
|
||||
# Try vLLM-style completions first (supports prompt_logprobs)
|
||||
# This is most efficient as it doesn't generate new tokens
|
||||
request_data = {
|
||||
|
|
@ -428,7 +436,7 @@ class BaseEnv(ABC):
|
|||
"logprobs": top_k,
|
||||
"echo": True, # Include prompt in response with logprobs
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.config.teacher_base_url}/completions",
|
||||
|
|
@ -438,22 +446,24 @@ class BaseEnv(ABC):
|
|||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
seq_token_ids, seq_logprobs = self._parse_completion_logprobs(
|
||||
data, top_k
|
||||
seq_token_ids, seq_logprobs = (
|
||||
self._parse_completion_logprobs(data, top_k)
|
||||
)
|
||||
if seq_token_ids and seq_logprobs:
|
||||
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens(
|
||||
seq_token_ids,
|
||||
seq_logprobs,
|
||||
target_token_len=len(tokens),
|
||||
prefix_token_len=prefix_token_len,
|
||||
aligned_ids, aligned_lps = (
|
||||
self._align_teacher_topk_to_tokens(
|
||||
seq_token_ids,
|
||||
seq_logprobs,
|
||||
target_token_len=len(tokens),
|
||||
prefix_token_len=prefix_token_len,
|
||||
)
|
||||
)
|
||||
token_id_results.append(aligned_ids)
|
||||
logprob_results.append(aligned_lps)
|
||||
continue
|
||||
except Exception:
|
||||
pass # Fall through to chat completions
|
||||
|
||||
|
||||
# Fallback: Use chat/completions with logprobs (OpenAI style)
|
||||
# This requires messages format
|
||||
if messages_list and i < len(messages_list):
|
||||
|
|
@ -476,7 +486,7 @@ class BaseEnv(ABC):
|
|||
}
|
||||
)
|
||||
messages.append({"role": "user", "content": full_text})
|
||||
|
||||
|
||||
chat_request = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
|
|
@ -485,7 +495,7 @@ class BaseEnv(ABC):
|
|||
"logprobs": True,
|
||||
"top_logprobs": top_k,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.config.teacher_base_url}/chat/completions",
|
||||
|
|
@ -501,11 +511,13 @@ class BaseEnv(ABC):
|
|||
# Chat fallback logprobs are for generated tokens, not prompt tokens.
|
||||
# To keep alignment correct for distillation, return empty per-position rows.
|
||||
if seq_token_ids and len(seq_token_ids) >= len(tokens):
|
||||
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens(
|
||||
seq_token_ids,
|
||||
seq_logprobs,
|
||||
target_token_len=len(tokens),
|
||||
prefix_token_len=0,
|
||||
aligned_ids, aligned_lps = (
|
||||
self._align_teacher_topk_to_tokens(
|
||||
seq_token_ids,
|
||||
seq_logprobs,
|
||||
target_token_len=len(tokens),
|
||||
prefix_token_len=0,
|
||||
)
|
||||
)
|
||||
else:
|
||||
aligned_ids = [[] for _ in range(len(tokens))]
|
||||
|
|
@ -513,16 +525,20 @@ class BaseEnv(ABC):
|
|||
token_id_results.append(aligned_ids)
|
||||
logprob_results.append(aligned_lps)
|
||||
else:
|
||||
logger.warning(f"Teacher API returned {response.status}")
|
||||
token_id_results.append([[] for _ in range(len(tokens))])
|
||||
logger.warning(
|
||||
f"Teacher API returned {response.status}"
|
||||
)
|
||||
token_id_results.append(
|
||||
[[] for _ in range(len(tokens))]
|
||||
)
|
||||
logprob_results.append([[] for _ in range(len(tokens))])
|
||||
except Exception as e:
|
||||
logger.warning(f"Teacher chat request failed: {e}")
|
||||
token_id_results.append([[] for _ in range(len(tokens))])
|
||||
logprob_results.append([[] for _ in range(len(tokens))])
|
||||
|
||||
|
||||
return token_id_results, logprob_results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching teacher logprobs: {e}")
|
||||
return [], []
|
||||
|
|
@ -556,7 +572,7 @@ class BaseEnv(ABC):
|
|||
aligned_lps.extend([[] for _ in range(pad_count)])
|
||||
|
||||
return aligned_ids, aligned_lps
|
||||
|
||||
|
||||
def _parse_completion_logprobs(
|
||||
self, data: Dict, top_k: int
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
|
|
@ -564,10 +580,10 @@ class BaseEnv(ABC):
|
|||
try:
|
||||
choice = data.get("choices", [{}])[0]
|
||||
logprobs_data = choice.get("logprobs", {})
|
||||
|
||||
|
||||
# vLLM returns top_logprobs as list of dicts
|
||||
top_logprobs = logprobs_data.get("top_logprobs", [])
|
||||
|
||||
|
||||
if not top_logprobs:
|
||||
return [], []
|
||||
|
||||
|
|
@ -580,15 +596,15 @@ class BaseEnv(ABC):
|
|||
elif isinstance(pos_logprobs, dict):
|
||||
# Format: {token_str: logprob, ...}
|
||||
sorted_items = sorted(
|
||||
pos_logprobs.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
pos_logprobs.items(), key=lambda x: x[1], reverse=True
|
||||
)[:top_k]
|
||||
pos_ids: List[int] = []
|
||||
pos_lps: List[float] = []
|
||||
for token_str, logprob in sorted_items:
|
||||
# Convert token string to ID
|
||||
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
||||
token_ids = self.tokenizer.encode(
|
||||
token_str, add_special_tokens=False
|
||||
)
|
||||
if token_ids:
|
||||
pos_ids.append(int(token_ids[0]))
|
||||
pos_lps.append(float(logprob))
|
||||
|
|
@ -602,7 +618,7 @@ class BaseEnv(ABC):
|
|||
except Exception as e:
|
||||
logger.warning(f"Error parsing completion logprobs: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
def _parse_chat_logprobs(
|
||||
self, data: Dict, top_k: int
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
|
|
@ -610,14 +626,14 @@ class BaseEnv(ABC):
|
|||
try:
|
||||
choice = data.get("choices", [{}])[0]
|
||||
logprobs_data = choice.get("logprobs", {})
|
||||
|
||||
|
||||
if not logprobs_data:
|
||||
return [], []
|
||||
|
||||
|
||||
content = logprobs_data.get("content", [])
|
||||
seq_token_ids: List[List[int]] = []
|
||||
seq_logprobs: List[List[float]] = []
|
||||
|
||||
|
||||
for token_data in content:
|
||||
top_logprobs = token_data.get("top_logprobs", [])
|
||||
pos_ids: List[int] = []
|
||||
|
|
@ -626,7 +642,9 @@ class BaseEnv(ABC):
|
|||
token_str = item.get("token", "")
|
||||
logprob = item.get("logprob", 0.0)
|
||||
# Convert token string to ID
|
||||
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
||||
token_ids = self.tokenizer.encode(
|
||||
token_str, add_special_tokens=False
|
||||
)
|
||||
if token_ids:
|
||||
pos_ids.append(int(token_ids[0]))
|
||||
pos_lps.append(float(logprob))
|
||||
|
|
@ -1251,7 +1269,9 @@ class BaseEnv(ABC):
|
|||
if valid_groups and do_send_to_api:
|
||||
# On-policy distillation: fetch teacher logprobs if enabled
|
||||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
|
||||
logger.info(
|
||||
f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups"
|
||||
)
|
||||
for group in valid_groups:
|
||||
has_new_format = (
|
||||
group.get("distill_token_ids") is not None
|
||||
|
|
@ -1259,9 +1279,11 @@ class BaseEnv(ABC):
|
|||
)
|
||||
if not has_new_format:
|
||||
try:
|
||||
teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
messages_list=group.get("messages"),
|
||||
teacher_token_ids, teacher_logprobs = (
|
||||
await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
messages_list=group.get("messages"),
|
||||
)
|
||||
)
|
||||
if teacher_token_ids and teacher_logprobs:
|
||||
group["distill_token_ids"] = teacher_token_ids
|
||||
|
|
@ -1270,10 +1292,15 @@ class BaseEnv(ABC):
|
|||
f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences"
|
||||
)
|
||||
else:
|
||||
logger.warning("[DISTILL] get_teacher_logprobs returned empty")
|
||||
logger.warning(
|
||||
"[DISTILL] get_teacher_logprobs returned empty"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}")
|
||||
logger.error(
|
||||
f"[DISTILL] Failed to fetch teacher logprobs: {e}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.debug(
|
||||
|
|
@ -1788,13 +1815,13 @@ class BaseEnv(ABC):
|
|||
cli_passed_flags, openai_full_prefix
|
||||
) # CLI args
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
|
||||
|
||||
# Debug logging for CLI args
|
||||
print(f"[CLI DEBUG] cli_passed_flags = {cli_passed_flags}")
|
||||
print(f"[CLI DEBUG] openai_full_prefix = {openai_full_prefix}")
|
||||
print(f"[CLI DEBUG] oai_cli_passed_args = {oai_cli_passed_args}")
|
||||
print(f"[CLI DEBUG] yaml_oai_config = {yaml_oai_config}")
|
||||
|
||||
|
||||
# Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided
|
||||
# This allows any environment to use --openai.* CLI args without modifying config_init
|
||||
# Use a new variable to avoid UnboundLocalError from closure scoping
|
||||
|
|
@ -1808,7 +1835,7 @@ class BaseEnv(ABC):
|
|||
logger.info(
|
||||
"Auto-converted ServerBaseline to APIServerConfig for CLI/YAML overrides"
|
||||
)
|
||||
|
||||
|
||||
if (
|
||||
isinstance(effective_server_configs, list)
|
||||
and len(effective_server_configs) == 1
|
||||
|
|
@ -1822,13 +1849,17 @@ class BaseEnv(ABC):
|
|||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
print(f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}")
|
||||
print(
|
||||
f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}"
|
||||
)
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}")
|
||||
print(
|
||||
f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"[CLI DEBUG] Not merging: default_openai_config_ "
|
||||
|
|
|
|||
|
|
@ -165,7 +165,9 @@ def resolve_openai_configs(
|
|||
"""
|
||||
from atroposlib.envs.server_handling.server_manager import ServerBaseline
|
||||
|
||||
print(f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}")
|
||||
print(
|
||||
f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}"
|
||||
)
|
||||
print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}")
|
||||
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
|
@ -216,7 +218,9 @@ def resolve_openai_configs(
|
|||
elif isinstance(default_server_configs, APIServerConfig):
|
||||
# Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline
|
||||
print("[RESOLVE DEBUG] Taking APIServerConfig merged path")
|
||||
logger.info("Using single OpenAI server configuration based on merged settings (default/YAML/CLI).")
|
||||
logger.info(
|
||||
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
|
||||
)
|
||||
try:
|
||||
final_openai_config = APIServerConfig(**openai_config_dict)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -193,7 +193,9 @@ class VLLMServer(APIServer):
|
|||
debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
|
||||
if debug_requests:
|
||||
base = self.config.base_url.replace("/v1", "")
|
||||
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n")
|
||||
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace(
|
||||
"\n", "\\n"
|
||||
)
|
||||
print(
|
||||
f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate "
|
||||
f"prompt_token_len={len(prompt_tokens)}",
|
||||
|
|
@ -211,7 +213,7 @@ class VLLMServer(APIServer):
|
|||
)
|
||||
print(
|
||||
f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate "
|
||||
'-H "Content-Type: application/json" -d \'<JSON_PAYLOAD>\'',
|
||||
"-H \"Content-Type: application/json\" -d '<JSON_PAYLOAD>'",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue