[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-20 04:58:43 +00:00
parent ccdd5a1ca6
commit 60fb6cae11
11 changed files with 221 additions and 136 deletions

View file

@ -374,7 +374,9 @@ class BaseEnv(ABC):
[batch][position][top_k]. [batch][position][top_k].
Returns ([], []) if teacher_base_url is not configured. Returns ([], []) if teacher_base_url is not configured.
""" """
logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences") logger.info(
f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences"
)
logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}") logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}")
if not self.config.teacher_base_url: if not self.config.teacher_base_url:
@ -400,7 +402,9 @@ class BaseEnv(ABC):
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
for i, tokens in enumerate(token_sequences): for i, tokens in enumerate(token_sequences):
logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens") logger.info(
f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens"
)
# Decode original sequence and optionally prepend teacher steering text. # Decode original sequence and optionally prepend teacher steering text.
base_text = self.tokenizer.decode(tokens, skip_special_tokens=False) base_text = self.tokenizer.decode(tokens, skip_special_tokens=False)
steering_prefix = "" steering_prefix = ""
@ -413,7 +417,11 @@ class BaseEnv(ABC):
steering_prefix += self.config.teacher_prefix_text steering_prefix += self.config.teacher_prefix_text
full_text = steering_prefix + base_text full_text = steering_prefix + base_text
prefix_token_len = ( prefix_token_len = (
len(self.tokenizer.encode(steering_prefix, add_special_tokens=False)) len(
self.tokenizer.encode(
steering_prefix, add_special_tokens=False
)
)
if steering_prefix if steering_prefix
else 0 else 0
) )
@ -438,15 +446,17 @@ class BaseEnv(ABC):
) as response: ) as response:
if response.status == 200: if response.status == 200:
data = await response.json() data = await response.json()
seq_token_ids, seq_logprobs = self._parse_completion_logprobs( seq_token_ids, seq_logprobs = (
data, top_k self._parse_completion_logprobs(data, top_k)
) )
if seq_token_ids and seq_logprobs: if seq_token_ids and seq_logprobs:
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( aligned_ids, aligned_lps = (
seq_token_ids, self._align_teacher_topk_to_tokens(
seq_logprobs, seq_token_ids,
target_token_len=len(tokens), seq_logprobs,
prefix_token_len=prefix_token_len, target_token_len=len(tokens),
prefix_token_len=prefix_token_len,
)
) )
token_id_results.append(aligned_ids) token_id_results.append(aligned_ids)
logprob_results.append(aligned_lps) logprob_results.append(aligned_lps)
@ -501,11 +511,13 @@ class BaseEnv(ABC):
# Chat fallback logprobs are for generated tokens, not prompt tokens. # Chat fallback logprobs are for generated tokens, not prompt tokens.
# To keep alignment correct for distillation, return empty per-position rows. # To keep alignment correct for distillation, return empty per-position rows.
if seq_token_ids and len(seq_token_ids) >= len(tokens): if seq_token_ids and len(seq_token_ids) >= len(tokens):
aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( aligned_ids, aligned_lps = (
seq_token_ids, self._align_teacher_topk_to_tokens(
seq_logprobs, seq_token_ids,
target_token_len=len(tokens), seq_logprobs,
prefix_token_len=0, target_token_len=len(tokens),
prefix_token_len=0,
)
) )
else: else:
aligned_ids = [[] for _ in range(len(tokens))] aligned_ids = [[] for _ in range(len(tokens))]
@ -513,8 +525,12 @@ class BaseEnv(ABC):
token_id_results.append(aligned_ids) token_id_results.append(aligned_ids)
logprob_results.append(aligned_lps) logprob_results.append(aligned_lps)
else: else:
logger.warning(f"Teacher API returned {response.status}") logger.warning(
token_id_results.append([[] for _ in range(len(tokens))]) f"Teacher API returned {response.status}"
)
token_id_results.append(
[[] for _ in range(len(tokens))]
)
logprob_results.append([[] for _ in range(len(tokens))]) logprob_results.append([[] for _ in range(len(tokens))])
except Exception as e: except Exception as e:
logger.warning(f"Teacher chat request failed: {e}") logger.warning(f"Teacher chat request failed: {e}")
@ -580,15 +596,15 @@ class BaseEnv(ABC):
elif isinstance(pos_logprobs, dict): elif isinstance(pos_logprobs, dict):
# Format: {token_str: logprob, ...} # Format: {token_str: logprob, ...}
sorted_items = sorted( sorted_items = sorted(
pos_logprobs.items(), pos_logprobs.items(), key=lambda x: x[1], reverse=True
key=lambda x: x[1],
reverse=True
)[:top_k] )[:top_k]
pos_ids: List[int] = [] pos_ids: List[int] = []
pos_lps: List[float] = [] pos_lps: List[float] = []
for token_str, logprob in sorted_items: for token_str, logprob in sorted_items:
# Convert token string to ID # Convert token string to ID
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) token_ids = self.tokenizer.encode(
token_str, add_special_tokens=False
)
if token_ids: if token_ids:
pos_ids.append(int(token_ids[0])) pos_ids.append(int(token_ids[0]))
pos_lps.append(float(logprob)) pos_lps.append(float(logprob))
@ -626,7 +642,9 @@ class BaseEnv(ABC):
token_str = item.get("token", "") token_str = item.get("token", "")
logprob = item.get("logprob", 0.0) logprob = item.get("logprob", 0.0)
# Convert token string to ID # Convert token string to ID
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) token_ids = self.tokenizer.encode(
token_str, add_special_tokens=False
)
if token_ids: if token_ids:
pos_ids.append(int(token_ids[0])) pos_ids.append(int(token_ids[0]))
pos_lps.append(float(logprob)) pos_lps.append(float(logprob))
@ -1251,7 +1269,9 @@ class BaseEnv(ABC):
if valid_groups and do_send_to_api: if valid_groups and do_send_to_api:
# On-policy distillation: fetch teacher logprobs if enabled # On-policy distillation: fetch teacher logprobs if enabled
if self.config.distillation_enabled and self.config.teacher_base_url: if self.config.distillation_enabled and self.config.teacher_base_url:
logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups") logger.info(
f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups"
)
for group in valid_groups: for group in valid_groups:
has_new_format = ( has_new_format = (
group.get("distill_token_ids") is not None group.get("distill_token_ids") is not None
@ -1259,9 +1279,11 @@ class BaseEnv(ABC):
) )
if not has_new_format: if not has_new_format:
try: try:
teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs( teacher_token_ids, teacher_logprobs = (
token_sequences=group["tokens"], await self.get_teacher_logprobs(
messages_list=group.get("messages"), token_sequences=group["tokens"],
messages_list=group.get("messages"),
)
) )
if teacher_token_ids and teacher_logprobs: if teacher_token_ids and teacher_logprobs:
group["distill_token_ids"] = teacher_token_ids group["distill_token_ids"] = teacher_token_ids
@ -1270,10 +1292,15 @@ class BaseEnv(ABC):
f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences" f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences"
) )
else: else:
logger.warning("[DISTILL] get_teacher_logprobs returned empty") logger.warning(
"[DISTILL] get_teacher_logprobs returned empty"
)
except Exception as e: except Exception as e:
logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}") logger.error(
f"[DISTILL] Failed to fetch teacher logprobs: {e}"
)
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
else: else:
logger.debug( logger.debug(
@ -1822,13 +1849,17 @@ class BaseEnv(ABC):
if isinstance(default_openai_config_, APIServerConfig) and isinstance( if isinstance(default_openai_config_, APIServerConfig) and isinstance(
yaml_oai_config, dict yaml_oai_config, dict
): ):
print(f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}") print(
f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}"
)
openai_config_dict = merge_dicts( openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
yaml_oai_config, yaml_oai_config,
oai_cli_passed_args, oai_cli_passed_args,
) )
print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}") print(
f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}"
)
else: else:
print( print(
"[CLI DEBUG] Not merging: default_openai_config_ " "[CLI DEBUG] Not merging: default_openai_config_ "

View file

@ -165,7 +165,9 @@ def resolve_openai_configs(
""" """
from atroposlib.envs.server_handling.server_manager import ServerBaseline from atroposlib.envs.server_handling.server_manager import ServerBaseline
print(f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}") print(
f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}"
)
print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}") print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}")
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
@ -216,7 +218,9 @@ def resolve_openai_configs(
elif isinstance(default_server_configs, APIServerConfig): elif isinstance(default_server_configs, APIServerConfig):
# Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline
print("[RESOLVE DEBUG] Taking APIServerConfig merged path") print("[RESOLVE DEBUG] Taking APIServerConfig merged path")
logger.info("Using single OpenAI server configuration based on merged settings (default/YAML/CLI).") logger.info(
"Using single OpenAI server configuration based on merged settings (default/YAML/CLI)."
)
try: try:
final_openai_config = APIServerConfig(**openai_config_dict) final_openai_config = APIServerConfig(**openai_config_dict)
except Exception as e: except Exception as e:

View file

@ -193,7 +193,9 @@ class VLLMServer(APIServer):
debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1" debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1"
if debug_requests: if debug_requests:
base = self.config.base_url.replace("/v1", "") base = self.config.base_url.replace("/v1", "")
prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n") prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace(
"\n", "\\n"
)
print( print(
f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate " f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate "
f"prompt_token_len={len(prompt_tokens)}", f"prompt_token_len={len(prompt_tokens)}",
@ -211,7 +213,7 @@ class VLLMServer(APIServer):
) )
print( print(
f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate " f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate "
'-H "Content-Type: application/json" -d \'<JSON_PAYLOAD>\'', "-H \"Content-Type: application/json\" -d '<JSON_PAYLOAD>'",
flush=True, flush=True,
) )

