mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
daemon errors
This commit is contained in:
parent
4348345dac
commit
e2c99f7f97
2 changed files with 56 additions and 45 deletions
|
|
@ -56,21 +56,9 @@ def weight_updater_process(
|
|||
"""
|
||||
# Configuration from environment
|
||||
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", 0))
|
||||
cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",")
|
||||
debug = int(os.environ.get("WEIGHT_UPDATER_DEBUG", 0))
|
||||
|
||||
# Determine world size based on setup
|
||||
if num_inference_nodes > 0:
|
||||
# Multi-node: 8 GPUs per node
|
||||
world_size = num_inference_nodes * 8
|
||||
ranks_per_node = 8
|
||||
else:
|
||||
# Single node: typically 4 inference GPUs
|
||||
world_size = 4
|
||||
ranks_per_node = 4
|
||||
|
||||
# Get network info
|
||||
hostnames = get_hostnames()
|
||||
master_addr, master_gloo_addr, master_inference_addr, urls = get_inference_urls(
|
||||
num_inference_nodes
|
||||
)
|
||||
|
|
@ -87,41 +75,63 @@ def weight_updater_process(
|
|||
f"q_heads={num_q_heads}, kv_heads={num_kv_heads}, gpu_id={gpu_id}",
|
||||
flush=True,
|
||||
)
|
||||
print(f"[Updater] Master: {master_addr}, world_size={world_size}", flush=True)
|
||||
|
||||
# Determine this worker's rank within the inference group
|
||||
rank = -1
|
||||
# For single-node mode (num_inference_nodes=0):
|
||||
# - Trainer is rank 0
|
||||
# - Inference daemon is rank 1 (or tp_rank + 1 for multi-GPU)
|
||||
# Total world size = 1 trainer + 1 inference = 2
|
||||
#
|
||||
# For multi-node mode:
|
||||
# - More complex, based on SLURM node allocation
|
||||
|
||||
if num_inference_nodes == 0:
|
||||
# Single node: skip first N GPUs (used by trainer)
|
||||
rank = int(cuda_devices[gpu_id]) - (8 - ranks_per_node)
|
||||
# Single node: simple setup
|
||||
# World = [trainer (rank 0), inference daemon (rank 1)]
|
||||
num_training_ranks = 1
|
||||
num_inference_ranks = 1
|
||||
world_size = num_training_ranks + num_inference_ranks
|
||||
rank = num_training_ranks + tp_rank # Daemon is rank 1
|
||||
else:
|
||||
# Multi-node: find which inference node we're on
|
||||
for i, url in enumerate(urls):
|
||||
# Multi-node: 8 GPUs per node
|
||||
hostnames = get_hostnames()
|
||||
cuda_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "0")).split(",")
|
||||
ranks_per_node = 8
|
||||
world_size = num_inference_nodes * ranks_per_node
|
||||
|
||||
rank = -1
|
||||
for i, url in enumerate(urls or []):
|
||||
if hostnames and url in hostnames:
|
||||
rank = ranks_per_node * i + int(cuda_devices[gpu_id])
|
||||
break
|
||||
|
||||
if rank < 0:
|
||||
print(f"[Updater] Could not determine rank for multi-node, exiting", flush=True)
|
||||
return
|
||||
|
||||
if rank < 0:
|
||||
print(f"[Updater] Could not determine rank, exiting", flush=True)
|
||||
return
|
||||
print(f"[Updater] Master: {master_addr}, world_size={world_size}, my_rank={rank}", flush=True)
|
||||
|
||||
# Load config from vLLM
|
||||
# Load config from vLLM (optional - may not exist for simple setups)
|
||||
print("[Updater] Loading bridge config...", flush=True)
|
||||
json_data = {}
|
||||
param_name_list = []
|
||||
try:
|
||||
json_data = get_json_data()
|
||||
param_name_list = sorted(json_data.get("param_mappings", {}).keys())
|
||||
print(f"[Updater] Loaded {len(param_name_list)} parameter mappings", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Updater] Failed to load config: {e}", flush=True)
|
||||
return
|
||||
print(f"[Updater] No config file found (will receive all params): {e}", flush=True)
|
||||
|
||||
param_name_list = sorted(json_data.get("param_mappings", {}).keys())
|
||||
num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1)
|
||||
total_group_size = num_training_gpus + world_size
|
||||
# Use the world_size and rank we already calculated
|
||||
total_group_size = world_size
|
||||
|
||||
# Offset rank by training GPUs
|
||||
rank = rank + num_training_gpus
|
||||
# For single-node mode, trainer is rank 0
|
||||
if num_inference_nodes == 0:
|
||||
num_training_gpus = 1
|
||||
else:
|
||||
num_training_gpus = json_data.get("dp_shard_degree", 1) * json_data.get("tp_degree", 1)
|
||||
|
||||
print(f"[Updater] Total group size: {total_group_size}", flush=True)
|
||||
print(f"[Updater] Training GPUs: {num_training_gpus}", flush=True)
|
||||
print(f"[Updater] Training ranks: {num_training_gpus}", flush=True)
|
||||
print(f"[Updater] My rank: {rank}", flush=True)
|
||||
|
||||
# Initialize process groups
|
||||
|
|
@ -150,6 +160,8 @@ def weight_updater_process(
|
|||
|
||||
except Exception as e:
|
||||
print(f"[Updater] Failed to create process groups: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Get device for tensors
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue