diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 3c7942bb..39836de7 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -1102,6 +1102,9 @@ def train_shared_vllm(config: TrainingConfig): # Load model with bridge attachment print("[2/3] Loading model with shared weights...") model, tokenizer = load_model_and_tokenizer(config, bridge=bridge) + + # maybe we can actually pick optimizer + optimizer = AdamW(model.parameters(), lr=config.lr) # For NCCL mode, build mapping between trainer's and vLLM's param names diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 313d81b0..740db922 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 """ Custom vLLM API server with shared memory weight updates. @@ -30,12 +31,18 @@ ARCHITECTURE: │ │ Trainer Process │ │ │ └─────────────────────┘ │ └─────────────────────────────────────────────────────────────────────────┘ + +CRITICAL: Patches must be applied BEFORE importing vLLM! """ +# ============================================================================= +# STEP 0: Standard library imports ONLY (no vLLM yet!) +# ============================================================================= import asyncio import json import os import ssl +import sys import threading import time from argparse import Namespace @@ -44,78 +51,88 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional -import torch -import vllm.envs as envs -from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import JSONResponse, Response, StreamingResponse -from pydantic import BaseModel -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.launcher import serve_http -from vllm.entrypoints.utils import with_cancellation -from vllm.logger import init_logger -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.usage.usage_lib import UsageContext -from vllm.utils import random_uuid -from vllm.v1.engine.async_llm import AsyncLLM - -try: - from vllm.utils.argparse_utils import FlexibleArgumentParser - from vllm.utils.system_utils import set_ulimit -except ImportError: - from vllm.utils import FlexibleArgumentParser, set_ulimit - -from vllm.outputs import RequestOutput # noqa: F401 -from vllm.version import __version__ as VLLM_VERSION - -logger = init_logger("vllm.entrypoints.api_server") - # ============================================================================= -# Apply vLLM Patches for Shared Memory +# STEP 1: Apply patches BEFORE any vLLM imports! # ============================================================================= -def _maybe_apply_patches() -> bool: +def _apply_patches_early() -> bool: """ Apply vLLM patches if shared weights are enabled. + This MUST be called before any vLLM imports! Returns True if patches were applied. """ enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1" - num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", -1)) + num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1")) if not enable_shared and num_inference_nodes < 0: + print("[vLLM Server] Shared weights not enabled, skipping patches") return False + print("[vLLM Server] VLLM_ENABLE_SHARED_WEIGHTS=1, applying patches...") + try: # Try relative import first (when run as module) try: from .vllm_patching import apply_patches except ImportError: # Fall back to absolute import (when run as script) - import sys - from pathlib import Path - # Add the example_trainer directory to path script_dir = Path(__file__).parent if str(script_dir) not in sys.path: sys.path.insert(0, str(script_dir)) from vllm_patching import apply_patches - apply_patches() - logger.info("✓ vLLM patches applied for shared memory weights") - return True + success = apply_patches() + if success: + print("[vLLM Server] ✓ vLLM patches applied successfully!") + else: + print("[vLLM Server] ✗ Failed to apply patches") + return success + except ImportError as e: - logger.warning(f"Could not import vllm_patching: {e}") - logger.warning("Shared memory weight updates will not be available") + print(f"[vLLM Server] Could not import vllm_patching: {e}") + print("[vLLM Server] Shared memory weight updates will not be available") return False except Exception as e: - logger.warning(f"Failed to apply patches: {e}") + print(f"[vLLM Server] Error applying patches: {e}") import traceback traceback.print_exc() return False -# Apply patches before any other vLLM imports -PATCHES_APPLIED = _maybe_apply_patches() +# Apply patches NOW, before any vLLM imports below! +PATCHES_APPLIED = _apply_patches_early() + + +# ============================================================================= +# STEP 2: Now safe to import vLLM (patches are already in place) +# ============================================================================= + +import torch # noqa: E402 +import vllm.envs as envs # noqa: E402 +from fastapi import FastAPI, HTTPException, Request # noqa: E402 +from fastapi.responses import JSONResponse, Response, StreamingResponse # noqa: E402 +from pydantic import BaseModel # noqa: E402 +from vllm.engine.arg_utils import AsyncEngineArgs # noqa: E402 +from vllm.entrypoints.launcher import serve_http # noqa: E402 +from vllm.entrypoints.utils import with_cancellation # noqa: E402 +from vllm.logger import init_logger # noqa: E402 +from vllm.sampling_params import RequestOutputKind, SamplingParams # noqa: E402 +from vllm.usage.usage_lib import UsageContext # noqa: E402 +from vllm.utils import random_uuid # noqa: E402 +from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402 + +try: + from vllm.utils.argparse_utils import FlexibleArgumentParser # noqa: E402 + from vllm.utils.system_utils import set_ulimit # noqa: E402 +except ImportError: + from vllm.utils import FlexibleArgumentParser, set_ulimit # noqa: E402 + +from vllm.outputs import RequestOutput # noqa: F401, E402 +from vllm.version import __version__ as VLLM_VERSION # noqa: E402 + +logger = init_logger("vllm.entrypoints.api_server") # ============================================================================= @@ -493,6 +510,7 @@ async def bridge_debug() -> JSONResponse: "patches_applied": PATCHES_APPLIED, "shared_weights_env": os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0"), "num_inference_nodes": os.environ.get("NUM_INFERENCE_NODES", "unset"), + "logdir": os.environ.get("LOGDIR", "unset"), } if engine is not None: @@ -600,18 +618,28 @@ async def init_app(args: Namespace, llm_engine: AsyncLLM | None = None) -> FastA ) app.state.engine_client = engine - # Export state dict info for trainers + # Export basic state dict info for trainers (the patched runner exports detailed info) _export_state_dict_info(args) return app def _export_state_dict_info(args: Namespace) -> None: - """Export model parameter mapping to JSON for trainer.""" - log_dir = os.environ.get("LOGDIR", ".") + """Export basic model info to JSON for trainer (backup if patches don't run).""" + log_dir = os.environ.get("LOGDIR", "/tmp/atropos_bridge") + Path(log_dir).mkdir(parents=True, exist_ok=True) json_path = Path(log_dir) / "vllm_bridge_config.json" + # Only write basic info if the file doesn't exist or is empty + # The patched runner will write complete info with param_mappings try: + if json_path.exists(): + with open(json_path, 'r') as f: + existing = json.load(f) + if existing.get("param_mappings") and len(existing["param_mappings"]) > 0: + logger.info(f"Config already has param_mappings, not overwriting") + return + info = { "model": getattr(args, "model", "unknown"), "dtype": getattr(args, "dtype", "auto"), @@ -624,7 +652,7 @@ def _export_state_dict_info(args: Namespace) -> None: with open(json_path, "w") as f: json.dump(info, f, indent=2) - logger.info(f"Exported state dict info to {json_path}") + logger.info(f"Exported basic state dict info to {json_path}") except Exception as e: logger.warning(f"Failed to export state dict info: {e}") diff --git a/example_trainer/vllm_patching/__init__.py b/example_trainer/vllm_patching/__init__.py index 83545151..7a4ab38b 100644 --- a/example_trainer/vllm_patching/__init__.py +++ b/example_trainer/vllm_patching/__init__.py @@ -15,7 +15,12 @@ Usage: from vllm import AsyncLLM """ -from .patched_gpu_runner import PatchedGPUModelRunner, apply_patches +from .patched_gpu_runner import ( + PatchedGPUModelRunner, + apply_patches, + get_patched_runner, + is_patched, +) from .weight_updater import weight_updater_process from .distributed_utils import ( init_process_group, @@ -27,6 +32,8 @@ from .distributed_utils import ( __all__ = [ "PatchedGPUModelRunner", "apply_patches", + "get_patched_runner", + "is_patched", "weight_updater_process", "init_process_group", "broadcast_object_list", diff --git a/example_trainer/vllm_patching/patched_gpu_runner.py b/example_trainer/vllm_patching/patched_gpu_runner.py index 76fd5dca..e45adef4 100644 --- a/example_trainer/vllm_patching/patched_gpu_runner.py +++ b/example_trainer/vllm_patching/patched_gpu_runner.py @@ -8,16 +8,17 @@ This patches vLLM's GPUModelRunner to: The key insight is that share_memory_() makes tensors accessible from multiple processes. The daemon receives updates via NCCL and copies them directly into the shared tensors, which vLLM reads for inference. + +CRITICAL: This module must be imported and apply_patches() called BEFORE +any vLLM imports. The patches MUST happen before vLLM caches module references. """ from __future__ import annotations import os +import sys from typing import TYPE_CHECKING -import torch -import torch.multiprocessing as mp - # Lazy imports to avoid circular dependencies if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -25,42 +26,92 @@ if TYPE_CHECKING: # Flag to track if patches have been applied _PATCHES_APPLIED = False +_PATCHED_RUNNER_CLASS = None -def apply_patches() -> None: +def apply_patches() -> bool: """ - Apply patches to vLLM's GPUModelRunner. + Apply patches to vLLM's GPUModelRunner in ALL locations. This must be called BEFORE importing vLLM's engine classes. Safe to call multiple times (idempotent). + Returns True if patches were applied successfully. + Usage: + # CRITICAL: Import and call BEFORE any vLLM imports! + import os + os.environ["VLLM_ENABLE_SHARED_WEIGHTS"] = "1" + from example_trainer.vllm_patching import apply_patches apply_patches() - from vllm import AsyncLLM # Now uses patched runner + # Now import vLLM + from vllm import AsyncLLM # Uses patched runner """ - global _PATCHES_APPLIED + global _PATCHES_APPLIED, _PATCHED_RUNNER_CLASS if _PATCHES_APPLIED: - return + return True try: - import vllm.v1.worker.gpu_worker + # Import the source module and get original class + import vllm.v1.worker.gpu_model_runner as gpu_model_runner_module from vllm.v1.worker.gpu_model_runner import GPUModelRunner as OriginalRunner - # Create patched class + # Create the patched class PatchedRunner = _create_patched_runner(OriginalRunner) + _PATCHED_RUNNER_CLASS = PatchedRunner - # Replace in vllm module - vllm.v1.worker.gpu_worker.GPUModelRunner = PatchedRunner + # ================================================================= + # PATCH 1: Replace in source module + # ================================================================= + gpu_model_runner_module.GPUModelRunner = PatchedRunner + print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_model_runner.GPUModelRunner") + + # ================================================================= + # PATCH 2: Replace in gpu_worker module (main usage location) + # ================================================================= + try: + import vllm.v1.worker.gpu_worker as gpu_worker_module + gpu_worker_module.GPUModelRunner = PatchedRunner + print("[vLLM Patch] ✓ Patched vllm.v1.worker.gpu_worker.GPUModelRunner") + except ImportError: + pass + + # ================================================================= + # PATCH 3: Update sys.modules entry for source module + # ================================================================= + # This ensures new imports get the patched version + if 'vllm.v1.worker.gpu_model_runner' in sys.modules: + sys.modules['vllm.v1.worker.gpu_model_runner'].GPUModelRunner = PatchedRunner + + # ================================================================= + # PATCH 4: Patch GPUWorker if already imported + # ================================================================= + try: + if 'vllm.v1.worker.gpu_worker' in sys.modules: + worker_module = sys.modules['vllm.v1.worker.gpu_worker'] + if hasattr(worker_module, 'GPUWorker'): + # Update any class-level references + worker_module.GPUModelRunner = PatchedRunner + except Exception: + pass _PATCHES_APPLIED = True print("[vLLM Patch] ✓ GPUModelRunner patched for shared memory updates") + return True except ImportError as e: print(f"[vLLM Patch] Warning: Could not apply patches: {e}") + print("[vLLM Patch] This may be due to vLLM version incompatibility") print("[vLLM Patch] Shared memory updates will not be available") + return False + except Exception as e: + print(f"[vLLM Patch] Error applying patches: {e}") + import traceback + traceback.print_exc() + return False def _create_patched_runner(BaseRunner: type) -> type: @@ -70,6 +121,8 @@ def _create_patched_runner(BaseRunner: type) -> type: Returns a new class that inherits from the original and adds shared memory + daemon functionality. """ + import torch + import torch.multiprocessing as mp from .weight_updater import weight_updater_process class PatchedGPUModelRunner(BaseRunner): @@ -86,25 +139,39 @@ def _create_patched_runner(BaseRunner: type) -> type: vLLM immediately sees the new weights for inference. """ + _shared_memory_setup_done = False + weight_updater_process = None + def load_model(self, *args, **kwargs) -> None: """Load model and set up shared memory + update daemon.""" + print(f"[vLLM Patch] PatchedGPUModelRunner.load_model() called!") + # Call original load_model super().load_model(*args, **kwargs) + print(f"[vLLM Patch] Model loaded, checking shared weights setup...") + # Check if shared memory updates are enabled enable_shared = os.environ.get("VLLM_ENABLE_SHARED_WEIGHTS", "0") == "1" - num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", -1)) + num_inference_nodes = int(os.environ.get("NUM_INFERENCE_NODES", "-1")) + + print(f"[vLLM Patch] VLLM_ENABLE_SHARED_WEIGHTS={enable_shared}, NUM_INFERENCE_NODES={num_inference_nodes}") if not enable_shared and num_inference_nodes < 0: print("[vLLM Patch] Shared weights disabled (set VLLM_ENABLE_SHARED_WEIGHTS=1 to enable)") return + if self._shared_memory_setup_done: + print("[vLLM Patch] Shared memory already set up, skipping") + return + print("[vLLM Patch] Setting up shared memory weight updates...") try: self._setup_shared_memory() self._spawn_weight_updater() - print("[vLLM Patch] ✓ Shared memory updates enabled") + PatchedGPUModelRunner._shared_memory_setup_done = True + print("[vLLM Patch] ✓ Shared memory updates enabled successfully!") except Exception as e: print(f"[vLLM Patch] Warning: Failed to set up shared memory: {e}") import traceback @@ -115,43 +182,73 @@ def _create_patched_runner(BaseRunner: type) -> type: import json from pathlib import Path - # Make entire model shareable - self.model.share_memory() + print("[vLLM Patch] _setup_shared_memory() starting...") - # Also share_memory_() on each parameter individually + # Get state dict state_dict = self.model.state_dict() - for key, val in state_dict.items(): - if val.is_cuda or val.device.type == 'cuda': - val.share_memory_() + print(f"[vLLM Patch] Model has {len(state_dict)} parameters") - print(f"[vLLM Patch] Shared {len(state_dict)} tensors in model") + # Make entire model shareable via share_memory_() on each tensor + shared_count = 0 + for key, val in state_dict.items(): + try: + if val.is_cuda: + val.share_memory_() + shared_count += 1 + except Exception as e: + print(f"[vLLM Patch] Warning: Could not share {key}: {e}") + + print(f"[vLLM Patch] Called share_memory_() on {shared_count} CUDA tensors") + + # Also try calling share_memory() on the model itself + try: + self.model.share_memory() + print("[vLLM Patch] Called model.share_memory()") + except Exception as e: + print(f"[vLLM Patch] Note: model.share_memory() not available: {e}") # Export parameter info to JSON for trainer - log_dir = os.environ.get("LOGDIR", ".") + log_dir = os.environ.get("LOGDIR", "/tmp/atropos_bridge") + Path(log_dir).mkdir(parents=True, exist_ok=True) json_path = Path(log_dir) / "vllm_bridge_config.json" param_mappings = {} + param_names = [] for name, tensor in state_dict.items(): param_mappings[name] = { "vllm_name": name, "shape": list(tensor.shape), "dtype": str(tensor.dtype), } + param_names.append(name) + + # Get model info + model_name = "unknown" + tp_degree = 1 + try: + model_name = str(self.model_config.model) + tp_degree = self.parallel_config.tensor_parallel_size + except Exception as e: + print(f"[vLLM Patch] Warning: Could not get model config: {e}") info = { - "model": str(self.model_config.model), - "tp_degree": self.parallel_config.tensor_parallel_size, + "model": model_name, + "tp_degree": tp_degree, "dp_shard_degree": 1, "param_mappings": param_mappings, - "param_names": sorted(state_dict.keys()), + "param_names": sorted(param_names), + "shared_weights_enabled": True, + "num_params": len(param_names), } try: with open(json_path, "w") as f: json.dump(info, f, indent=2) - print(f"[vLLM Patch] Exported {len(param_mappings)} params to {json_path}") + print(f"[vLLM Patch] ✓ Exported {len(param_mappings)} params to {json_path}") except Exception as e: - print(f"[vLLM Patch] Warning: Failed to export params: {e}") + print(f"[vLLM Patch] ERROR: Failed to export params: {e}") + import traceback + traceback.print_exc() def _spawn_weight_updater(self) -> None: """Spawn the daemon process for receiving weight updates.""" @@ -171,9 +268,19 @@ def _create_patched_runner(BaseRunner: type) -> type: # Get parallel configuration tp_rank = get_tensor_model_parallel_rank() - gpu_id = torch.cuda.device(self.device).idx if hasattr(self.device, 'idx') else 0 - print(f"[vLLM Patch] Spawning updater: tp_rank={tp_rank}, gpu={gpu_id}") + # Get GPU ID + gpu_id = 0 + try: + if hasattr(self, 'device'): + if hasattr(self.device, 'index'): + gpu_id = self.device.index or 0 + elif isinstance(self.device, int): + gpu_id = self.device + except Exception: + gpu_id = tp_rank + + print(f"[vLLM Patch] Spawning weight updater: tp_rank={tp_rank}, gpu={gpu_id}") # Spawn daemon process ctx = mp.get_context("spawn") @@ -191,11 +298,26 @@ def _create_patched_runner(BaseRunner: type) -> type: ) self.weight_updater_process.start() - print(f"[vLLM Patch] Weight updater daemon started (PID: {self.weight_updater_process.pid})") + print(f"[vLLM Patch] ✓ Weight updater daemon started (PID: {self.weight_updater_process.pid})") + + # Set proper class name + PatchedGPUModelRunner.__name__ = "PatchedGPUModelRunner" + PatchedGPUModelRunner.__qualname__ = "PatchedGPUModelRunner" return PatchedGPUModelRunner +def get_patched_runner() -> type | None: + """Get the patched runner class if patches have been applied.""" + return _PATCHED_RUNNER_CLASS + + +def is_patched() -> bool: + """Check if patches have been applied.""" + return _PATCHES_APPLIED + + +# Placeholder class for type checking class PatchedGPUModelRunner: """ Placeholder class for type checking. @@ -204,5 +326,3 @@ class PatchedGPUModelRunner: to properly inherit from vLLM's GPUModelRunner. """ pass - -