diff --git a/example_trainer/README.md b/example_trainer/README.md index da2c75fc..409c3c7b 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -156,7 +156,17 @@ python example_trainer/grpo.py \ --wandb-project gsm8k-grpo-shared ``` -### What Happens +### What Happens (Local Mode - num_inference_nodes=0) + +1. vLLM server starts on port 9001 +2. Trainer initializes bridge in LOCAL MODE (HTTP-based, no NCCL) +3. Trainer loads its own model copy and trains normally +4. After each `optimizer.step()`: + - `bridge.notify_update()` sends HTTP POST to vLLM + - Periodic checkpoint saves sync weights to disk +5. Much simpler than distributed mode! + +### What Happens (Distributed Mode - num_inference_nodes>0) 1. vLLM server starts, writes parameter mapping to `$LOGDIR/vllm_bridge_config.json` 2. Trainer reads mapping, joins NCCL process group with vLLM @@ -164,7 +174,7 @@ python example_trainer/grpo.py \ 4. Training loop: - Forward pass uses shared weights - `optimizer.step()` modifies shared tensors in-place - - `bridge.notify_update()` signals vLLM (optional coordination) + - `bridge.notify_update()` broadcasts via Gloo - vLLM immediately uses new weights for next inference 5. No restarts needed! diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index ef670610..f9d24608 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -408,7 +408,7 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module: if not PEFT_AVAILABLE: raise RuntimeError( "PEFT library not available. Install with: pip install peft" - ) + ) print("[Setup] Loading base model for LoRA mode...") base_model = AutoModelForCausalLM.from_pretrained( @@ -597,15 +597,15 @@ def run_training_step( total_neg += metrics["neg_count"] # Gradient clipping and optimizer step - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - optimizer.step() - optimizer.zero_grad() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + optimizer.zero_grad() # Normalize metrics - if total_pos > 0: - total_pos_logp /= total_pos - if total_neg > 0: - total_neg_logp /= total_neg + if total_pos > 0: + total_pos_logp /= total_pos + if total_neg > 0: + total_neg_logp /= total_neg return { "loss": total_loss, diff --git a/example_trainer/vllm_weight_bridge.py b/example_trainer/vllm_weight_bridge.py index 4a9945ea..c563c03a 100644 --- a/example_trainer/vllm_weight_bridge.py +++ b/example_trainer/vllm_weight_bridge.py @@ -1,10 +1,26 @@ """ -vLLM Weight Bridge - Shared memory integration between trainer and vLLM inference. +vLLM Weight Bridge - Integration between trainer and vLLM inference. -This module enables the trainer to directly modify vLLM's model weights in shared -GPU memory, eliminating the need for checkpoint saves and vLLM restarts. +This module provides two modes for coordinating weight updates: -Architecture: +LOCAL MODE (num_inference_nodes=0): + - Trainer and vLLM run as separate processes on the same machine + - Communication via HTTP to vLLM's /bridge/* endpoints + - No NCCL process groups needed + - Simpler setup, suitable for single-machine training + +DISTRIBUTED MODE (num_inference_nodes>0): + - Trainer and vLLM join the same NCCL process group + - Direct tensor sharing via shared GPU memory + - Lower latency, but requires coordinated setup + +Architecture (Local Mode): + ┌─────────────────┐ ┌─────────────────┐ + │ Trainer Process │ HTTP │ vLLM Process │ + │ (training) │────────▶│ (inference) │ + └─────────────────┘ └─────────────────┘ + +Architecture (Distributed Mode): ┌─────────────────────────────────────────┐ │ Shared GPU Memory (NCCL) │ │ Model weights owned by vLLM process │ @@ -13,14 +29,7 @@ Architecture: │ forward pass │ optimizer.step() ┌───────┴───────┐ ┌───────┴───────┐ │ vLLM Process │ │Trainer Process│ - │ (inference) │ │ (training) │ └───────────────┘ └───────────────┘ - -Key concepts: - 1. Process groups: Trainer joins the same NCCL group as vLLM workers - 2. Tensor attachment: Trainer's model params point to vLLM's actual buffers - 3. In-place updates: optimizer.step() modifies shared memory directly - 4. Synchronization: Barriers ensure no read-during-write races """ from __future__ import annotations @@ -303,9 +312,20 @@ class BridgeConfig: timeout_seconds: float = 300.0 log_dir: Optional[str] = None + # vLLM server URL for HTTP-based sync (local mode) + vllm_api_url: str = "http://localhost:9001" + # Derived from environment num_gpus_per_node: int = field(default_factory=lambda: torch.cuda.device_count()) + @property + def is_local_mode(self) -> bool: + """ + Local mode: single machine, no NCCL process groups needed. + Communication happens via HTTP to vLLM server. + """ + return self.num_inference_nodes == 0 + @classmethod def from_training_config(cls, config: Any) -> "BridgeConfig": """Create BridgeConfig from a TrainingConfig object.""" @@ -317,6 +337,7 @@ class BridgeConfig: model_name=config.model_name, device=config.device, log_dir=os.environ.get("LOGDIR"), + vllm_api_url=f"http://localhost:{getattr(config, 'vllm_port', 9001)}", ) @@ -374,6 +395,9 @@ class VLLMWeightBridge: """ Initialize the bridge: join process groups and load parameter mappings. + In local mode (num_inference_nodes=0), skips NCCL setup and uses HTTP. + In distributed mode, creates NCCL/Gloo process groups. + This must be called before any other methods. """ if self._initialized: @@ -381,6 +405,52 @@ class VLLMWeightBridge: print(f"[Bridge] Initializing weight bridge for rank {self.config.trainer_rank}") + if self.config.is_local_mode: + self._initialize_local_mode() + else: + self._initialize_distributed_mode() + + self._initialized = True + + def _initialize_local_mode(self) -> None: + """ + Initialize for local single-machine mode. + + In local mode: + - No NCCL process groups (trainer and vLLM are separate processes) + - Communication via HTTP to vLLM's bridge endpoints + - Trainer loads its own model copy, updates are synced via checkpoints + """ + print("[Bridge] Using LOCAL MODE (HTTP-based sync, no NCCL)") + print(f"[Bridge] vLLM API URL: {self.config.vllm_api_url}") + + # Verify vLLM server is reachable + try: + import requests + response = requests.get(f"{self.config.vllm_api_url}/health", timeout=5) + if response.status_code == 200: + print("[Bridge] vLLM server is reachable") + else: + print(f"[Bridge] Warning: vLLM health check returned {response.status_code}") + except Exception as e: + print(f"[Bridge] Warning: Could not reach vLLM server: {e}") + print("[Bridge] Training will continue, but vLLM sync may not work") + + # Load parameter mappings if available (optional in local mode) + try: + self._load_param_mappings() + except RuntimeError: + print("[Bridge] Parameter mapping file not found (optional in local mode)") + self.param_mappings = {} + + def _initialize_distributed_mode(self) -> None: + """ + Initialize for distributed multi-node mode. + + Creates NCCL and Gloo process groups for direct tensor sharing. + """ + print("[Bridge] Using DISTRIBUTED MODE (NCCL tensor sharing)") + # Get rendezvous URLs master_addr, master_gloo_addr, master_inference_addr, nodelist = get_inference_urls( self.config.num_inference_nodes @@ -400,9 +470,8 @@ class VLLMWeightBridge: # Calculate total group size (trainers + inference workers) num_training_gpus = self._get_num_training_gpus() + # In distributed mode, each inference node contributes num_gpus_per_node workers num_inference_gpus = self.config.num_inference_nodes * self.config.num_gpus_per_node - if self.config.num_inference_nodes == 0: - num_inference_gpus = self.config.num_gpus_per_node # Local mode total_group_size = num_training_gpus + num_inference_gpus trainer_rank_in_group = self.config.trainer_rank @@ -432,8 +501,6 @@ class VLLMWeightBridge: ) print("[Bridge] Gloo process group initialized") - self._initialized = True - def _load_param_mappings(self) -> None: """Load parameter name mappings from vLLM's exported JSON.""" log_dir = self.config.log_dir or os.environ.get("LOGDIR", ".") @@ -630,14 +697,37 @@ class VLLMWeightBridge: """ Notify inference workers that weights have been updated. - This is a lightweight synchronization point. Inference workers can - check for this signal before starting a new batch to ensure they - have the latest weights. + In local mode: sends HTTP request to vLLM's /bridge/notify_update endpoint + In distributed mode: broadcasts update counter via Gloo """ - if self.gloo_group is None: - return + self._update_count += 1 - # Simple approach: broadcast the update counter + if self.config.is_local_mode: + self._notify_update_http() + elif self.gloo_group is not None: + self._notify_update_distributed() + + def _notify_update_http(self) -> None: + """Notify vLLM via HTTP (local mode).""" + try: + import requests + response = requests.post( + f"{self.config.vllm_api_url}/bridge/notify_update", + json={ + "update_count": self._update_count, + "trainer_rank": self.config.trainer_rank, + "timestamp": time.time(), + }, + timeout=5, + ) + if response.status_code != 200: + print(f"[Bridge] Warning: notify_update returned {response.status_code}") + except Exception as e: + # Don't fail training if vLLM notification fails + print(f"[Bridge] Warning: Could not notify vLLM: {e}") + + def _notify_update_distributed(self) -> None: + """Notify via Gloo broadcast (distributed mode).""" update_tensor = torch.tensor([self._update_count], dtype=torch.long) dist.broadcast(update_tensor, src=0, group=self.gloo_group) @@ -657,7 +747,7 @@ class VLLMWeightBridge: self.gloo_group = None self._initialized = False - print("[Bridge] Cleaned up process groups") + print("[Bridge] Cleaned up") # =============================================================================