mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
hot swap adapter
This commit is contained in:
parent
c86b36844b
commit
347f9ea363
3 changed files with 121 additions and 86 deletions
|
|
@ -119,15 +119,6 @@ from vllm.usage.usage_lib import UsageContext # noqa: E402
|
|||
from vllm.utils import random_uuid # noqa: E402
|
||||
from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402
|
||||
|
||||
# Try to import LoRARequest for adapter support
|
||||
try:
|
||||
from vllm.lora.request import LoRARequest # noqa: E402
|
||||
|
||||
LORA_AVAILABLE = True
|
||||
except ImportError:
|
||||
LORA_AVAILABLE = False
|
||||
LoRARequest = None # type: ignore
|
||||
|
||||
try:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser # noqa: E402
|
||||
from vllm.utils.system_utils import set_ulimit # noqa: E402
|
||||
|
|
@ -158,28 +149,12 @@ class BridgeState:
|
|||
|
||||
# LoRA state
|
||||
active_lora_path: Optional[str] = None
|
||||
active_lora_name: Optional[str] = None
|
||||
active_lora_id: int = 0 # vLLM requires unique integer ID per adapter
|
||||
lora_load_count: int = 0
|
||||
|
||||
|
||||
bridge_state = BridgeState()
|
||||
|
||||
|
||||
def _get_lora_request() -> Optional["LoRARequest"]:
|
||||
"""Get the current LoRA request if an adapter is active."""
|
||||
if not LORA_AVAILABLE:
|
||||
return None
|
||||
if bridge_state.active_lora_path is None:
|
||||
return None
|
||||
|
||||
return LoRARequest(
|
||||
lora_name=bridge_state.active_lora_name or "default_adapter",
|
||||
lora_int_id=bridge_state.active_lora_id,
|
||||
lora_path=bridge_state.active_lora_path,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models for API
|
||||
# =============================================================================
|
||||
|
|
@ -199,10 +174,7 @@ class LoraLoadRequest(BaseModel):
|
|||
|
||||
|
||||
class LoraStatusResponse(BaseModel):
|
||||
lora_available: bool
|
||||
active_adapter_path: Optional[str]
|
||||
active_adapter_name: Optional[str]
|
||||
active_adapter_id: Optional[int]
|
||||
active_adapter: Optional[str]
|
||||
load_count: int
|
||||
available_adapters: List[str]
|
||||
|
||||
|
|
@ -271,12 +243,7 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
# Get active LoRA adapter if any
|
||||
lora_request = _get_lora_request()
|
||||
|
||||
results_generator = engine.generate(
|
||||
prompt, sampling_params, request_id, lora_request=lora_request
|
||||
)
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
|
|
@ -380,13 +347,8 @@ async def openai_chat_completions(request: Request) -> Response:
|
|||
|
||||
request_id = random_uuid()
|
||||
|
||||
# Get active LoRA adapter if any
|
||||
lora_request = _get_lora_request()
|
||||
|
||||
final_output = None
|
||||
async for request_output in engine.generate(
|
||||
prompt, sampling_params, request_id, lora_request=lora_request
|
||||
):
|
||||
async for request_output in engine.generate(prompt, sampling_params, request_id):
|
||||
final_output = request_output
|
||||
|
||||
if final_output is None:
|
||||
|
|
@ -472,15 +434,12 @@ async def openai_completions(request: Request) -> Response:
|
|||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
|
||||
# Get active LoRA adapter if any
|
||||
lora_request = _get_lora_request()
|
||||
|
||||
for prompt_text in prompts:
|
||||
request_id = random_uuid()
|
||||
|
||||
final_output = None
|
||||
async for request_output in engine.generate(
|
||||
prompt_text, sampling_params, request_id, lora_request=lora_request
|
||||
prompt_text, sampling_params, request_id
|
||||
):
|
||||
final_output = request_output
|
||||
|
||||
|
|
@ -732,14 +691,13 @@ async def lora_status() -> LoraStatusResponse:
|
|||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
available = []
|
||||
|
||||
# Scan for available adapters in LOGDIR
|
||||
if os.path.exists(log_dir):
|
||||
for item in os.listdir(log_dir):
|
||||
item_path = os.path.join(log_dir, item)
|
||||
if os.path.isdir(item_path) and os.path.exists(
|
||||
os.path.join(item_path, "adapter_config.json")
|
||||
):
|
||||
available.append(item_path) # Return full path
|
||||
available.append(item)
|
||||
|
||||
return LoraStatusResponse(
|
||||
lora_available=LORA_AVAILABLE,
|
||||
|
|
@ -755,26 +713,12 @@ async def lora_status() -> LoraStatusResponse:
|
|||
|
||||
@app.post("/lora/load")
|
||||
async def lora_load(request: LoraLoadRequest) -> JSONResponse:
|
||||
"""Load a LoRA adapter for use in generation requests."""
|
||||
if not LORA_AVAILABLE:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="LoRA not available. Start vLLM with --enable-lora",
|
||||
)
|
||||
|
||||
"""Load a LoRA adapter."""
|
||||
if not os.path.exists(request.adapter_path):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Adapter not found: {request.adapter_path}"
|
||||
)
|
||||
|
||||
# Verify adapter_config.json exists
|
||||
adapter_config = os.path.join(request.adapter_path, "adapter_config.json")
|
||||
if not os.path.exists(adapter_config):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid adapter: {request.adapter_path} (missing adapter_config.json)",
|
||||
)
|
||||
|
||||
with bridge_state.lock:
|
||||
bridge_state.active_lora_path = request.adapter_path
|
||||
bridge_state.active_lora_name = (
|
||||
|
|
@ -785,17 +729,12 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse:
|
|||
) # vLLM needs unique int ID
|
||||
bridge_state.lora_load_count += 1
|
||||
|
||||
logger.info(
|
||||
f"LoRA adapter loaded: {request.adapter_path} "
|
||||
f"(name={bridge_state.active_lora_name}, id={bridge_state.active_lora_id})"
|
||||
)
|
||||
logger.info(f"LoRA adapter loaded: {request.adapter_path}")
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": "ok",
|
||||
"adapter_path": request.adapter_path,
|
||||
"adapter_name": bridge_state.active_lora_name,
|
||||
"adapter_id": bridge_state.active_lora_id,
|
||||
"load_count": bridge_state.lora_load_count,
|
||||
}
|
||||
)
|
||||
|
|
@ -803,13 +742,10 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse:
|
|||
|
||||
@app.post("/lora/unload")
|
||||
async def lora_unload() -> JSONResponse:
|
||||
"""Unload current LoRA adapter (revert to base model)."""
|
||||
"""Unload current LoRA adapter."""
|
||||
with bridge_state.lock:
|
||||
prev_path = bridge_state.active_lora_path
|
||||
prev_name = bridge_state.active_lora_name
|
||||
prev = bridge_state.active_lora_path
|
||||
bridge_state.active_lora_path = None
|
||||
bridge_state.active_lora_name = None
|
||||
# Note: Keep active_lora_id - vLLM may cache adapters by ID
|
||||
|
||||
logger.info(f"LoRA adapter unloaded: {prev_path} ({prev_name})")
|
||||
return JSONResponse(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue