This commit is contained in:
Jai Suphavadeeprasit 2026-01-20 22:19:20 -05:00
parent 347f9ea363
commit 83ebfe349f

View file

@ -128,6 +128,15 @@ except ImportError:
from vllm.outputs import RequestOutput # noqa: F401, E402
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
# Try to import LoRARequest for adapter support
try:
from vllm.lora.request import LoRARequest # noqa: E402
LORA_AVAILABLE = True
except ImportError:
LORA_AVAILABLE = False
LoRARequest = None # type: ignore
logger = init_logger("vllm.entrypoints.api_server")
@ -149,12 +158,28 @@ class BridgeState:
# LoRA state
active_lora_path: Optional[str] = None
active_lora_name: Optional[str] = None
active_lora_id: int = 0 # vLLM requires unique integer ID per adapter
lora_load_count: int = 0
bridge_state = BridgeState()
def _get_lora_request() -> Optional["LoRARequest"]:
"""Get the current LoRA request if an adapter is active."""
if not LORA_AVAILABLE:
return None
if bridge_state.active_lora_path is None:
return None
return LoRARequest(
lora_name=bridge_state.active_lora_name or "default_adapter",
lora_int_id=bridge_state.active_lora_id,
lora_path=bridge_state.active_lora_path,
)
# =============================================================================
# Pydantic Models for API
# =============================================================================
@ -174,7 +199,10 @@ class LoraLoadRequest(BaseModel):
class LoraStatusResponse(BaseModel):
active_adapter: Optional[str]
lora_available: bool
active_adapter_path: Optional[str]
active_adapter_name: Optional[str]
active_adapter_id: Optional[int]
load_count: int
available_adapters: List[str]
@ -243,7 +271,12 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id)
# Get active LoRA adapter if any
lora_request = _get_lora_request()
results_generator = engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
)
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
@ -347,8 +380,13 @@ async def openai_chat_completions(request: Request) -> Response:
request_id = random_uuid()
# Get active LoRA adapter if any
lora_request = _get_lora_request()
final_output = None
async for request_output in engine.generate(prompt, sampling_params, request_id):
async for request_output in engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
):
final_output = request_output
if final_output is None:
@ -434,12 +472,15 @@ async def openai_completions(request: Request) -> Response:
total_prompt_tokens = 0
total_completion_tokens = 0
# Get active LoRA adapter if any
lora_request = _get_lora_request()
for prompt_text in prompts:
request_id = random_uuid()
final_output = None
async for request_output in engine.generate(
prompt_text, sampling_params, request_id
prompt_text, sampling_params, request_id, lora_request=lora_request
):
final_output = request_output