diff --git a/example_trainer/scripts/compare_all_modes_math_zero.sh b/example_trainer/scripts/compare_all_modes_math_zero.sh index 2c1f511e..55065d54 100755 --- a/example_trainer/scripts/compare_all_modes_math_zero.sh +++ b/example_trainer/scripts/compare_all_modes_math_zero.sh @@ -244,7 +244,7 @@ python -u environments/math_server_zero.py serve \ SHARED_ENV_PID=$! echo "[SHARED_VLLM] Starting trainer..." -CUDA_VISIBLE_DEVICES=$SHARED_GPU python -m example_trainer.grpo \ +CUDA_VISIBLE_DEVICES=$SHARED_GPU PYTHONUNBUFFERED=1 stdbuf -oL -eL python -u -m example_trainer.grpo \ --model-name "$MODEL" \ --weight-bridge-mode shared_vllm \ --vllm-port $SHARED_VLLM_PORT \ @@ -277,7 +277,7 @@ python -u environments/math_server_zero.py serve \ LORA_ONLY_ENV_PID=$! echo "[LORA_ONLY] Starting trainer..." -CUDA_VISIBLE_DEVICES=$LORA_ONLY_GPU python -m example_trainer.grpo \ +CUDA_VISIBLE_DEVICES=$LORA_ONLY_GPU PYTHONUNBUFFERED=1 stdbuf -oL -eL python -u -m example_trainer.grpo \ --model-name "$MODEL" \ --weight-bridge-mode lora_only \ --vllm-port $LORA_ONLY_VLLM_PORT \ @@ -301,7 +301,8 @@ LORA_ONLY_TRAINER_PID=$! echo "" echo "[LORA_RESTART] Starting trainer (manages vLLM internally)..." # NOTE: lora_restart shares GPU with trainer's model (~8GB), so use lower vLLM memory -CUDA_VISIBLE_DEVICES=$LORA_RESTART_GPU python -m example_trainer.grpo \ +# Use unbuffered output (-u) and stdbuf to capture crashes +CUDA_VISIBLE_DEVICES=$LORA_RESTART_GPU PYTHONUNBUFFERED=1 stdbuf -oL -eL python -u -m example_trainer.grpo \ --model-name "$MODEL" \ --weight-bridge-mode lora_restart \ --vllm-port $LORA_RESTART_VLLM_PORT \ diff --git a/example_trainer/vllm_manager.py b/example_trainer/vllm_manager.py index 4088994c..7f9f20fe 100644 --- a/example_trainer/vllm_manager.py +++ b/example_trainer/vllm_manager.py @@ -45,10 +45,10 @@ def kill_process_on_port(port: int, timeout: float = 5.0) -> bool: ) if result.stdout.strip(): pids = result.stdout.strip().split("\n") + print(f" Killing {len(pids)} processes on port {port}...") for pid in pids: try: os.kill(int(pid), signal.SIGTERM) - print(f" Sent SIGTERM to PID {pid}") except (ProcessLookupError, ValueError): pass @@ -61,12 +61,15 @@ def kill_process_on_port(port: int, timeout: float = 5.0) -> bool: time.sleep(0.5) # Force kill if still running + killed_count = 0 for pid in pids: try: os.kill(int(pid), signal.SIGKILL) - print(f" Sent SIGKILL to PID {pid}") + killed_count += 1 except (ProcessLookupError, ValueError): pass + if killed_count > 0: + print(f" Force killed {killed_count} stubborn processes") time.sleep(1) return not is_port_in_use(port)