mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
clearing more bloat
This commit is contained in:
parent
ab8d2f2dac
commit
036b87e921
4 changed files with 27 additions and 682 deletions
|
|
@ -1,13 +1,15 @@
|
|||
"""
|
||||
Patched GPU Model Runner - Enables shared memory weight updates.
|
||||
Patched GPU Model Runner - Enables CUDA IPC for single-copy training.
|
||||
|
||||
This patches vLLM's GPUModelRunner to:
|
||||
1. Call share_memory_() on model weights after loading
|
||||
2. Spawn a daemon process that receives NCCL weight updates from trainers
|
||||
2. Export CUDA IPC handles to vllm_bridge_config.json
|
||||
|
||||
The key insight is that share_memory_() makes tensors accessible from
|
||||
multiple processes. The daemon receives updates via NCCL and copies them
|
||||
directly into the shared tensors, which vLLM reads for inference.
|
||||
The key insight is that CUDA IPC handles allow the trainer process to
|
||||
attach to the EXACT SAME GPU memory that vLLM uses. This means:
|
||||
- ONE copy of model weights in GPU memory
|
||||
- Trainer's optimizer.step() updates vLLM's weights directly
|
||||
- No synchronization needed - vLLM immediately sees new weights
|
||||
|
||||
CRITICAL: This module must be imported and apply_patches() called BEFORE
|
||||
any vLLM imports. The patches MUST happen before vLLM caches module references.
|
||||
|
|
@ -119,28 +121,24 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
Create a patched GPUModelRunner class.
|
||||
|
||||
Returns a new class that inherits from the original and adds
|
||||
shared memory + daemon functionality.
|
||||
CUDA IPC export functionality for single-copy training.
|
||||
"""
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from .weight_updater import weight_updater_process
|
||||
|
||||
class PatchedGPUModelRunner(BaseRunner):
|
||||
"""
|
||||
Patched GPUModelRunner that enables shared memory weight updates.
|
||||
Patched GPUModelRunner that enables CUDA IPC for single-copy training.
|
||||
|
||||
After loading the model, this:
|
||||
1. Calls share_memory_() on all parameters to make them accessible
|
||||
from other processes
|
||||
2. Spawns a daemon process that joins NCCL groups with the trainer
|
||||
and receives weight updates
|
||||
|
||||
The daemon copies updates directly into the shared tensors, so
|
||||
vLLM immediately sees the new weights for inference.
|
||||
1. Calls share_memory_() on all parameters
|
||||
2. Exports CUDA IPC handles to vllm_bridge_config.json
|
||||
|
||||
The trainer reads these IPC handles and attaches to the SAME
|
||||
GPU memory, so optimizer.step() updates weights that vLLM
|
||||
immediately sees for inference.
|
||||
"""
|
||||
|
||||
_shared_memory_setup_done = False
|
||||
weight_updater_process = None
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
"""Load model and set up shared memory + update daemon."""
|
||||
|
|
@ -171,27 +169,12 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
self._setup_shared_memory()
|
||||
PatchedGPUModelRunner._shared_memory_setup_done = True
|
||||
print("[vLLM Patch] ✓ Shared memory setup complete!", flush=True)
|
||||
print("[vLLM Patch] ✓ IPC handles exported - trainer can now attach!", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] ERROR in _setup_shared_memory: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Spawn weight updater daemon (optional - can be skipped for HTTP-only mode)
|
||||
skip_daemon = os.environ.get("VLLM_SKIP_WEIGHT_DAEMON", "0") == "1"
|
||||
if skip_daemon:
|
||||
print("[vLLM Patch] Skipping weight updater daemon (VLLM_SKIP_WEIGHT_DAEMON=1)", flush=True)
|
||||
return
|
||||
|
||||
try:
|
||||
print("[vLLM Patch] Spawning weight updater daemon...", flush=True)
|
||||
self._spawn_weight_updater()
|
||||
print("[vLLM Patch] ✓ Weight updater daemon spawned!", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] ERROR spawning weight updater: {e}", flush=True)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print("[vLLM Patch] Continuing without daemon (HTTP-only mode)", flush=True)
|
||||
|
||||
def _setup_shared_memory(self) -> None:
|
||||
"""Move model tensors to shared memory and export param info."""
|
||||
|
|
@ -326,70 +309,6 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def _spawn_weight_updater(self) -> None:
|
||||
"""Start the weight updater as a background thread.
|
||||
|
||||
Note: We use threading instead of multiprocessing because vLLM's
|
||||
worker processes are daemons, and daemons cannot spawn child processes.
|
||||
"""
|
||||
import threading
|
||||
|
||||
print("[vLLM Patch] _spawn_weight_updater() called", flush=True)
|
||||
|
||||
try:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
print("[vLLM Patch] Imported get_tensor_model_parallel_rank", flush=True)
|
||||
except ImportError as e:
|
||||
print(f"[vLLM Patch] Could not import get_tensor_model_parallel_rank: {e}", flush=True)
|
||||
get_tensor_model_parallel_rank = lambda: 0
|
||||
|
||||
# Get model configuration
|
||||
state_dict = self.model.state_dict()
|
||||
print(f"[vLLM Patch] Got state_dict with {len(state_dict)} params", flush=True)
|
||||
|
||||
# Get attention head counts
|
||||
hf_config = self.model_config.hf_text_config
|
||||
num_heads = getattr(hf_config, "num_attention_heads", 0)
|
||||
num_kv_heads = self.model_config.get_total_num_kv_heads()
|
||||
print(f"[vLLM Patch] num_heads={num_heads}, num_kv_heads={num_kv_heads}", flush=True)
|
||||
|
||||
# Get parallel configuration
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
print(f"[vLLM Patch] tp_rank={tp_rank}", flush=True)
|
||||
|
||||
# Get GPU ID
|
||||
gpu_id = 0
|
||||
try:
|
||||
if hasattr(self, 'device'):
|
||||
if hasattr(self.device, 'index'):
|
||||
gpu_id = self.device.index or 0
|
||||
elif isinstance(self.device, int):
|
||||
gpu_id = self.device
|
||||
except Exception:
|
||||
gpu_id = tp_rank
|
||||
|
||||
print(f"[vLLM Patch] Starting weight updater thread: tp_rank={tp_rank}, gpu={gpu_id}", flush=True)
|
||||
|
||||
# Start as a daemon thread (threads CAN be started from daemon processes)
|
||||
self.weight_updater_thread = threading.Thread(
|
||||
target=weight_updater_process,
|
||||
args=(
|
||||
state_dict,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
tp_rank,
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
gpu_id,
|
||||
),
|
||||
daemon=True,
|
||||
name=f"WeightUpdater_TP{tp_rank}",
|
||||
)
|
||||
|
||||
print("[vLLM Patch] Starting thread...", flush=True)
|
||||
self.weight_updater_thread.start()
|
||||
|
||||
print(f"[vLLM Patch] ✓ Weight updater thread started (name: {self.weight_updater_thread.name})", flush=True)
|
||||
|
||||
# Set proper class name
|
||||
PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner"
|
||||
PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue