diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 4a126ed2..5d8caf0c 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -711,157 +711,273 @@ async def bridge_disable() -> JSONResponse: return JSONResponse({"status": "ok"}) -def _worker_export_cuda_ipc_handles(self) -> dict: - """ - Worker-side function to export CUDA IPC handles. - - This function runs INSIDE the vLLM worker process where the model lives. - Called via collective_rpc - 'self' is the GPU worker instance. - - Returns: - Dictionary with IPC handles for all model parameters. - """ - import base64 - import pickle - - model = self.model_runner.model - - ipc_handles = {} - failed_params = [] - - for name, param in model.named_parameters(): - try: - if not param.is_cuda: - failed_params.append(f"{name}: not on CUDA") - continue - - storage = param.data.storage() - handle = storage._share_cuda_() - - handle_bytes = pickle.dumps(handle) - handle_b64 = base64.b64encode(handle_bytes).decode('ascii') - - ipc_handles[name] = { - "ipc_handle": handle_b64, - "shape": list(param.shape), - "dtype": str(param.dtype), - "device_index": param.device.index if param.device.index is not None else 0, - "storage_offset": param.storage_offset(), - "numel": param.numel(), - "stride": list(param.stride()), - } - except Exception as e: - failed_params.append(f"{name}: {str(e)}") - - return { - "handles": ipc_handles, - "failed": failed_params, - "model_class": model.__class__.__name__, - } +# ============================================================================= +# Weight Update Endpoints (Pause/Resume for Training) +# ============================================================================= -@app.post("/bridge/export_cuda_ipc") -async def bridge_export_cuda_ipc() -> JSONResponse: +@app.post("/bridge/pause") +async def bridge_pause() -> JSONResponse: """ - Export CUDA IPC handles for all model parameters. + Pause generation to allow weight updates. - Uses collective_rpc to execute on the worker process where the model lives. - This is the vLLM v1 way to access the model implicitly. + This is vLLM's built-in mechanism for weight updates! + Waits for in-flight requests to finish, then pauses. - IMPORTANT: Only works when both processes are on the SAME GPU. - - Returns: - JSON with IPC handles, shapes, dtypes for each parameter. + Use this BEFORE updating weights from the trainer. """ assert engine is not None try: - rpc_result = None - method_used = None + await engine.pause_generation( + wait_for_inflight_requests=True, + clear_cache=True, + ) + logger.info("Generation paused for weight updates") - # Try to find collective_rpc in AsyncLLM's internals - # vLLM v1 exposes this on various paths depending on version + return JSONResponse({ + "status": "paused", + "message": "Ready for weight updates. Call /bridge/resume when done.", + }) + except Exception as e: + logger.error(f"Failed to pause generation: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/bridge/resume") +async def bridge_resume() -> JSONResponse: + """ + Resume generation after weight updates. + + Call this AFTER updating weights from the trainer. + """ + assert engine is not None + + try: + await engine.resume_generation() + logger.info("Generation resumed after weight updates") - paths_to_try = [ - ("engine", lambda: engine.collective_rpc(_worker_export_cuda_ipc_handles)), - ("engine.llm_engine", lambda: engine.llm_engine.collective_rpc(_worker_export_cuda_ipc_handles)), - ("engine.engine_core", lambda: engine.engine_core.collective_rpc(_worker_export_cuda_ipc_handles)), - ("engine.llm_engine.model_executor", lambda: engine.llm_engine.model_executor.collective_rpc(_worker_export_cuda_ipc_handles)), - ] + return JSONResponse({ + "status": "resumed", + "message": "Generation resumed with updated weights.", + }) + except Exception as e: + logger.error(f"Failed to resume generation: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/bridge/is_paused") +async def bridge_is_paused() -> JSONResponse: + """Check if generation is currently paused.""" + assert engine is not None + + paused = await engine.is_paused() + return JSONResponse({"paused": paused}) + + +@app.post("/bridge/sleep") +async def bridge_sleep(level: int = 1) -> JSONResponse: + """ + Put the engine to sleep to free GPU memory. + + Level 1: Minimal sleep, fast wake up + Higher levels: Deeper sleep, frees more memory + + Use for memory-constrained environments. + """ + assert engine is not None + + try: + await engine.sleep(level=level) + logger.info(f"Engine put to sleep (level {level})") - for name, rpc_fn in paths_to_try: - try: - logger.info(f"Trying {name}.collective_rpc...") - rpc_result = rpc_fn() - method_used = name - logger.info(f"SUCCESS via {name}") - break - except AttributeError as e: - logger.debug(f"{name} failed: {e}") - continue - except Exception as e: - logger.warning(f"{name} error: {e}") - continue + return JSONResponse({ + "status": "sleeping", + "level": level, + "message": "GPU memory freed. Call /bridge/wake_up to resume.", + }) + except Exception as e: + logger.error(f"Failed to sleep: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/bridge/wake_up") +async def bridge_wake_up() -> JSONResponse: + """ + Wake up the engine from sleep. + + Reloads the model into GPU memory. + """ + assert engine is not None + + try: + await engine.wake_up() + logger.info("Engine woken up") - if rpc_result is None: - # Log engine structure for debugging - logger.error("collective_rpc not found. Engine structure:") - logger.error(f" type: {type(engine).__name__}") - logger.error(f" attrs: {[a for a in dir(engine) if not a.startswith('_')][:20]}") - - raise HTTPException( - status_code=500, - detail=( - "collective_rpc not available on AsyncLLM in this vLLM version. " - "Options: 1) Use vllm_sync_server.py for CUDA IPC, " - "2) Use multi-GPU mode (separate GPUs for trainer/vLLM), " - "3) Use legacy or LoRA mode." - ) - ) + return JSONResponse({ + "status": "awake", + "message": "Model reloaded into GPU memory.", + }) + except Exception as e: + logger.error(f"Failed to wake up: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/bridge/is_sleeping") +async def bridge_is_sleeping() -> JSONResponse: + """Check if engine is currently sleeping.""" + assert engine is not None + + sleeping = await engine.is_sleeping() + return JSONResponse({"sleeping": sleeping}) + + +# ============================================================================= +# RPC Endpoints (Call Worker Methods) +# ============================================================================= + + +class CollectiveRPCRequest(BaseModel): + """Request to call a method on all workers.""" + method: str + timeout: Optional[float] = None + args: List[Any] = [] + kwargs: Dict[str, Any] = {} + + +@app.post("/bridge/collective_rpc") +async def bridge_collective_rpc(request: CollectiveRPCRequest) -> JSONResponse: + """ + Call a method on all workers via collective RPC. + + The method must exist on the worker class. + This is an advanced endpoint for custom worker operations. + + Example worker methods: + - 'save_model' - Save model weights + - 'get_model_info' - Get model information + + Note: The method name is passed as a STRING, not a function. + """ + assert engine is not None + + try: + result = await engine.collective_rpc( + method=request.method, + timeout=request.timeout, + args=tuple(request.args), + kwargs=request.kwargs if request.kwargs else None, + ) - # Process result (collective_rpc returns a list, one per worker) - result = rpc_result[0] if isinstance(rpc_result, list) else rpc_result - ipc_handles = result.get("handles", {}) - failed_params = result.get("failed", []) - - if failed_params: - logger.warning(f"Could not export {len(failed_params)} parameters") - - if len(ipc_handles) == 0: - raise HTTPException(status_code=500, detail="No IPC handles exported") - - # Save to file - log_dir = os.environ.get("LOGDIR", ".") - ipc_path = Path(log_dir) / "cuda_ipc_handles.json" - - with open(ipc_path, "w") as f: - json.dump({ - "handles": ipc_handles, - "model_class": result.get("model_class", "unknown"), - "device_count": torch.cuda.device_count(), - "export_time": time.time(), - "method": method_used, - }, f, indent=2) - - logger.info(f"Exported {len(ipc_handles)} CUDA IPC handles via {method_used}") + logger.info(f"collective_rpc({request.method}) completed") return JSONResponse({ "status": "ok", - "num_parameters": len(ipc_handles), - "ipc_path": str(ipc_path), - "total_params": sum(info["numel"] for info in ipc_handles.values()), - "method": method_used, + "method": request.method, + "result": result if isinstance(result, (dict, list, str, int, float, bool, type(None))) else str(result), }) - - except HTTPException: - raise except Exception as e: - logger.error(f"Failed to export CUDA IPC handles: {e}") - import traceback - logger.error(traceback.format_exc()) + logger.error(f"collective_rpc failed: {e}") raise HTTPException(status_code=500, detail=str(e)) +@app.get("/bridge/debug") +async def bridge_debug() -> JSONResponse: + """ + Debug endpoint to inspect AsyncLLM capabilities. + + Lists available attributes and methods on the engine. + """ + assert engine is not None + + debug_info = { + "engine_type": type(engine).__name__, + "vllm_version": VLLM_VERSION, + "model_config": { + "model": str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown", + "dtype": str(engine.model_config.dtype) if hasattr(engine, "model_config") else "unknown", + }, + "available_methods": {}, + "important_attributes": {}, + } + + # Check for important methods + important_methods = [ + "pause_generation", "resume_generation", "is_paused", + "sleep", "wake_up", "is_sleeping", + "collective_rpc", "add_lora", "remove_lora", "list_loras", + "generate", "encode", "abort", "check_health", + ] + + for method in important_methods: + has_method = hasattr(engine, method) and callable(getattr(engine, method)) + debug_info["available_methods"][method] = has_method + + # Check important attributes + important_attrs = [ + "engine_core", "model_config", "vllm_config", + "input_processor", "output_processor", "tokenizer", + ] + + for attr in important_attrs: + if hasattr(engine, attr): + attr_val = getattr(engine, attr) + debug_info["important_attributes"][attr] = type(attr_val).__name__ + else: + debug_info["important_attributes"][attr] = None + + return JSONResponse(debug_info) + + +@app.get("/bridge/list_endpoints") +async def bridge_list_endpoints() -> JSONResponse: + """ + List all available bridge endpoints with descriptions. + + Use this to discover what capabilities are available. + """ + endpoints = { + "health": { + "GET /health": "Basic health check", + "GET /health_generate": "Deep health check (sends test request)", + }, + "generation": { + "POST /generate": "Generate text (vLLM native format)", + "POST /v1/completions": "Generate text (OpenAI format)", + "POST /v1/chat/completions": "Chat completion (OpenAI format)", + }, + "bridge_control": { + "GET /bridge/info": "Get bridge status and rendezvous info", + "POST /bridge/init": "Initialize weight bridge for NCCL", + "POST /bridge/disable": "Disable weight bridge", + "GET /bridge/state_dict_info": "Get model parameter info", + }, + "weight_updates": { + "POST /bridge/pause": "⭐ Pause generation for weight updates", + "POST /bridge/resume": "⭐ Resume generation after weight updates", + "GET /bridge/is_paused": "Check if paused", + "POST /bridge/notify_update": "Notify server of weight update", + }, + "memory_management": { + "POST /bridge/sleep": "Put engine to sleep (free GPU memory)", + "POST /bridge/wake_up": "Wake engine up (reload model)", + "GET /bridge/is_sleeping": "Check if sleeping", + }, + "lora_adapters": { + "GET /lora/status": "Get LoRA status", + "POST /lora/load": "Load LoRA adapter", + "POST /lora/unload": "Unload LoRA adapter", + }, + "advanced": { + "POST /bridge/collective_rpc": "Call method on workers", + "GET /bridge/debug": "Debug engine structure", + "GET /bridge/list_endpoints": "This endpoint", + }, + } + + return JSONResponse(endpoints) + + # ============================================================================= # LoRA Endpoints (for adapter hot-swapping) # ============================================================================= @@ -1045,14 +1161,28 @@ async def run_server( assert engine is not None # Log bridge endpoints + logger.info("=" * 60) logger.info("Bridge endpoints available:") - logger.info(" GET /bridge/info - Get bridge status") - logger.info(" POST /bridge/init - Initialize weight bridge") - logger.info(" POST /bridge/notify_update - Notify of weight update") - logger.info(" GET /bridge/state_dict_info - Get model parameters") - logger.info(" GET /lora/status - Get LoRA adapter status") - logger.info(" POST /lora/load - Load LoRA adapter") - logger.info(" POST /lora/unload - Unload LoRA adapter") + logger.info("-" * 60) + logger.info("Weight Updates (use these for training!):") + logger.info(" POST /bridge/pause - Pause generation for weight updates") + logger.info(" POST /bridge/resume - Resume after updating weights") + logger.info(" GET /bridge/is_paused - Check pause state") + logger.info("-" * 60) + logger.info("Memory Management:") + logger.info(" POST /bridge/sleep - Free GPU memory") + logger.info(" POST /bridge/wake_up - Reload model") + logger.info("-" * 60) + logger.info("LoRA Adapters:") + logger.info(" GET /lora/status - Get adapter status") + logger.info(" POST /lora/load - Load adapter") + logger.info(" POST /lora/unload - Unload adapter") + logger.info("-" * 60) + logger.info("Debug:") + logger.info(" GET /bridge/debug - Inspect engine") + logger.info(" GET /bridge/list_endpoints - List all endpoints") + logger.info(" POST /bridge/collective_rpc - Call worker methods") + logger.info("=" * 60) shutdown_task = await serve_http( app,