diff --git a/example_trainer/vllm_patching/weight_updater.py b/example_trainer/vllm_patching/weight_updater.py index da5e2bfe..d9a67cde 100644 --- a/example_trainer/vllm_patching/weight_updater.py +++ b/example_trainer/vllm_patching/weight_updater.py @@ -56,21 +56,9 @@ def weight_updater_process( """ # Configuration from environment num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", 0)) - cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",") debug = int(os.environ.get("WEIGHT_UPDATER_DEBUG", 0)) - # Determine world size based on setup - if num_inference_nodes > 0: - # Multi-node: 8 GPUs per node - world_size = num_inference_nodes * 8 - ranks_per_node = 8 - else: - # Single node: typically 4 inference GPUs - world_size = 4 - ranks_per_node = 4 - # Get network info - hostnames = get_hostnames() master_addr, master_gloo_addr, master_inference_addr, urls = get_inference_urls( num_inference_nodes ) @@ -87,41 +75,63 @@ def weight_updater_process( f"q_heads={num_q_heads}, kv_heads={num_kv_heads}, gpu_id={gpu_id}", flush=True, ) - print(f"[Updater] Master: {master_addr}, world_size={world_size}", flush=True) - # Determine this worker's rank within the inference group - rank = -1 + # For single-node mode (num_inference_nodes=0): + # - Trainer is rank 0 + # - Inference daemon is rank 1 (or tp_rank + 1 for multi-GPU) + # Total world size = 1 trainer + 1 inference = 2 + # + # For multi-node mode: + # - More complex, based on SLURM node allocation + if num_inference_nodes == 0: - # Single node: skip first N GPUs (used by trainer) - rank = int(cuda_devices[gpu_id]) - (8 - ranks_per_node) + # Single node: simple setup + # World = [trainer (rank 0), inference daemon (rank 1)] + num_training_ranks = 1 + num_inference_ranks = 1 + world_size = num_training_ranks + num_inference_ranks + rank = num_training_ranks + tp_rank # Daemon is rank 1 else: - # Multi-node: find which inference node we're on - for i, url in enumerate(urls): + # Multi-node: 8 GPUs per node + hostnames = get_hostnames() + cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",") + ranks_per_node = 8 + world_size = num_inference_nodes * ranks_per_node + + rank = -1 + for i, url in enumerate(urls or []): if hostnames and url in hostnames: rank = ranks_per_node * i + int(cuda_devices[gpu_id]) break + + if rank < 0: + print(f"[Updater] Could not determine rank for multi-node, exiting", flush=True) + return - if rank < 0: - print(f"[Updater] Could not determine rank, exiting", flush=True) - return + print(f"[Updater] Master: {master_addr}, world_size={world_size}, my_rank={rank}", flush=True) - # Load config from vLLM + # Load config from vLLM (optional - may not exist for simple setups) print("[Updater] Loading bridge config...", flush=True) + json_data = {} + param_name_list = [] try: json_data = get_json_data() + param_name_list = sorted(json_data.get("param_mappings", {}).keys()) + print(f"[Updater] Loaded {len(param_name_list)} parameter mappings", flush=True) except Exception as e: - print(f"[Updater] Failed to load config: {e}", flush=True) - return + print(f"[Updater] No config file found (will receive all params): {e}", flush=True) - param_name_list = sorted(json_data.get("param_mappings", {}).keys()) - num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1) - total_group_size = num_training_gpus + world_size + # Use the world_size and rank we already calculated + total_group_size = world_size - # Offset rank by training GPUs - rank = rank + num_training_gpus + # For single-node mode, trainer is rank 0 + if num_inference_nodes == 0: + num_training_gpus = 1 + else: + num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1) print(f"[Updater] Total group size: {total_group_size}", flush=True) - print(f"[Updater] Training GPUs: {num_training_gpus}", flush=True) + print(f"[Updater] Training ranks: {num_training_gpus}", flush=True) print(f"[Updater] My rank: {rank}", flush=True) # Initialize process groups @@ -150,6 +160,8 @@ def weight_updater_process( except Exception as e: print(f"[Updater] Failed to create process groups: {e}", flush=True) + import traceback + traceback.print_exc() return # Get device for tensors diff --git a/example_trainer/vllm_weight_bridge.py b/example_trainer/vllm_weight_bridge.py index 6ea2c512..e8c494a6 100644 --- a/example_trainer/vllm_weight_bridge.py +++ b/example_trainer/vllm_weight_bridge.py @@ -295,25 +295,24 @@ class VLLMWeightBridge: self._load_param_mappings() # Calculate group sizes - self._num_training_gpus = ( - self.config.world_size * - (1 if self.config.num_inference_nodes == 0 else 8) # Assume 8 GPUs/node - ) + # For single-node mode (num_inference_nodes=0): + # - Simple setup: 1 trainer + 1 inference daemon = 2 ranks + # For multi-node mode: + # - More complex based on SLURM allocation if self.config.num_inference_nodes == 0: - # Single node: some GPUs for training, some for inference - num_inference_gpus = 4 # Default: 4 GPUs for inference - self._num_training_gpus = torch.cuda.device_count() - num_inference_gpus + # Single node: simple 2-rank setup + self._num_training_gpus = 1 + num_inference_gpus = 1 + else: + # Multi-node: 8 GPUs per node + self._num_training_gpus = self.config.world_size * 8 + num_inference_gpus = self.config.num_inference_nodes * 8 - num_inference_gpus = ( - self.config.num_inference_nodes * 8 - if self.config.num_inference_nodes > 0 - else 4 - ) self._total_group_size = self._num_training_gpus + num_inference_gpus - print(f"[Bridge] Training GPUs: {self._num_training_gpus}") - print(f"[Bridge] Inference GPUs: {num_inference_gpus}") + print(f"[Bridge] Training ranks: {self._num_training_gpus}") + print(f"[Bridge] Inference ranks: {num_inference_gpus}") print(f"[Bridge] Total group size: {self._total_group_size}") # Create Gloo group (for coordination)