mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
nccl loras 2
This commit is contained in:
parent
c33f9170c3
commit
a05a7dc276
3 changed files with 309 additions and 59 deletions
|
|
@ -648,6 +648,112 @@ async def lora_unload() -> JSONResponse:
|
|||
# =============================================================================
|
||||
|
||||
nccl_bridge: Optional[Any] = None # Will hold NCCLWeightBridge instance
|
||||
nccl_setup_thread: Optional[threading.Thread] = None
|
||||
nccl_setup_error: Optional[str] = None
|
||||
|
||||
|
||||
def _setup_nccl_receiver_thread(
|
||||
init_method: str,
|
||||
world_size: int,
|
||||
param_metadata: dict,
|
||||
param_mappings: dict,
|
||||
):
|
||||
"""Background thread to setup NCCL receiver and wait for weight updates."""
|
||||
global nccl_bridge, nccl_setup_error
|
||||
|
||||
logger.info(f"[NCCL] Receiver thread started, attempting to import nccl_weight_bridge...")
|
||||
|
||||
NCCLBridgeConfig = None
|
||||
NCCLWeightBridge = None
|
||||
|
||||
# Try multiple import methods
|
||||
try:
|
||||
from .nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
|
||||
logger.info("[NCCL] Imported via relative import")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if NCCLBridgeConfig is None:
|
||||
try:
|
||||
from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
|
||||
logger.info("[NCCL] Imported via direct import")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if NCCLBridgeConfig is None:
|
||||
try:
|
||||
import sys
|
||||
from pathlib import Path
|
||||
script_dir = Path(__file__).parent
|
||||
if str(script_dir) not in sys.path:
|
||||
sys.path.insert(0, str(script_dir))
|
||||
from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
|
||||
logger.info("[NCCL] Imported via sys.path manipulation")
|
||||
except ImportError as e:
|
||||
nccl_setup_error = f"NCCL weight bridge module not available: {e}"
|
||||
logger.error(nccl_setup_error)
|
||||
return
|
||||
|
||||
if NCCLBridgeConfig is None:
|
||||
nccl_setup_error = "Failed to import NCCLBridgeConfig"
|
||||
logger.error(nccl_setup_error)
|
||||
return
|
||||
|
||||
try:
|
||||
config = NCCLBridgeConfig(
|
||||
rank=1, # vLLM is rank 1 (trainer is rank 0)
|
||||
world_size=world_size,
|
||||
init_method=init_method,
|
||||
)
|
||||
|
||||
logger.info(f"[NCCL] Starting setup as rank 1, world_size={world_size}")
|
||||
nccl_bridge = NCCLWeightBridge(config)
|
||||
|
||||
if not nccl_bridge.setup():
|
||||
nccl_setup_error = "Failed to setup NCCL bridge"
|
||||
logger.error(nccl_setup_error)
|
||||
return
|
||||
|
||||
# Set param metadata from trainer
|
||||
nccl_bridge.param_names = param_metadata.get("param_names", [])
|
||||
nccl_bridge.param_shapes = {
|
||||
k: tuple(v) for k, v in
|
||||
param_metadata.get("param_shapes", {}).items()
|
||||
}
|
||||
nccl_bridge.param_dtypes = param_metadata.get("param_dtypes", {})
|
||||
|
||||
logger.info(f"[NCCL] ✓ Bridge setup complete, {len(nccl_bridge.param_names)} params registered")
|
||||
logger.info(f"[NCCL] Starting receiver loop to wait for weight updates...")
|
||||
|
||||
# Start receiver loop - wait for weight updates from trainer
|
||||
while True:
|
||||
try:
|
||||
step, weights = nccl_bridge.receive_lora_weights()
|
||||
|
||||
if step < 0:
|
||||
logger.info("[NCCL] Received shutdown signal, exiting receiver loop")
|
||||
break
|
||||
|
||||
logger.info(f"[NCCL] ✓ Received weights for step {step} ({len(weights)} params)")
|
||||
|
||||
# TODO: Apply weights to vLLM's LoRA adapter
|
||||
# For now, we just log receipt - actual weight application
|
||||
# would require access to vLLM's internal model state
|
||||
|
||||
except Exception as e:
|
||||
if "shutdown" in str(e).lower() or nccl_bridge is None:
|
||||
logger.info("[NCCL] Receiver loop terminated")
|
||||
break
|
||||
logger.error(f"[NCCL] Error receiving weights: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
nccl_setup_error = f"NCCL setup error: {e}"
|
||||
logger.error(nccl_setup_error)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
@app.post("/nccl/start_receiver")
|
||||
|
|
@ -655,6 +761,10 @@ async def nccl_start_receiver(request: Request) -> JSONResponse:
|
|||
"""
|
||||
Start NCCL weight receiver (for lora_nccl training mode).
|
||||
|
||||
This endpoint starts the NCCL setup in a background thread so that
|
||||
both trainer (rank 0) and vLLM (rank 1) can join the NCCL group
|
||||
simultaneously.
|
||||
|
||||
Request JSON:
|
||||
{
|
||||
"init_method": "tcp://localhost:29500",
|
||||
|
|
@ -663,76 +773,75 @@ async def nccl_start_receiver(request: Request) -> JSONResponse:
|
|||
"param_mappings": {...}
|
||||
}
|
||||
"""
|
||||
global nccl_bridge
|
||||
global nccl_bridge, nccl_setup_thread, nccl_setup_error
|
||||
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
# Stop any existing bridge
|
||||
if nccl_bridge is not None:
|
||||
try:
|
||||
nccl_bridge.cleanup()
|
||||
except Exception:
|
||||
pass
|
||||
nccl_bridge = None
|
||||
|
||||
nccl_setup_error = None
|
||||
request_dict = await request.json()
|
||||
|
||||
try:
|
||||
from .nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
|
||||
except ImportError:
|
||||
try:
|
||||
from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="NCCL weight bridge module not available"
|
||||
)
|
||||
|
||||
# Get vLLM's state dict for in-place updates
|
||||
# This is tricky because vLLM's model is encapsulated
|
||||
# For now, we'll need to use the engine's internal access
|
||||
state_dict = {} # TODO: Get actual vLLM state dict
|
||||
|
||||
config = NCCLBridgeConfig(
|
||||
rank=1, # vLLM is rank 1 (trainer is rank 0)
|
||||
world_size=request_dict.get("world_size", 2),
|
||||
init_method=request_dict.get("init_method", "tcp://localhost:29500"),
|
||||
)
|
||||
|
||||
nccl_bridge = NCCLWeightBridge(config)
|
||||
|
||||
if not nccl_bridge.setup():
|
||||
raise HTTPException(status_code=500, detail="Failed to setup NCCL bridge")
|
||||
|
||||
# Set param metadata from trainer
|
||||
nccl_bridge.param_names = request_dict.get("param_metadata", {}).get("param_names", [])
|
||||
nccl_bridge.param_shapes = {
|
||||
k: tuple(v) for k, v in
|
||||
request_dict.get("param_metadata", {}).get("param_shapes", {}).items()
|
||||
}
|
||||
nccl_bridge.param_dtypes = request_dict.get("param_metadata", {}).get("param_dtypes", {})
|
||||
|
||||
init_method = request_dict.get("init_method", "tcp://localhost:29500")
|
||||
world_size = request_dict.get("world_size", 2)
|
||||
param_metadata = request_dict.get("param_metadata", {})
|
||||
param_mappings = request_dict.get("param_mappings", {})
|
||||
|
||||
# Start receiver thread
|
||||
nccl_bridge.start_receiver(
|
||||
state_dict,
|
||||
param_mappings,
|
||||
on_update=lambda step: logger.info(f"NCCL weight update received: step {step}")
|
||||
logger.info(f"[NCCL] Received start_receiver request: init_method={init_method}, world_size={world_size}")
|
||||
logger.info(f"[NCCL] Param metadata: {len(param_metadata.get('param_names', []))} params")
|
||||
|
||||
# Start NCCL setup in background thread
|
||||
# This allows the HTTP response to return immediately while NCCL joins
|
||||
nccl_setup_thread = threading.Thread(
|
||||
target=_setup_nccl_receiver_thread,
|
||||
args=(init_method, world_size, param_metadata, param_mappings),
|
||||
daemon=True,
|
||||
name="nccl_receiver_thread",
|
||||
)
|
||||
nccl_setup_thread.start()
|
||||
|
||||
# Wait a moment to catch immediate errors
|
||||
import time as time_mod
|
||||
time_mod.sleep(0.5)
|
||||
|
||||
if nccl_setup_error:
|
||||
return JSONResponse({
|
||||
"status": "error",
|
||||
"message": nccl_setup_error,
|
||||
}, status_code=500)
|
||||
|
||||
logger.info(f"[NCCL] Receiver thread started, waiting for trainer to connect...")
|
||||
|
||||
return JSONResponse({
|
||||
"status": "ok",
|
||||
"message": "NCCL receiver started",
|
||||
"message": "NCCL receiver setup started - waiting for trainer to connect",
|
||||
"rank": 1,
|
||||
"world_size": config.world_size,
|
||||
"world_size": world_size,
|
||||
})
|
||||
|
||||
|
||||
@app.post("/nccl/stop_receiver")
|
||||
async def nccl_stop_receiver() -> JSONResponse:
|
||||
"""Stop NCCL weight receiver."""
|
||||
global nccl_bridge
|
||||
global nccl_bridge, nccl_setup_thread
|
||||
|
||||
if nccl_bridge is None:
|
||||
return JSONResponse({"status": "ok", "message": "No receiver running"})
|
||||
|
||||
nccl_bridge.stop_receiver()
|
||||
nccl_bridge.cleanup()
|
||||
try:
|
||||
nccl_bridge.cleanup()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during NCCL cleanup: {e}")
|
||||
|
||||
nccl_bridge = None
|
||||
nccl_setup_thread = None
|
||||
|
||||
return JSONResponse({"status": "ok", "message": "NCCL receiver stopped"})
|
||||
|
||||
|
|
@ -740,14 +849,32 @@ async def nccl_stop_receiver() -> JSONResponse:
|
|||
@app.get("/nccl/status")
|
||||
async def nccl_status() -> JSONResponse:
|
||||
"""Get NCCL receiver status."""
|
||||
global nccl_setup_error
|
||||
|
||||
if nccl_setup_thread is not None and nccl_setup_thread.is_alive():
|
||||
return JSONResponse({
|
||||
"active": False,
|
||||
"status": "connecting",
|
||||
"message": "NCCL setup in progress...",
|
||||
})
|
||||
|
||||
if nccl_setup_error is not None:
|
||||
return JSONResponse({
|
||||
"active": False,
|
||||
"status": "error",
|
||||
"error": nccl_setup_error,
|
||||
})
|
||||
|
||||
if nccl_bridge is None:
|
||||
return JSONResponse({
|
||||
"active": False,
|
||||
"status": "not_started",
|
||||
"update_count": 0,
|
||||
})
|
||||
|
||||
return JSONResponse({
|
||||
"active": nccl_bridge.is_initialized,
|
||||
"status": "connected" if nccl_bridge.is_initialized else "disconnected",
|
||||
"update_count": nccl_bridge.update_count,
|
||||
"last_update_time": nccl_bridge.last_update_time,
|
||||
"num_params": len(nccl_bridge.param_names),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue