weight updates async

This commit is contained in:
Jai Suphavadeeprasit 2025-12-10 18:27:47 -05:00
parent b3874b658a
commit 39e94c4278

View file

@ -711,157 +711,273 @@ async def bridge_disable() -> JSONResponse:
return JSONResponse({"status": "ok"})
def _worker_export_cuda_ipc_handles(self) -> dict:
"""
Worker-side function to export CUDA IPC handles.
This function runs INSIDE the vLLM worker process where the model lives.
Called via collective_rpc - 'self' is the GPU worker instance.
Returns:
Dictionary with IPC handles for all model parameters.
"""
import base64
import pickle
model = self.model_runner.model
ipc_handles = {}
failed_params = []
for name, param in model.named_parameters():
try:
if not param.is_cuda:
failed_params.append(f"{name}: not on CUDA")
continue
storage = param.data.storage()
handle = storage._share_cuda_()
handle_bytes = pickle.dumps(handle)
handle_b64 = base64.b64encode(handle_bytes).decode('ascii')
ipc_handles[name] = {
"ipc_handle": handle_b64,
"shape": list(param.shape),
"dtype": str(param.dtype),
"device_index": param.device.index if param.device.index is not None else 0,
"storage_offset": param.storage_offset(),
"numel": param.numel(),
"stride": list(param.stride()),
}
except Exception as e:
failed_params.append(f"{name}: {str(e)}")
return {
"handles": ipc_handles,
"failed": failed_params,
"model_class": model.__class__.__name__,
}
# =============================================================================
# Weight Update Endpoints (Pause/Resume for Training)
# =============================================================================
@app.post("/bridge/export_cuda_ipc")
async def bridge_export_cuda_ipc() -> JSONResponse:
@app.post("/bridge/pause")
async def bridge_pause() -> JSONResponse:
"""
Export CUDA IPC handles for all model parameters.
Pause generation to allow weight updates.
Uses collective_rpc to execute on the worker process where the model lives.
This is the vLLM v1 way to access the model implicitly.
This is vLLM's built-in mechanism for weight updates!
Waits for in-flight requests to finish, then pauses.
IMPORTANT: Only works when both processes are on the SAME GPU.
Returns:
JSON with IPC handles, shapes, dtypes for each parameter.
Use this BEFORE updating weights from the trainer.
"""
assert engine is not None
try:
rpc_result = None
method_used = None
await engine.pause_generation(
wait_for_inflight_requests=True,
clear_cache=True,
)
logger.info("Generation paused for weight updates")
# Try to find collective_rpc in AsyncLLM's internals
# vLLM v1 exposes this on various paths depending on version
return JSONResponse({
"status": "paused",
"message": "Ready for weight updates. Call /bridge/resume when done.",
})
except Exception as e:
logger.error(f"Failed to pause generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/bridge/resume")
async def bridge_resume() -> JSONResponse:
"""
Resume generation after weight updates.
Call this AFTER updating weights from the trainer.
"""
assert engine is not None
try:
await engine.resume_generation()
logger.info("Generation resumed after weight updates")
paths_to_try = [
("engine", lambda: engine.collective_rpc(_worker_export_cuda_ipc_handles)),
("engine.llm_engine", lambda: engine.llm_engine.collective_rpc(_worker_export_cuda_ipc_handles)),
("engine.engine_core", lambda: engine.engine_core.collective_rpc(_worker_export_cuda_ipc_handles)),
("engine.llm_engine.model_executor", lambda: engine.llm_engine.model_executor.collective_rpc(_worker_export_cuda_ipc_handles)),
]
return JSONResponse({
"status": "resumed",
"message": "Generation resumed with updated weights.",
})
except Exception as e:
logger.error(f"Failed to resume generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/bridge/is_paused")
async def bridge_is_paused() -> JSONResponse:
"""Check if generation is currently paused."""
assert engine is not None
paused = await engine.is_paused()
return JSONResponse({"paused": paused})
@app.post("/bridge/sleep")
async def bridge_sleep(level: int = 1) -> JSONResponse:
"""
Put the engine to sleep to free GPU memory.
Level 1: Minimal sleep, fast wake up
Higher levels: Deeper sleep, frees more memory
Use for memory-constrained environments.
"""
assert engine is not None
try:
await engine.sleep(level=level)
logger.info(f"Engine put to sleep (level {level})")
for name, rpc_fn in paths_to_try:
try:
logger.info(f"Trying {name}.collective_rpc...")
rpc_result = rpc_fn()
method_used = name
logger.info(f"SUCCESS via {name}")
break
except AttributeError as e:
logger.debug(f"{name} failed: {e}")
continue
except Exception as e:
logger.warning(f"{name} error: {e}")
continue
return JSONResponse({
"status": "sleeping",
"level": level,
"message": "GPU memory freed. Call /bridge/wake_up to resume.",
})
except Exception as e:
logger.error(f"Failed to sleep: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/bridge/wake_up")
async def bridge_wake_up() -> JSONResponse:
"""
Wake up the engine from sleep.
Reloads the model into GPU memory.
"""
assert engine is not None
try:
await engine.wake_up()
logger.info("Engine woken up")
if rpc_result is None:
# Log engine structure for debugging
logger.error("collective_rpc not found. Engine structure:")
logger.error(f" type: {type(engine).__name__}")
logger.error(f" attrs: {[a for a in dir(engine) if not a.startswith('_')][:20]}")
raise HTTPException(
status_code=500,
detail=(
"collective_rpc not available on AsyncLLM in this vLLM version. "
"Options: 1) Use vllm_sync_server.py for CUDA IPC, "
"2) Use multi-GPU mode (separate GPUs for trainer/vLLM), "
"3) Use legacy or LoRA mode."
)
)
return JSONResponse({
"status": "awake",
"message": "Model reloaded into GPU memory.",
})
except Exception as e:
logger.error(f"Failed to wake up: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/bridge/is_sleeping")
async def bridge_is_sleeping() -> JSONResponse:
"""Check if engine is currently sleeping."""
assert engine is not None
sleeping = await engine.is_sleeping()
return JSONResponse({"sleeping": sleeping})
# =============================================================================
# RPC Endpoints (Call Worker Methods)
# =============================================================================
class CollectiveRPCRequest(BaseModel):
"""Request to call a method on all workers."""
method: str
timeout: Optional[float] = None
args: List[Any] = []
kwargs: Dict[str, Any] = {}
@app.post("/bridge/collective_rpc")
async def bridge_collective_rpc(request: CollectiveRPCRequest) -> JSONResponse:
"""
Call a method on all workers via collective RPC.
The method must exist on the worker class.
This is an advanced endpoint for custom worker operations.
Example worker methods:
- 'save_model' - Save model weights
- 'get_model_info' - Get model information
Note: The method name is passed as a STRING, not a function.
"""
assert engine is not None
try:
result = await engine.collective_rpc(
method=request.method,
timeout=request.timeout,
args=tuple(request.args),
kwargs=request.kwargs if request.kwargs else None,
)
# Process result (collective_rpc returns a list, one per worker)
result = rpc_result[0] if isinstance(rpc_result, list) else rpc_result
ipc_handles = result.get("handles", {})
failed_params = result.get("failed", [])
if failed_params:
logger.warning(f"Could not export {len(failed_params)} parameters")
if len(ipc_handles) == 0:
raise HTTPException(status_code=500, detail="No IPC handles exported")
# Save to file
log_dir = os.environ.get("LOGDIR", ".")
ipc_path = Path(log_dir) / "cuda_ipc_handles.json"
with open(ipc_path, "w") as f:
json.dump({
"handles": ipc_handles,
"model_class": result.get("model_class", "unknown"),
"device_count": torch.cuda.device_count(),
"export_time": time.time(),
"method": method_used,
}, f, indent=2)
logger.info(f"Exported {len(ipc_handles)} CUDA IPC handles via {method_used}")
logger.info(f"collective_rpc({request.method}) completed")
return JSONResponse({
"status": "ok",
"num_parameters": len(ipc_handles),
"ipc_path": str(ipc_path),
"total_params": sum(info["numel"] for info in ipc_handles.values()),
"method": method_used,
"method": request.method,
"result": result if isinstance(result, (dict, list, str, int, float, bool, type(None))) else str(result),
})
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to export CUDA IPC handles: {e}")
import traceback
logger.error(traceback.format_exc())
logger.error(f"collective_rpc failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/bridge/debug")
async def bridge_debug() -> JSONResponse:
"""
Debug endpoint to inspect AsyncLLM capabilities.
Lists available attributes and methods on the engine.
"""
assert engine is not None
debug_info = {
"engine_type": type(engine).__name__,
"vllm_version": VLLM_VERSION,
"model_config": {
"model": str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown",
"dtype": str(engine.model_config.dtype) if hasattr(engine, "model_config") else "unknown",
},
"available_methods": {},
"important_attributes": {},
}
# Check for important methods
important_methods = [
"pause_generation", "resume_generation", "is_paused",
"sleep", "wake_up", "is_sleeping",
"collective_rpc", "add_lora", "remove_lora", "list_loras",
"generate", "encode", "abort", "check_health",
]
for method in important_methods:
has_method = hasattr(engine, method) and callable(getattr(engine, method))
debug_info["available_methods"][method] = has_method
# Check important attributes
important_attrs = [
"engine_core", "model_config", "vllm_config",
"input_processor", "output_processor", "tokenizer",
]
for attr in important_attrs:
if hasattr(engine, attr):
attr_val = getattr(engine, attr)
debug_info["important_attributes"][attr] = type(attr_val).__name__
else:
debug_info["important_attributes"][attr] = None
return JSONResponse(debug_info)
@app.get("/bridge/list_endpoints")
async def bridge_list_endpoints() -> JSONResponse:
"""
List all available bridge endpoints with descriptions.
Use this to discover what capabilities are available.
"""
endpoints = {
"health": {
"GET /health": "Basic health check",
"GET /health_generate": "Deep health check (sends test request)",
},
"generation": {
"POST /generate": "Generate text (vLLM native format)",
"POST /v1/completions": "Generate text (OpenAI format)",
"POST /v1/chat/completions": "Chat completion (OpenAI format)",
},
"bridge_control": {
"GET /bridge/info": "Get bridge status and rendezvous info",
"POST /bridge/init": "Initialize weight bridge for NCCL",
"POST /bridge/disable": "Disable weight bridge",
"GET /bridge/state_dict_info": "Get model parameter info",
},
"weight_updates": {
"POST /bridge/pause": "⭐ Pause generation for weight updates",
"POST /bridge/resume": "⭐ Resume generation after weight updates",
"GET /bridge/is_paused": "Check if paused",
"POST /bridge/notify_update": "Notify server of weight update",
},
"memory_management": {
"POST /bridge/sleep": "Put engine to sleep (free GPU memory)",
"POST /bridge/wake_up": "Wake engine up (reload model)",
"GET /bridge/is_sleeping": "Check if sleeping",
},
"lora_adapters": {
"GET /lora/status": "Get LoRA status",
"POST /lora/load": "Load LoRA adapter",
"POST /lora/unload": "Unload LoRA adapter",
},
"advanced": {
"POST /bridge/collective_rpc": "Call method on workers",
"GET /bridge/debug": "Debug engine structure",
"GET /bridge/list_endpoints": "This endpoint",
},
}
return JSONResponse(endpoints)
# =============================================================================
# LoRA Endpoints (for adapter hot-swapping)
# =============================================================================
@ -1045,14 +1161,28 @@ async def run_server(
assert engine is not None
# Log bridge endpoints
logger.info("=" * 60)
logger.info("Bridge endpoints available:")
logger.info(" GET /bridge/info - Get bridge status")
logger.info(" POST /bridge/init - Initialize weight bridge")
logger.info(" POST /bridge/notify_update - Notify of weight update")
logger.info(" GET /bridge/state_dict_info - Get model parameters")
logger.info(" GET /lora/status - Get LoRA adapter status")
logger.info(" POST /lora/load - Load LoRA adapter")
logger.info(" POST /lora/unload - Unload LoRA adapter")
logger.info("-" * 60)
logger.info("Weight Updates (use these for training!):")
logger.info(" POST /bridge/pause - Pause generation for weight updates")
logger.info(" POST /bridge/resume - Resume after updating weights")
logger.info(" GET /bridge/is_paused - Check pause state")
logger.info("-" * 60)
logger.info("Memory Management:")
logger.info(" POST /bridge/sleep - Free GPU memory")
logger.info(" POST /bridge/wake_up - Reload model")
logger.info("-" * 60)
logger.info("LoRA Adapters:")
logger.info(" GET /lora/status - Get adapter status")
logger.info(" POST /lora/load - Load adapter")
logger.info(" POST /lora/unload - Unload adapter")
logger.info("-" * 60)
logger.info("Debug:")
logger.info(" GET /bridge/debug - Inspect engine")
logger.info(" GET /bridge/list_endpoints - List all endpoints")
logger.info(" POST /bridge/collective_rpc - Call worker methods")
logger.info("=" * 60)
shutdown_task = await serve_http(
app,