nccl loras

This commit is contained in:
Jai Suphavadeeprasit 2026-02-11 20:46:41 -05:00
parent 950be6f0d4
commit 2501e33ae3
8 changed files with 1121 additions and 16 deletions

View file

@ -643,6 +643,117 @@ async def lora_unload() -> JSONResponse:
)
# =============================================================================
# NCCL Weight Receiver (for lora_nccl mode)
# =============================================================================
nccl_bridge: Optional[Any] = None # Will hold NCCLWeightBridge instance
@app.post("/nccl/start_receiver")
async def nccl_start_receiver(request: Request) -> JSONResponse:
"""
Start NCCL weight receiver (for lora_nccl training mode).
Request JSON:
{
"init_method": "tcp://localhost:29500",
"world_size": 2,
"param_metadata": {...},
"param_mappings": {...}
}
"""
global nccl_bridge
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
request_dict = await request.json()
try:
from .nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
except ImportError:
try:
from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
except ImportError:
raise HTTPException(
status_code=500,
detail="NCCL weight bridge module not available"
)
# Get vLLM's state dict for in-place updates
# This is tricky because vLLM's model is encapsulated
# For now, we'll need to use the engine's internal access
state_dict = {} # TODO: Get actual vLLM state dict
config = NCCLBridgeConfig(
rank=1, # vLLM is rank 1 (trainer is rank 0)
world_size=request_dict.get("world_size", 2),
init_method=request_dict.get("init_method", "tcp://localhost:29500"),
)
nccl_bridge = NCCLWeightBridge(config)
if not nccl_bridge.setup():
raise HTTPException(status_code=500, detail="Failed to setup NCCL bridge")
# Set param metadata from trainer
nccl_bridge.param_names = request_dict.get("param_metadata", {}).get("param_names", [])
nccl_bridge.param_shapes = {
k: tuple(v) for k, v in
request_dict.get("param_metadata", {}).get("param_shapes", {}).items()
}
nccl_bridge.param_dtypes = request_dict.get("param_metadata", {}).get("param_dtypes", {})
param_mappings = request_dict.get("param_mappings", {})
# Start receiver thread
nccl_bridge.start_receiver(
state_dict,
param_mappings,
on_update=lambda step: logger.info(f"NCCL weight update received: step {step}")
)
return JSONResponse({
"status": "ok",
"message": "NCCL receiver started",
"rank": 1,
"world_size": config.world_size,
})
@app.post("/nccl/stop_receiver")
async def nccl_stop_receiver() -> JSONResponse:
"""Stop NCCL weight receiver."""
global nccl_bridge
if nccl_bridge is None:
return JSONResponse({"status": "ok", "message": "No receiver running"})
nccl_bridge.stop_receiver()
nccl_bridge.cleanup()
nccl_bridge = None
return JSONResponse({"status": "ok", "message": "NCCL receiver stopped"})
@app.get("/nccl/status")
async def nccl_status() -> JSONResponse:
"""Get NCCL receiver status."""
if nccl_bridge is None:
return JSONResponse({
"active": False,
"update_count": 0,
})
return JSONResponse({
"active": nccl_bridge.is_initialized,
"update_count": nccl_bridge.update_count,
"last_update_time": nccl_bridge.last_update_time,
"num_params": len(nccl_bridge.param_names),
})
# =============================================================================
# Server Setup
# =============================================================================
@ -748,6 +859,9 @@ async def run_server(
logger.info(" GET /lora/status - LoRA adapter status")
logger.info(" POST /lora/load - Load LoRA adapter")
logger.info(" POST /lora/unload - Unload LoRA adapter")
logger.info(" POST /nccl/start_receiver - Start NCCL weight receiver (lora_nccl mode)")
logger.info(" POST /nccl/stop_receiver - Stop NCCL weight receiver")
logger.info(" GET /nccl/status - NCCL receiver status")
logger.info("=" * 60)
shutdown_task = await serve_http(