diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index c47385a8..2d5a8688 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -36,6 +36,7 @@ CRITICAL: Patches must be applied BEFORE importing vLLM! # ============================================================================= import asyncio import json +import multiprocessing import os import ssl import sys @@ -46,6 +47,21 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, List, Optional +# ============================================================================= +# CRITICAL: Set up multiprocessing and vLLM engine BEFORE any CUDA imports +# ============================================================================= + +# Default to v0 engine to avoid CUDA fork issues with v1 engine +# Users can override with VLLM_USE_V1=1 if needed +os.environ.setdefault("VLLM_USE_V1", "0") + +# Set spawn method for multiprocessing (required for CUDA) +os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") +try: + multiprocessing.set_start_method('spawn', force=True) +except RuntimeError: + pass # Already set + # ============================================================================= # STEP 1: Apply patches BEFORE any vLLM imports! # =============================================================================