[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-20 04:58:43 +00:00
parent ccdd5a1ca6
commit 60fb6cae11
11 changed files with 221 additions and 136 deletions

View file

@ -194,7 +194,7 @@ python -m example_trainer.grpo --weight-bridge-mode lora_only ...
---
## Shared vLLM Mode
## Shared vLLM Mode
Single-copy mode shares GPU memory between vLLM and the trainer - zero model duplication!

View file

@ -136,15 +136,17 @@ class TrainingConfig(BaseModel):
wandb_group: Optional[str] = Field(None, description="Wandb group name")
# === Training Mode Configuration ===
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field(
"none",
description=(
"How to synchronize weights with inference server. "
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
"'none': legacy mode, restart vLLM with new checkpoint files."
),
weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = (
Field(
"none",
description=(
"How to synchronize weights with inference server. "
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
"'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). "
"'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). "
"'none': legacy mode, restart vLLM with new checkpoint files."
),
)
)
# === Distributed Training Configuration ===

View file

@ -258,10 +258,14 @@ def pad_data_to_good_offset(
else None
)
final_distill_token_id_batches = (
distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None
distill_token_id_batches
if (has_any_distill and distill_token_id_batches)
else None
)
final_distill_logprob_batches = (
distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None
distill_logprob_batches
if (has_any_distill and distill_logprob_batches)
else None
)
return (

View file

@ -735,19 +735,23 @@ def train_lora_restart(config: TrainingConfig):
# Periodic adapter save + vLLM restart
sync_time = 0
should_sync = (step + 1) % config.vllm_restart_interval == 0
if should_sync and (step + 1) < config.training_steps: # Don't restart on last step
if (
should_sync and (step + 1) < config.training_steps
): # Don't restart on last step
sync_start = time.time()
# Save new adapter
current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1)
current_adapter_path = save_lora_checkpoint(
model, config.save_path, step + 1
)
# Restart vLLM with new adapter
print(" [RESTART] Restarting vLLM with new adapter...")
_terminate_vllm(vllm_proc, config.vllm_port)
vllm_proc = _launch_vllm_with_lora(config, current_adapter_path)
if vllm_proc is None:
raise RuntimeError("Failed to restart vLLM")
sync_time = time.time() - sync_start
benchmark_stats["sync_times"].append(sync_time)
benchmark_stats["restart_times"].append(sync_time)
@ -798,45 +802,53 @@ def train_lora_restart(config: TrainingConfig):
_vllm_restart_counter = 0
def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]:
def _launch_vllm_with_lora(
config: TrainingConfig, adapter_path: str
) -> Optional[subprocess.Popen]:
"""
Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference).
Unlike lora_only mode, this does NOT use --enforce-eager, so we get
~108 TPS instead of ~13 TPS (8x faster).
"""
global _vllm_restart_counter
from .vllm_manager import kill_process_on_port, wait_for_vllm_ready
# Kill any existing process on the port
print(f" Cleaning up port {config.vllm_port}...")
kill_process_on_port(config.vllm_port)
# Clear CUDA cache before starting new vLLM
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Wait for port and GPU memory to be fully released
time.sleep(5)
# Find the vllm_api_server.py script
script_dir = os.path.dirname(os.path.abspath(__file__))
server_script = os.path.join(script_dir, "vllm_api_server.py")
# Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS)
cmd = [
sys.executable, server_script,
"--model", config.model_name,
"--port", str(config.vllm_port),
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),
"--max-model-len", str(config.max_model_len),
sys.executable,
server_script,
"--model",
config.model_name,
"--port",
str(config.vllm_port),
"--gpu-memory-utilization",
str(config.vllm_gpu_memory_utilization),
"--max-model-len",
str(config.max_model_len),
"--enable-lora",
"--max-lora-rank", str(max(config.lora_r * 2, 32)),
"--max-lora-rank",
str(max(config.lora_r * 2, 32)),
# Note: NOT adding --enforce-eager - this gives us ~8x faster inference!
# Without --enforce-eager, vLLM can use more optimizations.
]
# Set environment for GPU selection
env = os.environ.copy()
if config.vllm_gpu is not None:
@ -844,32 +856,39 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
print(f" GPU: {config.vllm_gpu} (via CUDA_VISIBLE_DEVICES)")
else:
print(" GPU: Same as trainer (inherited CUDA_VISIBLE_DEVICES)")
print(f" Launching: {' '.join(cmd)}")
print(f" Adapter: {adapter_path}")
# Log vLLM output to file for debugging (unique file per restart)
vllm_log_path = os.path.join(config.save_path, f"vllm_restart_{_vllm_restart_counter}.log")
vllm_log_path = os.path.join(
config.save_path, f"vllm_restart_{_vllm_restart_counter}.log"
)
_vllm_restart_counter += 1
print(f" vLLM log: {vllm_log_path}")
try:
vllm_log_file = open(vllm_log_path, "w")
# Start in new session so we can kill entire process group later
proc = subprocess.Popen(
cmd, env=env, stdout=vllm_log_file, stderr=subprocess.STDOUT,
start_new_session=True # Creates new process group for easy cleanup
cmd,
env=env,
stdout=vllm_log_file,
stderr=subprocess.STDOUT,
start_new_session=True, # Creates new process group for easy cleanup
)
print(f" vLLM PID: {proc.pid} (process group: {os.getpgid(proc.pid)})")
print(" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)...")
print(
" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)..."
)
# Wait for server to be ready (longer timeout for CUDA graph compilation)
if not wait_for_vllm_ready(config.vllm_port, timeout=300):
print(" ERROR: vLLM failed to start after 300s")
print(f" Check log: {vllm_log_path}")
# Print last 30 lines of the log
try:
with open(vllm_log_path, 'r') as f:
with open(vllm_log_path, "r") as f:
lines = f.readlines()
print(" Last 30 lines of vLLM log:")
for line in lines[-30:]:
@ -878,7 +897,7 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
print(f" Could not read log: {e}")
proc.terminate()
return None
# Load the LoRA adapter
print(" Loading LoRA adapter...")
try:
@ -890,13 +909,15 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona
if resp.status_code == 200:
print(" ✓ Adapter loaded successfully")
else:
print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}")
print(
f" WARNING: Adapter load returned {resp.status_code}: {resp.text}"
)
except Exception as e:
print(f" WARNING: Could not load adapter: {e}")
# Continue anyway - base model inference still works
return proc
except Exception as e:
print(f" ERROR: {e}")
return None
@ -906,12 +927,12 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
"""Terminate a vLLM process and release GPU resources."""
import signal
import subprocess as sp
print(f" Terminating vLLM on port {port}...")
# Get current GPU device
gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
# Phase 1: Kill the process group if we have a handle (kills all children too)
main_pid = None
if proc is not None:
@ -927,12 +948,13 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
proc.wait(timeout=5)
except Exception as e:
print(f" Warning: {e}")
# Phase 2: Kill by port (catches anything still running)
from .vllm_manager import kill_process_on_port
kill_process_on_port(port)
time.sleep(2)
# Phase 3: Aggressively kill ALL vLLM-related processes
print(" Killing all vLLM-related processes...")
kill_commands = [
@ -948,19 +970,22 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
sp.run(cmd, shell=True, capture_output=True, timeout=5)
except Exception:
pass
# Phase 4: Use nvidia-smi to find and kill GPU processes (nuclear option)
print(f" Checking for zombie GPU processes on GPU {gpu_id}...")
try:
result = sp.run(
f"nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader,nounits -i {gpu_id}",
shell=True, capture_output=True, text=True, timeout=10
shell=True,
capture_output=True,
text=True,
timeout=10,
)
if result.stdout.strip():
print(f" Found GPU processes:\n{result.stdout}")
for line in result.stdout.strip().split('\n'):
for line in result.stdout.strip().split("\n"):
if line.strip():
parts = line.split(',')
parts = line.split(",")
if len(parts) >= 1:
pid = parts[0].strip()
# Don't kill the current Python process (trainer)
@ -972,7 +997,7 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
pass
except Exception as e:
print(f" Warning: nvidia-smi check failed: {e}")
# Phase 5: Wait for GPU memory release - CRITICAL
# The CUDA driver needs time to actually free memory after process death
print(" Waiting for GPU memory release...")
@ -982,24 +1007,26 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None:
torch.cuda.empty_cache()
free_mem = torch.cuda.mem_get_info()[0] / 1e9
total_mem = torch.cuda.mem_get_info()[1] / 1e9
print(f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)")
print(
f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)"
)
# If we have enough memory (>50% free), break early
if free_mem > total_mem * 0.5:
print(f" ✓ Sufficient memory available ({free_mem:.1f} GB)")
break
# Final cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
free_mem = torch.cuda.mem_get_info()[0] / 1e9
total_mem = torch.cuda.mem_get_info()[1] / 1e9
print(f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)")
print(
f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)"
)
if free_mem < total_mem * 0.3:
print(" WARNING: Low GPU memory! May fail to restart vLLM.")
print(" Consider reducing --vllm-gpu-memory-utilization")
print(" vLLM terminated")

