hot swap adapter

This commit is contained in:
Jai Suphavadeeprasit 2026-01-20 22:11:53 -05:00
parent c86b36844b
commit 347f9ea363
3 changed files with 121 additions and 86 deletions

View file

@ -119,15 +119,6 @@ from vllm.usage.usage_lib import UsageContext # noqa: E402
from vllm.utils import random_uuid # noqa: E402
from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402
# Try to import LoRARequest for adapter support
try:
from vllm.lora.request import LoRARequest # noqa: E402
LORA_AVAILABLE = True
except ImportError:
LORA_AVAILABLE = False
LoRARequest = None # type: ignore
try:
from vllm.utils.argparse_utils import FlexibleArgumentParser # noqa: E402
from vllm.utils.system_utils import set_ulimit # noqa: E402
@ -158,28 +149,12 @@ class BridgeState:
# LoRA state
active_lora_path: Optional[str] = None
active_lora_name: Optional[str] = None
active_lora_id: int = 0 # vLLM requires unique integer ID per adapter
lora_load_count: int = 0
bridge_state = BridgeState()
def _get_lora_request() -> Optional["LoRARequest"]:
"""Get the current LoRA request if an adapter is active."""
if not LORA_AVAILABLE:
return None
if bridge_state.active_lora_path is None:
return None
return LoRARequest(
lora_name=bridge_state.active_lora_name or "default_adapter",
lora_int_id=bridge_state.active_lora_id,
lora_path=bridge_state.active_lora_path,
)
# =============================================================================
# Pydantic Models for API
# =============================================================================
@ -199,10 +174,7 @@ class LoraLoadRequest(BaseModel):
class LoraStatusResponse(BaseModel):
lora_available: bool
active_adapter_path: Optional[str]
active_adapter_name: Optional[str]
active_adapter_id: Optional[int]
active_adapter: Optional[str]
load_count: int
available_adapters: List[str]
@ -271,12 +243,7 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
# Get active LoRA adapter if any
lora_request = _get_lora_request()
results_generator = engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
)
results_generator = engine.generate(prompt, sampling_params, request_id)
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
@ -380,13 +347,8 @@ async def openai_chat_completions(request: Request) -> Response:
request_id = random_uuid()
# Get active LoRA adapter if any
lora_request = _get_lora_request()
final_output = None
async for request_output in engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
):
async for request_output in engine.generate(prompt, sampling_params, request_id):
final_output = request_output
if final_output is None:
@ -472,15 +434,12 @@ async def openai_completions(request: Request) -> Response:
total_prompt_tokens = 0
total_completion_tokens = 0
# Get active LoRA adapter if any
lora_request = _get_lora_request()
for prompt_text in prompts:
request_id = random_uuid()
final_output = None
async for request_output in engine.generate(
prompt_text, sampling_params, request_id, lora_request=lora_request
prompt_text, sampling_params, request_id
):
final_output = request_output
@ -732,14 +691,13 @@ async def lora_status() -> LoraStatusResponse:
log_dir = os.environ.get("LOGDIR", ".")
available = []
# Scan for available adapters in LOGDIR
if os.path.exists(log_dir):
for item in os.listdir(log_dir):
item_path = os.path.join(log_dir, item)
if os.path.isdir(item_path) and os.path.exists(
os.path.join(item_path, "adapter_config.json")
):
available.append(item_path) # Return full path
available.append(item)
return LoraStatusResponse(
lora_available=LORA_AVAILABLE,
@ -755,26 +713,12 @@ async def lora_status() -> LoraStatusResponse:
@app.post("/lora/load")
async def lora_load(request: LoraLoadRequest) -> JSONResponse:
"""Load a LoRA adapter for use in generation requests."""
if not LORA_AVAILABLE:
raise HTTPException(
status_code=501,
detail="LoRA not available. Start vLLM with --enable-lora",
)
"""Load a LoRA adapter."""
if not os.path.exists(request.adapter_path):
raise HTTPException(
status_code=404, detail=f"Adapter not found: {request.adapter_path}"
)
# Verify adapter_config.json exists
adapter_config = os.path.join(request.adapter_path, "adapter_config.json")
if not os.path.exists(adapter_config):
raise HTTPException(
status_code=400,
detail=f"Invalid adapter: {request.adapter_path} (missing adapter_config.json)",
)
with bridge_state.lock:
bridge_state.active_lora_path = request.adapter_path
bridge_state.active_lora_name = (
@ -785,17 +729,12 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse:
) # vLLM needs unique int ID
bridge_state.lora_load_count += 1
logger.info(
f"LoRA adapter loaded: {request.adapter_path} "
f"(name={bridge_state.active_lora_name}, id={bridge_state.active_lora_id})"
)
logger.info(f"LoRA adapter loaded: {request.adapter_path}")
return JSONResponse(
{
"status": "ok",
"adapter_path": request.adapter_path,
"adapter_name": bridge_state.active_lora_name,
"adapter_id": bridge_state.active_lora_id,
"load_count": bridge_state.lora_load_count,
}
)
@ -803,13 +742,10 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse:
@app.post("/lora/unload")
async def lora_unload() -> JSONResponse:
"""Unload current LoRA adapter (revert to base model)."""
"""Unload current LoRA adapter."""
with bridge_state.lock:
prev_path = bridge_state.active_lora_path
prev_name = bridge_state.active_lora_name
prev = bridge_state.active_lora_path
bridge_state.active_lora_path = None
bridge_state.active_lora_name = None
# Note: Keep active_lora_id - vLLM may cache adapters by ID
logger.info(f"LoRA adapter unloaded: {prev_path} ({prev_name})")
return JSONResponse(