nccl loras 2

This commit is contained in:
Jai Suphavadeeprasit 2026-02-11 21:29:43 -05:00
parent c33f9170c3
commit a05a7dc276
3 changed files with 309 additions and 59 deletions

View file

@ -648,6 +648,112 @@ async def lora_unload() -> JSONResponse:
# =============================================================================
nccl_bridge: Optional[Any] = None # Will hold NCCLWeightBridge instance
nccl_setup_thread: Optional[threading.Thread] = None
nccl_setup_error: Optional[str] = None
def _setup_nccl_receiver_thread(
init_method: str,
world_size: int,
param_metadata: dict,
param_mappings: dict,
):
"""Background thread to setup NCCL receiver and wait for weight updates."""
global nccl_bridge, nccl_setup_error
logger.info(f"[NCCL] Receiver thread started, attempting to import nccl_weight_bridge...")
NCCLBridgeConfig = None
NCCLWeightBridge = None
# Try multiple import methods
try:
from .nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
logger.info("[NCCL] Imported via relative import")
except ImportError:
pass
if NCCLBridgeConfig is None:
try:
from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
logger.info("[NCCL] Imported via direct import")
except ImportError:
pass
if NCCLBridgeConfig is None:
try:
import sys
from pathlib import Path
script_dir = Path(__file__).parent
if str(script_dir) not in sys.path:
sys.path.insert(0, str(script_dir))
from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
logger.info("[NCCL] Imported via sys.path manipulation")
except ImportError as e:
nccl_setup_error = f"NCCL weight bridge module not available: {e}"
logger.error(nccl_setup_error)
return
if NCCLBridgeConfig is None:
nccl_setup_error = "Failed to import NCCLBridgeConfig"
logger.error(nccl_setup_error)
return
try:
config = NCCLBridgeConfig(
rank=1, # vLLM is rank 1 (trainer is rank 0)
world_size=world_size,
init_method=init_method,
)
logger.info(f"[NCCL] Starting setup as rank 1, world_size={world_size}")
nccl_bridge = NCCLWeightBridge(config)
if not nccl_bridge.setup():
nccl_setup_error = "Failed to setup NCCL bridge"
logger.error(nccl_setup_error)
return
# Set param metadata from trainer
nccl_bridge.param_names = param_metadata.get("param_names", [])
nccl_bridge.param_shapes = {
k: tuple(v) for k, v in
param_metadata.get("param_shapes", {}).items()
}
nccl_bridge.param_dtypes = param_metadata.get("param_dtypes", {})
logger.info(f"[NCCL] ✓ Bridge setup complete, {len(nccl_bridge.param_names)} params registered")
logger.info(f"[NCCL] Starting receiver loop to wait for weight updates...")
# Start receiver loop - wait for weight updates from trainer
while True:
try:
step, weights = nccl_bridge.receive_lora_weights()
if step < 0:
logger.info("[NCCL] Received shutdown signal, exiting receiver loop")
break
logger.info(f"[NCCL] ✓ Received weights for step {step} ({len(weights)} params)")
# TODO: Apply weights to vLLM's LoRA adapter
# For now, we just log receipt - actual weight application
# would require access to vLLM's internal model state
except Exception as e:
if "shutdown" in str(e).lower() or nccl_bridge is None:
logger.info("[NCCL] Receiver loop terminated")
break
logger.error(f"[NCCL] Error receiving weights: {e}")
import traceback
traceback.print_exc()
break
except Exception as e:
nccl_setup_error = f"NCCL setup error: {e}"
logger.error(nccl_setup_error)
import traceback
traceback.print_exc()
@app.post("/nccl/start_receiver")
@ -655,6 +761,10 @@ async def nccl_start_receiver(request: Request) -> JSONResponse:
"""
Start NCCL weight receiver (for lora_nccl training mode).
This endpoint starts the NCCL setup in a background thread so that
both trainer (rank 0) and vLLM (rank 1) can join the NCCL group
simultaneously.
Request JSON:
{
"init_method": "tcp://localhost:29500",
@ -663,76 +773,75 @@ async def nccl_start_receiver(request: Request) -> JSONResponse:
"param_mappings": {...}
}
"""
global nccl_bridge
global nccl_bridge, nccl_setup_thread, nccl_setup_error
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Stop any existing bridge
if nccl_bridge is not None:
try:
nccl_bridge.cleanup()
except Exception:
pass
nccl_bridge = None
nccl_setup_error = None
request_dict = await request.json()
try:
from .nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
except ImportError:
try:
from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
except ImportError:
raise HTTPException(
status_code=500,
detail="NCCL weight bridge module not available"
)
# Get vLLM's state dict for in-place updates
# This is tricky because vLLM's model is encapsulated
# For now, we'll need to use the engine's internal access
state_dict = {} # TODO: Get actual vLLM state dict
config = NCCLBridgeConfig(
rank=1, # vLLM is rank 1 (trainer is rank 0)
world_size=request_dict.get("world_size", 2),
init_method=request_dict.get("init_method", "tcp://localhost:29500"),
)
nccl_bridge = NCCLWeightBridge(config)
if not nccl_bridge.setup():
raise HTTPException(status_code=500, detail="Failed to setup NCCL bridge")
# Set param metadata from trainer
nccl_bridge.param_names = request_dict.get("param_metadata", {}).get("param_names", [])
nccl_bridge.param_shapes = {
k: tuple(v) for k, v in
request_dict.get("param_metadata", {}).get("param_shapes", {}).items()
}
nccl_bridge.param_dtypes = request_dict.get("param_metadata", {}).get("param_dtypes", {})
init_method = request_dict.get("init_method", "tcp://localhost:29500")
world_size = request_dict.get("world_size", 2)
param_metadata = request_dict.get("param_metadata", {})
param_mappings = request_dict.get("param_mappings", {})
# Start receiver thread
nccl_bridge.start_receiver(
state_dict,
param_mappings,
on_update=lambda step: logger.info(f"NCCL weight update received: step {step}")
logger.info(f"[NCCL] Received start_receiver request: init_method={init_method}, world_size={world_size}")
logger.info(f"[NCCL] Param metadata: {len(param_metadata.get('param_names', []))} params")
# Start NCCL setup in background thread
# This allows the HTTP response to return immediately while NCCL joins
nccl_setup_thread = threading.Thread(
target=_setup_nccl_receiver_thread,
args=(init_method, world_size, param_metadata, param_mappings),
daemon=True,
name="nccl_receiver_thread",
)
nccl_setup_thread.start()
# Wait a moment to catch immediate errors
import time as time_mod
time_mod.sleep(0.5)
if nccl_setup_error:
return JSONResponse({
"status": "error",
"message": nccl_setup_error,
}, status_code=500)
logger.info(f"[NCCL] Receiver thread started, waiting for trainer to connect...")
return JSONResponse({
"status": "ok",
"message": "NCCL receiver started",
"message": "NCCL receiver setup started - waiting for trainer to connect",
"rank": 1,
"world_size": config.world_size,
"world_size": world_size,
})
@app.post("/nccl/stop_receiver")
async def nccl_stop_receiver() -> JSONResponse:
"""Stop NCCL weight receiver."""
global nccl_bridge
global nccl_bridge, nccl_setup_thread
if nccl_bridge is None:
return JSONResponse({"status": "ok", "message": "No receiver running"})
nccl_bridge.stop_receiver()
nccl_bridge.cleanup()
try:
nccl_bridge.cleanup()
except Exception as e:
logger.warning(f"Error during NCCL cleanup: {e}")
nccl_bridge = None
nccl_setup_thread = None
return JSONResponse({"status": "ok", "message": "NCCL receiver stopped"})
@ -740,14 +849,32 @@ async def nccl_stop_receiver() -> JSONResponse:
@app.get("/nccl/status")
async def nccl_status() -> JSONResponse:
"""Get NCCL receiver status."""
global nccl_setup_error
if nccl_setup_thread is not None and nccl_setup_thread.is_alive():
return JSONResponse({
"active": False,
"status": "connecting",
"message": "NCCL setup in progress...",
})
if nccl_setup_error is not None:
return JSONResponse({
"active": False,
"status": "error",
"error": nccl_setup_error,
})
if nccl_bridge is None:
return JSONResponse({
"active": False,
"status": "not_started",
"update_count": 0,
})
return JSONResponse({
"active": nccl_bridge.is_initialized,
"status": "connected" if nccl_bridge.is_initialized else "disconnected",
"update_count": nccl_bridge.update_count,
"last_update_time": nccl_bridge.last_update_time,
"num_params": len(nccl_bridge.param_names),