patching problem

This commit is contained in:
Jai Suphavadeeprasit 2026-01-13 00:09:40 -05:00
parent 5af1a4a974
commit 3ac4a64f6f
4 changed files with 234 additions and 76 deletions

View file

@ -1,3 +1,4 @@
#!/usr/bin/env python3
"""
Custom vLLM API server with shared memory weight updates.
@ -30,12 +31,18 @@ ARCHITECTURE:
Trainer Process
CRITICAL: Patches must be applied BEFORE importing vLLM!
"""
# =============================================================================
# STEP 0: Standard library imports ONLY (no vLLM yet!)
# =============================================================================
import asyncio
import json
import os
import ssl
import sys
import threading
import time
from argparse import Namespace
@ -44,78 +51,88 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
import vllm.envs as envs
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from pydantic import BaseModel
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
from vllm.v1.engine.async_llm import AsyncLLM
try:
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.system_utils import set_ulimit
except ImportError:
from vllm.utils import FlexibleArgumentParser, set_ulimit
from vllm.outputs import RequestOutput # noqa: F401
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")
# =============================================================================
# Apply vLLM Patches for Shared Memory
# STEP 1: Apply patches BEFORE any vLLM imports!
# =============================================================================
def _maybe_apply_patches() -> bool:
def _apply_patches_early() -> bool:
"""
Apply vLLM patches if shared weights are enabled.
This MUST be called before any vLLM imports!
Returns True if patches were applied.
"""
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", -1))
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1"))
if not enable_shared and num_inference_nodes < 0:
print("[vLLM Server] Shared weights not enabled, skipping patches")
return False
print("[vLLM Server] VLLM_ENABLE_SHARED_WEIGHTS=1, applying patches...")
try:
# Try relative import first (when run as module)
try:
from .vllm_patching import apply_patches
except ImportError:
# Fall back to absolute import (when run as script)
import sys
from pathlib import Path
# Add the example_trainer directory to path
script_dir = Path(__file__).parent
if str(script_dir) not in sys.path:
sys.path.insert(0, str(script_dir))
from vllm_patching import apply_patches
apply_patches()
logger.info("✓ vLLM patches applied for shared memory weights")
return True
success = apply_patches()
if success:
print("[vLLM Server] ✓ vLLM patches applied successfully!")
else:
print("[vLLM Server] ✗ Failed to apply patches")
return success
except ImportError as e:
logger.warning(f"Could not import vllm_patching: {e}")
logger.warning("Shared memory weight updates will not be available")
print(f"[vLLM Server] Could not import vllm_patching: {e}")
print("[vLLM Server] Shared memory weight updates will not be available")
return False
except Exception as e:
logger.warning(f"Failed to apply patches: {e}")
print(f"[vLLM Server] Error applying patches: {e}")
import traceback
traceback.print_exc()
return False
# Apply patches before any other vLLM imports
PATCHES_APPLIED = _maybe_apply_patches()
# Apply patches NOW, before any vLLM imports below!
PATCHES_APPLIED = _apply_patches_early()
# =============================================================================
# STEP 2: Now safe to import vLLM (patches are already in place)
# =============================================================================
import torch # noqa: E402
import vllm.envs as envs # noqa: E402
from fastapi import FastAPI, HTTPException, Request # noqa: E402
from fastapi.responses import JSONResponse, Response, StreamingResponse # noqa: E402
from pydantic import BaseModel # noqa: E402
from vllm.engine.arg_utils import AsyncEngineArgs # noqa: E402
from vllm.entrypoints.launcher import serve_http # noqa: E402
from vllm.entrypoints.utils import with_cancellation # noqa: E402
from vllm.logger import init_logger # noqa: E402
from vllm.sampling_params import RequestOutputKind, SamplingParams # noqa: E402
from vllm.usage.usage_lib import UsageContext # noqa: E402
from vllm.utils import random_uuid # noqa: E402
from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402
try:
from vllm.utils.argparse_utils import FlexibleArgumentParser # noqa: E402
from vllm.utils.system_utils import set_ulimit # noqa: E402
except ImportError:
from vllm.utils import FlexibleArgumentParser, set_ulimit # noqa: E402
from vllm.outputs import RequestOutput # noqa: F401, E402
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
logger = init_logger("vllm.entrypoints.api_server")
# =============================================================================
@ -493,6 +510,7 @@ async def bridge_debug() -> JSONResponse:
"patches_applied": PATCHES_APPLIED,
"shared_weights_env": os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0"),
"num_inference_nodes": os.environ.get("NUM_INFERENCE_NODES", "unset"),
"logdir": os.environ.get("LOGDIR", "unset"),
}
if engine is not None:
@ -600,18 +618,28 @@ async def init_app(args: Namespace, llm_engine: AsyncLLM | None = None) -> FastA
)
app.state.engine_client = engine
# Export state dict info for trainers
# Export basic state dict info for trainers (the patched runner exports detailed info)
_export_state_dict_info(args)
return app
def _export_state_dict_info(args: Namespace) -> None:
"""Export model parameter mapping to JSON for trainer."""
log_dir = os.environ.get("LOGDIR", ".")
"""Export basic model info to JSON for trainer (backup if patches don't run)."""
log_dir = os.environ.get("LOGDIR", "/tmp/atropos_bridge")
Path(log_dir).mkdir(parents=True, exist_ok=True)
json_path = Path(log_dir) / "vllm_bridge_config.json"
# Only write basic info if the file doesn't exist or is empty
# The patched runner will write complete info with param_mappings
try:
if json_path.exists():
with open(json_path, 'r') as f:
existing = json.load(f)
if existing.get("param_mappings") and len(existing["param_mappings"]) > 0:
logger.info(f"Config already has param_mappings, not overwriting")
return
info = {
"model": getattr(args, "model", "unknown"),
"dtype": getattr(args, "dtype", "auto"),
@ -624,7 +652,7 @@ def _export_state_dict_info(args: Namespace) -> None:
with open(json_path, "w") as f:
json.dump(info, f, indent=2)
logger.info(f"Exported state dict info to {json_path}")
logger.info(f"Exported basic state dict info to {json_path}")
except Exception as e:
logger.warning(f"Failed to export state dict info: {e}")