[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-01-19 02:35:28 +00:00 committed by Jai Suphavadeeprasit
parent 4740dfa216
commit fe2fd3d824
5 changed files with 510 additions and 337 deletions

View file

@ -46,27 +46,27 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
# =============================================================================
# STEP 1: Apply patches BEFORE any vLLM imports!
# =============================================================================
def _apply_patches_early() -> bool:
"""
Apply vLLM patches if shared weights are enabled.
This MUST be called before any vLLM imports!
Returns True if patches were applied.
"""
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1"))
if not enable_shared and num_inference_nodes < 0:
print("[vLLM Server] Shared weights not enabled, skipping patches")
return False
print("[vLLM Server] VLLM_ENABLE_SHARED_WEIGHTS=1, applying patches...")
try:
# Try relative import first (when run as module)
try:
@ -77,14 +77,14 @@ def _apply_patches_early() -> bool:
if str(script_dir) not in sys.path:
sys.path.insert(0, str(script_dir))
from vllm_patching import apply_patches
success = apply_patches()
if success:
print("[vLLM Server] ✓ vLLM patches applied successfully!")
else:
print("[vLLM Server] ✗ Failed to apply patches")
return success
except ImportError as e:
print(f"[vLLM Server] Could not import vllm_patching: {e}")
print("[vLLM Server] Shared memory weight updates will not be available")
@ -92,6 +92,7 @@ def _apply_patches_early() -> bool:
except Exception as e:
print(f"[vLLM Server] Error applying patches: {e}")
import traceback
traceback.print_exc()
return False
@ -141,10 +142,11 @@ engine: Optional[AsyncLLM] = None
@dataclass
class BridgeState:
"""State for shared memory and LoRA."""
update_count: int = 0
last_update_time: float = 0.0
lock: threading.Lock = field(default_factory=threading.Lock)
# LoRA state
active_lora_path: Optional[str] = None
lora_load_count: int = 0
@ -193,10 +195,10 @@ async def health_generate() -> Response:
"""Health check that verifies model can generate."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
sampling_params = SamplingParams()
request_id = random_uuid()
try:
results_generator = engine.generate(
{"prompt_token_ids": [0]}, sampling_params, request_id
@ -219,7 +221,7 @@ async def health_generate() -> Response:
async def generate(request: Request) -> Response:
"""
Generate completion for the request.
The request should be a JSON object with:
- prompt: the prompt to use for generation
- stream: whether to stream results
@ -234,7 +236,7 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
"""Internal generate handler."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY
@ -265,15 +267,17 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = final_output.prompt or engine.tokenizer.decode(
final_output.prompt_token_ids
)
text_outputs = [output.text for output in final_output.outputs]
finish_reasons = [output.finish_reason for output in final_output.outputs]
ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons}
if sampling_params.logprobs is not None:
output_logprobs = [
[[{key: value.logprob for key, value in logprob.items()}]
for logprob in x.logprobs]
[
[{key: value.logprob for key, value in logprob.items()}]
for logprob in x.logprobs
]
for x in final_output.outputs
]
ret["logprobs"] = output_logprobs
@ -293,17 +297,21 @@ async def bridge_info() -> JSONResponse:
"""Get bridge status and configuration."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
return JSONResponse({
"enabled": PATCHES_APPLIED,
"shared_weights": PATCHES_APPLIED,
"update_count": bridge_state.update_count,
"last_update_time": bridge_state.last_update_time,
"model_name": model_name,
"device": "cuda" if torch.cuda.is_available() else "cpu",
})
model_name = (
str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
)
return JSONResponse(
{
"enabled": PATCHES_APPLIED,
"shared_weights": PATCHES_APPLIED,
"update_count": bridge_state.update_count,
"last_update_time": bridge_state.last_update_time,
"model_name": model_name,
"device": "cuda" if torch.cuda.is_available() else "cpu",
}
)
@app.get("/bridge/state_dict_info")
@ -311,15 +319,17 @@ async def bridge_state_dict_info() -> JSONResponse:
"""Get model parameter information."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Basic model info
try:
model_config = engine.model_config
return JSONResponse({
"model": str(model_config.model),
"dtype": str(model_config.dtype),
"shared_weights_enabled": PATCHES_APPLIED,
})
return JSONResponse(
{
"model": str(model_config.model),
"dtype": str(model_config.dtype),
"shared_weights_enabled": PATCHES_APPLIED,
}
)
except Exception as e:
return JSONResponse({"error": str(e)})
@ -334,10 +344,10 @@ async def bridge_pause() -> JSONResponse:
"""Pause generation to allow weight updates."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
# vLLM v1 supports pause/resume
if hasattr(engine, '_pause_cond'):
if hasattr(engine, "_pause_cond"):
async with engine._pause_cond:
engine._paused = True
logger.info("Engine paused")
@ -353,9 +363,9 @@ async def bridge_resume() -> JSONResponse:
"""Resume generation after weight updates."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
if hasattr(engine, '_pause_cond'):
if hasattr(engine, "_pause_cond"):
async with engine._pause_cond:
engine._paused = False
engine._pause_cond.notify_all()
@ -372,8 +382,8 @@ async def bridge_is_paused() -> JSONResponse:
"""Check if engine is paused."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
paused = getattr(engine, '_paused', False)
paused = getattr(engine, "_paused", False)
return JSONResponse({"paused": paused})
@ -387,7 +397,7 @@ async def bridge_sleep() -> JSONResponse:
"""Put engine to sleep to free GPU memory."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
await engine.sleep()
logger.info("Engine sleeping")
@ -401,7 +411,7 @@ async def bridge_wake_up() -> JSONResponse:
"""Wake engine and reload model."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
try:
await engine.wake_up()
logger.info("Engine woken up")
@ -415,7 +425,7 @@ async def bridge_is_sleeping() -> JSONResponse:
"""Check if engine is sleeping."""
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
sleeping = await engine.is_sleeping()
return JSONResponse({"sleeping": sleeping})
@ -436,7 +446,7 @@ async def bridge_debug() -> JSONResponse:
"num_inference_nodes": os.environ.get("NUM_INFERENCE_NODES", "unset"),
"logdir": os.environ.get("LOGDIR", "unset"),
}
if engine is not None:
try:
debug_info["model_config"] = {
@ -445,7 +455,7 @@ async def bridge_debug() -> JSONResponse:
}
except Exception:
pass
return JSONResponse(debug_info)
@ -455,10 +465,12 @@ async def list_endpoints() -> JSONResponse:
endpoints = []
for route in app.routes:
if hasattr(route, "path") and hasattr(route, "methods"):
endpoints.append({
"path": route.path,
"methods": list(route.methods),
})
endpoints.append(
{
"path": route.path,
"methods": list(route.methods),
}
)
return JSONResponse({"endpoints": endpoints})
@ -472,7 +484,7 @@ async def lora_status() -> LoraStatusResponse:
"""Get LoRA adapter status."""
log_dir = os.environ.get("LOGDIR", ".")
available = []
if os.path.exists(log_dir):
for item in os.listdir(log_dir):
item_path = os.path.join(log_dir, item)
@ -480,7 +492,7 @@ async def lora_status() -> LoraStatusResponse:
os.path.join(item_path, "adapter_config.json")
):
available.append(item)
return LoraStatusResponse(
active_adapter=bridge_state.active_lora_path,
load_count=bridge_state.lora_load_count,
@ -492,19 +504,23 @@ async def lora_status() -> LoraStatusResponse:
async def lora_load(request: LoraLoadRequest) -> JSONResponse:
"""Load a LoRA adapter."""
if not os.path.exists(request.adapter_path):
raise HTTPException(status_code=404, detail=f"Adapter not found: {request.adapter_path}")
raise HTTPException(
status_code=404, detail=f"Adapter not found: {request.adapter_path}"
)
with bridge_state.lock:
bridge_state.active_lora_path = request.adapter_path
bridge_state.lora_load_count += 1
logger.info(f"LoRA adapter loaded: {request.adapter_path}")
return JSONResponse({
"status": "ok",
"adapter_path": request.adapter_path,
"load_count": bridge_state.lora_load_count,
})
return JSONResponse(
{
"status": "ok",
"adapter_path": request.adapter_path,
"load_count": bridge_state.lora_load_count,
}
)
@app.post("/lora/unload")
@ -513,7 +529,7 @@ async def lora_unload() -> JSONResponse:
with bridge_state.lock:
prev = bridge_state.active_lora_path
bridge_state.active_lora_path = None
logger.info(f"LoRA adapter unloaded: {prev}")
return JSONResponse({"status": "ok", "previous_adapter": prev})
@ -533,18 +549,21 @@ def build_app(args: Namespace) -> FastAPI:
async def init_app(args: Namespace, llm_engine: AsyncLLM | None = None) -> FastAPI:
"""Initialize the application and vLLM engine."""
app = build_app(args)
global engine
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (
llm_engine if llm_engine is not None
else AsyncLLM.from_engine_args(engine_args, usage_context=UsageContext.API_SERVER)
llm_engine
if llm_engine is not None
else AsyncLLM.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER
)
)
app.state.engine_client = engine
# Export basic state dict info for trainers (the patched runner exports detailed info)
_export_state_dict_info(args)
return app
@ -553,17 +572,20 @@ def _export_state_dict_info(args: Namespace) -> None:
log_dir = os.environ.get("LOGDIR", ".")
Path(log_dir).mkdir(parents=True, exist_ok=True)
json_path = Path(log_dir) / "vllm_bridge_config.json"
# Only write basic info if the file doesn't exist or is empty
# The patched runner will write complete info with param_mappings
try:
if json_path.exists():
with open(json_path, 'r') as f:
with open(json_path, "r") as f:
existing = json.load(f)
if existing.get("param_mappings") and len(existing["param_mappings"]) > 0:
if (
existing.get("param_mappings")
and len(existing["param_mappings"]) > 0
):
logger.info(f"Config already has param_mappings, not overwriting")
return
info = {
"model": getattr(args, "model", "unknown"),
"dtype": getattr(args, "dtype", "auto"),
@ -572,32 +594,34 @@ def _export_state_dict_info(args: Namespace) -> None:
"param_mappings": {},
"shared_weights_enabled": PATCHES_APPLIED,
}
with open(json_path, "w") as f:
json.dump(info, f, indent=2)
logger.info(f"Exported basic state dict info to {json_path}")
except Exception as e:
logger.warning(f"Failed to export state dict info: {e}")
async def run_server(args: Namespace, llm_engine: AsyncLLM | None = None, **uvicorn_kwargs: Any) -> None:
async def run_server(
args: Namespace, llm_engine: AsyncLLM | None = None, **uvicorn_kwargs: Any
) -> None:
"""Run the vLLM API server."""
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
if PATCHES_APPLIED:
logger.info("=" * 60)
logger.info("SHARED MEMORY MODE ENABLED")
logger.info("Weight updates from trainer will be reflected immediately!")
logger.info("=" * 60)
set_ulimit()
app = await init_app(args, llm_engine)
if engine is None:
raise RuntimeError("No engine initialized")
# Log available endpoints
logger.info("=" * 60)
logger.info("Available endpoints:")
@ -607,22 +631,22 @@ async def run_server(args: Namespace, llm_engine: AsyncLLM | None = None, **uvic
logger.info(" POST /bridge/resume - Resume generation")
logger.info(" GET /lora/status - LoRA adapter status")
logger.info("=" * 60)
shutdown_task = await serve_http(
app,
sock=None,
enable_ssl_refresh=getattr(args, 'enable_ssl_refresh', False),
enable_ssl_refresh=getattr(args, "enable_ssl_refresh", False),
host=args.host,
port=args.port,
log_level=getattr(args, 'log_level', 'info'),
log_level=getattr(args, "log_level", "info"),
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=getattr(args, 'ssl_keyfile', None),
ssl_certfile=getattr(args, 'ssl_certfile', None),
ssl_ca_certs=getattr(args, 'ssl_ca_certs', None),
ssl_cert_reqs=getattr(args, 'ssl_cert_reqs', ssl.CERT_NONE),
ssl_keyfile=getattr(args, "ssl_keyfile", None),
ssl_certfile=getattr(args, "ssl_certfile", None),
ssl_ca_certs=getattr(args, "ssl_ca_certs", None),
ssl_cert_reqs=getattr(args, "ssl_cert_reqs", ssl.CERT_NONE),
**uvicorn_kwargs,
)
await shutdown_task
@ -637,9 +661,9 @@ if __name__ == "__main__":
parser.add_argument("--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE))
parser.add_argument("--root-path", type=str, default=None)
parser.add_argument("--log-level", type=str, default="info")
# Add vLLM engine args
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
asyncio.run(run_server(args))