View file

@ -289,7 +289,9 @@ class GSM8kEnv(BaseEnv):
} }
) )
to_postprocess = await self.score(to_score) to_postprocess = await self.score(to_score)
accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", [])) accepted = (
0 if to_postprocess is None else len(to_postprocess.get("tokens", []))
)
print( print(
f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}", f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}",
flush=True, flush=True,

View file

@ -12,7 +12,6 @@ from typing import Dict, List, Optional, Tuple, Union
import wandb import wandb
from datasets import load_dataset from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify from math_verify import LatexExtractionConfig, parse, verify
from math_verify.errors import TimeoutException from math_verify.errors import TimeoutException

View file

@ -136,15 +136,17 @@ class TrainingConfig(BaseModel):
wandb_group: Optional[str] = Field(None, description="Wandb group name") wandb_group: Optional[str] = Field(None, description="Wandb group name")
# === Training Mode Configuration === # === Training Mode Configuration ===
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field( weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = (
"none", Field(
description=( "none",
"How to synchronize weights with inference server. " description=(
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. " "How to synchronize weights with inference server. "
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). " "'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). " "'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
"'none': legacy mode, restart vLLM with new checkpoint files." "'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
), "'none': legacy mode, restart vLLM with new checkpoint files."
),
)
) )
# === Distributed Training Configuration === # === Distributed Training Configuration ===

