mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
better debugging
This commit is contained in:
parent
fad8e77be2
commit
9df62a8f64
1 changed files with 35 additions and 8 deletions
|
|
@ -165,17 +165,33 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
print("[vLLM Patch] Shared memory already set up, skipping")
|
||||
return
|
||||
|
||||
print("[vLLM Patch] Setting up shared memory weight updates...")
|
||||
print("[vLLM Patch] Setting up shared memory weight updates...", flush=True)
|
||||
|
||||
try:
|
||||
self._setup_shared_memory()
|
||||
self._spawn_weight_updater()
|
||||
PatchedGPUModelRunner._shared_memory_setup_done = True
|
||||
print("[vLLM Patch] ✓ Shared memory updates enabled successfully!")
|
||||
print("[vLLM Patch] ✓ Shared memory setup complete!", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Warning: Failed to set up shared memory: {e}")
|
||||
print(f"[vLLM Patch] ERROR in _setup_shared_memory: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Spawn weight updater daemon (optional - can be skipped for HTTP-only mode)
|
||||
skip_daemon = os.environ.get("VLLM_SKIP_WEIGHT_DAEMON", "0") == "1"
|
||||
if skip_daemon:
|
||||
print("[vLLM Patch] Skipping weight updater daemon (VLLM_SKIP_WEIGHT_DAEMON=1)", flush=True)
|
||||
return
|
||||
|
||||
try:
|
||||
print("[vLLM Patch] Spawning weight updater daemon...", flush=True)
|
||||
self._spawn_weight_updater()
|
||||
print("[vLLM Patch] ✓ Weight updater daemon spawned!", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] ERROR spawning weight updater: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print("[vLLM Patch] Continuing without daemon (HTTP-only mode)", flush=True)
|
||||
|
||||
def _setup_shared_memory(self) -> None:
|
||||
"""Move model tensors to shared memory and export param info."""
|
||||
|
|
@ -252,22 +268,28 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
|
||||
def _spawn_weight_updater(self) -> None:
|
||||
"""Spawn the daemon process for receiving weight updates."""
|
||||
print("[vLLM Patch] _spawn_weight_updater() called", flush=True)
|
||||
|
||||
try:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
except ImportError:
|
||||
# Fallback for older vLLM versions
|
||||
print("[vLLM Patch] Imported get_tensor_model_parallel_rank", flush=True)
|
||||
except ImportError as e:
|
||||
print(f"[vLLM Patch] Could not import get_tensor_model_parallel_rank: {e}", flush=True)
|
||||
get_tensor_model_parallel_rank = lambda: 0
|
||||
|
||||
# Get model configuration
|
||||
state_dict = self.model.state_dict()
|
||||
print(f"[vLLM Patch] Got state_dict with {len(state_dict)} params", flush=True)
|
||||
|
||||
# Get attention head counts
|
||||
hf_config = self.model_config.hf_text_config
|
||||
num_heads = getattr(hf_config, "num_attention_heads", 0)
|
||||
num_kv_heads = self.model_config.get_total_num_kv_heads()
|
||||
print(f"[vLLM Patch] num_heads={num_heads}, num_kv_heads={num_kv_heads}", flush=True)
|
||||
|
||||
# Get parallel configuration
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
print(f"[vLLM Patch] tp_rank={tp_rank}", flush=True)
|
||||
|
||||
# Get GPU ID
|
||||
gpu_id = 0
|
||||
|
|
@ -280,10 +302,13 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
except Exception:
|
||||
gpu_id = tp_rank
|
||||
|
||||
print(f"[vLLM Patch] Spawning weight updater: tp_rank={tp_rank}, gpu={gpu_id}")
|
||||
print(f"[vLLM Patch] Spawning weight updater: tp_rank={tp_rank}, gpu={gpu_id}", flush=True)
|
||||
|
||||
# Spawn daemon process
|
||||
print("[vLLM Patch] Creating spawn context...", flush=True)
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
print("[vLLM Patch] Creating Process...", flush=True)
|
||||
self.weight_updater_process = ctx.Process(
|
||||
target=weight_updater_process,
|
||||
args=(
|
||||
|
|
@ -296,9 +321,11 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
print("[vLLM Patch] Starting daemon process...", flush=True)
|
||||
self.weight_updater_process.start()
|
||||
|
||||
print(f"[vLLM Patch] ✓ Weight updater daemon started (PID: {self.weight_updater_process.pid})")
|
||||
print(f"[vLLM Patch] ✓ Weight updater daemon started (PID: {self.weight_updater_process.pid})", flush=True)
|
||||
|
||||
# Set proper class name
|
||||
PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue