mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
4740dfa216
commit
fe2fd3d824
5 changed files with 510 additions and 337 deletions
|
|
@ -34,76 +34,79 @@ _PATCHED_RUNNER_CLASS = None
|
|||
def apply_patches() -> bool:
|
||||
"""
|
||||
Apply patches to vLLM's GPUModelRunner in ALL locations.
|
||||
|
||||
|
||||
This must be called BEFORE importing vLLM's engine classes.
|
||||
Safe to call multiple times (idempotent).
|
||||
|
||||
|
||||
Returns True if patches were applied successfully.
|
||||
|
||||
|
||||
Usage:
|
||||
# CRITICAL: Import and call BEFORE any vLLM imports!
|
||||
import os
|
||||
os.environ["VLLM_ENABLE_SHARED_WEIGHTS"] = "1"
|
||||
|
||||
|
||||
from example_trainer.vllm_patching import apply_patches
|
||||
apply_patches()
|
||||
|
||||
|
||||
# Now import vLLM
|
||||
from vllm import AsyncLLM # Uses patched runner
|
||||
"""
|
||||
global _PATCHES_APPLIED, _PATCHED_RUNNER_CLASS
|
||||
|
||||
|
||||
if _PATCHES_APPLIED:
|
||||
return True
|
||||
|
||||
|
||||
try:
|
||||
# Import the source module and get original class
|
||||
import vllm.v1.worker.gpu_model_runner as gpu_model_runner_module
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner as OriginalRunner
|
||||
|
||||
|
||||
# Create the patched class
|
||||
PatchedRunner = _create_patched_runner(OriginalRunner)
|
||||
_PATCHED_RUNNER_CLASS = PatchedRunner
|
||||
|
||||
|
||||
# =================================================================
|
||||
# PATCH 1: Replace in source module
|
||||
# =================================================================
|
||||
gpu_model_runner_module.GPUModelRunner = PatchedRunner
|
||||
print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_model_runner.GPUModelRunner")
|
||||
|
||||
|
||||
# =================================================================
|
||||
# PATCH 2: Replace in gpu_worker module (main usage location)
|
||||
# =================================================================
|
||||
try:
|
||||
import vllm.v1.worker.gpu_worker as gpu_worker_module
|
||||
|
||||
gpu_worker_module.GPUModelRunner = PatchedRunner
|
||||
print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_worker.GPUModelRunner")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# =================================================================
|
||||
# PATCH 3: Update sys.modules entry for source module
|
||||
# =================================================================
|
||||
# This ensures new imports get the patched version
|
||||
if 'vllm.v1.worker.gpu_model_runner' in sys.modules:
|
||||
sys.modules['vllm.v1.worker.gpu_model_runner'].GPUModelRunner = PatchedRunner
|
||||
|
||||
if "vllm.v1.worker.gpu_model_runner" in sys.modules:
|
||||
sys.modules["vllm.v1.worker.gpu_model_runner"].GPUModelRunner = (
|
||||
PatchedRunner
|
||||
)
|
||||
|
||||
# =================================================================
|
||||
# PATCH 4: Patch GPUWorker if already imported
|
||||
# =================================================================
|
||||
try:
|
||||
if 'vllm.v1.worker.gpu_worker' in sys.modules:
|
||||
worker_module = sys.modules['vllm.v1.worker.gpu_worker']
|
||||
if hasattr(worker_module, 'GPUWorker'):
|
||||
if "vllm.v1.worker.gpu_worker" in sys.modules:
|
||||
worker_module = sys.modules["vllm.v1.worker.gpu_worker"]
|
||||
if hasattr(worker_module, "GPUWorker"):
|
||||
# Update any class-level references
|
||||
worker_module.GPUModelRunner = PatchedRunner
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
_PATCHES_APPLIED = True
|
||||
print("[vLLM Patch] ✓ GPUModelRunner patched for shared memory updates")
|
||||
return True
|
||||
|
||||
|
||||
except ImportError as e:
|
||||
print(f"[vLLM Patch] Warning: Could not apply patches: {e}")
|
||||
print("[vLLM Patch] This may be due to vLLM version incompatibility")
|
||||
|
|
@ -112,6 +115,7 @@ def apply_patches() -> bool:
|
|||
except Exception as e:
|
||||
print(f"[vLLM Patch] Error applying patches: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
|
@ -119,74 +123,82 @@ def apply_patches() -> bool:
|
|||
def _create_patched_runner(BaseRunner: type) -> type:
|
||||
"""
|
||||
Create a patched GPUModelRunner class.
|
||||
|
||||
|
||||
Returns a new class that inherits from the original and adds
|
||||
CUDA IPC export functionality for single-copy training.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
class PatchedGPUModelRunner(BaseRunner):
|
||||
"""
|
||||
Patched GPUModelRunner that enables CUDA IPC for single-copy training.
|
||||
|
||||
|
||||
After loading the model, this:
|
||||
1. Calls share_memory_() on all parameters
|
||||
2. Exports CUDA IPC handles to vllm_bridge_config.json
|
||||
|
||||
|
||||
The trainer reads these IPC handles and attaches to the SAME
|
||||
GPU memory, so optimizer.step() updates weights that vLLM
|
||||
immediately sees for inference.
|
||||
"""
|
||||
|
||||
|
||||
_shared_memory_setup_done = False
|
||||
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
"""Load model and set up shared memory + update daemon."""
|
||||
print(f"[vLLM Patch] PatchedGPUModelRunner.load_model() called!")
|
||||
|
||||
|
||||
# Call original load_model
|
||||
super().load_model(*args, **kwargs)
|
||||
|
||||
|
||||
print(f"[vLLM Patch] Model loaded, checking shared weights setup...")
|
||||
|
||||
|
||||
# Check if shared memory updates are enabled
|
||||
enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1"
|
||||
num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1"))
|
||||
|
||||
print(f"[vLLM Patch] VLLM_ENABLE_SHARED_WEIGHTS={enable_shared}, NUM_INFERENCE_NODES={num_inference_nodes}")
|
||||
|
||||
|
||||
print(
|
||||
f"[vLLM Patch] VLLM_ENABLE_SHARED_WEIGHTS={enable_shared}, NUM_INFERENCE_NODES={num_inference_nodes}"
|
||||
)
|
||||
|
||||
if not enable_shared and num_inference_nodes < 0:
|
||||
print("[vLLM Patch] Shared weights disabled (set VLLM_ENABLE_SHARED_WEIGHTS=1 to enable)")
|
||||
print(
|
||||
"[vLLM Patch] Shared weights disabled (set VLLM_ENABLE_SHARED_WEIGHTS=1 to enable)"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
if self._shared_memory_setup_done:
|
||||
print("[vLLM Patch] Shared memory already set up, skipping")
|
||||
return
|
||||
|
||||
|
||||
print("[vLLM Patch] Setting up shared memory weight updates...", flush=True)
|
||||
|
||||
|
||||
try:
|
||||
self._setup_shared_memory()
|
||||
PatchedGPUModelRunner._shared_memory_setup_done = True
|
||||
print("[vLLM Patch] ✓ Shared memory setup complete!", flush=True)
|
||||
print("[vLLM Patch] ✓ IPC handles exported - trainer can now attach!", flush=True)
|
||||
print(
|
||||
"[vLLM Patch] ✓ IPC handles exported - trainer can now attach!",
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] ERROR in _setup_shared_memory: {e}", flush=True)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
def _setup_shared_memory(self) -> None:
|
||||
"""Move model tensors to shared memory and export param info."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
print("[vLLM Patch] _setup_shared_memory() starting...")
|
||||
|
||||
|
||||
# Get state dict
|
||||
state_dict = self.model.state_dict()
|
||||
print(f"[vLLM Patch] Model has {len(state_dict)} parameters")
|
||||
|
||||
|
||||
# Make entire model shareable via share_memory_() on each tensor
|
||||
shared_count = 0
|
||||
for key, val in state_dict.items():
|
||||
|
|
@ -196,25 +208,25 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
shared_count += 1
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Warning: Could not share {key}: {e}")
|
||||
|
||||
|
||||
print(f"[vLLM Patch] Called share_memory_() on {shared_count} CUDA tensors")
|
||||
|
||||
|
||||
# Also try calling share_memory() on the model itself
|
||||
try:
|
||||
self.model.share_memory()
|
||||
print("[vLLM Patch] Called model.share_memory()")
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Note: model.share_memory() not available: {e}")
|
||||
|
||||
|
||||
# Export parameter info to JSON for trainer
|
||||
log_dir = os.environ.get("LOGDIR", ".")
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
json_path = Path(log_dir) / "vllm_bridge_config.json"
|
||||
|
||||
|
||||
param_mappings = {}
|
||||
param_names = []
|
||||
ipc_handles = {}
|
||||
|
||||
|
||||
for name, tensor in state_dict.items():
|
||||
param_mappings[name] = {
|
||||
"vllm_name": name,
|
||||
|
|
@ -223,14 +235,15 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
"device": str(tensor.device),
|
||||
}
|
||||
param_names.append(name)
|
||||
|
||||
|
||||
# Export CUDA IPC handles for true single-copy mode
|
||||
if tensor.is_cuda:
|
||||
try:
|
||||
import base64
|
||||
|
||||
storage = tensor.untyped_storage()
|
||||
share_data = storage._share_cuda_()
|
||||
|
||||
|
||||
# share_data is a tuple of 8 items - we need ALL of them:
|
||||
# [0] = device index (int)
|
||||
# [1] = cudaIpcMemHandle_t (bytes)
|
||||
|
|
@ -240,15 +253,21 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
# [5] = ref counter offset (int)
|
||||
# [6] = event handle (bytes)
|
||||
# [7] = event sync required (bool)
|
||||
|
||||
|
||||
ipc_handles[name] = {
|
||||
"device_index": share_data[0],
|
||||
"ipc_handle_b64": base64.b64encode(share_data[1]).decode('ascii'),
|
||||
"ipc_handle_b64": base64.b64encode(share_data[1]).decode(
|
||||
"ascii"
|
||||
),
|
||||
"storage_size": share_data[2],
|
||||
"storage_offset_orig": share_data[3],
|
||||
"ref_counter_handle_b64": base64.b64encode(share_data[4]).decode('ascii'),
|
||||
"ref_counter_handle_b64": base64.b64encode(
|
||||
share_data[4]
|
||||
).decode("ascii"),
|
||||
"ref_counter_offset": share_data[5],
|
||||
"event_handle_b64": base64.b64encode(share_data[6]).decode('ascii'),
|
||||
"event_handle_b64": base64.b64encode(share_data[6]).decode(
|
||||
"ascii"
|
||||
),
|
||||
"event_sync_required": share_data[7],
|
||||
# Tensor metadata for reconstruction
|
||||
"tensor_storage_offset": tensor.storage_offset(),
|
||||
|
|
@ -257,12 +276,19 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
"dtype": str(tensor.dtype),
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] Could not get IPC handle for {name}: {e}", flush=True)
|
||||
print(
|
||||
f"[vLLM Patch] Could not get IPC handle for {name}: {e}",
|
||||
flush=True,
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print(f"[vLLM Patch] Exported {len(ipc_handles)} IPC handles for single-copy mode", flush=True)
|
||||
|
||||
|
||||
print(
|
||||
f"[vLLM Patch] Exported {len(ipc_handles)} IPC handles for single-copy mode",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Get model info
|
||||
model_name = "unknown"
|
||||
tp_degree = 1
|
||||
|
|
@ -273,21 +299,23 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
print(f"[vLLM Patch] Warning: Could not get model config: {e}")
|
||||
|
||||
import base64
|
||||
|
||||
|
||||
# Convert bytes to base64 for JSON serialization
|
||||
def serialize_ipc_handles(handles):
|
||||
result = {}
|
||||
for k, v in handles.items():
|
||||
if isinstance(v, bytes):
|
||||
result[k] = {"_bytes_b64_": base64.b64encode(v).decode('ascii')}
|
||||
result[k] = {"_bytes_b64_": base64.b64encode(v).decode("ascii")}
|
||||
elif isinstance(v, dict):
|
||||
result[k] = serialize_ipc_handles(v)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
serialized_ipc_handles = serialize_ipc_handles(ipc_handles) if ipc_handles else {}
|
||||
|
||||
|
||||
serialized_ipc_handles = (
|
||||
serialize_ipc_handles(ipc_handles) if ipc_handles else {}
|
||||
)
|
||||
|
||||
info = {
|
||||
"model": model_name,
|
||||
"tp_degree": tp_degree,
|
||||
|
|
@ -299,20 +327,23 @@ def _create_patched_runner(BaseRunner: type) -> type:
|
|||
"num_params": len(param_names),
|
||||
"single_copy_enabled": len(ipc_handles) > 0,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
print(f"[vLLM Patch] ✓ Exported {len(param_mappings)} params to {json_path}")
|
||||
print(
|
||||
f"[vLLM Patch] ✓ Exported {len(param_mappings)} params to {json_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[vLLM Patch] ERROR: Failed to export params: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# Set proper class name
|
||||
PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner"
|
||||
PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner"
|
||||
|
||||
|
||||
return PatchedGPUModelRunner
|
||||
|
||||
|
||||
|
|
@ -330,8 +361,9 @@ def is_patched() -> bool:
|
|||
class PatchedGPUModelRunner:
|
||||
"""
|
||||
Placeholder class for type checking.
|
||||
|
||||
|
||||
The actual patched class is created dynamically by _create_patched_runner()
|
||||
to properly inherit from vLLM's GPUModelRunner.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue