mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ccdd5a1ca6
commit
60fb6cae11
11 changed files with 221 additions and 136 deletions
|
|
@ -294,10 +294,10 @@ Distillation is configured in `BaseEnvConfig` and available via CLI under `--env
|
|||
|
||||
Both setups are supported:
|
||||
|
||||
- **Self-distillation (same model family for teacher and student)**
|
||||
- **Self-distillation (same model family for teacher and student)**
|
||||
Point `teacher_base_url` to a server running the same model (or equivalent checkpoint family) as the student. This is the most stable setup for token-level alignment.
|
||||
|
||||
- **Cross-model distillation (different teacher and student models)**
|
||||
- **Cross-model distillation (different teacher and student models)**
|
||||
Also supported, but tokenization compatibility becomes more important. If token vocabularies/template behavior differ significantly, alignment quality may degrade.
|
||||
|
||||
In practice, self-distillation is usually easiest to bring up first, then cross-model can be layered in once your pipeline is stable.
|
||||
|
|
|
|||
|
|
@ -360,47 +360,51 @@ class BaseEnv(ABC):
|
|||
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
||||
"""
|
||||
Fetch top-K logprobs from teacher model for given sequences.
|
||||
|
||||
|
||||
Supports any OpenAI-compatible API (vLLM, OpenAI, Together, etc.).
|
||||
|
||||
|
||||
Args:
|
||||
token_sequences: List of token ID sequences to get logprobs for
|
||||
messages_list: Optional list of message histories (for chat APIs).
|
||||
If provided, uses chat/completions with logprobs.
|
||||
top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k)
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (distill_token_ids, distill_logprobs), both shaped as:
|
||||
[batch][position][top_k].
|
||||
Returns ([], []) if teacher_base_url is not configured.
|
||||
"""
|
||||
logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences")
|
||||
logger.info(
|
||||
f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences"
|
||||
)
|
||||
logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}")
|
||||
|
||||
|
||||
if not self.config.teacher_base_url:
|
||||
logger.warning("[TEACHER] No teacher_base_url configured, returning empty")
|
||||
return [], []
|
||||
|
||||
|
||||
if top_k is None:
|
||||
top_k = self.config.teacher_top_k
|
||||
|
||||
|
||||
# Get API key from config or environment
|
||||
api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "")
|
||||
model_name = self.config.teacher_model_name or "default"
|
||||
|
||||
|
||||
logger.info(f"[TEACHER] Using model={model_name}, top_k={top_k}")
|
||||
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
|
||||
token_id_results: List[List[List[int]]] = []
|
||||
logprob_results: List[List[List[float]]] = []
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i, tokens in enumerate(token_sequences):
|
||||
logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens")
|
||||
logger.info(
|
||||
f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens"
|
||||
)
|
||||
# Decode original sequence and optionally prepend teacher steering text.
|
||||
base_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
|
||||
steering_prefix = ""
|
||||
|
|
@ -413,11 +417,15 @@ class BaseEnv(ABC):
|
|||
steering_prefix += self.config.teacher_prefix_text
|
||||
full_text = steering_prefix + base_text
|
||||
prefix_token_len = (
|
||||
len(self.tokenizer.encode(steering_prefix, add_special_tokens=False))
|
||||
len(
|
||||
self.tokenizer.encode(
|
||||
steering_prefix, add_special_tokens=False
|
||||
)
|
||||
)
|
||||
if steering_prefix
|
||||
else 0
|
||||
)
|
||||
|
||||
|
||||
# Try vLLM-style completions first (supports prompt_logprobs)
|
||||
# This is most efficient as it doesn't generate new tokens
|
||||
request_data = {
|
||||
|
|
@ -428,7 +436,7 @@ class BaseEnv(ABC):
|
|||
"logprobs": top_k,
|
||||
"echo": True, # Include prompt in response with logprobs
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.config.teacher_base_url}/completions",
|
||||
|
|
@ -438,22 +446,24 @@ class BaseEnv(ABC):
|
|||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
seq_token_ids, seq_logprobs = self._parse_completion_logprobs(
|
||||
data, top_k
|
||||
seq_token_ids, seq_logprobs = (
|
||||
self._parse_completion_logprobs(data, top_k)
|
||||
)
|
||||
if seq_token_ids and seq_logprobs:
|
||||
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens(
|
||||
seq_token_ids,
|
||||
seq_logprobs,
|
||||
target_token_len=len(tokens),
|
||||
prefix_token_len=prefix_token_len,
|
||||
aligned_ids, aligned_lps = (
|
||||
self._align_teacher_topk_to_tokens(
|
||||
seq_token_ids,
|
||||
seq_logprobs,
|
||||
target_token_len=len(tokens),
|
||||
prefix_token_len=prefix_token_len,
|
||||
)
|
||||
)
|
||||
token_id_results.append(aligned_ids)
|
||||
logprob_results.append(aligned_lps)
|
||||
continue
|
||||
except Exception:
|
||||
pass # Fall through to chat completions
|
||||
|
||||
|
||||
# Fallback: Use chat/completions with logprobs (OpenAI style)
|
||||
# This requires messages format
|
||||
if messages_list and i < len(messages_list):
|
||||
|
|
@ -476,7 +486,7 @@ class BaseEnv(ABC):
|
|||
}
|
||||
)
|
||||
messages.append({"role": "user", "content": full_text})
|
||||
|
||||
|
||||
chat_request = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
|
|
@ -485,7 +495,7 @@ class BaseEnv(ABC):
|
|||
"logprobs": True,
|
||||
"top_logprobs": top_k,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
f"{self.config.teacher_base_url}/chat/completions",
|
||||
|
|
@ -501,11 +511,13 @@ class BaseEnv(ABC):
|
|||
# Chat fallback logprobs are for generated tokens, not prompt tokens.
|
||||
# To keep alignment correct for distillation, return empty per-position rows.
|
||||
if seq_token_ids and len(seq_token_ids) >= len(tokens):
|
||||
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens(
|
||||
seq_token_ids,
|
||||
seq_logprobs,
|
||||
target_token_len=len(tokens),
|
||||
prefix_token_len=0,
|
||||
aligned_ids, aligned_lps = (
|
||||
self._align_teacher_topk_to_tokens(
|
||||
seq_token_ids,
|
||||
seq_logprobs,
|
||||
target_token_len=len(tokens),
|
||||
prefix_token_len=0,
|
||||
)
|
||||
)
|
||||
else:
|
||||
aligned_ids = [[] for _ in range(len(tokens))]
|
||||
|
|
@ -513,16 +525,20 @@ class BaseEnv(ABC):
|
|||
token_id_results.append(aligned_ids)
|
||||
logprob_results.append(aligned_lps)
|
||||
else:
|
||||
logger.warning(f"Teacher API returned {response.status}")
|
||||
token_id_results.append([[] for _ in range(len(tokens))])
|
||||
logger.warning(
|
||||
f"Teacher API returned {response.status}"
|
||||
)
|
||||
token_id_results.append(
|
||||
[[] for _ in range(len(tokens))]
|
||||
)
|
||||
logprob_results.append([[] for _ in range(len(tokens))])
|
||||
except Exception as e:
|
||||
logger.warning(f"Teacher chat request failed: {e}")
|
||||
token_id_results.append([[] for _ in range(len(tokens))])
|
||||
logprob_results.append([[] for _ in range(len(tokens))])
|
||||
|
||||
|
||||
return token_id_results, logprob_results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching teacher logprobs: {e}")
|
||||
return [], []
|
||||
|
|
@ -556,7 +572,7 @@ class BaseEnv(ABC):
|
|||
aligned_lps.extend([[] for _ in range(pad_count)])
|
||||
|
||||
return aligned_ids, aligned_lps
|
||||
|
||||
|
||||
def _parse_completion_logprobs(
|
||||
self, data: Dict, top_k: int
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
|
|
@ -564,10 +580,10 @@ class BaseEnv(ABC):
|
|||
try:
|
||||
choice = data.get("choices", [{}])[0]
|
||||
logprobs_data = choice.get("logprobs", {})
|
||||
|
||||
|
||||
# vLLM returns top_logprobs as list of dicts
|
||||
top_logprobs = logprobs_data.get("top_logprobs", [])
|
||||
|
||||
|
||||
if not top_logprobs:
|
||||
return [], []
|
||||
|
||||
|
|
@ -580,15 +596,15 @@ class BaseEnv(ABC):
|
|||
elif isinstance(pos_logprobs, dict):
|
||||
# Format: {token_str: logprob, ...}
|
||||
sorted_items = sorted(
|
||||
pos_logprobs.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
pos_logprobs.items(), key=lambda x: x[1], reverse=True
|
||||
)[:top_k]
|
||||
pos_ids: List[int] = []
|
||||
pos_lps: List[float] = []
|
||||
for token_str, logprob in sorted_items:
|
||||
# Convert token string to ID
|
||||
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
||||
token_ids = self.tokenizer.encode(
|
||||
token_str, add_special_tokens=False
|
||||
)
|
||||
if token_ids:
|
||||
pos_ids.append(int(token_ids[0]))
|
||||
pos_lps.append(float(logprob))
|
||||
|
|
@ -602,7 +618,7 @@ class BaseEnv(ABC):
|
|||
except Exception as e:
|
||||
logger.warning(f"Error parsing completion logprobs: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
def _parse_chat_logprobs(
|
||||
self, data: Dict, top_k: int
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
|
|
@ -610,14 +626,14 @@ class BaseEnv(ABC):
|
|||
try:
|
||||
choice = data.get("choices", [{}])[0]
|
||||
logprobs_data = choice.get("logprobs", {})
|
||||
|
||||
|
||||
if not logprobs_data:
|
||||
return [], []
|
||||
|
||||
|
||||
content = logprobs_data.get("content", [])
|
||||
seq_token_ids: List[List[int]] = []
|
||||
seq_logprobs: List[List[float]] = []
|
||||
|
||||
|
||||
for token_data in content:
|
||||
top_logprobs = token_data.get("top_logprobs", [])
|
||||
pos_ids: List[int] = []
|
||||
|
|
@ -626,7 +642,9 @@ class BaseEnv(ABC):
|
|||
token_str = item.get("token", "")
|
||||
logprob = item.get("logprob", 0.0)
|
||||
# Convert token string to ID
|
||||
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
||||
token_ids = self.tokenizer.encode(
|
||||
token_str, add_special_tokens=False
|
||||
)
|
||||
if token_ids:
|
||||
pos_ids.append(int(token_ids[0]))
|
||||
pos_lps.append(float(logprob))
|
||||
|
|
@ -1251,7 +1269,9 @@ class BaseEnv(ABC):
|
|||
if valid_groups and do_send_to_api:
|
||||
# On-policy distillation: fetch teacher logprobs if enabled
|
||||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
|
||||
logger.info(
|
||||
f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups"
|
||||
)
|
||||
for group in valid_groups:
|
||||
has_new_format = (
|
||||
group.get("distill_token_ids") is not None
|
||||
|
|
@ -1259,9 +1279,11 @@ class BaseEnv(ABC):
|
|||
)
|
||||
if not has_new_format:
|
||||
try:
|
||||
teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
messages_list=group.get("messages"),
|
||||
teacher_token_ids, teacher_logprobs = (
|
||||
await self.get_teacher_logprobs(
|
||||
token_sequences=group["tokens"],
|
||||
messages_list=group.get("messages"),
|
||||
)
|
||||
)
|
||||
if teacher_token_ids and teacher_logprobs:
|
||||
group["distill_token_ids"] = teacher_token_ids
|
||||
|
|
@ -1270,10 +1292,15 @@ class BaseEnv(ABC):
|
|||
f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences"
|
||||
)
|
||||
else:
|
||||
logger.warning("[DISTILL] get_teacher_logprobs returned empty")
|
||||
logger.warning(
|
||||
"[DISTILL] get_teacher_logprobs returned empty"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}")
|
||||
logger.error(
|
||||
f"[DISTILL] Failed to fetch teacher logprobs: {e}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.debug(
|
||||
|
|
@ -1788,13 +1815,13 @@ class BaseEnv(ABC):
|
|||
cli_passed_flags, openai_full_prefix
|
||||
) # CLI args
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
|
||||
|
||||
# Debug logging for CLI args
|
||||
print(f"[CLI DEBUG] cli_passed_flags = {cli_passed_flags}")
|
||||
print(f"[CLI DEBUG] openai_full_prefix = {openai_full_prefix}")
|
||||
print(f"[CLI DEBUG] oai_cli_passed_args = {oai_cli_passed_args}")
|
||||
print(f"[CLI DEBUG] yaml_oai_config = {yaml_oai_config}")
|
||||
|
||||
|
||||
# Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided
|
||||
# This allows any environment to use --openai.* CLI args without modifying config_init
|
||||
# Use a new variable to avoid UnboundLocalError from closure scoping
|
||||
|
|
@ -1808,7 +1835,7 @@ class BaseEnv(ABC):
|
|||
logger.info(
|
||||
"Auto-converted ServerBaseline to APIServerConfig for CLI/YAML overrides"
|
||||
)
|
||||
|
||||
|
||||
if (
|
||||
isinstance(effective_server_configs, list)
|
||||
and len(effective_server_configs) == 1
|
||||
|
|
@ -1822,13 +1849,17 @@ class BaseEnv(ABC):
|
|||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
print(f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}")
|
||||
print(
|
||||
f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}"
|
||||
)
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}")
|
||||
print(
|
||||
f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"[CLI DEBUG] Not merging: default_openai_config_ "
|
||||
|
|
|
|||
|
|
@ -165,7 +165,9 @@ def resolve_openai_configs(
|
|||
"""
|
||||
from atroposlib.envs.server_handling.server_manager import ServerBaseline
|
||||
|
||||
print(f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}")
|
||||
print(
|
||||
f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}"
|
||||
)
|
||||
print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}")
|
||||
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
|
@ -216,7 +218,9 @@ def resolve_openai_configs(
|
|||
elif isinstance(default_server_configs, APIServerConfig):
|
||||
# Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline
|
||||
print("[RESOLVE DEBUG] Taking APIServerConfig merged path")
|
||||
logger.info("Using single OpenAI server configuration based on merged settings (default/YAML/CLI).")
|
||||
logger.info(
|
||||
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
|
||||
)
|
||||
try:
|
||||
final_openai_config = APIServerConfig(**openai_config_dict)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -193,7 +193,9 @@ class VLLMServer(APIServer):
|
|||
debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
|
||||
if debug_requests:
|
||||
base = self.config.base_url.replace("/v1", "")
|
||||
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n")
|
||||
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace(
|
||||
"\n", "\\n"
|
||||
)
|
||||
print(
|
||||
f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate "
|
||||
f"prompt_token_len={len(prompt_tokens)}",
|
||||
|
|
@ -211,7 +213,7 @@ class VLLMServer(APIServer):
|
|||
)
|
||||
print(
|
||||
f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate "
|
||||
'-H "Content-Type: application/json" -d \'<JSON_PAYLOAD>\'',
|
||||
"-H \"Content-Type: application/json\" -d '<JSON_PAYLOAD>'",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -289,7 +289,9 @@ class GSM8kEnv(BaseEnv):
|
|||
}
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", []))
|
||||
accepted = (
|
||||
0 if to_postprocess is None else len(to_postprocess.get("tokens", []))
|
||||
)
|
||||
print(
|
||||
f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}",
|
||||
flush=True,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from math_verify.errors import TimeoutException
|
||||
|
|
@ -146,7 +145,7 @@ class MathEnv(BaseEnv):
|
|||
wandb_name = os.environ.get("MATH_ENV_WANDB_NAME", "math-zero-env")
|
||||
max_token_length = int(os.environ.get("MATH_ENV_MAX_TOKENS", "32000"))
|
||||
worker_timeout = float(os.environ.get("MATH_ENV_WORKER_TIMEOUT", "1500"))
|
||||
|
||||
|
||||
env_config = RSConfig(
|
||||
tokenizer_name=model_name,
|
||||
group_size=8,
|
||||
|
|
@ -524,7 +523,7 @@ class MathEnv(BaseEnv):
|
|||
and (not scores["overrides"][i].get("set_advantage_to_zero", False))
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
return scores
|
||||
|
||||
async def get_next_item(self):
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ python -m example_trainer.grpo --weight-bridge-mode lora_only ...
|
|||
|
||||
---
|
||||
|
||||
## Shared vLLM Mode
|
||||
## Shared vLLM Mode
|
||||
|
||||
Single-copy mode shares GPU memory between vLLM and the trainer - zero model duplication!
|
||||
|
||||
|
|
|
|||
|
|
@ -136,15 +136,17 @@ class TrainingConfig(BaseModel):
|
|||
wandb_group: Optional[str] = Field(None, description="Wandb group name")
|
||||
|
||||
# === Training Mode Configuration ===
|
||||
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field(
|
||||
"none",
|
||||
description=(
|
||||
"How to synchronize weights with inference server. "
|
||||
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
|
||||
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
|
||||
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
|
||||
"'none': legacy mode, restart vLLM with new checkpoint files."
|
||||
),
|
||||
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = (
|
||||
Field(
|
||||
"none",
|
||||
description=(
|
||||
"How to synchronize weights with inference server. "
|
||||
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
|
||||
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
|
||||
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
|
||||
"'none': legacy mode, restart vLLM with new checkpoint files."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# === Distributed Training Configuration ===
|
||||
|
|
|
|||
|
|
@ -258,10 +258,14 @@ def pad_data_to_good_offset(
|
|||
else None
|
||||
)
|
||||
final_distill_token_id_batches = (
|
||||
distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None
|
||||
distill_token_id_batches
|
||||
if (has_any_distill and distill_token_id_batches)
|
||||
else None
|
||||
)
|
||||
final_distill_logprob_batches = (
|
||||
distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None
|
||||
distill_logprob_batches
|
||||
if (has_any_distill and distill_logprob_batches)
|
||||
else None
|
||||
)
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -735,19 +735,23 @@ def train_lora_restart(config: TrainingConfig):
|
|||
# Periodic adapter save + vLLM restart
|
||||
sync_time = 0
|
||||
should_sync = (step + 1) % config.vllm_restart_interval == 0
|
||||
if should_sync and (step + 1) < config.training_steps: # Don't restart on last step
|
||||
if (
|
||||
should_sync and (step + 1) < config.training_steps
|
||||
): # Don't restart on last step
|
||||
sync_start = time.time()
|
||||
|
||||
|
||||
# Save new adapter
|
||||
current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1)
|
||||
|
||||
current_adapter_path = save_lora_checkpoint(
|
||||
model, config.save_path, step + 1
|
||||
)
|
||||
|
||||
# Restart vLLM with new adapter
|
||||
print(" [RESTART] Restarting vLLM with new adapter...")
|
||||
_terminate_vllm(vllm_proc, config.vllm_port)
|
||||
vllm_proc = _launch_vllm_with_lora(config, current_adapter_path)
|
||||
if vllm_proc is None:
|
||||
raise RuntimeError("Failed to restart vLLM")
|
||||
|
||||
|
||||
sync_time = time.time() - sync_start
|
||||
benchmark_stats["sync_times"].append(sync_time)
|
||||
benchmark_stats["restart_times"].append(sync_time)
|
||||
|
|
@ -798,45 +802,53 @@ def train_lora_restart(config: TrainingConfig):
|
|||
_vllm_restart_counter = 0
|
||||
|
||||
|
||||
def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]:
|
||||
def _launch_vllm_with_lora(
|
||||
config: TrainingConfig, adapter_path: str
|
||||
) -> Optional[subprocess.Popen]:
|
||||
"""
|
||||
Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference).
|
||||
|
||||
|
||||
Unlike lora_only mode, this does NOT use --enforce-eager, so we get
|
||||
~108 TPS instead of ~13 TPS (8x faster).
|
||||
"""
|
||||
global _vllm_restart_counter
|
||||
from .vllm_manager import kill_process_on_port, wait_for_vllm_ready
|
||||
|
||||
|
||||
# Kill any existing process on the port
|
||||
print(f" Cleaning up port {config.vllm_port}...")
|
||||
kill_process_on_port(config.vllm_port)
|
||||
|
||||
|
||||
# Clear CUDA cache before starting new vLLM
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
# Wait for port and GPU memory to be fully released
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
# Find the vllm_api_server.py script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
server_script = os.path.join(script_dir, "vllm_api_server.py")
|
||||
|
||||
|
||||
# Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS)
|
||||
cmd = [
|
||||
sys.executable, server_script,
|
||||
"--model", config.model_name,
|
||||
"--port", str(config.vllm_port),
|
||||
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),
|
||||
"--max-model-len", str(config.max_model_len),
|
||||
sys.executable,
|
||||
server_script,
|
||||
"--model",
|
||||
config.model_name,
|
||||
"--port",
|
||||
str(config.vllm_port),
|
||||
"--gpu-memory-utilization",
|
||||
str(config.vllm_gpu_memory_utilization),
|
||||
"--max-model-len",
|
||||
str(config.max_model_len),
|
||||
"--enable-lora",
|
||||
"--max-lora-rank", str(max(config.lora_r * 2, 32)),
|
||||
"--max-lora-rank",
|
||||
str(max(config.lora_r * 2, 32)),
|
||||
# Note: NOT adding --enforce-eager - this gives us ~8x faster inference!
|
||||
# Without --enforce-eager, vLLM can use more optimizations.
|
||||
]
|
||||
|
||||
|
||||
# Set environment for GPU selection
|
||||
env = os.environ.copy()
|
||||
if config.vllm_gpu is not None:
|
||||
|
|
@ -844,32 +856,39 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
|||
print(f" GPU: {config.vllm_gpu} (via CUDA_VISIBLE_DEVICES)")
|
||||
else:
|
||||
print(" GPU: Same as trainer (inherited CUDA_VISIBLE_DEVICES)")
|
||||
|
||||
|
||||
print(f" Launching: {' '.join(cmd)}")
|
||||
print(f" Adapter: {adapter_path}")
|
||||
|
||||
|
||||
# Log vLLM output to file for debugging (unique file per restart)
|
||||
vllm_log_path = os.path.join(config.save_path, f"vllm_restart_{_vllm_restart_counter}.log")
|
||||
vllm_log_path = os.path.join(
|
||||
config.save_path, f"vllm_restart_{_vllm_restart_counter}.log"
|
||||
)
|
||||
_vllm_restart_counter += 1
|
||||
print(f" vLLM log: {vllm_log_path}")
|
||||
|
||||
|
||||
try:
|
||||
vllm_log_file = open(vllm_log_path, "w")
|
||||
# Start in new session so we can kill entire process group later
|
||||
proc = subprocess.Popen(
|
||||
cmd, env=env, stdout=vllm_log_file, stderr=subprocess.STDOUT,
|
||||
start_new_session=True # Creates new process group for easy cleanup
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=vllm_log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Creates new process group for easy cleanup
|
||||
)
|
||||
print(f" vLLM PID: {proc.pid} (process group: {os.getpgid(proc.pid)})")
|
||||
print(" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)...")
|
||||
|
||||
print(
|
||||
" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)..."
|
||||
)
|
||||
|
||||
# Wait for server to be ready (longer timeout for CUDA graph compilation)
|
||||
if not wait_for_vllm_ready(config.vllm_port, timeout=300):
|
||||
print(" ERROR: vLLM failed to start after 300s")
|
||||
print(f" Check log: {vllm_log_path}")
|
||||
# Print last 30 lines of the log
|
||||
try:
|
||||
with open(vllm_log_path, 'r') as f:
|
||||
with open(vllm_log_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
print(" Last 30 lines of vLLM log:")
|
||||
for line in lines[-30:]:
|
||||
|
|
@ -878,7 +897,7 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
|||
print(f" Could not read log: {e}")
|
||||
proc.terminate()
|
||||
return None
|
||||
|
||||
|
||||
# Load the LoRA adapter
|
||||
print(" Loading LoRA adapter...")
|
||||
try:
|
||||
|
|
@ -890,13 +909,15 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
|||
if resp.status_code == 200:
|
||||
print(" ✓ Adapter loaded successfully")
|
||||
else:
|
||||
print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}")
|
||||
print(
|
||||
f" WARNING: Adapter load returned {resp.status_code}: {resp.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" WARNING: Could not load adapter: {e}")
|
||||
# Continue anyway - base model inference still works
|
||||
|
||||
|
||||
return proc
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
return None
|
||||
|
|
@ -906,12 +927,12 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
|||
"""Terminate a vLLM process and release GPU resources."""
|
||||
import signal
|
||||
import subprocess as sp
|
||||
|
||||
|
||||
print(f" Terminating vLLM on port {port}...")
|
||||
|
||||
|
||||
# Get current GPU device
|
||||
gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
|
||||
|
||||
|
||||
# Phase 1: Kill the process group if we have a handle (kills all children too)
|
||||
main_pid = None
|
||||
if proc is not None:
|
||||
|
|
@ -927,12 +948,13 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
|||
proc.wait(timeout=5)
|
||||
except Exception as e:
|
||||
print(f" Warning: {e}")
|
||||
|
||||
|
||||
# Phase 2: Kill by port (catches anything still running)
|
||||
from .vllm_manager import kill_process_on_port
|
||||
|
||||
kill_process_on_port(port)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
# Phase 3: Aggressively kill ALL vLLM-related processes
|
||||
print(" Killing all vLLM-related processes...")
|
||||
kill_commands = [
|
||||
|
|
@ -948,19 +970,22 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
|||
sp.run(cmd, shell=True, capture_output=True, timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Phase 4: Use nvidia-smi to find and kill GPU processes (nuclear option)
|
||||
print(f" Checking for zombie GPU processes on GPU {gpu_id}...")
|
||||
try:
|
||||
result = sp.run(
|
||||
f"nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader,nounits -i {gpu_id}",
|
||||
shell=True, capture_output=True, text=True, timeout=10
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.stdout.strip():
|
||||
print(f" Found GPU processes:\n{result.stdout}")
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if line.strip():
|
||||
parts = line.split(',')
|
||||
parts = line.split(",")
|
||||
if len(parts) >= 1:
|
||||
pid = parts[0].strip()
|
||||
# Don't kill the current Python process (trainer)
|
||||
|
|
@ -972,7 +997,7 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
|||
pass
|
||||
except Exception as e:
|
||||
print(f" Warning: nvidia-smi check failed: {e}")
|
||||
|
||||
|
||||
# Phase 5: Wait for GPU memory release - CRITICAL
|
||||
# The CUDA driver needs time to actually free memory after process death
|
||||
print(" Waiting for GPU memory release...")
|
||||
|
|
@ -982,24 +1007,26 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
|||
torch.cuda.empty_cache()
|
||||
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
||||
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
||||
print(f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)")
|
||||
print(
|
||||
f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)"
|
||||
)
|
||||
# If we have enough memory (>50% free), break early
|
||||
if free_mem > total_mem * 0.5:
|
||||
print(f" ✓ Sufficient memory available ({free_mem:.1f} GB)")
|
||||
break
|
||||
|
||||
|
||||
# Final cleanup
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
||||
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
||||
print(f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)")
|
||||
|
||||
print(
|
||||
f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)"
|
||||
)
|
||||
|
||||
if free_mem < total_mem * 0.3:
|
||||
print(" WARNING: Low GPU memory! May fail to restart vLLM.")
|
||||
print(" Consider reducing --vllm-gpu-memory-utilization")
|
||||
|
||||
|
||||
print(" vLLM terminated")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -63,9 +63,10 @@ def compute_distillation_loss(
|
|||
continue
|
||||
|
||||
ids_tensor = torch.tensor(pos_ids, device=logits.device, dtype=torch.long)
|
||||
teacher_lps = torch.tensor(
|
||||
pos_lps, device=logits.device, dtype=logits.dtype
|
||||
) / temperature
|
||||
teacher_lps = (
|
||||
torch.tensor(pos_lps, device=logits.device, dtype=logits.dtype)
|
||||
/ temperature
|
||||
)
|
||||
|
||||
student_log_probs = F.log_softmax(logits[b, t] / temperature, dim=-1)
|
||||
student_subset = student_log_probs[ids_tensor]
|
||||
|
|
@ -75,7 +76,9 @@ def compute_distillation_loss(
|
|||
token_loss = -(teacher_probs * student_subset).sum()
|
||||
else:
|
||||
teacher_log_probs = F.log_softmax(teacher_lps, dim=-1)
|
||||
token_loss = (teacher_probs * (teacher_log_probs - student_subset)).sum()
|
||||
token_loss = (
|
||||
teacher_probs * (teacher_log_probs - student_subset)
|
||||
).sum()
|
||||
|
||||
total = total + token_loss
|
||||
count = count + 1.0
|
||||
|
|
@ -323,7 +326,11 @@ def compute_grpo_loss(
|
|||
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
|
||||
|
||||
distill_loss_val = 0.0
|
||||
if distillation_enabled and distill_token_ids is not None and distill_logprobs is not None:
|
||||
if (
|
||||
distillation_enabled
|
||||
and distill_token_ids is not None
|
||||
and distill_logprobs is not None
|
||||
):
|
||||
distill_loss = compute_distillation_loss(
|
||||
logits=scaled_logits,
|
||||
mask=mask,
|
||||
|
|
@ -332,7 +339,10 @@ def compute_grpo_loss(
|
|||
temperature=distillation_temperature,
|
||||
loss_type=distillation_loss_type,
|
||||
)
|
||||
total_loss = total_loss + (distillation_coef * distill_loss) / gradient_accumulation_steps
|
||||
total_loss = (
|
||||
total_loss
|
||||
+ (distillation_coef * distill_loss) / gradient_accumulation_steps
|
||||
)
|
||||
distill_loss_val = distill_loss.item()
|
||||
|
||||
metrics = {
|
||||
|
|
@ -437,9 +447,13 @@ def run_training_step(
|
|||
inf_logprobs = inference_logprob_batches[batch_idx]
|
||||
distill_ids = None
|
||||
distill_lps = None
|
||||
if distill_token_id_batches is not None and batch_idx < len(distill_token_id_batches):
|
||||
if distill_token_id_batches is not None and batch_idx < len(
|
||||
distill_token_id_batches
|
||||
):
|
||||
distill_ids = distill_token_id_batches[batch_idx]
|
||||
if distill_logprob_batches is not None and batch_idx < len(distill_logprob_batches):
|
||||
if distill_logprob_batches is not None and batch_idx < len(
|
||||
distill_logprob_batches
|
||||
):
|
||||
distill_lps = distill_logprob_batches[batch_idx]
|
||||
|
||||
loss, metrics = compute_grpo_loss(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue