mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
nccl loras
This commit is contained in:
parent
950be6f0d4
commit
2501e33ae3
8 changed files with 1121 additions and 16 deletions
|
|
@ -643,6 +643,117 @@ async def lora_unload() -> JSONResponse:
|
|||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NCCL Weight Receiver (for lora_nccl mode)
|
||||
# =============================================================================
|
||||
|
||||
nccl_bridge: Optional[Any] = None # Will hold NCCLWeightBridge instance
|
||||
|
||||
|
||||
@app.post("/nccl/start_receiver")
|
||||
async def nccl_start_receiver(request: Request) -> JSONResponse:
|
||||
"""
|
||||
Start NCCL weight receiver (for lora_nccl training mode).
|
||||
|
||||
Request JSON:
|
||||
{
|
||||
"init_method": "tcp://localhost:29500",
|
||||
"world_size": 2,
|
||||
"param_metadata": {...},
|
||||
"param_mappings": {...}
|
||||
}
|
||||
"""
|
||||
global nccl_bridge
|
||||
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
request_dict = await request.json()
|
||||
|
||||
try:
|
||||
from .nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
|
||||
except ImportError:
|
||||
try:
|
||||
from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="NCCL weight bridge module not available"
|
||||
)
|
||||
|
||||
# Get vLLM's state dict for in-place updates
|
||||
# This is tricky because vLLM's model is encapsulated
|
||||
# For now, we'll need to use the engine's internal access
|
||||
state_dict = {} # TODO: Get actual vLLM state dict
|
||||
|
||||
config = NCCLBridgeConfig(
|
||||
rank=1, # vLLM is rank 1 (trainer is rank 0)
|
||||
world_size=request_dict.get("world_size", 2),
|
||||
init_method=request_dict.get("init_method", "tcp://localhost:29500"),
|
||||
)
|
||||
|
||||
nccl_bridge = NCCLWeightBridge(config)
|
||||
|
||||
if not nccl_bridge.setup():
|
||||
raise HTTPException(status_code=500, detail="Failed to setup NCCL bridge")
|
||||
|
||||
# Set param metadata from trainer
|
||||
nccl_bridge.param_names = request_dict.get("param_metadata", {}).get("param_names", [])
|
||||
nccl_bridge.param_shapes = {
|
||||
k: tuple(v) for k, v in
|
||||
request_dict.get("param_metadata", {}).get("param_shapes", {}).items()
|
||||
}
|
||||
nccl_bridge.param_dtypes = request_dict.get("param_metadata", {}).get("param_dtypes", {})
|
||||
|
||||
param_mappings = request_dict.get("param_mappings", {})
|
||||
|
||||
# Start receiver thread
|
||||
nccl_bridge.start_receiver(
|
||||
state_dict,
|
||||
param_mappings,
|
||||
on_update=lambda step: logger.info(f"NCCL weight update received: step {step}")
|
||||
)
|
||||
|
||||
return JSONResponse({
|
||||
"status": "ok",
|
||||
"message": "NCCL receiver started",
|
||||
"rank": 1,
|
||||
"world_size": config.world_size,
|
||||
})
|
||||
|
||||
|
||||
@app.post("/nccl/stop_receiver")
|
||||
async def nccl_stop_receiver() -> JSONResponse:
|
||||
"""Stop NCCL weight receiver."""
|
||||
global nccl_bridge
|
||||
|
||||
if nccl_bridge is None:
|
||||
return JSONResponse({"status": "ok", "message": "No receiver running"})
|
||||
|
||||
nccl_bridge.stop_receiver()
|
||||
nccl_bridge.cleanup()
|
||||
nccl_bridge = None
|
||||
|
||||
return JSONResponse({"status": "ok", "message": "NCCL receiver stopped"})
|
||||
|
||||
|
||||
@app.get("/nccl/status")
|
||||
async def nccl_status() -> JSONResponse:
|
||||
"""Get NCCL receiver status."""
|
||||
if nccl_bridge is None:
|
||||
return JSONResponse({
|
||||
"active": False,
|
||||
"update_count": 0,
|
||||
})
|
||||
|
||||
return JSONResponse({
|
||||
"active": nccl_bridge.is_initialized,
|
||||
"update_count": nccl_bridge.update_count,
|
||||
"last_update_time": nccl_bridge.last_update_time,
|
||||
"num_params": len(nccl_bridge.param_names),
|
||||
})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Server Setup
|
||||
# =============================================================================
|
||||
|
|
@ -748,6 +859,9 @@ async def run_server(
|
|||
logger.info(" GET /lora/status - LoRA adapter status")
|
||||
logger.info(" POST /lora/load - Load LoRA adapter")
|
||||
logger.info(" POST /lora/unload - Unload LoRA adapter")
|
||||
logger.info(" POST /nccl/start_receiver - Start NCCL weight receiver (lora_nccl mode)")
|
||||
logger.info(" POST /nccl/stop_receiver - Stop NCCL weight receiver")
|
||||
logger.info(" GET /nccl/status - NCCL receiver status")
|
||||
logger.info("=" * 60)
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue