diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index f00aa5a6..87823526 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -11,8 +11,8 @@ from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, - ServerBaseline, ScoredDataGroup, + ServerBaseline, ) from atroposlib.type_definitions import Item diff --git a/example_trainer/README.md b/example_trainer/README.md index 7286e96f..29820d66 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -200,7 +200,7 @@ python -m example_trainer.grpo --weight-bridge-mode lora_only ... --- -## Shared vLLM Mode +## Shared vLLM Mode Single-copy mode shares GPU memory between vLLM and the trainer - zero model duplication! diff --git a/example_trainer/config.py b/example_trainer/config.py index 3ce86ff0..4ddeddb5 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -114,15 +114,17 @@ class TrainingConfig(BaseModel): wandb_group: Optional[str] = Field(None, description="Wandb group name") # === Training Mode Configuration === - weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field( - "none", - description=( - "How to synchronize weights with inference server. " - "'shared_vllm': attach to vLLM's shared memory tensors and update in-place. " - "'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). " - "'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). " - "'none': legacy mode, restart vLLM with new checkpoint files." - ), + weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = ( + Field( + "none", + description=( + "How to synchronize weights with inference server. " + "'shared_vllm': attach to vLLM's shared memory tensors and update in-place. " + "'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). " + "'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). " + "'none': legacy mode, restart vLLM with new checkpoint files." + ), + ) ) train_layer_indices: Optional[List[int]] = Field( None, diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index e099ea88..4c9e2893 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -8,11 +8,11 @@ Contains the four main training modes: - train_lora_restart: LoRA training with vLLM restarts (FAST mode) """ +import logging import os import subprocess import sys import time -import logging from typing import Iterable, Optional import requests @@ -21,7 +21,6 @@ from torch.optim import AdamW from .api import check_atropos_api, register_trainer - logger = logging.getLogger(__name__) @@ -56,9 +55,7 @@ def create_optimizer_for_params( logger.info("[Setup] Using 8-bit AdamW optimizer") return optimizer except ImportError: - logger.warning( - "[Setup] bitsandbytes not installed, falling back to AdamW" - ) + logger.warning("[Setup] bitsandbytes not installed, falling back to AdamW") logger.info("[Setup] Install with: pip install bitsandbytes") if config.optimizer == "adafactor": @@ -740,19 +737,23 @@ def train_lora_restart(config: TrainingConfig): # Periodic adapter save + vLLM restart sync_time = 0 should_sync = (step + 1) % config.vllm_restart_interval == 0 - if should_sync and (step + 1) < config.training_steps: # Don't restart on last step + if ( + should_sync and (step + 1) < config.training_steps + ): # Don't restart on last step sync_start = time.time() - + # Save new adapter - current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1) - + current_adapter_path = save_lora_checkpoint( + model, config.save_path, step + 1 + ) + # Restart vLLM with new adapter print(" [RESTART] Restarting vLLM with new adapter...") _terminate_vllm(vllm_proc, config.vllm_port) vllm_proc = _launch_vllm_with_lora(config, current_adapter_path) if vllm_proc is None: raise RuntimeError("Failed to restart vLLM") - + sync_time = time.time() - sync_start benchmark_stats["sync_times"].append(sync_time) benchmark_stats["restart_times"].append(sync_time) @@ -803,45 +804,53 @@ def train_lora_restart(config: TrainingConfig): _vllm_restart_counter = 0 -def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]: +def _launch_vllm_with_lora( + config: TrainingConfig, adapter_path: str +) -> Optional[subprocess.Popen]: """ Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference). - + Unlike lora_only mode, this does NOT use --enforce-eager, so we get ~108 TPS instead of ~13 TPS (8x faster). """ global _vllm_restart_counter from .vllm_manager import kill_process_on_port, wait_for_vllm_ready - + # Kill any existing process on the port print(f" Cleaning up port {config.vllm_port}...") kill_process_on_port(config.vllm_port) - + # Clear CUDA cache before starting new vLLM if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() - + # Wait for port and GPU memory to be fully released time.sleep(5) - + # Find the vllm_api_server.py script script_dir = os.path.dirname(os.path.abspath(__file__)) server_script = os.path.join(script_dir, "vllm_api_server.py") - + # Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS) cmd = [ - sys.executable, server_script, - "--model", config.model_name, - "--port", str(config.vllm_port), - "--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization), - "--max-model-len", str(config.max_model_len), + sys.executable, + server_script, + "--model", + config.model_name, + "--port", + str(config.vllm_port), + "--gpu-memory-utilization", + str(config.vllm_gpu_memory_utilization), + "--max-model-len", + str(config.max_model_len), "--enable-lora", - "--max-lora-rank", str(max(config.lora_r * 2, 32)), + "--max-lora-rank", + str(max(config.lora_r * 2, 32)), # Note: NOT adding --enforce-eager - this gives us ~8x faster inference! # Without --enforce-eager, vLLM can use more optimizations. ] - + # Set environment for GPU selection env = os.environ.copy() if config.vllm_gpu is not None: @@ -849,32 +858,39 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona print(f" GPU: {config.vllm_gpu} (via CUDA_VISIBLE_DEVICES)") else: print(" GPU: Same as trainer (inherited CUDA_VISIBLE_DEVICES)") - + print(f" Launching: {' '.join(cmd)}") print(f" Adapter: {adapter_path}") - + # Log vLLM output to file for debugging (unique file per restart) - vllm_log_path = os.path.join(config.save_path, f"vllm_restart_{_vllm_restart_counter}.log") + vllm_log_path = os.path.join( + config.save_path, f"vllm_restart_{_vllm_restart_counter}.log" + ) _vllm_restart_counter += 1 print(f" vLLM log: {vllm_log_path}") - + try: vllm_log_file = open(vllm_log_path, "w") # Start in new session so we can kill entire process group later proc = subprocess.Popen( - cmd, env=env, stdout=vllm_log_file, stderr=subprocess.STDOUT, - start_new_session=True # Creates new process group for easy cleanup + cmd, + env=env, + stdout=vllm_log_file, + stderr=subprocess.STDOUT, + start_new_session=True, # Creates new process group for easy cleanup ) print(f" vLLM PID: {proc.pid} (process group: {os.getpgid(proc.pid)})") - print(" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)...") - + print( + " NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)..." + ) + # Wait for server to be ready (longer timeout for CUDA graph compilation) if not wait_for_vllm_ready(config.vllm_port, timeout=300): print(" ERROR: vLLM failed to start after 300s") print(f" Check log: {vllm_log_path}") # Print last 30 lines of the log try: - with open(vllm_log_path, 'r') as f: + with open(vllm_log_path, "r") as f: lines = f.readlines() print(" Last 30 lines of vLLM log:") for line in lines[-30:]: @@ -883,7 +899,7 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona print(f" Could not read log: {e}") proc.terminate() return None - + # Load the LoRA adapter print(" Loading LoRA adapter...") try: @@ -895,13 +911,15 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona if resp.status_code == 200: print(" ✓ Adapter loaded successfully") else: - print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}") + print( + f" WARNING: Adapter load returned {resp.status_code}: {resp.text}" + ) except Exception as e: print(f" WARNING: Could not load adapter: {e}") # Continue anyway - base model inference still works - + return proc - + except Exception as e: print(f" ERROR: {e}") return None @@ -911,12 +929,12 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: """Terminate a vLLM process and release GPU resources.""" import signal import subprocess as sp - + print(f" Terminating vLLM on port {port}...") - + # Get current GPU device gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0] - + # Phase 1: Kill the process group if we have a handle (kills all children too) main_pid = None if proc is not None: @@ -932,12 +950,13 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: proc.wait(timeout=5) except Exception as e: print(f" Warning: {e}") - + # Phase 2: Kill by port (catches anything still running) from .vllm_manager import kill_process_on_port + kill_process_on_port(port) time.sleep(2) - + # Phase 3: Aggressively kill ALL vLLM-related processes print(" Killing all vLLM-related processes...") kill_commands = [ @@ -953,19 +972,22 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: sp.run(cmd, shell=True, capture_output=True, timeout=5) except Exception: pass - + # Phase 4: Use nvidia-smi to find and kill GPU processes (nuclear option) print(f" Checking for zombie GPU processes on GPU {gpu_id}...") try: result = sp.run( f"nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader,nounits -i {gpu_id}", - shell=True, capture_output=True, text=True, timeout=10 + shell=True, + capture_output=True, + text=True, + timeout=10, ) if result.stdout.strip(): print(f" Found GPU processes:\n{result.stdout}") - for line in result.stdout.strip().split('\n'): + for line in result.stdout.strip().split("\n"): if line.strip(): - parts = line.split(',') + parts = line.split(",") if len(parts) >= 1: pid = parts[0].strip() # Don't kill the current Python process (trainer) @@ -977,7 +999,7 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: pass except Exception as e: print(f" Warning: nvidia-smi check failed: {e}") - + # Phase 5: Wait for GPU memory release - CRITICAL # The CUDA driver needs time to actually free memory after process death print(" Waiting for GPU memory release...") @@ -987,24 +1009,26 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: torch.cuda.empty_cache() free_mem = torch.cuda.mem_get_info()[0] / 1e9 total_mem = torch.cuda.mem_get_info()[1] / 1e9 - print(f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)") + print( + f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)" + ) # If we have enough memory (>50% free), break early if free_mem > total_mem * 0.5: print(f" ✓ Sufficient memory available ({free_mem:.1f} GB)") break - + # Final cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() free_mem = torch.cuda.mem_get_info()[0] / 1e9 total_mem = torch.cuda.mem_get_info()[1] / 1e9 - print(f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)") - + print( + f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)" + ) + if free_mem < total_mem * 0.3: print(" WARNING: Low GPU memory! May fail to restart vLLM.") print(" Consider reducing --vllm-gpu-memory-utilization") - + print(" vLLM terminated") - -