View file

@ -258,10 +258,14 @@ def pad_data_to_good_offset(
else None else None
) )
final_distill_token_id_batches = ( final_distill_token_id_batches = (
distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None distill_token_id_batches
if (has_any_distill and distill_token_id_batches)
else None
) )
final_distill_logprob_batches = ( final_distill_logprob_batches = (
distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None distill_logprob_batches
if (has_any_distill and distill_logprob_batches)
else None
) )
return ( return (

View file

@ -735,11 +735,15 @@ def train_lora_restart(config: TrainingConfig):
# Periodic adapter save + vLLM restart # Periodic adapter save + vLLM restart
sync_time = 0 sync_time = 0
should_sync = (step + 1) % config.vllm_restart_interval == 0 should_sync = (step + 1) % config.vllm_restart_interval == 0
if should_sync and (step + 1) < config.training_steps: # Don't restart on last step if (
should_sync and (step + 1) < config.training_steps
): # Don't restart on last step
sync_start = time.time() sync_start = time.time()
# Save new adapter # Save new adapter
current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1) current_adapter_path = save_lora_checkpoint(
model, config.save_path, step + 1
)
# Restart vLLM with new adapter # Restart vLLM with new adapter
print(" [RESTART] Restarting vLLM with new adapter...") print(" [RESTART] Restarting vLLM with new adapter...")
@ -798,7 +802,9 @@ def train_lora_restart(config: TrainingConfig):
_vllm_restart_counter = 0 _vllm_restart_counter = 0
def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]: def _launch_vllm_with_lora(
config: TrainingConfig, adapter_path: str
) -> Optional[subprocess.Popen]:
""" """
Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference). Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference).
@ -826,13 +832,19 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
# Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS) # Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS)
cmd = [ cmd = [
sys.executable, server_script, sys.executable,
"--model", config.model_name, server_script,
"--port", str(config.vllm_port), "--model",
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization), config.model_name,
"--max-model-len", str(config.max_model_len), "--port",
str(config.vllm_port),
"--gpu-memory-utilization",
str(config.vllm_gpu_memory_utilization),
"--max-model-len",
str(config.max_model_len),
"--enable-lora", "--enable-lora",
"--max-lora-rank", str(max(config.lora_r * 2, 32)), "--max-lora-rank",
str(max(config.lora_r * 2, 32)),
# Note: NOT adding --enforce-eager - this gives us ~8x faster inference! # Note: NOT adding --enforce-eager - this gives us ~8x faster inference!
# Without --enforce-eager, vLLM can use more optimizations. # Without --enforce-eager, vLLM can use more optimizations.
] ]
@ -849,7 +861,9 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
print(f" Adapter: {adapter_path}") print(f" Adapter: {adapter_path}")
# Log vLLM output to file for debugging (unique file per restart) # Log vLLM output to file for debugging (unique file per restart)
vllm_log_path = os.path.join(config.save_path, f"vllm_restart_{_vllm_restart_counter}.log") vllm_log_path = os.path.join(
config.save_path, f"vllm_restart_{_vllm_restart_counter}.log"
)
_vllm_restart_counter += 1 _vllm_restart_counter += 1
print(f" vLLM log: {vllm_log_path}") print(f" vLLM log: {vllm_log_path}")
@ -857,11 +871,16 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
vllm_log_file = open(vllm_log_path, "w") vllm_log_file = open(vllm_log_path, "w")
# Start in new session so we can kill entire process group later # Start in new session so we can kill entire process group later
proc = subprocess.Popen( proc = subprocess.Popen(
cmd, env=env, stdout=vllm_log_file, stderr=subprocess.STDOUT, cmd,
start_new_session=True # Creates new process group for easy cleanup env=env,
stdout=vllm_log_file,
stderr=subprocess.STDOUT,
start_new_session=True, # Creates new process group for easy cleanup
) )
print(f" vLLM PID: {proc.pid} (process group: {os.getpgid(proc.pid)})") print(f" vLLM PID: {proc.pid} (process group: {os.getpgid(proc.pid)})")
print(" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)...") print(
" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)..."
)
# Wait for server to be ready (longer timeout for CUDA graph compilation) # Wait for server to be ready (longer timeout for CUDA graph compilation)
if not wait_for_vllm_ready(config.vllm_port, timeout=300): if not wait_for_vllm_ready(config.vllm_port, timeout=300):
@ -869,7 +888,7 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
print(f" Check log: {vllm_log_path}") print(f" Check log: {vllm_log_path}")
# Print last 30 lines of the log # Print last 30 lines of the log
try: try:
with open(vllm_log_path, 'r') as f: with open(vllm_log_path, "r") as f:
lines = f.readlines() lines = f.readlines()
print(" Last 30 lines of vLLM log:") print(" Last 30 lines of vLLM log:")
for line in lines[-30:]: for line in lines[-30:]:
@ -890,7 +909,9 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
if resp.status_code == 200: if resp.status_code == 200:
print(" ✓ Adapter loaded successfully") print(" ✓ Adapter loaded successfully")
else: else:
print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}") print(
f" WARNING: Adapter load returned {resp.status_code}: {resp.text}"
)
except Exception as e: except Exception as e:
print(f" WARNING: Could not load adapter: {e}") print(f" WARNING: Could not load adapter: {e}")
# Continue anyway - base model inference still works # Continue anyway - base model inference still works
@ -930,6 +951,7 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
# Phase 2: Kill by port (catches anything still running) # Phase 2: Kill by port (catches anything still running)
from .vllm_manager import kill_process_on_port from .vllm_manager import kill_process_on_port
kill_process_on_port(port) kill_process_on_port(port)
time.sleep(2) time.sleep(2)
@ -954,13 +976,16 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
try: try:
result = sp.run( result = sp.run(
f"nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader,nounits -i {gpu_id}", f"nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader,nounits -i {gpu_id}",
shell=True, capture_output=True, text=True, timeout=10 shell=True,
capture_output=True,
text=True,
timeout=10,
) )
if result.stdout.strip(): if result.stdout.strip():
print(f" Found GPU processes:\n{result.stdout}") print(f" Found GPU processes:\n{result.stdout}")
for line in result.stdout.strip().split('\n'): for line in result.stdout.strip().split("\n"):
if line.strip(): if line.strip():
parts = line.split(',') parts = line.split(",")
if len(parts) >= 1: if len(parts) >= 1:
pid = parts[0].strip() pid = parts[0].strip()
# Don't kill the current Python process (trainer) # Don't kill the current Python process (trainer)
@ -982,7 +1007,9 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
torch.cuda.empty_cache() torch.cuda.empty_cache()
free_mem = torch.cuda.mem_get_info()[0] / 1e9 free_mem = torch.cuda.mem_get_info()[0] / 1e9
total_mem = torch.cuda.mem_get_info()[1] / 1e9 total_mem = torch.cuda.mem_get_info()[1] / 1e9
print(f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)") print(
f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)"
)
# If we have enough memory (>50% free), break early # If we have enough memory (>50% free), break early
if free_mem > total_mem * 0.5: if free_mem > total_mem * 0.5:
print(f" ✓ Sufficient memory available ({free_mem:.1f} GB)") print(f" ✓ Sufficient memory available ({free_mem:.1f} GB)")
@ -994,12 +1021,12 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
torch.cuda.synchronize() torch.cuda.synchronize()
free_mem = torch.cuda.mem_get_info()[0] / 1e9 free_mem = torch.cuda.mem_get_info()[0] / 1e9
total_mem = torch.cuda.mem_get_info()[1] / 1e9 total_mem = torch.cuda.mem_get_info()[1] / 1e9
print(f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)") print(
f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)"
)
if free_mem < total_mem * 0.3: if free_mem < total_mem * 0.3:
print(" WARNING: Low GPU memory! May fail to restart vLLM.") print(" WARNING: Low GPU memory! May fail to restart vLLM.")
print(" Consider reducing --vllm-gpu-memory-utilization") print(" Consider reducing --vllm-gpu-memory-utilization")
print(" vLLM terminated") print(" vLLM terminated")

