mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
weight updates async
This commit is contained in:
parent
b3874b658a
commit
39e94c4278
1 changed files with 266 additions and 136 deletions
|
|
@ -711,157 +711,273 @@ async def bridge_disable() -> JSONResponse:
|
|||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
def _worker_export_cuda_ipc_handles(self) -> dict:
|
||||
"""
|
||||
Worker-side function to export CUDA IPC handles.
|
||||
|
||||
This function runs INSIDE the vLLM worker process where the model lives.
|
||||
Called via collective_rpc - 'self' is the GPU worker instance.
|
||||
|
||||
Returns:
|
||||
Dictionary with IPC handles for all model parameters.
|
||||
"""
|
||||
import base64
|
||||
import pickle
|
||||
|
||||
model = self.model_runner.model
|
||||
|
||||
ipc_handles = {}
|
||||
failed_params = []
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
try:
|
||||
if not param.is_cuda:
|
||||
failed_params.append(f"{name}: not on CUDA")
|
||||
continue
|
||||
|
||||
storage = param.data.storage()
|
||||
handle = storage._share_cuda_()
|
||||
|
||||
handle_bytes = pickle.dumps(handle)
|
||||
handle_b64 = base64.b64encode(handle_bytes).decode('ascii')
|
||||
|
||||
ipc_handles[name] = {
|
||||
"ipc_handle": handle_b64,
|
||||
"shape": list(param.shape),
|
||||
"dtype": str(param.dtype),
|
||||
"device_index": param.device.index if param.device.index is not None else 0,
|
||||
"storage_offset": param.storage_offset(),
|
||||
"numel": param.numel(),
|
||||
"stride": list(param.stride()),
|
||||
}
|
||||
except Exception as e:
|
||||
failed_params.append(f"{name}: {str(e)}")
|
||||
|
||||
return {
|
||||
"handles": ipc_handles,
|
||||
"failed": failed_params,
|
||||
"model_class": model.__class__.__name__,
|
||||
}
|
||||
# =============================================================================
|
||||
# Weight Update Endpoints (Pause/Resume for Training)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.post("/bridge/export_cuda_ipc")
|
||||
async def bridge_export_cuda_ipc() -> JSONResponse:
|
||||
@app.post("/bridge/pause")
|
||||
async def bridge_pause() -> JSONResponse:
|
||||
"""
|
||||
Export CUDA IPC handles for all model parameters.
|
||||
Pause generation to allow weight updates.
|
||||
|
||||
Uses collective_rpc to execute on the worker process where the model lives.
|
||||
This is the vLLM v1 way to access the model implicitly.
|
||||
This is vLLM's built-in mechanism for weight updates!
|
||||
Waits for in-flight requests to finish, then pauses.
|
||||
|
||||
IMPORTANT: Only works when both processes are on the SAME GPU.
|
||||
|
||||
Returns:
|
||||
JSON with IPC handles, shapes, dtypes for each parameter.
|
||||
Use this BEFORE updating weights from the trainer.
|
||||
"""
|
||||
assert engine is not None
|
||||
|
||||
try:
|
||||
rpc_result = None
|
||||
method_used = None
|
||||
await engine.pause_generation(
|
||||
wait_for_inflight_requests=True,
|
||||
clear_cache=True,
|
||||
)
|
||||
logger.info("Generation paused for weight updates")
|
||||
|
||||
# Try to find collective_rpc in AsyncLLM's internals
|
||||
# vLLM v1 exposes this on various paths depending on version
|
||||
return JSONResponse({
|
||||
"status": "paused",
|
||||
"message": "Ready for weight updates. Call /bridge/resume when done.",
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pause generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/bridge/resume")
|
||||
async def bridge_resume() -> JSONResponse:
|
||||
"""
|
||||
Resume generation after weight updates.
|
||||
|
||||
Call this AFTER updating weights from the trainer.
|
||||
"""
|
||||
assert engine is not None
|
||||
|
||||
try:
|
||||
await engine.resume_generation()
|
||||
logger.info("Generation resumed after weight updates")
|
||||
|
||||
paths_to_try = [
|
||||
("engine", lambda: engine.collective_rpc(_worker_export_cuda_ipc_handles)),
|
||||
("engine.llm_engine", lambda: engine.llm_engine.collective_rpc(_worker_export_cuda_ipc_handles)),
|
||||
("engine.engine_core", lambda: engine.engine_core.collective_rpc(_worker_export_cuda_ipc_handles)),
|
||||
("engine.llm_engine.model_executor", lambda: engine.llm_engine.model_executor.collective_rpc(_worker_export_cuda_ipc_handles)),
|
||||
]
|
||||
return JSONResponse({
|
||||
"status": "resumed",
|
||||
"message": "Generation resumed with updated weights.",
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resume generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/bridge/is_paused")
|
||||
async def bridge_is_paused() -> JSONResponse:
|
||||
"""Check if generation is currently paused."""
|
||||
assert engine is not None
|
||||
|
||||
paused = await engine.is_paused()
|
||||
return JSONResponse({"paused": paused})
|
||||
|
||||
|
||||
@app.post("/bridge/sleep")
|
||||
async def bridge_sleep(level: int = 1) -> JSONResponse:
|
||||
"""
|
||||
Put the engine to sleep to free GPU memory.
|
||||
|
||||
Level 1: Minimal sleep, fast wake up
|
||||
Higher levels: Deeper sleep, frees more memory
|
||||
|
||||
Use for memory-constrained environments.
|
||||
"""
|
||||
assert engine is not None
|
||||
|
||||
try:
|
||||
await engine.sleep(level=level)
|
||||
logger.info(f"Engine put to sleep (level {level})")
|
||||
|
||||
for name, rpc_fn in paths_to_try:
|
||||
try:
|
||||
logger.info(f"Trying {name}.collective_rpc...")
|
||||
rpc_result = rpc_fn()
|
||||
method_used = name
|
||||
logger.info(f"SUCCESS via {name}")
|
||||
break
|
||||
except AttributeError as e:
|
||||
logger.debug(f"{name} failed: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"{name} error: {e}")
|
||||
continue
|
||||
return JSONResponse({
|
||||
"status": "sleeping",
|
||||
"level": level,
|
||||
"message": "GPU memory freed. Call /bridge/wake_up to resume.",
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sleep: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/bridge/wake_up")
|
||||
async def bridge_wake_up() -> JSONResponse:
|
||||
"""
|
||||
Wake up the engine from sleep.
|
||||
|
||||
Reloads the model into GPU memory.
|
||||
"""
|
||||
assert engine is not None
|
||||
|
||||
try:
|
||||
await engine.wake_up()
|
||||
logger.info("Engine woken up")
|
||||
|
||||
if rpc_result is None:
|
||||
# Log engine structure for debugging
|
||||
logger.error("collective_rpc not found. Engine structure:")
|
||||
logger.error(f" type: {type(engine).__name__}")
|
||||
logger.error(f" attrs: {[a for a in dir(engine) if not a.startswith('_')][:20]}")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=(
|
||||
"collective_rpc not available on AsyncLLM in this vLLM version. "
|
||||
"Options: 1) Use vllm_sync_server.py for CUDA IPC, "
|
||||
"2) Use multi-GPU mode (separate GPUs for trainer/vLLM), "
|
||||
"3) Use legacy or LoRA mode."
|
||||
)
|
||||
)
|
||||
return JSONResponse({
|
||||
"status": "awake",
|
||||
"message": "Model reloaded into GPU memory.",
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to wake up: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/bridge/is_sleeping")
|
||||
async def bridge_is_sleeping() -> JSONResponse:
|
||||
"""Check if engine is currently sleeping."""
|
||||
assert engine is not None
|
||||
|
||||
sleeping = await engine.is_sleeping()
|
||||
return JSONResponse({"sleeping": sleeping})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RPC Endpoints (Call Worker Methods)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CollectiveRPCRequest(BaseModel):
|
||||
"""Request to call a method on all workers."""
|
||||
method: str
|
||||
timeout: Optional[float] = None
|
||||
args: List[Any] = []
|
||||
kwargs: Dict[str, Any] = {}
|
||||
|
||||
|
||||
@app.post("/bridge/collective_rpc")
|
||||
async def bridge_collective_rpc(request: CollectiveRPCRequest) -> JSONResponse:
|
||||
"""
|
||||
Call a method on all workers via collective RPC.
|
||||
|
||||
The method must exist on the worker class.
|
||||
This is an advanced endpoint for custom worker operations.
|
||||
|
||||
Example worker methods:
|
||||
- 'save_model' - Save model weights
|
||||
- 'get_model_info' - Get model information
|
||||
|
||||
Note: The method name is passed as a STRING, not a function.
|
||||
"""
|
||||
assert engine is not None
|
||||
|
||||
try:
|
||||
result = await engine.collective_rpc(
|
||||
method=request.method,
|
||||
timeout=request.timeout,
|
||||
args=tuple(request.args),
|
||||
kwargs=request.kwargs if request.kwargs else None,
|
||||
)
|
||||
|
||||
# Process result (collective_rpc returns a list, one per worker)
|
||||
result = rpc_result[0] if isinstance(rpc_result, list) else rpc_result
|
||||
ipc_handles = result.get("handles", {})
|
||||
failed_params = result.get("failed", [])
|
||||
|
||||
if failed_params:
|
||||
logger.warning(f"Could not export {len(failed_params)} parameters")
|
||||
|
||||
if len(ipc_handles) == 0:
|
||||
raise HTTPException(status_code=500, detail="No IPC handles exported")
|
||||
|
||||
# Save to file
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
ipc_path = Path(log_dir) / "cuda_ipc_handles.json"
|
||||
|
||||
with open(ipc_path, "w") as f:
|
||||
json.dump({
|
||||
"handles": ipc_handles,
|
||||
"model_class": result.get("model_class", "unknown"),
|
||||
"device_count": torch.cuda.device_count(),
|
||||
"export_time": time.time(),
|
||||
"method": method_used,
|
||||
}, f, indent=2)
|
||||
|
||||
logger.info(f"Exported {len(ipc_handles)} CUDA IPC handles via {method_used}")
|
||||
logger.info(f"collective_rpc({request.method}) completed")
|
||||
|
||||
return JSONResponse({
|
||||
"status": "ok",
|
||||
"num_parameters": len(ipc_handles),
|
||||
"ipc_path": str(ipc_path),
|
||||
"total_params": sum(info["numel"] for info in ipc_handles.values()),
|
||||
"method": method_used,
|
||||
"method": request.method,
|
||||
"result": result if isinstance(result, (dict, list, str, int, float, bool, type(None))) else str(result),
|
||||
})
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export CUDA IPC handles: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"collective_rpc failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/bridge/debug")
|
||||
async def bridge_debug() -> JSONResponse:
|
||||
"""
|
||||
Debug endpoint to inspect AsyncLLM capabilities.
|
||||
|
||||
Lists available attributes and methods on the engine.
|
||||
"""
|
||||
assert engine is not None
|
||||
|
||||
debug_info = {
|
||||
"engine_type": type(engine).__name__,
|
||||
"vllm_version": VLLM_VERSION,
|
||||
"model_config": {
|
||||
"model": str(engine.model_config.model) if hasattr(engine, "model_config") else "unknown",
|
||||
"dtype": str(engine.model_config.dtype) if hasattr(engine, "model_config") else "unknown",
|
||||
},
|
||||
"available_methods": {},
|
||||
"important_attributes": {},
|
||||
}
|
||||
|
||||
# Check for important methods
|
||||
important_methods = [
|
||||
"pause_generation", "resume_generation", "is_paused",
|
||||
"sleep", "wake_up", "is_sleeping",
|
||||
"collective_rpc", "add_lora", "remove_lora", "list_loras",
|
||||
"generate", "encode", "abort", "check_health",
|
||||
]
|
||||
|
||||
for method in important_methods:
|
||||
has_method = hasattr(engine, method) and callable(getattr(engine, method))
|
||||
debug_info["available_methods"][method] = has_method
|
||||
|
||||
# Check important attributes
|
||||
important_attrs = [
|
||||
"engine_core", "model_config", "vllm_config",
|
||||
"input_processor", "output_processor", "tokenizer",
|
||||
]
|
||||
|
||||
for attr in important_attrs:
|
||||
if hasattr(engine, attr):
|
||||
attr_val = getattr(engine, attr)
|
||||
debug_info["important_attributes"][attr] = type(attr_val).__name__
|
||||
else:
|
||||
debug_info["important_attributes"][attr] = None
|
||||
|
||||
return JSONResponse(debug_info)
|
||||
|
||||
|
||||
@app.get("/bridge/list_endpoints")
|
||||
async def bridge_list_endpoints() -> JSONResponse:
|
||||
"""
|
||||
List all available bridge endpoints with descriptions.
|
||||
|
||||
Use this to discover what capabilities are available.
|
||||
"""
|
||||
endpoints = {
|
||||
"health": {
|
||||
"GET /health": "Basic health check",
|
||||
"GET /health_generate": "Deep health check (sends test request)",
|
||||
},
|
||||
"generation": {
|
||||
"POST /generate": "Generate text (vLLM native format)",
|
||||
"POST /v1/completions": "Generate text (OpenAI format)",
|
||||
"POST /v1/chat/completions": "Chat completion (OpenAI format)",
|
||||
},
|
||||
"bridge_control": {
|
||||
"GET /bridge/info": "Get bridge status and rendezvous info",
|
||||
"POST /bridge/init": "Initialize weight bridge for NCCL",
|
||||
"POST /bridge/disable": "Disable weight bridge",
|
||||
"GET /bridge/state_dict_info": "Get model parameter info",
|
||||
},
|
||||
"weight_updates": {
|
||||
"POST /bridge/pause": "⭐ Pause generation for weight updates",
|
||||
"POST /bridge/resume": "⭐ Resume generation after weight updates",
|
||||
"GET /bridge/is_paused": "Check if paused",
|
||||
"POST /bridge/notify_update": "Notify server of weight update",
|
||||
},
|
||||
"memory_management": {
|
||||
"POST /bridge/sleep": "Put engine to sleep (free GPU memory)",
|
||||
"POST /bridge/wake_up": "Wake engine up (reload model)",
|
||||
"GET /bridge/is_sleeping": "Check if sleeping",
|
||||
},
|
||||
"lora_adapters": {
|
||||
"GET /lora/status": "Get LoRA status",
|
||||
"POST /lora/load": "Load LoRA adapter",
|
||||
"POST /lora/unload": "Unload LoRA adapter",
|
||||
},
|
||||
"advanced": {
|
||||
"POST /bridge/collective_rpc": "Call method on workers",
|
||||
"GET /bridge/debug": "Debug engine structure",
|
||||
"GET /bridge/list_endpoints": "This endpoint",
|
||||
},
|
||||
}
|
||||
|
||||
return JSONResponse(endpoints)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LoRA Endpoints (for adapter hot-swapping)
|
||||
# =============================================================================
|
||||
|
|
@ -1045,14 +1161,28 @@ async def run_server(
|
|||
assert engine is not None
|
||||
|
||||
# Log bridge endpoints
|
||||
logger.info("=" * 60)
|
||||
logger.info("Bridge endpoints available:")
|
||||
logger.info(" GET /bridge/info - Get bridge status")
|
||||
logger.info(" POST /bridge/init - Initialize weight bridge")
|
||||
logger.info(" POST /bridge/notify_update - Notify of weight update")
|
||||
logger.info(" GET /bridge/state_dict_info - Get model parameters")
|
||||
logger.info(" GET /lora/status - Get LoRA adapter status")
|
||||
logger.info(" POST /lora/load - Load LoRA adapter")
|
||||
logger.info(" POST /lora/unload - Unload LoRA adapter")
|
||||
logger.info("-" * 60)
|
||||
logger.info("Weight Updates (use these for training!):")
|
||||
logger.info(" POST /bridge/pause - Pause generation for weight updates")
|
||||
logger.info(" POST /bridge/resume - Resume after updating weights")
|
||||
logger.info(" GET /bridge/is_paused - Check pause state")
|
||||
logger.info("-" * 60)
|
||||
logger.info("Memory Management:")
|
||||
logger.info(" POST /bridge/sleep - Free GPU memory")
|
||||
logger.info(" POST /bridge/wake_up - Reload model")
|
||||
logger.info("-" * 60)
|
||||
logger.info("LoRA Adapters:")
|
||||
logger.info(" GET /lora/status - Get adapter status")
|
||||
logger.info(" POST /lora/load - Load adapter")
|
||||
logger.info(" POST /lora/unload - Unload adapter")
|
||||
logger.info("-" * 60)
|
||||
logger.info("Debug:")
|
||||
logger.info(" GET /bridge/debug - Inspect engine")
|
||||
logger.info(" GET /bridge/list_endpoints - List all endpoints")
|
||||
logger.info(" POST /bridge/collective_rpc - Call worker methods")
|
||||
logger.info("=" * 60)
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue