IPC updates

This commit is contained in:
Jai Suphavadeeprasit 2025-12-10 16:59:51 -05:00
parent 78ea8bc3e7
commit 533f0bf286
3 changed files with 247 additions and 15 deletions

View file

@ -711,6 +711,83 @@ async def bridge_disable() -> JSONResponse:
return JSONResponse({"status": "ok"})
@app.post("/bridge/export_cuda_ipc")
async def bridge_export_cuda_ipc() -> JSONResponse:
"""
Export CUDA IPC handles for all model parameters.
This enables TRUE shared memory between vLLM and the trainer.
The trainer can reconstruct tensors from these handles without
allocating new GPU memory - they share the exact same memory!
IMPORTANT: Only works when both processes are on the SAME GPU.
Returns:
JSON with IPC handles, shapes, dtypes for each parameter.
"""
import base64
import pickle
assert engine is not None
try:
# Access the underlying model
model = engine.engine.model_executor.driver_worker.model_runner.model
ipc_handles = {}
for name, param in model.named_parameters():
try:
# Get the underlying storage
storage = param.data.storage()
# Get CUDA IPC handle - this is the key to shared memory!
# The handle can be sent to another process on the same GPU
# to reconstruct a tensor pointing to the SAME memory
handle = storage._share_cuda_()
# Encode handle for JSON transmission
handle_bytes = pickle.dumps(handle)
handle_b64 = base64.b64encode(handle_bytes).decode('ascii')
ipc_handles[name] = {
"ipc_handle": handle_b64,
"shape": list(param.shape),
"dtype": str(param.dtype),
"device_index": param.device.index,
"storage_offset": param.storage_offset(),
"numel": param.numel(),
"stride": list(param.stride()),
}
except Exception as e:
logger.warning(f"Could not export IPC handle for {name}: {e}")
continue
# Save to file for trainer to read
log_dir = os.environ.get("LOGDIR", ".")
ipc_path = Path(log_dir) / "cuda_ipc_handles.json"
with open(ipc_path, "w") as f:
json.dump({
"handles": ipc_handles,
"model": getattr(engine, "model_config", {}).get("model", "unknown"),
"device_count": torch.cuda.device_count(),
"export_time": time.time(),
}, f, indent=2)
logger.info(f"Exported {len(ipc_handles)} CUDA IPC handles to {ipc_path}")
return JSONResponse({
"status": "ok",
"num_parameters": len(ipc_handles),
"ipc_path": str(ipc_path),
"total_params": sum(info["numel"] for info in ipc_handles.values()),
})
except Exception as e:
logger.error(f"Failed to export CUDA IPC handles: {e}")
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# LoRA Endpoints (for adapter hot-swapping)
# =============================================================================