mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
patching problem
This commit is contained in:
parent
5af1a4a974
commit
3ac4a64f6f
4 changed files with 234 additions and 76 deletions
|
|
@ -1,3 +1,4 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Custom vLLM API server with shared memory weight updates.
|
||||
|
||||
|
|
@ -30,12 +31,18 @@ ARCHITECTURE:
|
|||
│ │ Trainer Process │ │
|
||||
│ └─────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
CRITICAL: Patches must be applied BEFORE importing vLLM!
|
||||
"""
|
||||
|
||||
# =============================================================================
|
||||
# STEP 0: Standard library imports ONLY (no vLLM yet!)
|
||||
# =============================================================================
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from argparse import Namespace
|
||||
|
|
@ -44,78 +51,88 @@ from dataclasses import dataclass, field
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
try:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.system_utils import set_ulimit
|
||||
except ImportError:
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit
|
||||
|
||||
from vllm.outputs import RequestOutput # noqa: F401
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Apply vLLM Patches for Shared Memory
|
||||
# STEP 1: Apply patches BEFORE any vLLM imports!
|
||||
# =============================================================================
|
||||
|
||||
def _maybe_apply_patches() -> bool:
|
||||
def _apply_patches_early() -> bool:
|
||||
"""
|
||||
Apply vLLM patches if shared weights are enabled.
|
||||
|
||||
This MUST be called before any vLLM imports!
|
||||
Returns True if patches were applied.
|
||||
"""
|
||||
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
|
||||
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", -1))
|
||||
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1"))
|
||||
|
||||
if not enable_shared and num_inference_nodes < 0:
|
||||
print("[vLLM Server] Shared weights not enabled, skipping patches")
|
||||
return False
|
||||
|
||||
print("[vLLM Server] VLLM_ENABLE_SHARED_WEIGHTS=1, applying patches...")
|
||||
|
||||
try:
|
||||
# Try relative import first (when run as module)
|
||||
try:
|
||||
from .vllm_patching import apply_patches
|
||||
except ImportError:
|
||||
# Fall back to absolute import (when run as script)
|
||||
import sys
|
||||
from pathlib import Path
|
||||
# Add the example_trainer directory to path
|
||||
script_dir = Path(__file__).parent
|
||||
if str(script_dir) not in sys.path:
|
||||
sys.path.insert(0, str(script_dir))
|
||||
from vllm_patching import apply_patches
|
||||
|
||||
apply_patches()
|
||||
logger.info("✓ vLLM patches applied for shared memory weights")
|
||||
return True
|
||||
success = apply_patches()
|
||||
if success:
|
||||
print("[vLLM Server] ✓ vLLM patches applied successfully!")
|
||||
else:
|
||||
print("[vLLM Server] ✗ Failed to apply patches")
|
||||
return success
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import vllm_patching: {e}")
|
||||
logger.warning("Shared memory weight updates will not be available")
|
||||
print(f"[vLLM Server] Could not import vllm_patching: {e}")
|
||||
print("[vLLM Server] Shared memory weight updates will not be available")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply patches: {e}")
|
||||
print(f"[vLLM Server] Error applying patches: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
# Apply patches before any other vLLM imports
|
||||
PATCHES_APPLIED = _maybe_apply_patches()
|
||||
# Apply patches NOW, before any vLLM imports below!
|
||||
PATCHES_APPLIED = _apply_patches_early()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# STEP 2: Now safe to import vLLM (patches are already in place)
|
||||
# =============================================================================
|
||||
|
||||
import torch # noqa: E402
|
||||
import vllm.envs as envs # noqa: E402
|
||||
from fastapi import FastAPI, HTTPException, Request # noqa: E402
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse # noqa: E402
|
||||
from pydantic import BaseModel # noqa: E402
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs # noqa: E402
|
||||
from vllm.entrypoints.launcher import serve_http # noqa: E402
|
||||
from vllm.entrypoints.utils import with_cancellation # noqa: E402
|
||||
from vllm.logger import init_logger # noqa: E402
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams # noqa: E402
|
||||
from vllm.usage.usage_lib import UsageContext # noqa: E402
|
||||
from vllm.utils import random_uuid # noqa: E402
|
||||
from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402
|
||||
|
||||
try:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser # noqa: E402
|
||||
from vllm.utils.system_utils import set_ulimit # noqa: E402
|
||||
except ImportError:
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit # noqa: E402
|
||||
|
||||
from vllm.outputs import RequestOutput # noqa: F401, E402
|
||||
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -493,6 +510,7 @@ async def bridge_debug() -> JSONResponse:
|
|||
"patches_applied": PATCHES_APPLIED,
|
||||
"shared_weights_env": os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0"),
|
||||
"num_inference_nodes": os.environ.get("NUM_INFERENCE_NODES", "unset"),
|
||||
"logdir": os.environ.get("LOGDIR", "unset"),
|
||||
}
|
||||
|
||||
if engine is not None:
|
||||
|
|
@ -600,18 +618,28 @@ async def init_app(args: Namespace, llm_engine: AsyncLLM | None = None) -> FastA
|
|||
)
|
||||
app.state.engine_client = engine
|
||||
|
||||
# Export state dict info for trainers
|
||||
# Export basic state dict info for trainers (the patched runner exports detailed info)
|
||||
_export_state_dict_info(args)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _export_state_dict_info(args: Namespace) -> None:
|
||||
"""Export model parameter mapping to JSON for trainer."""
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
"""Export basic model info to JSON for trainer (backup if patches don't run)."""
|
||||
log_dir = os.environ.get("LOGDIR", "/tmp/atropos_bridge")
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
|
||||
# Only write basic info if the file doesn't exist or is empty
|
||||
# The patched runner will write complete info with param_mappings
|
||||
try:
|
||||
if json_path.exists():
|
||||
with open(json_path, 'r') as f:
|
||||
existing = json.load(f)
|
||||
if existing.get("param_mappings") and len(existing["param_mappings"]) > 0:
|
||||
logger.info(f"Config already has param_mappings, not overwriting")
|
||||
return
|
||||
|
||||
info = {
|
||||
"model": getattr(args, "model", "unknown"),
|
||||
"dtype": getattr(args, "dtype", "auto"),
|
||||
|
|
@ -624,7 +652,7 @@ def _export_state_dict_info(args: Namespace) -> None:
|
|||
with open(json_path, "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
logger.info(f"Exported state dict info to {json_path}")
|
||||
logger.info(f"Exported basic state dict info to {json_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to export state dict info: {e}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue