diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index b436342d..d7f6e2e6 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -80,25 +80,32 @@ def _apply_patches_early() -> bool: try: # Try relative import first (when run as module) - from .vllm_patching import apply_patches + except ImportError: + # Fall back to absolute import (when run as script) + try: + import sys + from pathlib import Path + # Add parent directory to path so we can import vllm_patching + script_dir = Path(__file__).parent + if str(script_dir) not in sys.path: + sys.path.insert(0, str(script_dir)) + from vllm_patching import apply_patches + except ImportError as e: + print(f"[vLLM Server] Could not import vllm_patching: {e}") + print("[vLLM Server] Shared memory weight updates will not be available") + return False - + try: success = apply_patches() if success: print("[vLLM Server] ✓ vLLM patches applied successfully!") else: print("[vLLM Server] ✗ Failed to apply patches") return success - - except ImportError as e: - print(f"[vLLM Server] Could not import vllm_patching: {e}") - print("[vLLM Server] Shared memory weight updates will not be available") - return False except Exception as e: print(f"[vLLM Server] Error applying patches: {e}") import traceback - traceback.print_exc() return False