mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ccdd5a1ca6
commit
60fb6cae11
11 changed files with 221 additions and 136 deletions
|
|
@ -374,7 +374,9 @@ class BaseEnv(ABC):
|
||||||
[batch][position][top_k].
|
[batch][position][top_k].
|
||||||
Returns ([], []) if teacher_base_url is not configured.
|
Returns ([], []) if teacher_base_url is not configured.
|
||||||
"""
|
"""
|
||||||
logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences")
|
logger.info(
|
||||||
|
f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences"
|
||||||
|
)
|
||||||
logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}")
|
logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}")
|
||||||
|
|
||||||
if not self.config.teacher_base_url:
|
if not self.config.teacher_base_url:
|
||||||
|
|
@ -400,7 +402,9 @@ class BaseEnv(ABC):
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
for i, tokens in enumerate(token_sequences):
|
for i, tokens in enumerate(token_sequences):
|
||||||
logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens")
|
logger.info(
|
||||||
|
f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens"
|
||||||
|
)
|
||||||
# Decode original sequence and optionally prepend teacher steering text.
|
# Decode original sequence and optionally prepend teacher steering text.
|
||||||
base_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
|
base_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
|
||||||
steering_prefix = ""
|
steering_prefix = ""
|
||||||
|
|
@ -413,7 +417,11 @@ class BaseEnv(ABC):
|
||||||
steering_prefix += self.config.teacher_prefix_text
|
steering_prefix += self.config.teacher_prefix_text
|
||||||
full_text = steering_prefix + base_text
|
full_text = steering_prefix + base_text
|
||||||
prefix_token_len = (
|
prefix_token_len = (
|
||||||
len(self.tokenizer.encode(steering_prefix, add_special_tokens=False))
|
len(
|
||||||
|
self.tokenizer.encode(
|
||||||
|
steering_prefix, add_special_tokens=False
|
||||||
|
)
|
||||||
|
)
|
||||||
if steering_prefix
|
if steering_prefix
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
|
|
@ -438,15 +446,17 @@ class BaseEnv(ABC):
|
||||||
) as response:
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
seq_token_ids, seq_logprobs = self._parse_completion_logprobs(
|
seq_token_ids, seq_logprobs = (
|
||||||
data, top_k
|
self._parse_completion_logprobs(data, top_k)
|
||||||
)
|
)
|
||||||
if seq_token_ids and seq_logprobs:
|
if seq_token_ids and seq_logprobs:
|
||||||
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens(
|
aligned_ids, aligned_lps = (
|
||||||
seq_token_ids,
|
self._align_teacher_topk_to_tokens(
|
||||||
seq_logprobs,
|
seq_token_ids,
|
||||||
target_token_len=len(tokens),
|
seq_logprobs,
|
||||||
prefix_token_len=prefix_token_len,
|
target_token_len=len(tokens),
|
||||||
|
prefix_token_len=prefix_token_len,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
token_id_results.append(aligned_ids)
|
token_id_results.append(aligned_ids)
|
||||||
logprob_results.append(aligned_lps)
|
logprob_results.append(aligned_lps)
|
||||||
|
|
@ -501,11 +511,13 @@ class BaseEnv(ABC):
|
||||||
# Chat fallback logprobs are for generated tokens, not prompt tokens.
|
# Chat fallback logprobs are for generated tokens, not prompt tokens.
|
||||||
# To keep alignment correct for distillation, return empty per-position rows.
|
# To keep alignment correct for distillation, return empty per-position rows.
|
||||||
if seq_token_ids and len(seq_token_ids) >= len(tokens):
|
if seq_token_ids and len(seq_token_ids) >= len(tokens):
|
||||||
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens(
|
aligned_ids, aligned_lps = (
|
||||||
seq_token_ids,
|
self._align_teacher_topk_to_tokens(
|
||||||
seq_logprobs,
|
seq_token_ids,
|
||||||
target_token_len=len(tokens),
|
seq_logprobs,
|
||||||
prefix_token_len=0,
|
target_token_len=len(tokens),
|
||||||
|
prefix_token_len=0,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
aligned_ids = [[] for _ in range(len(tokens))]
|
aligned_ids = [[] for _ in range(len(tokens))]
|
||||||
|
|
@ -513,8 +525,12 @@ class BaseEnv(ABC):
|
||||||
token_id_results.append(aligned_ids)
|
token_id_results.append(aligned_ids)
|
||||||
logprob_results.append(aligned_lps)
|
logprob_results.append(aligned_lps)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Teacher API returned {response.status}")
|
logger.warning(
|
||||||
token_id_results.append([[] for _ in range(len(tokens))])
|
f"Teacher API returned {response.status}"
|
||||||
|
)
|
||||||
|
token_id_results.append(
|
||||||
|
[[] for _ in range(len(tokens))]
|
||||||
|
)
|
||||||
logprob_results.append([[] for _ in range(len(tokens))])
|
logprob_results.append([[] for _ in range(len(tokens))])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Teacher chat request failed: {e}")
|
logger.warning(f"Teacher chat request failed: {e}")
|
||||||
|
|
@ -580,15 +596,15 @@ class BaseEnv(ABC):
|
||||||
elif isinstance(pos_logprobs, dict):
|
elif isinstance(pos_logprobs, dict):
|
||||||
# Format: {token_str: logprob, ...}
|
# Format: {token_str: logprob, ...}
|
||||||
sorted_items = sorted(
|
sorted_items = sorted(
|
||||||
pos_logprobs.items(),
|
pos_logprobs.items(), key=lambda x: x[1], reverse=True
|
||||||
key=lambda x: x[1],
|
|
||||||
reverse=True
|
|
||||||
)[:top_k]
|
)[:top_k]
|
||||||
pos_ids: List[int] = []
|
pos_ids: List[int] = []
|
||||||
pos_lps: List[float] = []
|
pos_lps: List[float] = []
|
||||||
for token_str, logprob in sorted_items:
|
for token_str, logprob in sorted_items:
|
||||||
# Convert token string to ID
|
# Convert token string to ID
|
||||||
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
token_ids = self.tokenizer.encode(
|
||||||
|
token_str, add_special_tokens=False
|
||||||
|
)
|
||||||
if token_ids:
|
if token_ids:
|
||||||
pos_ids.append(int(token_ids[0]))
|
pos_ids.append(int(token_ids[0]))
|
||||||
pos_lps.append(float(logprob))
|
pos_lps.append(float(logprob))
|
||||||
|
|
@ -626,7 +642,9 @@ class BaseEnv(ABC):
|
||||||
token_str = item.get("token", "")
|
token_str = item.get("token", "")
|
||||||
logprob = item.get("logprob", 0.0)
|
logprob = item.get("logprob", 0.0)
|
||||||
# Convert token string to ID
|
# Convert token string to ID
|
||||||
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
token_ids = self.tokenizer.encode(
|
||||||
|
token_str, add_special_tokens=False
|
||||||
|
)
|
||||||
if token_ids:
|
if token_ids:
|
||||||
pos_ids.append(int(token_ids[0]))
|
pos_ids.append(int(token_ids[0]))
|
||||||
pos_lps.append(float(logprob))
|
pos_lps.append(float(logprob))
|
||||||
|
|
@ -1251,7 +1269,9 @@ class BaseEnv(ABC):
|
||||||
if valid_groups and do_send_to_api:
|
if valid_groups and do_send_to_api:
|
||||||
# On-policy distillation: fetch teacher logprobs if enabled
|
# On-policy distillation: fetch teacher logprobs if enabled
|
||||||
if self.config.distillation_enabled and self.config.teacher_base_url:
|
if self.config.distillation_enabled and self.config.teacher_base_url:
|
||||||
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups")
|
logger.info(
|
||||||
|
f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups"
|
||||||
|
)
|
||||||
for group in valid_groups:
|
for group in valid_groups:
|
||||||
has_new_format = (
|
has_new_format = (
|
||||||
group.get("distill_token_ids") is not None
|
group.get("distill_token_ids") is not None
|
||||||
|
|
@ -1259,9 +1279,11 @@ class BaseEnv(ABC):
|
||||||
)
|
)
|
||||||
if not has_new_format:
|
if not has_new_format:
|
||||||
try:
|
try:
|
||||||
teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs(
|
teacher_token_ids, teacher_logprobs = (
|
||||||
token_sequences=group["tokens"],
|
await self.get_teacher_logprobs(
|
||||||
messages_list=group.get("messages"),
|
token_sequences=group["tokens"],
|
||||||
|
messages_list=group.get("messages"),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if teacher_token_ids and teacher_logprobs:
|
if teacher_token_ids and teacher_logprobs:
|
||||||
group["distill_token_ids"] = teacher_token_ids
|
group["distill_token_ids"] = teacher_token_ids
|
||||||
|
|
@ -1270,10 +1292,15 @@ class BaseEnv(ABC):
|
||||||
f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences"
|
f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("[DISTILL] get_teacher_logprobs returned empty")
|
logger.warning(
|
||||||
|
"[DISTILL] get_teacher_logprobs returned empty"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}")
|
logger.error(
|
||||||
|
f"[DISTILL] Failed to fetch teacher logprobs: {e}"
|
||||||
|
)
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -1822,13 +1849,17 @@ class BaseEnv(ABC):
|
||||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||||
yaml_oai_config, dict
|
yaml_oai_config, dict
|
||||||
):
|
):
|
||||||
print(f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}")
|
print(
|
||||||
|
f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}"
|
||||||
|
)
|
||||||
openai_config_dict = merge_dicts(
|
openai_config_dict = merge_dicts(
|
||||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||||
yaml_oai_config,
|
yaml_oai_config,
|
||||||
oai_cli_passed_args,
|
oai_cli_passed_args,
|
||||||
)
|
)
|
||||||
print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}")
|
print(
|
||||||
|
f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
"[CLI DEBUG] Not merging: default_openai_config_ "
|
"[CLI DEBUG] Not merging: default_openai_config_ "
|
||||||
|
|
|
||||||
|
|
@ -165,7 +165,9 @@ def resolve_openai_configs(
|
||||||
"""
|
"""
|
||||||
from atroposlib.envs.server_handling.server_manager import ServerBaseline
|
from atroposlib.envs.server_handling.server_manager import ServerBaseline
|
||||||
|
|
||||||
print(f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}")
|
print(
|
||||||
|
f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}"
|
||||||
|
)
|
||||||
print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}")
|
print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}")
|
||||||
|
|
||||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||||
|
|
@ -216,7 +218,9 @@ def resolve_openai_configs(
|
||||||
elif isinstance(default_server_configs, APIServerConfig):
|
elif isinstance(default_server_configs, APIServerConfig):
|
||||||
# Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline
|
# Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline
|
||||||
print("[RESOLVE DEBUG] Taking APIServerConfig merged path")
|
print("[RESOLVE DEBUG] Taking APIServerConfig merged path")
|
||||||
logger.info("Using single OpenAI server configuration based on merged settings (default/YAML/CLI).")
|
logger.info(
|
||||||
|
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
final_openai_config = APIServerConfig(**openai_config_dict)
|
final_openai_config = APIServerConfig(**openai_config_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -193,7 +193,9 @@ class VLLMServer(APIServer):
|
||||||
debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
|
debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
|
||||||
if debug_requests:
|
if debug_requests:
|
||||||
base = self.config.base_url.replace("/v1", "")
|
base = self.config.base_url.replace("/v1", "")
|
||||||
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n")
|
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace(
|
||||||
|
"\n", "\\n"
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate "
|
f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate "
|
||||||
f"prompt_token_len={len(prompt_tokens)}",
|
f"prompt_token_len={len(prompt_tokens)}",
|
||||||
|
|
@ -211,7 +213,7 @@ class VLLMServer(APIServer):
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate "
|
f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate "
|
||||||
'-H "Content-Type: application/json" -d \'<JSON_PAYLOAD>\'',
|
"-H \"Content-Type: application/json\" -d '<JSON_PAYLOAD>'",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -289,7 +289,9 @@ class GSM8kEnv(BaseEnv):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
to_postprocess = await self.score(to_score)
|
to_postprocess = await self.score(to_score)
|
||||||
accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", []))
|
accepted = (
|
||||||
|
0 if to_postprocess is None else len(to_postprocess.get("tokens", []))
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}",
|
f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}",
|
||||||
flush=True,
|
flush=True,
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
from latex2sympy2_extended import NormalizationConfig
|
||||||
from math_verify import LatexExtractionConfig, parse, verify
|
from math_verify import LatexExtractionConfig, parse, verify
|
||||||
from math_verify.errors import TimeoutException
|
from math_verify.errors import TimeoutException
|
||||||
|
|
|
||||||
|
|
@ -136,15 +136,17 @@ class TrainingConfig(BaseModel):
|
||||||
wandb_group: Optional[str] = Field(None, description="Wandb group name")
|
wandb_group: Optional[str] = Field(None, description="Wandb group name")
|
||||||
|
|
||||||
# === Training Mode Configuration ===
|
# === Training Mode Configuration ===
|
||||||
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field(
|
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = (
|
||||||
"none",
|
Field(
|
||||||
description=(
|
"none",
|
||||||
"How to synchronize weights with inference server. "
|
description=(
|
||||||
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
|
"How to synchronize weights with inference server. "
|
||||||
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
|
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
|
||||||
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
|
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
|
||||||
"'none': legacy mode, restart vLLM with new checkpoint files."
|
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
|
||||||
),
|
"'none': legacy mode, restart vLLM with new checkpoint files."
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# === Distributed Training Configuration ===
|
# === Distributed Training Configuration ===
|
||||||
|
|
|
||||||
|
|
@ -258,10 +258,14 @@ def pad_data_to_good_offset(
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
final_distill_token_id_batches = (
|
final_distill_token_id_batches = (
|
||||||
distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None
|
distill_token_id_batches
|
||||||
|
if (has_any_distill and distill_token_id_batches)
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
final_distill_logprob_batches = (
|
final_distill_logprob_batches = (
|
||||||
distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None
|
distill_logprob_batches
|
||||||
|
if (has_any_distill and distill_logprob_batches)
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
|
||||||
|
|
@ -735,11 +735,15 @@ def train_lora_restart(config: TrainingConfig):
|
||||||
# Periodic adapter save + vLLM restart
|
# Periodic adapter save + vLLM restart
|
||||||
sync_time = 0
|
sync_time = 0
|
||||||
should_sync = (step + 1) % config.vllm_restart_interval == 0
|
should_sync = (step + 1) % config.vllm_restart_interval == 0
|
||||||
if should_sync and (step + 1) < config.training_steps: # Don't restart on last step
|
if (
|
||||||
|
should_sync and (step + 1) < config.training_steps
|
||||||
|
): # Don't restart on last step
|
||||||
sync_start = time.time()
|
sync_start = time.time()
|
||||||
|
|
||||||
# Save new adapter
|
# Save new adapter
|
||||||
current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1)
|
current_adapter_path = save_lora_checkpoint(
|
||||||
|
model, config.save_path, step + 1
|
||||||
|
)
|
||||||
|
|
||||||
# Restart vLLM with new adapter
|
# Restart vLLM with new adapter
|
||||||
print(" [RESTART] Restarting vLLM with new adapter...")
|
print(" [RESTART] Restarting vLLM with new adapter...")
|
||||||
|
|
@ -798,7 +802,9 @@ def train_lora_restart(config: TrainingConfig):
|
||||||
_vllm_restart_counter = 0
|
_vllm_restart_counter = 0
|
||||||
|
|
||||||
|
|
||||||
def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]:
|
def _launch_vllm_with_lora(
|
||||||
|
config: TrainingConfig, adapter_path: str
|
||||||
|
) -> Optional[subprocess.Popen]:
|
||||||
"""
|
"""
|
||||||
Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference).
|
Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference).
|
||||||
|
|
||||||
|
|
@ -826,13 +832,19 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
||||||
|
|
||||||
# Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS)
|
# Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS)
|
||||||
cmd = [
|
cmd = [
|
||||||
sys.executable, server_script,
|
sys.executable,
|
||||||
"--model", config.model_name,
|
server_script,
|
||||||
"--port", str(config.vllm_port),
|
"--model",
|
||||||
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),
|
config.model_name,
|
||||||
"--max-model-len", str(config.max_model_len),
|
"--port",
|
||||||
|
str(config.vllm_port),
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
str(config.vllm_gpu_memory_utilization),
|
||||||
|
"--max-model-len",
|
||||||
|
str(config.max_model_len),
|
||||||
"--enable-lora",
|
"--enable-lora",
|
||||||
"--max-lora-rank", str(max(config.lora_r * 2, 32)),
|
"--max-lora-rank",
|
||||||
|
str(max(config.lora_r * 2, 32)),
|
||||||
# Note: NOT adding --enforce-eager - this gives us ~8x faster inference!
|
# Note: NOT adding --enforce-eager - this gives us ~8x faster inference!
|
||||||
# Without --enforce-eager, vLLM can use more optimizations.
|
# Without --enforce-eager, vLLM can use more optimizations.
|
||||||
]
|
]
|
||||||
|
|
@ -849,7 +861,9 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
||||||
print(f" Adapter: {adapter_path}")
|
print(f" Adapter: {adapter_path}")
|
||||||
|
|
||||||
# Log vLLM output to file for debugging (unique file per restart)
|
# Log vLLM output to file for debugging (unique file per restart)
|
||||||
vllm_log_path = os.path.join(config.save_path, f"vllm_restart_{_vllm_restart_counter}.log")
|
vllm_log_path = os.path.join(
|
||||||
|
config.save_path, f"vllm_restart_{_vllm_restart_counter}.log"
|
||||||
|
)
|
||||||
_vllm_restart_counter += 1
|
_vllm_restart_counter += 1
|
||||||
print(f" vLLM log: {vllm_log_path}")
|
print(f" vLLM log: {vllm_log_path}")
|
||||||
|
|
||||||
|
|
@ -857,11 +871,16 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
||||||
vllm_log_file = open(vllm_log_path, "w")
|
vllm_log_file = open(vllm_log_path, "w")
|
||||||
# Start in new session so we can kill entire process group later
|
# Start in new session so we can kill entire process group later
|
||||||
proc = subprocess.Popen(
|
proc = subprocess.Popen(
|
||||||
cmd, env=env, stdout=vllm_log_file, stderr=subprocess.STDOUT,
|
cmd,
|
||||||
start_new_session=True # Creates new process group for easy cleanup
|
env=env,
|
||||||
|
stdout=vllm_log_file,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
start_new_session=True, # Creates new process group for easy cleanup
|
||||||
)
|
)
|
||||||
print(f" vLLM PID: {proc.pid} (process group: {os.getpgid(proc.pid)})")
|
print(f" vLLM PID: {proc.pid} (process group: {os.getpgid(proc.pid)})")
|
||||||
print(" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)...")
|
print(
|
||||||
|
" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)..."
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for server to be ready (longer timeout for CUDA graph compilation)
|
# Wait for server to be ready (longer timeout for CUDA graph compilation)
|
||||||
if not wait_for_vllm_ready(config.vllm_port, timeout=300):
|
if not wait_for_vllm_ready(config.vllm_port, timeout=300):
|
||||||
|
|
@ -869,7 +888,7 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
||||||
print(f" Check log: {vllm_log_path}")
|
print(f" Check log: {vllm_log_path}")
|
||||||
# Print last 30 lines of the log
|
# Print last 30 lines of the log
|
||||||
try:
|
try:
|
||||||
with open(vllm_log_path, 'r') as f:
|
with open(vllm_log_path, "r") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
print(" Last 30 lines of vLLM log:")
|
print(" Last 30 lines of vLLM log:")
|
||||||
for line in lines[-30:]:
|
for line in lines[-30:]:
|
||||||
|
|
@ -890,7 +909,9 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
print(" ✓ Adapter loaded successfully")
|
print(" ✓ Adapter loaded successfully")
|
||||||
else:
|
else:
|
||||||
print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}")
|
print(
|
||||||
|
f" WARNING: Adapter load returned {resp.status_code}: {resp.text}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" WARNING: Could not load adapter: {e}")
|
print(f" WARNING: Could not load adapter: {e}")
|
||||||
# Continue anyway - base model inference still works
|
# Continue anyway - base model inference still works
|
||||||
|
|
@ -930,6 +951,7 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
||||||
|
|
||||||
# Phase 2: Kill by port (catches anything still running)
|
# Phase 2: Kill by port (catches anything still running)
|
||||||
from .vllm_manager import kill_process_on_port
|
from .vllm_manager import kill_process_on_port
|
||||||
|
|
||||||
kill_process_on_port(port)
|
kill_process_on_port(port)
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
|
|
@ -954,13 +976,16 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
||||||
try:
|
try:
|
||||||
result = sp.run(
|
result = sp.run(
|
||||||
f"nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader,nounits -i {gpu_id}",
|
f"nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader,nounits -i {gpu_id}",
|
||||||
shell=True, capture_output=True, text=True, timeout=10
|
shell=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=10,
|
||||||
)
|
)
|
||||||
if result.stdout.strip():
|
if result.stdout.strip():
|
||||||
print(f" Found GPU processes:\n{result.stdout}")
|
print(f" Found GPU processes:\n{result.stdout}")
|
||||||
for line in result.stdout.strip().split('\n'):
|
for line in result.stdout.strip().split("\n"):
|
||||||
if line.strip():
|
if line.strip():
|
||||||
parts = line.split(',')
|
parts = line.split(",")
|
||||||
if len(parts) >= 1:
|
if len(parts) >= 1:
|
||||||
pid = parts[0].strip()
|
pid = parts[0].strip()
|
||||||
# Don't kill the current Python process (trainer)
|
# Don't kill the current Python process (trainer)
|
||||||
|
|
@ -982,7 +1007,9 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
||||||
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
||||||
print(f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)")
|
print(
|
||||||
|
f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)"
|
||||||
|
)
|
||||||
# If we have enough memory (>50% free), break early
|
# If we have enough memory (>50% free), break early
|
||||||
if free_mem > total_mem * 0.5:
|
if free_mem > total_mem * 0.5:
|
||||||
print(f" ✓ Sufficient memory available ({free_mem:.1f} GB)")
|
print(f" ✓ Sufficient memory available ({free_mem:.1f} GB)")
|
||||||
|
|
@ -994,12 +1021,12 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
||||||
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
||||||
print(f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)")
|
print(
|
||||||
|
f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
if free_mem < total_mem * 0.3:
|
if free_mem < total_mem * 0.3:
|
||||||
print(" WARNING: Low GPU memory! May fail to restart vLLM.")
|
print(" WARNING: Low GPU memory! May fail to restart vLLM.")
|
||||||
print(" Consider reducing --vllm-gpu-memory-utilization")
|
print(" Consider reducing --vllm-gpu-memory-utilization")
|
||||||
|
|
||||||
print(" vLLM terminated")
|
print(" vLLM terminated")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -63,9 +63,10 @@ def compute_distillation_loss(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ids_tensor = torch.tensor(pos_ids, device=logits.device, dtype=torch.long)
|
ids_tensor = torch.tensor(pos_ids, device=logits.device, dtype=torch.long)
|
||||||
teacher_lps = torch.tensor(
|
teacher_lps = (
|
||||||
pos_lps, device=logits.device, dtype=logits.dtype
|
torch.tensor(pos_lps, device=logits.device, dtype=logits.dtype)
|
||||||
) / temperature
|
/ temperature
|
||||||
|
)
|
||||||
|
|
||||||
student_log_probs = F.log_softmax(logits[b, t] / temperature, dim=-1)
|
student_log_probs = F.log_softmax(logits[b, t] / temperature, dim=-1)
|
||||||
student_subset = student_log_probs[ids_tensor]
|
student_subset = student_log_probs[ids_tensor]
|
||||||
|
|
@ -75,7 +76,9 @@ def compute_distillation_loss(
|
||||||
token_loss = -(teacher_probs * student_subset).sum()
|
token_loss = -(teacher_probs * student_subset).sum()
|
||||||
else:
|
else:
|
||||||
teacher_log_probs = F.log_softmax(teacher_lps, dim=-1)
|
teacher_log_probs = F.log_softmax(teacher_lps, dim=-1)
|
||||||
token_loss = (teacher_probs * (teacher_log_probs - student_subset)).sum()
|
token_loss = (
|
||||||
|
teacher_probs * (teacher_log_probs - student_subset)
|
||||||
|
).sum()
|
||||||
|
|
||||||
total = total + token_loss
|
total = total + token_loss
|
||||||
count = count + 1.0
|
count = count + 1.0
|
||||||
|
|
@ -323,7 +326,11 @@ def compute_grpo_loss(
|
||||||
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
|
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
|
||||||
|
|
||||||
distill_loss_val = 0.0
|
distill_loss_val = 0.0
|
||||||
if distillation_enabled and distill_token_ids is not None and distill_logprobs is not None:
|
if (
|
||||||
|
distillation_enabled
|
||||||
|
and distill_token_ids is not None
|
||||||
|
and distill_logprobs is not None
|
||||||
|
):
|
||||||
distill_loss = compute_distillation_loss(
|
distill_loss = compute_distillation_loss(
|
||||||
logits=scaled_logits,
|
logits=scaled_logits,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
|
|
@ -332,7 +339,10 @@ def compute_grpo_loss(
|
||||||
temperature=distillation_temperature,
|
temperature=distillation_temperature,
|
||||||
loss_type=distillation_loss_type,
|
loss_type=distillation_loss_type,
|
||||||
)
|
)
|
||||||
total_loss = total_loss + (distillation_coef * distill_loss) / gradient_accumulation_steps
|
total_loss = (
|
||||||
|
total_loss
|
||||||
|
+ (distillation_coef * distill_loss) / gradient_accumulation_steps
|
||||||
|
)
|
||||||
distill_loss_val = distill_loss.item()
|
distill_loss_val = distill_loss.item()
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
|
|
@ -437,9 +447,13 @@ def run_training_step(
|
||||||
inf_logprobs = inference_logprob_batches[batch_idx]
|
inf_logprobs = inference_logprob_batches[batch_idx]
|
||||||
distill_ids = None
|
distill_ids = None
|
||||||
distill_lps = None
|
distill_lps = None
|
||||||
if distill_token_id_batches is not None and batch_idx < len(distill_token_id_batches):
|
if distill_token_id_batches is not None and batch_idx < len(
|
||||||
|
distill_token_id_batches
|
||||||
|
):
|
||||||
distill_ids = distill_token_id_batches[batch_idx]
|
distill_ids = distill_token_id_batches[batch_idx]
|
||||||
if distill_logprob_batches is not None and batch_idx < len(distill_logprob_batches):
|
if distill_logprob_batches is not None and batch_idx < len(
|
||||||
|
distill_logprob_batches
|
||||||
|
):
|
||||||
distill_lps = distill_logprob_batches[batch_idx]
|
distill_lps = distill_logprob_batches[batch_idx]
|
||||||
|
|
||||||
loss, metrics = compute_grpo_loss(
|
loss, metrics = compute_grpo_loss(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue