better debugging

This commit is contained in:
Jai Suphavadeeprasit 2026-01-13 00:36:16 -05:00
parent fad8e77be2
commit 9df62a8f64

View file

@ -165,17 +165,33 @@ def _create_patched_runner(BaseRunner: type) -> type:
print("[vLLM Patch] Shared memory already set up, skipping")
return
print("[vLLM Patch] Setting up shared memory weight updates...")
print("[vLLM Patch] Setting up shared memory weight updates...", flush=True)
try:
self._setup_shared_memory()
self._spawn_weight_updater()
PatchedGPUModelRunner._shared_memory_setup_done = True
print("[vLLM Patch] ✓ Shared memory updates enabled successfully!")
print("[vLLM Patch] ✓ Shared memory setup complete!", flush=True)
except Exception as e:
print(f"[vLLM Patch] Warning: Failed to set up shared memory: {e}")
print(f"[vLLM Patch] ERROR in _setup_shared_memory: {e}", flush=True)
import traceback
traceback.print_exc()
return
# Spawn weight updater daemon (optional - can be skipped for HTTP-only mode)
skip_daemon = os.environ.get("VLLM_SKIP_WEIGHT_DAEMON", "0") == "1"
if skip_daemon:
print("[vLLM Patch] Skipping weight updater daemon (VLLM_SKIP_WEIGHT_DAEMON=1)", flush=True)
return
try:
print("[vLLM Patch] Spawning weight updater daemon...", flush=True)
self._spawn_weight_updater()
print("[vLLM Patch] ✓ Weight updater daemon spawned!", flush=True)
except Exception as e:
print(f"[vLLM Patch] ERROR spawning weight updater: {e}", flush=True)
import traceback
traceback.print_exc()
print("[vLLM Patch] Continuing without daemon (HTTP-only mode)", flush=True)
def _setup_shared_memory(self) -> None:
"""Move model tensors to shared memory and export param info."""
@ -252,22 +268,28 @@ def _create_patched_runner(BaseRunner: type) -> type:
def _spawn_weight_updater(self) -> None:
"""Spawn the daemon process for receiving weight updates."""
print("[vLLM Patch] _spawn_weight_updater() called", flush=True)
try:
from vllm.distributed import get_tensor_model_parallel_rank
except ImportError:
# Fallback for older vLLM versions
print("[vLLM Patch] Imported get_tensor_model_parallel_rank", flush=True)
except ImportError as e:
print(f"[vLLM Patch] Could not import get_tensor_model_parallel_rank: {e}", flush=True)
get_tensor_model_parallel_rank = lambda: 0
# Get model configuration
state_dict = self.model.state_dict()
print(f"[vLLM Patch] Got state_dict with {len(state_dict)} params", flush=True)
# Get attention head counts
hf_config = self.model_config.hf_text_config
num_heads = getattr(hf_config, "num_attention_heads", 0)
num_kv_heads = self.model_config.get_total_num_kv_heads()
print(f"[vLLM Patch] num_heads={num_heads}, num_kv_heads={num_kv_heads}", flush=True)
# Get parallel configuration
tp_rank = get_tensor_model_parallel_rank()
print(f"[vLLM Patch] tp_rank={tp_rank}", flush=True)
# Get GPU ID
gpu_id = 0
@ -280,10 +302,13 @@ def _create_patched_runner(BaseRunner: type) -> type:
except Exception:
gpu_id = tp_rank
print(f"[vLLM Patch] Spawning weight updater: tp_rank={tp_rank}, gpu={gpu_id}")
print(f"[vLLM Patch] Spawning weight updater: tp_rank={tp_rank}, gpu={gpu_id}", flush=True)
# Spawn daemon process
print("[vLLM Patch] Creating spawn context...", flush=True)
ctx = mp.get_context("spawn")
print("[vLLM Patch] Creating Process...", flush=True)
self.weight_updater_process = ctx.Process(
target=weight_updater_process,
args=(
@ -296,9 +321,11 @@ def _create_patched_runner(BaseRunner: type) -> type:
),
daemon=True,
)
print("[vLLM Patch] Starting daemon process...", flush=True)
self.weight_updater_process.start()
print(f"[vLLM Patch] ✓ Weight updater daemon started (PID: {self.weight_updater_process.pid})")
print(f"[vLLM Patch] ✓ Weight updater daemon started (PID: {self.weight_updater_process.pid})", flush=True)
# Set proper class name
PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner"