mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
4740dfa216
commit
fe2fd3d824
5 changed files with 510 additions and 337 deletions
|
|
@ -46,27 +46,27 @@ from dataclasses import dataclass, field
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# STEP 1: Apply patches BEFORE any vLLM imports!
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _apply_patches_early() -> bool:
|
||||
"""
|
||||
Apply vLLM patches if shared weights are enabled.
|
||||
|
||||
|
||||
This MUST be called before any vLLM imports!
|
||||
Returns True if patches were applied.
|
||||
"""
|
||||
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
|
||||
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1"))
|
||||
|
||||
|
||||
if not enable_shared and num_inference_nodes < 0:
|
||||
print("[vLLM Server] Shared weights not enabled, skipping patches")
|
||||
return False
|
||||
|
||||
|
||||
print("[vLLM Server] VLLM_ENABLE_SHARED_WEIGHTS=1, applying patches...")
|
||||
|
||||
|
||||
try:
|
||||
# Try relative import first (when run as module)
|
||||
try:
|
||||
|
|
@ -77,14 +77,14 @@ def _apply_patches_early() -> bool:
|
|||
if str(script_dir) not in sys.path:
|
||||
sys.path.insert(0, str(script_dir))
|
||||
from vllm_patching import apply_patches
|
||||
|
||||
|
||||
success = apply_patches()
|
||||
if success:
|
||||
print("[vLLM Server] ✓ vLLM patches applied successfully!")
|
||||
else:
|
||||
print("[vLLM Server] ✗ Failed to apply patches")
|
||||
return success
|
||||
|
||||
|
||||
except ImportError as e:
|
||||
print(f"[vLLM Server] Could not import vllm_patching: {e}")
|
||||
print("[vLLM Server] Shared memory weight updates will not be available")
|
||||
|
|
@ -92,6 +92,7 @@ def _apply_patches_early() -> bool:
|
|||
except Exception as e:
|
||||
print(f"[vLLM Server] Error applying patches: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
|
@ -141,10 +142,11 @@ engine: Optional[AsyncLLM] = None
|
|||
@dataclass
|
||||
class BridgeState:
|
||||
"""State for shared memory and LoRA."""
|
||||
|
||||
update_count: int = 0
|
||||
last_update_time: float = 0.0
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
|
||||
# LoRA state
|
||||
active_lora_path: Optional[str] = None
|
||||
lora_load_count: int = 0
|
||||
|
|
@ -193,10 +195,10 @@ async def health_generate() -> Response:
|
|||
"""Health check that verifies model can generate."""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
|
||||
sampling_params = SamplingParams()
|
||||
request_id = random_uuid()
|
||||
|
||||
|
||||
try:
|
||||
results_generator = engine.generate(
|
||||
{"prompt_token_ids": [0]}, sampling_params, request_id
|
||||
|
|
@ -219,7 +221,7 @@ async def health_generate() -> Response:
|
|||
async def generate(request: Request) -> Response:
|
||||
"""
|
||||
Generate completion for the request.
|
||||
|
||||
|
||||
The request should be a JSON object with:
|
||||
- prompt: the prompt to use for generation
|
||||
- stream: whether to stream results
|
||||
|
|
@ -234,7 +236,7 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
"""Internal generate handler."""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", False)
|
||||
request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY
|
||||
|
|
@ -265,15 +267,17 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
|||
prompt = final_output.prompt or engine.tokenizer.decode(
|
||||
final_output.prompt_token_ids
|
||||
)
|
||||
|
||||
|
||||
text_outputs = [output.text for output in final_output.outputs]
|
||||
finish_reasons = [output.finish_reason for output in final_output.outputs]
|
||||
ret = {"text": text_outputs, "prompt": prompt, "finish_reasons": finish_reasons}
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
output_logprobs = [
|
||||
[[{key: value.logprob for key, value in logprob.items()}]
|
||||
for logprob in x.logprobs]
|
||||
[
|
||||
[{key: value.logprob for key, value in logprob.items()}]
|
||||
for logprob in x.logprobs
|
||||
]
|
||||
for x in final_output.outputs
|
||||
]
|
||||
ret["logprobs"] = output_logprobs
|
||||
|
|
@ -293,17 +297,21 @@ async def bridge_info() -> JSONResponse:
|
|||
"""Get bridge status and configuration."""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
|
||||
|
||||
return JSONResponse({
|
||||
"enabled": PATCHES_APPLIED,
|
||||
"shared_weights": PATCHES_APPLIED,
|
||||
"update_count": bridge_state.update_count,
|
||||
"last_update_time": bridge_state.last_update_time,
|
||||
"model_name": model_name,
|
||||
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
||||
})
|
||||
|
||||
model_name = (
|
||||
str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown"
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"enabled": PATCHES_APPLIED,
|
||||
"shared_weights": PATCHES_APPLIED,
|
||||
"update_count": bridge_state.update_count,
|
||||
"last_update_time": bridge_state.last_update_time,
|
||||
"model_name": model_name,
|
||||
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/bridge/state_dict_info")
|
||||
|
|
@ -311,15 +319,17 @@ async def bridge_state_dict_info() -> JSONResponse:
|
|||
"""Get model parameter information."""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
|
||||
# Basic model info
|
||||
try:
|
||||
model_config = engine.model_config
|
||||
return JSONResponse({
|
||||
"model": str(model_config.model),
|
||||
"dtype": str(model_config.dtype),
|
||||
"shared_weights_enabled": PATCHES_APPLIED,
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"model": str(model_config.model),
|
||||
"dtype": str(model_config.dtype),
|
||||
"shared_weights_enabled": PATCHES_APPLIED,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return JSONResponse({"error": str(e)})
|
||||
|
||||
|
|
@ -334,10 +344,10 @@ async def bridge_pause() -> JSONResponse:
|
|||
"""Pause generation to allow weight updates."""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
|
||||
try:
|
||||
# vLLM v1 supports pause/resume
|
||||
if hasattr(engine, '_pause_cond'):
|
||||
if hasattr(engine, "_pause_cond"):
|
||||
async with engine._pause_cond:
|
||||
engine._paused = True
|
||||
logger.info("Engine paused")
|
||||
|
|
@ -353,9 +363,9 @@ async def bridge_resume() -> JSONResponse:
|
|||
"""Resume generation after weight updates."""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
|
||||
try:
|
||||
if hasattr(engine, '_pause_cond'):
|
||||
if hasattr(engine, "_pause_cond"):
|
||||
async with engine._pause_cond:
|
||||
engine._paused = False
|
||||
engine._pause_cond.notify_all()
|
||||
|
|
@ -372,8 +382,8 @@ async def bridge_is_paused() -> JSONResponse:
|
|||
"""Check if engine is paused."""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
paused = getattr(engine, '_paused', False)
|
||||
|
||||
paused = getattr(engine, "_paused", False)
|
||||
return JSONResponse({"paused": paused})
|
||||
|
||||
|
||||
|
|
@ -387,7 +397,7 @@ async def bridge_sleep() -> JSONResponse:
|
|||
"""Put engine to sleep to free GPU memory."""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
|
||||
try:
|
||||
await engine.sleep()
|
||||
logger.info("Engine sleeping")
|
||||
|
|
@ -401,7 +411,7 @@ async def bridge_wake_up() -> JSONResponse:
|
|||
"""Wake engine and reload model."""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
|
||||
try:
|
||||
await engine.wake_up()
|
||||
logger.info("Engine woken up")
|
||||
|
|
@ -415,7 +425,7 @@ async def bridge_is_sleeping() -> JSONResponse:
|
|||
"""Check if engine is sleeping."""
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
|
||||
sleeping = await engine.is_sleeping()
|
||||
return JSONResponse({"sleeping": sleeping})
|
||||
|
||||
|
|
@ -436,7 +446,7 @@ async def bridge_debug() -> JSONResponse:
|
|||
"num_inference_nodes": os.environ.get("NUM_INFERENCE_NODES", "unset"),
|
||||
"logdir": os.environ.get("LOGDIR", "unset"),
|
||||
}
|
||||
|
||||
|
||||
if engine is not None:
|
||||
try:
|
||||
debug_info["model_config"] = {
|
||||
|
|
@ -445,7 +455,7 @@ async def bridge_debug() -> JSONResponse:
|
|||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
return JSONResponse(debug_info)
|
||||
|
||||
|
||||
|
|
@ -455,10 +465,12 @@ async def list_endpoints() -> JSONResponse:
|
|||
endpoints = []
|
||||
for route in app.routes:
|
||||
if hasattr(route, "path") and hasattr(route, "methods"):
|
||||
endpoints.append({
|
||||
"path": route.path,
|
||||
"methods": list(route.methods),
|
||||
})
|
||||
endpoints.append(
|
||||
{
|
||||
"path": route.path,
|
||||
"methods": list(route.methods),
|
||||
}
|
||||
)
|
||||
return JSONResponse({"endpoints": endpoints})
|
||||
|
||||
|
||||
|
|
@ -472,7 +484,7 @@ async def lora_status() -> LoraStatusResponse:
|
|||
"""Get LoRA adapter status."""
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
available = []
|
||||
|
||||
|
||||
if os.path.exists(log_dir):
|
||||
for item in os.listdir(log_dir):
|
||||
item_path = os.path.join(log_dir, item)
|
||||
|
|
@ -480,7 +492,7 @@ async def lora_status() -> LoraStatusResponse:
|
|||
os.path.join(item_path, "adapter_config.json")
|
||||
):
|
||||
available.append(item)
|
||||
|
||||
|
||||
return LoraStatusResponse(
|
||||
active_adapter=bridge_state.active_lora_path,
|
||||
load_count=bridge_state.lora_load_count,
|
||||
|
|
@ -492,19 +504,23 @@ async def lora_status() -> LoraStatusResponse:
|
|||
async def lora_load(request: LoraLoadRequest) -> JSONResponse:
|
||||
"""Load a LoRA adapter."""
|
||||
if not os.path.exists(request.adapter_path):
|
||||
raise HTTPException(status_code=404, detail=f"Adapter not found: {request.adapter_path}")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Adapter not found: {request.adapter_path}"
|
||||
)
|
||||
|
||||
with bridge_state.lock:
|
||||
bridge_state.active_lora_path = request.adapter_path
|
||||
bridge_state.lora_load_count += 1
|
||||
|
||||
|
||||
logger.info(f"LoRA adapter loaded: {request.adapter_path}")
|
||||
|
||||
return JSONResponse({
|
||||
"status": "ok",
|
||||
"adapter_path": request.adapter_path,
|
||||
"load_count": bridge_state.lora_load_count,
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": "ok",
|
||||
"adapter_path": request.adapter_path,
|
||||
"load_count": bridge_state.lora_load_count,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.post("/lora/unload")
|
||||
|
|
@ -513,7 +529,7 @@ async def lora_unload() -> JSONResponse:
|
|||
with bridge_state.lock:
|
||||
prev = bridge_state.active_lora_path
|
||||
bridge_state.active_lora_path = None
|
||||
|
||||
|
||||
logger.info(f"LoRA adapter unloaded: {prev}")
|
||||
return JSONResponse({"status": "ok", "previous_adapter": prev})
|
||||
|
||||
|
|
@ -533,18 +549,21 @@ def build_app(args: Namespace) -> FastAPI:
|
|||
async def init_app(args: Namespace, llm_engine: AsyncLLM | None = None) -> FastAPI:
|
||||
"""Initialize the application and vLLM engine."""
|
||||
app = build_app(args)
|
||||
|
||||
|
||||
global engine
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = (
|
||||
llm_engine if llm_engine is not None
|
||||
else AsyncLLM.from_engine_args(engine_args, usage_context=UsageContext.API_SERVER)
|
||||
llm_engine
|
||||
if llm_engine is not None
|
||||
else AsyncLLM.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.API_SERVER
|
||||
)
|
||||
)
|
||||
app.state.engine_client = engine
|
||||
|
||||
|
||||
# Export basic state dict info for trainers (the patched runner exports detailed info)
|
||||
_export_state_dict_info(args)
|
||||
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
|
@ -553,17 +572,20 @@ def _export_state_dict_info(args: Namespace) -> None:
|
|||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
|
||||
|
||||
# Only write basic info if the file doesn't exist or is empty
|
||||
# The patched runner will write complete info with param_mappings
|
||||
try:
|
||||
if json_path.exists():
|
||||
with open(json_path, 'r') as f:
|
||||
with open(json_path, "r") as f:
|
||||
existing = json.load(f)
|
||||
if existing.get("param_mappings") and len(existing["param_mappings"]) > 0:
|
||||
if (
|
||||
existing.get("param_mappings")
|
||||
and len(existing["param_mappings"]) > 0
|
||||
):
|
||||
logger.info(f"Config already has param_mappings, not overwriting")
|
||||
return
|
||||
|
||||
|
||||
info = {
|
||||
"model": getattr(args, "model", "unknown"),
|
||||
"dtype": getattr(args, "dtype", "auto"),
|
||||
|
|
@ -572,32 +594,34 @@ def _export_state_dict_info(args: Namespace) -> None:
|
|||
"param_mappings": {},
|
||||
"shared_weights_enabled": PATCHES_APPLIED,
|
||||
}
|
||||
|
||||
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
|
||||
logger.info(f"Exported basic state dict info to {json_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to export state dict info: {e}")
|
||||
|
||||
|
||||
async def run_server(args: Namespace, llm_engine: AsyncLLM | None = None, **uvicorn_kwargs: Any) -> None:
|
||||
async def run_server(
|
||||
args: Namespace, llm_engine: AsyncLLM | None = None, **uvicorn_kwargs: Any
|
||||
) -> None:
|
||||
"""Run the vLLM API server."""
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
|
||||
if PATCHES_APPLIED:
|
||||
logger.info("=" * 60)
|
||||
logger.info("SHARED MEMORY MODE ENABLED")
|
||||
logger.info("Weight updates from trainer will be reflected immediately!")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
set_ulimit()
|
||||
app = await init_app(args, llm_engine)
|
||||
|
||||
|
||||
if engine is None:
|
||||
raise RuntimeError("No engine initialized")
|
||||
|
||||
|
||||
# Log available endpoints
|
||||
logger.info("=" * 60)
|
||||
logger.info("Available endpoints:")
|
||||
|
|
@ -607,22 +631,22 @@ async def run_server(args: Namespace, llm_engine: AsyncLLM | None = None, **uvic
|
|||
logger.info(" POST /bridge/resume - Resume generation")
|
||||
logger.info(" GET /lora/status - LoRA adapter status")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=None,
|
||||
enable_ssl_refresh=getattr(args, 'enable_ssl_refresh', False),
|
||||
enable_ssl_refresh=getattr(args, "enable_ssl_refresh", False),
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=getattr(args, 'log_level', 'info'),
|
||||
log_level=getattr(args, "log_level", "info"),
|
||||
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=getattr(args, 'ssl_keyfile', None),
|
||||
ssl_certfile=getattr(args, 'ssl_certfile', None),
|
||||
ssl_ca_certs=getattr(args, 'ssl_ca_certs', None),
|
||||
ssl_cert_reqs=getattr(args, 'ssl_cert_reqs', ssl.CERT_NONE),
|
||||
ssl_keyfile=getattr(args, "ssl_keyfile", None),
|
||||
ssl_certfile=getattr(args, "ssl_certfile", None),
|
||||
ssl_ca_certs=getattr(args, "ssl_ca_certs", None),
|
||||
ssl_cert_reqs=getattr(args, "ssl_cert_reqs", ssl.CERT_NONE),
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
|
||||
await shutdown_task
|
||||
|
||||
|
||||
|
|
@ -637,9 +661,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE))
|
||||
parser.add_argument("--root-path", type=str, default=None)
|
||||
parser.add_argument("--log-level", type=str, default="info")
|
||||
|
||||
|
||||
# Add vLLM engine args
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
asyncio.run(run_server(args))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue