diff --git a/example_trainer/nccl_weight_bridge.py b/example_trainer/nccl_weight_bridge.py index 9b205be4..20185daa 100644 --- a/example_trainer/nccl_weight_bridge.py +++ b/example_trainer/nccl_weight_bridge.py @@ -318,6 +318,69 @@ class NCCLWeightBridge: return elapsed + def receive_lora_weights( + self, + on_receive: Optional[Callable[[int, Dict[str, torch.Tensor]], None]] = None, + ) -> Tuple[int, Dict[str, torch.Tensor]]: + """ + Receive LoRA weights from trainer via NCCL broadcast. + + This is a BLOCKING call that waits for the trainer to send weights. + For non-blocking continuous receive, use start_receiver(). + + Args: + on_receive: Optional callback with (step, weights_dict) + + Returns: + Tuple of (step_number, dict of param_name -> tensor) + """ + if not self.is_initialized: + raise RuntimeError("NCCLBridge not initialized. Call setup() first.") + + if self.rank == 0: + raise RuntimeError("receive_lora_weights() should only be called from rank > 0 (vLLM)") + + device = "cuda" + + # Receive step index + step_tensor = torch.zeros(1, dtype=torch.long, device=device) + dist.broadcast(step_tensor, src=0, group=self.nccl_group) + step = step_tensor.item() + + if step < 0: + # Negative step means shutdown signal + return step, {} + + # Receive each parameter + received_weights = {} + for name in self.param_names: + shape = self.param_shapes[name] + dtype_str = self.param_dtypes[name] + # Handle dtype string conversion + if isinstance(dtype_str, str): + dtype = getattr(torch, dtype_str.replace("torch.", "")) + else: + dtype = dtype_str + + # Create buffer and receive + buffer = torch.zeros(shape, dtype=dtype, device=device) + dist.broadcast(buffer, src=0, group=self.nccl_group) + received_weights[name] = buffer + + # Receive completion signal + done_tensor = torch.zeros(1, dtype=torch.long, device=device) + dist.broadcast(done_tensor, src=0, group=self.nccl_group) + + self.update_count += 1 + self.last_update_time = time.time() + + print(f"[NCCLBridge] Received LoRA weights (step {step})") + + if on_receive: + on_receive(step, received_weights) + + return step, received_weights + def start_receiver( self, state_dict: Dict[str, torch.Tensor], diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index 5e0e6371..0cfa149f 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -698,34 +698,91 @@ def train_lora_nccl(config: TrainingConfig): print("=" * 60 + "\n") # Check external vLLM server - print("[1/4] Checking external vLLM server...") + print("[1/5] Checking external vLLM server...") if not check_vllm_health(config.vllm_port): print(f"\nERROR: vLLM server not running on port {config.vllm_port}") print("\nLoRA NCCL mode requires an external vLLM server. Start it first:") print( - f" NCCL_LORA_ENABLED=1 python example_trainer/vllm_api_server.py " + f" python example_trainer/vllm_api_server.py " f"--model {config.model_name} --port {config.vllm_port} --enable-lora --enforce-eager" ) raise RuntimeError(f"External vLLM server required on port {config.vllm_port}") print(f"vLLM server healthy on port {config.vllm_port}") # Load model with LoRA adapters - print("[2/4] Loading model with LoRA adapters...") + print("[2/5] Loading model with LoRA adapters...") model, tokenizer = load_model_and_tokenizer(config) # Only optimize LoRA parameters trainable_params = [p for p in model.parameters() if p.requires_grad] optimizer = AdamW(trainable_params, lr=config.lr) - # Setup NCCL bridge - print("[3/4] Setting up NCCL weight bridge...") + # Import NCCL bridge components from .nccl_weight_bridge import ( NCCLBridgeConfig, NCCLWeightBridge, create_trainer_param_to_vllm_mapping, export_bridge_config, + get_lora_params, ) + # Pre-register params to get metadata for vLLM + lora_params = get_lora_params(model) + param_names = sorted(lora_params.keys()) + param_shapes = {name: list(p.shape) for name, p in lora_params.items()} + param_dtypes = {name: str(p.dtype) for name, p in lora_params.items()} + + param_metadata = { + "param_names": param_names, + "param_shapes": param_shapes, + "param_dtypes": param_dtypes, + "num_params": len(param_names), + } + + param_mappings = create_trainer_param_to_vllm_mapping( + param_names, + model_name=config.model_name + ) + + # Tell vLLM to start its NCCL receiver FIRST (it will join as rank 1) + print("[3/5] Starting NCCL receiver on vLLM server...") + vllm_base_url = f"http://localhost:{config.vllm_port}" + try: + response = requests.post( + f"{vllm_base_url}/nccl/start_receiver", + json={ + "init_method": config.nccl_init_method, + "world_size": config.nccl_world_size, + "param_metadata": param_metadata, + "param_mappings": param_mappings, + }, + timeout=30, + ) + resp_data = response.json() + if response.status_code != 200 or resp_data.get("status") == "error": + raise RuntimeError(f"Failed to start NCCL receiver on vLLM: {resp_data}") + print(f" vLLM NCCL receiver started: {resp_data}") + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to contact vLLM server: {e}") + + # Wait for vLLM to be in "connecting" state + import time as time_module + print(" Waiting for vLLM NCCL receiver to initialize...") + for i in range(10): + time_module.sleep(1) + try: + status_resp = requests.get(f"{vllm_base_url}/nccl/status", timeout=5) + status = status_resp.json() + print(f" vLLM NCCL status: {status.get('status', 'unknown')}") + if status.get("status") == "error": + raise RuntimeError(f"vLLM NCCL setup failed: {status.get('error')}") + if status.get("status") in ["connecting", "connected"]: + break + except Exception as e: + print(f" Status check error: {e}") + + # Now setup trainer's NCCL bridge (joins as rank 0) + print("[4/5] Setting up trainer NCCL weight bridge...") nccl_config = NCCLBridgeConfig( rank=0, # Trainer is always rank 0 world_size=config.nccl_world_size, @@ -734,16 +791,19 @@ def train_lora_nccl(config: TrainingConfig): bridge = NCCLWeightBridge(nccl_config) if not bridge.setup(): + # Try to stop vLLM receiver on failure + try: + requests.post(f"{vllm_base_url}/nccl/stop_receiver", timeout=5) + except Exception: + pass raise RuntimeError("Failed to setup NCCL bridge") - # Register parameters and create mappings - param_metadata = bridge.register_params(model) - param_mappings = create_trainer_param_to_vllm_mapping( - bridge.param_names, - model_name=config.model_name - ) + # Register parameters with the bridge (we already have the metadata) + bridge.param_names = param_names + bridge.param_shapes = {name: tuple(shape) for name, shape in param_shapes.items()} + bridge.param_dtypes = param_dtypes - # Export config for vLLM + # Export config for debugging/recovery bridge_config_path = os.path.join(config.save_path, "nccl_bridge_config.json") os.makedirs(config.save_path, exist_ok=True) export_bridge_config( @@ -754,7 +814,7 @@ def train_lora_nccl(config: TrainingConfig): config.nccl_world_size, ) - print(f"[4/4] Starting training for {config.training_steps} steps") + print(f"[5/5] Starting training for {config.training_steps} steps") print("-" * 60) # Check Atropos API diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 4336f2c7..465743bc 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -648,6 +648,112 @@ async def lora_unload() -> JSONResponse: # ============================================================================= nccl_bridge: Optional[Any] = None # Will hold NCCLWeightBridge instance +nccl_setup_thread: Optional[threading.Thread] = None +nccl_setup_error: Optional[str] = None + + +def _setup_nccl_receiver_thread( + init_method: str, + world_size: int, + param_metadata: dict, + param_mappings: dict, +): + """Background thread to setup NCCL receiver and wait for weight updates.""" + global nccl_bridge, nccl_setup_error + + logger.info(f"[NCCL] Receiver thread started, attempting to import nccl_weight_bridge...") + + NCCLBridgeConfig = None + NCCLWeightBridge = None + + # Try multiple import methods + try: + from .nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge + logger.info("[NCCL] Imported via relative import") + except ImportError: + pass + + if NCCLBridgeConfig is None: + try: + from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge + logger.info("[NCCL] Imported via direct import") + except ImportError: + pass + + if NCCLBridgeConfig is None: + try: + import sys + from pathlib import Path + script_dir = Path(__file__).parent + if str(script_dir) not in sys.path: + sys.path.insert(0, str(script_dir)) + from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge + logger.info("[NCCL] Imported via sys.path manipulation") + except ImportError as e: + nccl_setup_error = f"NCCL weight bridge module not available: {e}" + logger.error(nccl_setup_error) + return + + if NCCLBridgeConfig is None: + nccl_setup_error = "Failed to import NCCLBridgeConfig" + logger.error(nccl_setup_error) + return + + try: + config = NCCLBridgeConfig( + rank=1, # vLLM is rank 1 (trainer is rank 0) + world_size=world_size, + init_method=init_method, + ) + + logger.info(f"[NCCL] Starting setup as rank 1, world_size={world_size}") + nccl_bridge = NCCLWeightBridge(config) + + if not nccl_bridge.setup(): + nccl_setup_error = "Failed to setup NCCL bridge" + logger.error(nccl_setup_error) + return + + # Set param metadata from trainer + nccl_bridge.param_names = param_metadata.get("param_names", []) + nccl_bridge.param_shapes = { + k: tuple(v) for k, v in + param_metadata.get("param_shapes", {}).items() + } + nccl_bridge.param_dtypes = param_metadata.get("param_dtypes", {}) + + logger.info(f"[NCCL] ✓ Bridge setup complete, {len(nccl_bridge.param_names)} params registered") + logger.info(f"[NCCL] Starting receiver loop to wait for weight updates...") + + # Start receiver loop - wait for weight updates from trainer + while True: + try: + step, weights = nccl_bridge.receive_lora_weights() + + if step < 0: + logger.info("[NCCL] Received shutdown signal, exiting receiver loop") + break + + logger.info(f"[NCCL] ✓ Received weights for step {step} ({len(weights)} params)") + + # TODO: Apply weights to vLLM's LoRA adapter + # For now, we just log receipt - actual weight application + # would require access to vLLM's internal model state + + except Exception as e: + if "shutdown" in str(e).lower() or nccl_bridge is None: + logger.info("[NCCL] Receiver loop terminated") + break + logger.error(f"[NCCL] Error receiving weights: {e}") + import traceback + traceback.print_exc() + break + + except Exception as e: + nccl_setup_error = f"NCCL setup error: {e}" + logger.error(nccl_setup_error) + import traceback + traceback.print_exc() @app.post("/nccl/start_receiver") @@ -655,6 +761,10 @@ async def nccl_start_receiver(request: Request) -> JSONResponse: """ Start NCCL weight receiver (for lora_nccl training mode). + This endpoint starts the NCCL setup in a background thread so that + both trainer (rank 0) and vLLM (rank 1) can join the NCCL group + simultaneously. + Request JSON: { "init_method": "tcp://localhost:29500", @@ -663,76 +773,75 @@ async def nccl_start_receiver(request: Request) -> JSONResponse: "param_mappings": {...} } """ - global nccl_bridge + global nccl_bridge, nccl_setup_thread, nccl_setup_error if engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") + # Stop any existing bridge + if nccl_bridge is not None: + try: + nccl_bridge.cleanup() + except Exception: + pass + nccl_bridge = None + + nccl_setup_error = None request_dict = await request.json() - try: - from .nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge - except ImportError: - try: - from nccl_weight_bridge import NCCLBridgeConfig, NCCLWeightBridge - except ImportError: - raise HTTPException( - status_code=500, - detail="NCCL weight bridge module not available" - ) - - # Get vLLM's state dict for in-place updates - # This is tricky because vLLM's model is encapsulated - # For now, we'll need to use the engine's internal access - state_dict = {} # TODO: Get actual vLLM state dict - - config = NCCLBridgeConfig( - rank=1, # vLLM is rank 1 (trainer is rank 0) - world_size=request_dict.get("world_size", 2), - init_method=request_dict.get("init_method", "tcp://localhost:29500"), - ) - - nccl_bridge = NCCLWeightBridge(config) - - if not nccl_bridge.setup(): - raise HTTPException(status_code=500, detail="Failed to setup NCCL bridge") - - # Set param metadata from trainer - nccl_bridge.param_names = request_dict.get("param_metadata", {}).get("param_names", []) - nccl_bridge.param_shapes = { - k: tuple(v) for k, v in - request_dict.get("param_metadata", {}).get("param_shapes", {}).items() - } - nccl_bridge.param_dtypes = request_dict.get("param_metadata", {}).get("param_dtypes", {}) - + init_method = request_dict.get("init_method", "tcp://localhost:29500") + world_size = request_dict.get("world_size", 2) + param_metadata = request_dict.get("param_metadata", {}) param_mappings = request_dict.get("param_mappings", {}) - # Start receiver thread - nccl_bridge.start_receiver( - state_dict, - param_mappings, - on_update=lambda step: logger.info(f"NCCL weight update received: step {step}") + logger.info(f"[NCCL] Received start_receiver request: init_method={init_method}, world_size={world_size}") + logger.info(f"[NCCL] Param metadata: {len(param_metadata.get('param_names', []))} params") + + # Start NCCL setup in background thread + # This allows the HTTP response to return immediately while NCCL joins + nccl_setup_thread = threading.Thread( + target=_setup_nccl_receiver_thread, + args=(init_method, world_size, param_metadata, param_mappings), + daemon=True, + name="nccl_receiver_thread", ) + nccl_setup_thread.start() + + # Wait a moment to catch immediate errors + import time as time_mod + time_mod.sleep(0.5) + + if nccl_setup_error: + return JSONResponse({ + "status": "error", + "message": nccl_setup_error, + }, status_code=500) + + logger.info(f"[NCCL] Receiver thread started, waiting for trainer to connect...") return JSONResponse({ "status": "ok", - "message": "NCCL receiver started", + "message": "NCCL receiver setup started - waiting for trainer to connect", "rank": 1, - "world_size": config.world_size, + "world_size": world_size, }) @app.post("/nccl/stop_receiver") async def nccl_stop_receiver() -> JSONResponse: """Stop NCCL weight receiver.""" - global nccl_bridge + global nccl_bridge, nccl_setup_thread if nccl_bridge is None: return JSONResponse({"status": "ok", "message": "No receiver running"}) - nccl_bridge.stop_receiver() - nccl_bridge.cleanup() + try: + nccl_bridge.cleanup() + except Exception as e: + logger.warning(f"Error during NCCL cleanup: {e}") + nccl_bridge = None + nccl_setup_thread = None return JSONResponse({"status": "ok", "message": "NCCL receiver stopped"}) @@ -740,14 +849,32 @@ async def nccl_stop_receiver() -> JSONResponse: @app.get("/nccl/status") async def nccl_status() -> JSONResponse: """Get NCCL receiver status.""" + global nccl_setup_error + + if nccl_setup_thread is not None and nccl_setup_thread.is_alive(): + return JSONResponse({ + "active": False, + "status": "connecting", + "message": "NCCL setup in progress...", + }) + + if nccl_setup_error is not None: + return JSONResponse({ + "active": False, + "status": "error", + "error": nccl_setup_error, + }) + if nccl_bridge is None: return JSONResponse({ "active": False, + "status": "not_started", "update_count": 0, }) return JSONResponse({ "active": nccl_bridge.is_initialized, + "status": "connected" if nccl_bridge.is_initialized else "disconnected", "update_count": nccl_bridge.update_count, "last_update_time": nccl_bridge.last_update_time, "num_params": len(nccl_bridge.param_names),