View file

@ -63,9 +63,10 @@ def compute_distillation_loss(
continue
ids_tensor = torch.tensor(pos_ids, device=logits.device, dtype=torch.long)
teacher_lps = torch.tensor(
pos_lps, device=logits.device, dtype=logits.dtype
) / temperature
teacher_lps = (
torch.tensor(pos_lps, device=logits.device, dtype=logits.dtype)
/ temperature
)
student_log_probs = F.log_softmax(logits[b, t] / temperature, dim=-1)
student_subset = student_log_probs[ids_tensor]
@ -75,7 +76,9 @@ def compute_distillation_loss(
token_loss = -(teacher_probs * student_subset).sum()
else:
teacher_log_probs = F.log_softmax(teacher_lps, dim=-1)
token_loss = (teacher_probs * (teacher_log_probs - student_subset)).sum()
token_loss = (
teacher_probs * (teacher_log_probs - student_subset)
).sum()
total = total + token_loss
count = count + 1.0
@ -323,7 +326,11 @@ def compute_grpo_loss(
interpretable_loss = (avg_logp * advantages.squeeze()).mean().item()
distill_loss_val = 0.0
if distillation_enabled and distill_token_ids is not None and distill_logprobs is not None:
if (
distillation_enabled
and distill_token_ids is not None
and distill_logprobs is not None
):
distill_loss = compute_distillation_loss(
logits=scaled_logits,
mask=mask,
@ -332,7 +339,10 @@ def compute_grpo_loss(
temperature=distillation_temperature,
loss_type=distillation_loss_type,
)
total_loss = total_loss + (distillation_coef * distill_loss) / gradient_accumulation_steps
total_loss = (
total_loss
+ (distillation_coef * distill_loss) / gradient_accumulation_steps
)
distill_loss_val = distill_loss.item()
metrics = {
@ -437,9 +447,13 @@ def run_training_step(
inf_logprobs = inference_logprob_batches[batch_idx]
distill_ids = None
distill_lps = None
if distill_token_id_batches is not None and batch_idx < len(distill_token_id_batches):
if distill_token_id_batches is not None and batch_idx < len(
distill_token_id_batches
):
distill_ids = distill_token_id_batches[batch_idx]
if distill_logprob_batches is not None and batch_idx < len(distill_logprob_batches):
if distill_logprob_batches is not None and batch_idx < len(
distill_logprob_batches
):
distill_lps = distill_logprob_batches[batch_idx]
loss, metrics = compute_grpo_loss(