mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
editing
This commit is contained in:
parent
347f9ea363
commit
83ebfe349f
1 changed files with 45 additions and 4 deletions
|
|
@ -128,6 +128,15 @@ except ImportError:
|
|||
from vllm.outputs import RequestOutput # noqa: F401, E402
|
||||
from vllm.version import __version__ as VLLM_VERSION # noqa: E402
|
||||
|
||||
# Try to import LoRARequest for adapter support
|
||||
try:
|
||||
from vllm.lora.request import LoRARequest # noqa: E402
|
||||
|
||||
LORA_AVAILABLE = True
|
||||
except ImportError:
|
||||
LORA_AVAILABLE = False
|
||||
LoRARequest = None # type: ignore
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
|
||||
|
|
@ -149,12 +158,28 @@ class BridgeState:
|
|||
|
||||
# LoRA state
|
||||
active_lora_path: Optional[str] = None
|
||||
active_lora_name: Optional[str] = None
|
||||
active_lora_id: int = 0 # vLLM requires unique integer ID per adapter
|
||||
lora_load_count: int = 0
|
||||
|
||||
|
||||
bridge_state = BridgeState()
|
||||
|
||||
|
||||
def _get_lora_request() -> Optional["LoRARequest"]:
|
||||
"""Get the current LoRA request if an adapter is active."""
|
||||
if not LORA_AVAILABLE:
|
||||
return None
|
||||
if bridge_state.active_lora_path is None:
|
||||
return None
|
||||
|
||||
return LoRARequest(
|
||||
lora_name=bridge_state.active_lora_name or "default_adapter",
|
||||
lora_int_id=bridge_state.active_lora_id,
|
||||
lora_path=bridge_state.active_lora_path,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models for API
|
||||
# =============================================================================
|
||||
|
|
@ -174,7 +199,10 @@ class LoraLoadRequest(BaseModel):
|
|||
|
||||
|
||||
class LoraStatusResponse(BaseModel):
|
||||
active_adapter: Optional[str]
|
||||
lora_available: bool
|
||||
active_adapter_path: Optional[str]
|
||||
active_adapter_name: Optional[str]
|
||||
active_adapter_id: Optional[int]
|
||||
load_count: int
|
||||
available_adapters: List[str]
|
||||
|
||||
|
|
@ -243,7 +271,12 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
# Get active LoRA adapter if any
|
||||
lora_request = _get_lora_request()
|
||||
|
||||
results_generator = engine.generate(
|
||||
prompt, sampling_params, request_id, lora_request=lora_request
|
||||
)
|
||||
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
|
|
@ -347,8 +380,13 @@ async def openai_chat_completions(request: Request) -> Response:
|
|||
|
||||
request_id = random_uuid()
|
||||
|
||||
# Get active LoRA adapter if any
|
||||
lora_request = _get_lora_request()
|
||||
|
||||
final_output = None
|
||||
async for request_output in engine.generate(prompt, sampling_params, request_id):
|
||||
async for request_output in engine.generate(
|
||||
prompt, sampling_params, request_id, lora_request=lora_request
|
||||
):
|
||||
final_output = request_output
|
||||
|
||||
if final_output is None:
|
||||
|
|
@ -434,12 +472,15 @@ async def openai_completions(request: Request) -> Response:
|
|||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
|
||||
# Get active LoRA adapter if any
|
||||
lora_request = _get_lora_request()
|
||||
|
||||
for prompt_text in prompts:
|
||||
request_id = random_uuid()
|
||||
|
||||
final_output = None
|
||||
async for request_output in engine.generate(
|
||||
prompt_text, sampling_params, request_id
|
||||
prompt_text, sampling_params, request_id, lora_request=lora_request
|
||||
):
|
||||
final_output = request_output
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue