diff --git a/example_trainer/vllm_patching/patched_gpu_runner.py b/example_trainer/vllm_patching/patched_gpu_runner.py index 3c124598..7631683d 100644 --- a/example_trainer/vllm_patching/patched_gpu_runner.py +++ b/example_trainer/vllm_patching/patched_gpu_runner.py @@ -267,7 +267,13 @@ def _create_patched_runner(BaseRunner: type) -> type: traceback.print_exc() def _spawn_weight_updater(self) -> None: - """Spawn the daemon process for receiving weight updates.""" + """Start the weight updater as a background thread. + + Note: We use threading instead of multiprocessing because vLLM's + worker processes are daemons, and daemons cannot spawn child processes. + """ + import threading + print("[vLLM Patch] _spawn_weight_updater() called", flush=True) try: @@ -302,14 +308,10 @@ def _create_patched_runner(BaseRunner: type) -> type: except Exception: gpu_id = tp_rank - print(f"[vLLM Patch] Spawning weight updater: tp_rank={tp_rank}, gpu={gpu_id}", flush=True) + print(f"[vLLM Patch] Starting weight updater thread: tp_rank={tp_rank}, gpu={gpu_id}", flush=True) - # Spawn daemon process - print("[vLLM Patch] Creating spawn context...", flush=True) - ctx = mp.get_context("spawn") - - print("[vLLM Patch] Creating Process...", flush=True) - self.weight_updater_process = ctx.Process( + # Start as a daemon thread (threads CAN be started from daemon processes) + self.weight_updater_thread = threading.Thread( target=weight_updater_process, args=( state_dict, @@ -320,12 +322,13 @@ def _create_patched_runner(BaseRunner: type) -> type: gpu_id, ), daemon=True, + name=f"WeightUpdater_TP{tp_rank}", ) - print("[vLLM Patch] Starting daemon process...", flush=True) - self.weight_updater_process.start() + print("[vLLM Patch] Starting thread...", flush=True) + self.weight_updater_thread.start() - print(f"[vLLM Patch] ✓ Weight updater daemon started (PID: {self.weight_updater_process.pid})", flush=True) + print(f"[vLLM Patch] ✓ Weight updater thread started (name: {self.weight_updater_thread.name})", flush=True) # Set proper class name PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner"