diff --git a/example_trainer/requirements.txt b/example_trainer/requirements.txt index f75a11bb..0313a941 100644 --- a/example_trainer/requirements.txt +++ b/example_trainer/requirements.txt @@ -4,3 +4,5 @@ transformers datasets accelerate peft +requests +wandb diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 5d8caf0c..89d181c3 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -38,6 +38,9 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional +import base64 +import pickle + import torch import vllm.envs as envs from fastapi import FastAPI, HTTPException, Request @@ -52,6 +55,14 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import random_uuid from vllm.v1.engine.async_llm import AsyncLLM +# Import sync LLM for collective_rpc with function support +try: + from vllm import LLM as SyncLLM + SYNC_LLM_AVAILABLE = True +except ImportError: + SYNC_LLM_AVAILABLE = False + SyncLLM = None + try: from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.system_utils import set_ulimit @@ -69,6 +80,7 @@ logger = init_logger("vllm.entrypoints.api_server") app = FastAPI() engine: Optional[AsyncLLM] = None +sync_engine: Optional["SyncLLM"] = None # For collective_rpc with functions @dataclass @@ -89,6 +101,15 @@ class BridgeState: bridge_state = BridgeState() +def get_engine(): + """Get the active engine (async or sync).""" + if engine is not None: + return engine + if sync_engine is not None: + return sync_engine + raise HTTPException(status_code=503, detail="No engine available") + + # ============================================================================= # Pydantic Models for API # ============================================================================= @@ -156,17 +177,30 @@ async def health_generate() -> Response: This sends a minimal request through the full inference pipeline to verify the model is loaded and functioning. """ - assert engine is not None - sampling_params = SamplingParams() - request_id = random_uuid() - results_generator = engine.generate( - {"prompt_token_ids": [0]}, sampling_params, request_id - ) - try: - async for request_output in results_generator: - final_output = request_output # type: RequestOutput # noqa: F841 - except asyncio.CancelledError: - return Response(status_code=499) + sampling_params = SamplingParams(max_tokens=1) + + if engine is not None: + # Async engine path + request_id = random_uuid() + results_generator = engine.generate( + {"prompt_token_ids": [0]}, sampling_params, request_id + ) + try: + async for request_output in results_generator: + final_output = request_output # type: RequestOutput # noqa: F841 + except asyncio.CancelledError: + return Response(status_code=499) + elif sync_engine is not None: + # Sync engine path (CUDA IPC mode) + import concurrent.futures + def _sync_health_check(): + return sync_engine.generate(["Hello"], sampling_params) + loop = asyncio.get_event_loop() + with concurrent.futures.ThreadPoolExecutor() as pool: + await loop.run_in_executor(pool, _sync_health_check) + else: + return Response(status_code=503) + return Response(status_code=200) @@ -223,13 +257,70 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: if logprobs_val is not None: request_dict["logprobs"] = max(1, logprobs_val) # At least 1 - request_dict["output_kind"] = RequestOutputKind.FINAL_ONLY sampling_params = SamplingParams(**request_dict) request_id = random_uuid() + + # Handle both async engine (standard) and sync engine (CUDA IPC mode) + if engine is not None: + # Standard async mode + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + results_generator = engine.generate(prompt, sampling_params, request_id) + elif sync_engine is not None: + # CUDA IPC mode: use sync engine with thread pool + # Sync LLM doesn't support streaming, so disable it + if stream: + logger.warning("Streaming not supported in CUDA IPC mode, using non-streaming") + stream = False + + # Run sync generation in thread pool + import concurrent.futures + def _sync_generate(): + return sync_engine.generate([prompt], sampling_params) + + loop = asyncio.get_event_loop() + with concurrent.futures.ThreadPoolExecutor() as pool: + outputs = await loop.run_in_executor(pool, _sync_generate) + + # Convert to match async output format + if outputs: + final_output = outputs[0] + prompt_text = final_output.prompt or ( + sync_engine.get_tokenizer().decode(final_output.prompt_token_ids) + if final_output.prompt_token_ids else "" + ) + text_outputs = [output.text for output in final_output.outputs] + finish_reasons = [output.finish_reason for output in final_output.outputs] + ret = {"text": text_outputs, "prompt": prompt_text, "finish_reasons": finish_reasons} + + # Include logprobs if requested + if sampling_params.logprobs is not None: + output_logprobs = [] + for x in final_output.outputs: + if x.logprobs: + seq_logprobs = [ + [{str(key): value.logprob for key, value in logprob.items()}] + for logprob in x.logprobs + ] + else: + seq_logprobs = [] + output_logprobs.append(seq_logprobs) + + prompt_token_ids = final_output.prompt_token_ids + output_token_ids = [list(x.token_ids) for x in final_output.outputs] + ret["logprobs"] = output_logprobs + ret["prompt_token_ids"] = list(prompt_token_ids) if prompt_token_ids else [] + ret["token_ids"] = output_token_ids + + return JSONResponse(ret) + else: + return JSONResponse({"error": "No output generated"}, status_code=500) + else: + raise HTTPException(status_code=503, detail="No engine available") - assert engine is not None - results_generator = engine.generate(prompt, sampling_params, request_id) - + # ========================================================================= + # Async engine path (standard mode) - streaming and non-streaming + # ========================================================================= + # Streaming: yield results as theyre generated async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: @@ -251,6 +342,7 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: return Response(status_code=499) assert final_output is not None + assert engine is not None prompt_text = final_output.prompt or engine.tokenizer.decode( final_output.prompt_token_ids ) @@ -346,19 +438,31 @@ async def openai_completions(request: Request) -> Response: sampling_kwargs["stop"] = stop sampling_params = SamplingParams(**sampling_kwargs) - sampling_params.output_kind = RequestOutputKind.FINAL_ONLY request_id = random_uuid() - assert engine is not None - results_generator = engine.generate(prompt, sampling_params, request_id) + # Handle both async and sync engines + if engine is not None: + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + results_generator = engine.generate(prompt, sampling_params, request_id) - # Non-streaming response - final_output = None - try: - async for request_output in results_generator: - final_output = request_output - except asyncio.CancelledError: - return Response(status_code=499) + # Non-streaming response + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) + elif sync_engine is not None: + # CUDA IPC mode: use sync engine + import concurrent.futures + def _sync_generate(): + return sync_engine.generate([prompt], sampling_params) + loop = asyncio.get_event_loop() + with concurrent.futures.ThreadPoolExecutor() as pool: + outputs = await loop.run_in_executor(pool, _sync_generate) + final_output = outputs[0] if outputs else None + else: + raise HTTPException(status_code=503, detail="No engine available") if final_output is None: return JSONResponse( @@ -446,11 +550,14 @@ async def openai_chat_completions(request: Request) -> Response: stop = request_dict.get("stop") # Convert messages to prompt using chat template - assert engine is not None + active_engine = get_engine() # Try to use the tokenizer's chat template try: - tokenizer = engine.tokenizer.tokenizer + if engine is not None: + tokenizer = engine.tokenizer.tokenizer + else: + tokenizer = sync_engine.get_tokenizer() if hasattr(tokenizer, "apply_chat_template"): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True @@ -484,18 +591,31 @@ async def openai_chat_completions(request: Request) -> Response: sampling_kwargs["stop"] = stop sampling_params = SamplingParams(**sampling_kwargs) - sampling_params.output_kind = RequestOutputKind.FINAL_ONLY request_id = random_uuid() - results_generator = engine.generate(prompt, sampling_params, request_id) + # Handle both async and sync engines + if engine is not None: + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + results_generator = engine.generate(prompt, sampling_params, request_id) - # Non-streaming response - final_output = None - try: - async for request_output in results_generator: - final_output = request_output - except asyncio.CancelledError: - return Response(status_code=499) + # Non-streaming response + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) + elif sync_engine is not None: + # CUDA IPC mode: use sync engine + import concurrent.futures + def _sync_generate(): + return sync_engine.generate([prompt], sampling_params) + loop = asyncio.get_event_loop() + with concurrent.futures.ThreadPoolExecutor() as pool: + outputs = await loop.run_in_executor(pool, _sync_generate) + final_output = outputs[0] if outputs else None + else: + raise HTTPException(status_code=503, detail="No engine available") if final_output is None: return JSONResponse( @@ -542,9 +662,14 @@ async def list_models() -> JSONResponse: Returns the currently loaded model. """ - assert engine is not None + active_engine = get_engine() - model_name = str(engine.engine.model_config.model) if hasattr(engine, "engine") else "unknown" + if engine is not None: + model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown" + elif sync_engine is not None: + model_name = str(sync_engine.llm_engine.model_config.model) if hasattr(sync_engine, "llm_engine") else "unknown" + else: + model_name = "unknown" return JSONResponse({ "object": "list", @@ -567,9 +692,14 @@ async def get_model(model_id: str) -> JSONResponse: """ Get model info (OpenAI-compatible). """ - assert engine is not None + active_engine = get_engine() - model_name = str(engine.engine.model_config.model) if hasattr(engine, "engine") else "unknown" + if engine is not None: + model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown" + elif sync_engine is not None: + model_name = str(sync_engine.llm_engine.model_config.model) if hasattr(sync_engine, "llm_engine") else "unknown" + else: + model_name = "unknown" return JSONResponse({ "id": model_name, @@ -595,16 +725,25 @@ async def bridge_info() -> BridgeInfoResponse: Trainers call this to discover how to connect to the weight-sharing process group. Returns connection details and current sync state. """ - assert engine is not None + active_engine = get_engine() + + if engine is not None: + model_name = str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown" + device = "unknown" # Can't easily get device from AsyncLLM + elif sync_engine is not None: + model_name = str(sync_engine.llm_engine.model_config.model) if hasattr(sync_engine, "llm_engine") else "unknown" + device = "cuda" # Sync engine is always on CUDA for IPC + else: + model_name = "unknown" + device = "unknown" return BridgeInfoResponse( enabled=bridge_state.enabled, update_count=bridge_state.update_count, last_update_time=bridge_state.last_update_time, rendezvous_info=bridge_state.rendezvous_info, - model_name=str(engine.engine.model_config.model), - device=str(next(iter(engine.engine.model_executor.driver_worker.model_runner.model.parameters())).device) - if hasattr(engine, "engine") else "unknown", + model_name=model_name, + device=device, ) @@ -669,13 +808,25 @@ async def bridge_state_dict_info() -> JSONResponse: Returns parameter names, shapes, and dtypes so trainers can properly map their tensors to the inference model's parameters. """ - assert engine is not None + active_engine = get_engine() try: - # Access the underlying model - model = engine.engine.model_executor.driver_worker.model_runner.model + # Access the underlying model based on engine type + if sync_engine is not None: + # CUDA IPC mode: can access model directly + model = sync_engine.llm_engine.model_executor.driver_worker.model_runner.model + elif engine is not None: + # Async mode: model is in subprocess, can't access directly + return JSONResponse({ + "status": "unavailable", + "message": "Model state dict not accessible in async mode. Use CUDA IPC mode (--enable-cuda-ipc) for direct access.", + "num_parameters": 0, + "parameters": {}, + }) + else: + raise HTTPException(status_code=503, detail="No engine available") + state_dict_info = {} - for name, param in model.named_parameters(): state_dict_info[name] = { "shape": list(param.shape), @@ -725,8 +876,16 @@ async def bridge_pause() -> JSONResponse: Waits for in-flight requests to finish, then pauses. Use this BEFORE updating weights from the trainer. + + NOTE: Only available with AsyncLLM (not CUDA IPC mode). """ - assert engine is not None + if engine is None: + if sync_engine is not None: + return JSONResponse({ + "status": "not_supported", + "message": "Pause/resume not supported in CUDA IPC mode. Weights are shared directly.", + }) + raise HTTPException(status_code=503, detail="No engine available") try: await engine.pause_generation( @@ -750,8 +909,16 @@ async def bridge_resume() -> JSONResponse: Resume generation after weight updates. Call this AFTER updating weights from the trainer. + + NOTE: Only available with AsyncLLM (not CUDA IPC mode). """ - assert engine is not None + if engine is None: + if sync_engine is not None: + return JSONResponse({ + "status": "not_supported", + "message": "Pause/resume not supported in CUDA IPC mode. Weights are shared directly.", + }) + raise HTTPException(status_code=503, detail="No engine available") try: await engine.resume_generation() @@ -769,7 +936,10 @@ async def bridge_resume() -> JSONResponse: @app.get("/bridge/is_paused") async def bridge_is_paused() -> JSONResponse: """Check if generation is currently paused.""" - assert engine is not None + if engine is None: + if sync_engine is not None: + return JSONResponse({"paused": False, "mode": "cuda_ipc"}) + raise HTTPException(status_code=503, detail="No engine available") paused = await engine.is_paused() return JSONResponse({"paused": paused}) @@ -784,8 +954,16 @@ async def bridge_sleep(level: int = 1) -> JSONResponse: Higher levels: Deeper sleep, frees more memory Use for memory-constrained environments. + + NOTE: Only available with AsyncLLM (not CUDA IPC mode). """ - assert engine is not None + if engine is None: + if sync_engine is not None: + return JSONResponse({ + "status": "not_supported", + "message": "Sleep/wake not supported in CUDA IPC mode.", + }) + raise HTTPException(status_code=503, detail="No engine available") try: await engine.sleep(level=level) @@ -807,8 +985,16 @@ async def bridge_wake_up() -> JSONResponse: Wake up the engine from sleep. Reloads the model into GPU memory. + + NOTE: Only available with AsyncLLM (not CUDA IPC mode). """ - assert engine is not None + if engine is None: + if sync_engine is not None: + return JSONResponse({ + "status": "not_supported", + "message": "Sleep/wake not supported in CUDA IPC mode.", + }) + raise HTTPException(status_code=503, detail="No engine available") try: await engine.wake_up() @@ -826,7 +1012,10 @@ async def bridge_wake_up() -> JSONResponse: @app.get("/bridge/is_sleeping") async def bridge_is_sleeping() -> JSONResponse: """Check if engine is currently sleeping.""" - assert engine is not None + if engine is None: + if sync_engine is not None: + return JSONResponse({"sleeping": False, "mode": "cuda_ipc"}) + raise HTTPException(status_code=503, detail="No engine available") sleeping = await engine.is_sleeping() return JSONResponse({"sleeping": sleeping}) @@ -857,9 +1046,16 @@ async def bridge_collective_rpc(request: CollectiveRPCRequest) -> JSONResponse: - 'save_model' - Save model weights - 'get_model_info' - Get model information - Note: The method name is passed as a STRING, not a function. + Note: For AsyncLLM, the method name is passed as a STRING. + For sync LLM (CUDA IPC mode), use /bridge/export_cuda_ipc instead. """ - assert engine is not None + if engine is None: + if sync_engine is not None: + return JSONResponse({ + "status": "not_supported", + "message": "Use /bridge/export_cuda_ipc for sync LLM collective operations.", + }) + raise HTTPException(status_code=503, detail="No engine available") try: result = await engine.collective_rpc( @@ -881,26 +1077,209 @@ async def bridge_collective_rpc(request: CollectiveRPCRequest) -> JSONResponse: raise HTTPException(status_code=500, detail=str(e)) +# ============================================================================= +# CUDA IPC Export (True Shared Memory) +# ============================================================================= + + +def _export_cuda_ipc_handles_fn(worker_self) -> dict: + """ + Worker-side function to export CUDA IPC handles. + + This function runs INSIDE the vLLM worker process where the model lives. + The first argument 'worker_self' is the GPU worker instance. + + Returns: + Dictionary with IPC handles for all model parameters. + """ + model = worker_self.model_runner.model + + ipc_handles = {} + failed_params = [] + + for name, param in model.named_parameters(): + try: + if not param.is_cuda: + failed_params.append(f"{name}: not on CUDA") + continue + + # Get the underlying storage and create IPC handle + storage = param.data.storage() + handle = storage._share_cuda_() + + # Serialize the handle + handle_bytes = pickle.dumps(handle) + handle_b64 = base64.b64encode(handle_bytes).decode('ascii') + + ipc_handles[name] = { + "ipc_handle": handle_b64, + "shape": list(param.shape), + "dtype": str(param.dtype), + "device_index": param.device.index if param.device.index is not None else 0, + "storage_offset": param.storage_offset(), + "numel": param.numel(), + "stride": list(param.stride()), + } + except Exception as e: + failed_params.append(f"{name}: {str(e)}") + + return { + "handles": ipc_handles, + "failed": failed_params, + "model_class": model.__class__.__name__, + "num_params": len(list(model.parameters())), + } + + +@app.post("/bridge/export_cuda_ipc") +async def bridge_export_cuda_ipc() -> JSONResponse: + """ + Export CUDA IPC handles for all model parameters. + + This enables TRUE shared memory between vLLM and the trainer! + Both processes can access the SAME GPU tensors. + + Uses sync LLM's collective_rpc which accepts functions. + + REQUIREMENTS: + - Both processes must be on the SAME GPU + - vLLM must be started with --enable-cuda-ipc flag + + Returns: + JSON with path to IPC handles file and parameter count. + """ + global sync_engine + + if sync_engine is None: + raise HTTPException( + status_code=503, + detail=( + "Sync LLM not initialized. Start server with --enable-cuda-ipc flag. " + "Note: CUDA IPC requires sync LLM which may reduce throughput." + ) + ) + + try: + # Use sync LLM's collective_rpc with a FUNCTION (not a string!) + # This is the key difference from AsyncLLM + logger.info("Calling collective_rpc with function to export IPC handles...") + + # Run in thread pool to avoid blocking + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit( + sync_engine.collective_rpc, + _export_cuda_ipc_handles_fn + ) + results = future.result(timeout=60) + + # collective_rpc returns a list (one result per worker) + result = results[0] if results else {} + ipc_handles = result.get("handles", {}) + failed_params = result.get("failed", []) + + if failed_params: + logger.warning(f"Could not export {len(failed_params)} parameters: {failed_params[:5]}...") + + if len(ipc_handles) == 0: + raise HTTPException(status_code=500, detail="No IPC handles exported") + + # Save to file for trainer to read + log_dir = os.environ.get("LOGDIR", ".") + ipc_path = Path(log_dir) / "cuda_ipc_handles.json" + + with open(ipc_path, "w") as f: + json.dump({ + "handles": ipc_handles, + "model_class": result.get("model_class", "unknown"), + "num_params": result.get("num_params", 0), + "device_count": torch.cuda.device_count(), + "export_time": time.time(), + }, f, indent=2) + + logger.info(f"✓ Exported {len(ipc_handles)} CUDA IPC handles to {ipc_path}") + + return JSONResponse({ + "status": "ok", + "num_parameters": len(ipc_handles), + "failed_parameters": len(failed_params), + "ipc_path": str(ipc_path), + "total_elements": sum(info["numel"] for info in ipc_handles.values()), + "model_class": result.get("model_class", "unknown"), + "message": "IPC handles exported. Trainer can now attach to shared memory.", + }) + + except concurrent.futures.TimeoutError: + raise HTTPException(status_code=504, detail="collective_rpc timed out after 60s") + except Exception as e: + logger.error(f"Failed to export CUDA IPC handles: {e}") + import traceback + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/bridge/cuda_ipc_status") +async def bridge_cuda_ipc_status() -> JSONResponse: + """ + Check CUDA IPC status and whether shared memory is available. + """ + log_dir = os.environ.get("LOGDIR", ".") + ipc_path = Path(log_dir) / "cuda_ipc_handles.json" + + status = { + "sync_llm_available": SYNC_LLM_AVAILABLE, + "sync_engine_initialized": sync_engine is not None, + "ipc_handles_exported": ipc_path.exists(), + "ipc_path": str(ipc_path) if ipc_path.exists() else None, + "cuda_device_count": torch.cuda.device_count(), + } + + if ipc_path.exists(): + try: + with open(ipc_path) as f: + data = json.load(f) + status["num_parameters"] = len(data.get("handles", {})) + status["model_class"] = data.get("model_class") + status["export_time"] = data.get("export_time") + except Exception as e: + status["ipc_file_error"] = str(e) + + return JSONResponse(status) + + @app.get("/bridge/debug") async def bridge_debug() -> JSONResponse: """ - Debug endpoint to inspect AsyncLLM capabilities. + Debug endpoint to inspect engine capabilities. Lists available attributes and methods on the engine. """ - assert engine is not None + active_engine = get_engine() debug_info = { - "engine_type": type(engine).__name__, + "engine_type": type(active_engine).__name__, + "engine_mode": "async" if engine is not None else "sync_cuda_ipc", "vllm_version": VLLM_VERSION, - "model_config": { - "model": str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown", - "dtype": str(engine.model_config.dtype) if hasattr(engine, "model_config") else "unknown", - }, + "model_config": {}, "available_methods": {}, "important_attributes": {}, } + # Get model config + if engine is not None: + debug_info["model_config"] = { + "model": str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown", + "dtype": str(engine.model_config.dtype) if hasattr(engine, "model_config") else "unknown", + } + elif sync_engine is not None: + try: + debug_info["model_config"] = { + "model": str(sync_engine.llm_engine.model_config.model), + "dtype": str(sync_engine.llm_engine.model_config.dtype), + } + except Exception: + debug_info["model_config"] = {"model": "unknown", "dtype": "unknown"} + # Check for important methods important_methods = [ "pause_generation", "resume_generation", "is_paused", @@ -910,18 +1289,19 @@ async def bridge_debug() -> JSONResponse: ] for method in important_methods: - has_method = hasattr(engine, method) and callable(getattr(engine, method)) + has_method = hasattr(active_engine, method) and callable(getattr(active_engine, method)) debug_info["available_methods"][method] = has_method # Check important attributes important_attrs = [ "engine_core", "model_config", "vllm_config", "input_processor", "output_processor", "tokenizer", + "llm_engine", # For sync LLM ] for attr in important_attrs: - if hasattr(engine, attr): - attr_val = getattr(engine, attr) + if hasattr(active_engine, attr): + attr_val = getattr(active_engine, attr) debug_info["important_attributes"][attr] = type(attr_val).__name__ else: debug_info["important_attributes"][attr] = None @@ -1098,16 +1478,43 @@ async def init_app( """ app = build_app(args) - global engine - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = ( - llm_engine - if llm_engine is not None - else AsyncLLM.from_engine_args( - engine_args, usage_context=UsageContext.API_SERVER + global engine, sync_engine + + use_cuda_ipc = getattr(args, "enable_cuda_ipc", False) + + if use_cuda_ipc: + # CUDA IPC MODE: Use sync LLM only (model in same process) + # This allows function-based collective_rpc for IPC handle export + if not SYNC_LLM_AVAILABLE: + raise RuntimeError("CUDA IPC requested but vllm.LLM not available") + + logger.info("=" * 60) + logger.info("CUDA IPC MODE: Using sync LLM for true shared memory") + logger.info("=" * 60) + + sync_engine = SyncLLM( + model=args.model, + dtype=getattr(args, "dtype", "auto"), + gpu_memory_utilization=getattr(args, "gpu_memory_utilization", 0.9), + tensor_parallel_size=getattr(args, "tensor_parallel_size", 1), + trust_remote_code=getattr(args, "trust_remote_code", False), ) - ) - app.state.engine_client = engine + engine = None # No async engine in CUDA IPC mode + logger.info("✓ Sync LLM ready for CUDA IPC") + + else: + # STANDARD MODE: Use AsyncLLM (model in subprocess) + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = ( + llm_engine + if llm_engine is not None + else AsyncLLM.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER + ) + ) + sync_engine = None + + app.state.engine_client = engine or sync_engine # Export state dict info for trainers _export_state_dict_info(args) @@ -1158,7 +1565,10 @@ async def run_server( set_ulimit() app = await init_app(args, llm_engine) - assert engine is not None + + # Verify at least one engine is initialized + if engine is None and sync_engine is None: + raise RuntimeError("No engine initialized") # Log bridge endpoints logger.info("=" * 60) @@ -1239,6 +1649,18 @@ if __name__ == "__main__": default=None, help="FastAPI root_path when app is behind a path based routing proxy", ) + + # CUDA IPC for true shared memory + parser.add_argument( + "--enable-cuda-ipc", + action="store_true", + default=False, + help=( + "Enable CUDA IPC for true shared memory with trainer. " + "Requires trainer to be on the same GPU. " + "This initializes a sync LLM alongside the async engine." + ), + ) # Add vLLM engine arguments parser = AsyncEngineArgs.add_cli_args(parser)