daemon errors

This commit is contained in:
Jai Suphavadeeprasit 2025-12-29 19:58:31 -05:00
parent 4348345dac
commit e2c99f7f97
2 changed files with 56 additions and 45 deletions

View file

@ -295,25 +295,24 @@ class VLLMWeightBridge:
self._load_param_mappings()
# Calculate group sizes
self._num_training_gpus = (
self.config.world_size *
(1 if self.config.num_inference_nodes == 0 else 8) # Assume 8 GPUs/node
)
# For single-node mode (num_inference_nodes=0):
# - Simple setup: 1 trainer + 1 inference daemon = 2 ranks
# For multi-node mode:
# - More complex based on SLURM allocation
if self.config.num_inference_nodes == 0:
# Single node: some GPUs for training, some for inference
num_inference_gpus = 4 # Default: 4 GPUs for inference
self._num_training_gpus = torch.cuda.device_count() - num_inference_gpus
# Single node: simple 2-rank setup
self._num_training_gpus = 1
num_inference_gpus = 1
else:
# Multi-node: 8 GPUs per node
self._num_training_gpus = self.config.world_size * 8
num_inference_gpus = self.config.num_inference_nodes * 8
num_inference_gpus = (
self.config.num_inference_nodes * 8
if self.config.num_inference_nodes > 0
else 4
)
self._total_group_size = self._num_training_gpus + num_inference_gpus
print(f"[Bridge] Training GPUs: {self._num_training_gpus}")
print(f"[Bridge] Inference GPUs: {num_inference_gpus}")
print(f"[Bridge] Training ranks: {self._num_training_gpus}")
print(f"[Bridge] Inference ranks: {num_inference_gpus}")
print(f"[Bridge] Total group size: {self._total_group_size}")
# Create Gloo group (for coordination)