View file

@ -63,9 +63,10 @@ def compute_distillation_loss(
continue continue
ids_tensor = torch.tensor(pos_ids, device=logits.device, dtype=torch.long) ids_tensor = torch.tensor(pos_ids, device=logits.device, dtype=torch.long)
teacher_lps = torch.tensor( teacher_lps = (
pos_lps, device=logits.device, dtype=logits.dtype torch.tensor(pos_lps, device=logits.device, dtype=logits.dtype)
) / temperature / temperature
)
student_log_probs = F.log_softmax(logits[b, t] / temperature, dim=-1) student_log_probs = F.log_softmax(logits[b, t] / temperature, dim=-1)
student_subset = student_log_probs[ids_tensor] student_subset = student_log_probs[ids_tensor]
@ -75,7 +76,9 @@ def compute_distillation_loss(
token_loss = -(teacher_probs * student_subset).sum() token_loss = -(teacher_probs * student_subset).sum()
else: else:
teacher_log_probs = F.log_softmax(teacher_lps, dim=-1) teacher_log_probs = F.log_softmax(teacher_lps, dim=-1)
token_loss = (teacher_probs * (teacher_log_probs - student_subset)).sum() token_loss = (
teacher_probs * (teacher_log_probs - student_subset)
).sum()
total = total + token_loss total = total + token_loss
count = count + 1.0 count = count + 1.0
@ -323,7 +326,11 @@ def compute_grpo_loss(
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item() interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
distill_loss_val = 0.0 distill_loss_val = 0.0
if distillation_enabled and distill_token_ids is not None and distill_logprobs is not None: if (
distillation_enabled
and distill_token_ids is not None
and distill_logprobs is not None
):
distill_loss = compute_distillation_loss( distill_loss = compute_distillation_loss(
logits=scaled_logits, logits=scaled_logits,
mask=mask, mask=mask,
@ -332,7 +339,10 @@ def compute_grpo_loss(
temperature=distillation_temperature, temperature=distillation_temperature,
loss_type=distillation_loss_type, loss_type=distillation_loss_type,
) )
total_loss = total_loss + (distillation_coef * distill_loss) / gradient_accumulation_steps total_loss = (
total_loss
+ (distillation_coef * distill_loss) / gradient_accumulation_steps
)
distill_loss_val = distill_loss.item() distill_loss_val = distill_loss.item()
metrics = { metrics = {
@ -437,9 +447,13 @@ def run_training_step(
inf_logprobs = inference_logprob_batches[batch_idx] inf_logprobs = inference_logprob_batches[batch_idx]
distill_ids = None distill_ids = None
distill_lps = None distill_lps = None
if distill_token_id_batches is not None and batch_idx < len(distill_token_id_batches): if distill_token_id_batches is not None and batch_idx < len(
distill_token_id_batches
):
distill_ids = distill_token_id_batches[batch_idx] distill_ids = distill_token_id_batches[batch_idx]
if distill_logprob_batches is not None and batch_idx < len(distill_logprob_batches): if distill_logprob_batches is not None and batch_idx < len(
distill_logprob_batches
):
distill_lps = distill_logprob_batches[batch_idx] distill_lps = distill_logprob_batches[batch_idx]
loss, metrics = compute_grpo_loss( loss, metrics = compute_grpo_loss(