diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 5d776233..73201600 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -8,7 +8,7 @@ import shutil import string import subprocess import time -from typing import List, Literal, Optional, Tuple +from typing import Dict, List, Literal, Optional, Tuple import numpy as np import requests @@ -175,13 +175,21 @@ class TrainingConfig(BaseModel): "data fetch time, and GPU memory usage per step." ), ) + atropos_url: str = Field( + "http://localhost:8000", + description=( + "URL of the Atropos API server (environment server). " + "Default is http://localhost:8000. Change for concurrent tests." + ), + ) -def check_atropos_api(timeout: float = 30.0) -> bool: +def check_atropos_api(url: str = "http://localhost:8000", timeout: float = 30.0) -> bool: """ Check if the Atropos API server is reachable. Args: + url: Base URL of the Atropos API server timeout: Maximum time to wait for the server Returns: @@ -192,17 +200,17 @@ def check_atropos_api(timeout: float = 30.0) -> bool: start = _time.time() while _time.time() - start < timeout: try: - response = requests.get("http://localhost:8000/info", timeout=2) + response = requests.get(f"{url}/info", timeout=2) if response.status_code == 200: - print("[Trainer] ✓ Atropos API server is reachable") + print(f"[Trainer] ✓ Atropos API server is reachable at {url}") return True except requests.exceptions.ConnectionError: pass except Exception as e: - print(f"[Trainer] Waiting for Atropos API... ({e})") + print(f"[Trainer] Waiting for Atropos API at {url}... ({e})") _time.sleep(1) - print("[Trainer] ⚠ Warning: Atropos API server not reachable") + print(f"[Trainer] ⚠ Warning: Atropos API server not reachable at {url}") return False @@ -213,8 +221,9 @@ def register_trainer(config: TrainingConfig): Verifies registration succeeded before returning. """ + url = config.atropos_url response = requests.post( - "http://localhost:8000/register", + f"{url}/register", json={ # wandb fields are required strings - use empty string if None "wandb_group": config.wandb_group or "", @@ -237,12 +246,12 @@ def register_trainer(config: TrainingConfig): if "uuid" not in data: raise RuntimeError(f"Registration failed: {data}") - print(f"[Trainer] ✓ Registered with Atropos API (uuid: {data['uuid']})") + print(f"[Trainer] ✓ Registered with Atropos API at {url} (uuid: {data['uuid']})") @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30)) -def get_batch(): - data = requests.get("http://localhost:8000/batch", timeout=10).json() +def get_batch(url: str = "http://localhost:8000"): + data = requests.get(f"{url}/batch", timeout=10).json() # Check if there was an error (trainer not registered) if data.get("status") == "error": @@ -366,7 +375,7 @@ def pad_data_to_good_offset(data, batch_size: int): def get_data( - batch_size: int, seq_len: int + batch_size: int, seq_len: int, atropos_url: str = "http://localhost:8000" ) -> List[ Tuple[ List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor] @@ -377,7 +386,7 @@ def get_data( """ batches = [] while True: - data = get_batch() + data = get_batch(url=atropos_url) if data["batch"] is not None: # Save the batch with open("temp.json", "w", encoding="utf-8") as f: @@ -529,35 +538,27 @@ def _attach_to_vllm_shared_tensors( # Map vLLM tensor names to HuggingFace model parameter names hf_state_dict = {} - vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles) + vllm_to_hf_mapping = _create_vllm_to_hf_mapping( + model, ipc_handles, debug=config.debug_loading + ) + + # Cache for reconstructed vLLM tensors (to avoid reconstructing fused tensors multiple times) + vllm_tensor_cache: Dict[str, torch.Tensor] = {} + + def reconstruct_vllm_tensor(vllm_name: str) -> Optional[torch.Tensor]: + """Reconstruct a vLLM tensor from IPC handle, with caching.""" + if vllm_name in vllm_tensor_cache: + return vllm_tensor_cache[vllm_name] - attached_count = 0 - for hf_name, vllm_name in vllm_to_hf_mapping.items(): if vllm_name not in ipc_handles: - continue + return None ipc_info = ipc_handles[vllm_name] + if "ipc_handle_b64" not in ipc_info: + return None + try: - # Reconstruct tensor from IPC handle - # We need all 8 items from the original _share_cuda_() call - if "ipc_handle_b64" not in ipc_info: - print(f"[Setup] Missing ipc_handle_b64 for {hf_name}") - continue - - # DEBUG: Only try first tensor to see if IPC works at all - if attached_count == 0 and config.debug_loading: - print(f"[Setup DEBUG] Attempting first tensor: {hf_name}", flush=True) - print( - f"[Setup DEBUG] device_index: {ipc_info['device_index']}", - flush=True, - ) - print( - f"[Setup DEBUG] storage_size: {ipc_info['storage_size']}", - flush=True, - ) - print(f"[Setup DEBUG] shape: {ipc_info['shape']}", flush=True) - # Decode all the bytes fields from base64 device_index = ipc_info["device_index"] ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"]) @@ -568,13 +569,6 @@ def _attach_to_vllm_shared_tensors( event_handle = base64.b64decode(ipc_info["event_handle_b64"]) event_sync_required = ipc_info["event_sync_required"] - if attached_count == 0 and config.debug_loading: - print( - f"[Setup DEBUG] Decoded IPC handle, len={len(ipc_handle)}", - flush=True, - ) - print("[Setup DEBUG] About to call _new_shared_cuda...", flush=True) - # Reconstruct the 8-tuple that _new_shared_cuda expects share_tuple = ( device_index, @@ -587,14 +581,9 @@ def _attach_to_vllm_shared_tensors( event_sync_required, ) - # Create storage from IPC handle (needs all 8 items) + # Create storage from IPC handle storage = torch.UntypedStorage._new_shared_cuda(*share_tuple) - if attached_count == 0 and config.debug_loading: - print( - f"[Setup DEBUG] Storage created! size={storage.size()}", flush=True - ) - # Reconstruct tensor dtype = getattr(torch, ipc_info["dtype"].replace("torch.", "")) tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}") @@ -605,32 +594,119 @@ def _attach_to_vllm_shared_tensors( stride=ipc_info["stride"], ) - if attached_count == 0 and config.debug_loading: - print(f"[Setup DEBUG] Tensor set! shape={tensor.shape}", flush=True) + vllm_tensor_cache[vllm_name] = tensor + return tensor - # Make tensor require gradients for training - tensor.requires_grad_(True) + except Exception as e: + print(f"[Setup] Failed to reconstruct {vllm_name}: {e}", flush=True) + return None - hf_state_dict[hf_name] = tensor - attached_count += 1 + attached_count = 0 + fused_count = 0 - if attached_count == 1 and config.debug_loading: - print("[Setup DEBUG] ✓ First tensor attached successfully!", flush=True) + for hf_name, mapping_info in vllm_to_hf_mapping.items(): + try: + # Check if this is a fused mapping or direct mapping + if isinstance(mapping_info, dict): + # Fused mapping - need to slice the source tensor + vllm_name = mapping_info["source"] + slice_start, slice_end = mapping_info["slice"] + slice_dim = mapping_info["dim"] + + full_tensor = reconstruct_vllm_tensor(vllm_name) + if full_tensor is None: + if config.debug_loading: + print(f"[Setup] Could not get source tensor for {hf_name}") + continue + + # Create a VIEW (not copy!) into the fused tensor + # This maintains shared memory - gradients flow back to vLLM's tensor + if slice_dim == 0: + tensor = full_tensor[slice_start:slice_end] + else: + # For other dimensions (rare, but handle it) + tensor = full_tensor.narrow(slice_dim, slice_start, slice_end - slice_start) + + # Verify it's a view, not a copy + if tensor.storage().data_ptr() != full_tensor.storage().data_ptr(): + print(f"[Setup] WARNING: {hf_name} is a COPY, not a view!") + + if attached_count == 0 and config.debug_loading: + print(f"[Setup DEBUG] Fused tensor slice: {hf_name}") + print(f"[Setup DEBUG] Source: {vllm_name} shape={full_tensor.shape}") + print(f"[Setup DEBUG] Slice: [{slice_start}:{slice_end}] -> {tensor.shape}") + + tensor.requires_grad_(True) + hf_state_dict[hf_name] = tensor + fused_count += 1 + attached_count += 1 + + else: + # Direct mapping - reconstruct tensor directly + vllm_name = mapping_info + + tensor = reconstruct_vllm_tensor(vllm_name) + if tensor is None: + continue + + if attached_count == 0 and config.debug_loading: + print(f"[Setup DEBUG] Attempting first tensor: {hf_name}", flush=True) + ipc_info = ipc_handles[vllm_name] + print(f"[Setup DEBUG] device_index: {ipc_info['device_index']}", flush=True) + print(f"[Setup DEBUG] storage_size: {ipc_info['storage_size']}", flush=True) + print(f"[Setup DEBUG] shape: {ipc_info['shape']}", flush=True) + + tensor.requires_grad_(True) + hf_state_dict[hf_name] = tensor + attached_count += 1 + + if attached_count == 1 and config.debug_loading: + print("[Setup DEBUG] ✓ First tensor attached successfully!", flush=True) except Exception as e: print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True) import traceback - traceback.print_exc() - continue + + print(f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)") if attached_count == 0: print("[Setup] Could not attach any tensors, falling back to regular loading") return None + # ========================================================================= + # EARLY VALIDATION: Check that we mapped a reasonable number of parameters + # This catches obvious mapping failures before we try to load + # ========================================================================= + hf_param_count = len(list(model.named_parameters())) + mapping_coverage = attached_count / hf_param_count if hf_param_count > 0 else 0 + + print(f"[Setup] Mapping coverage: {attached_count}/{hf_param_count} ({mapping_coverage:.1%})") + + # Expect at least 90% coverage for a valid mapping + # Some params like inv_freq buffers won't be in vLLM + if mapping_coverage < 0.90: + unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys()) + warning_msg = f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n" + warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n" + for name in list(unmapped_params)[:20]: + warning_msg += f" - {name}\n" + print(warning_msg) + + if mapping_coverage < 0.50: + raise RuntimeError( + f"[Setup] CRITICAL: Only {mapping_coverage:.1%} of parameters mapped!\n" + "This indicates a serious mapping failure. Check:\n" + " 1. vLLM and HuggingFace use the same model architecture\n" + " 2. tensor-parallel-size=1 for single-copy mode\n" + " 3. vllm_bridge_config.json contains valid ipc_handles" + ) + print(f"[Setup] ✓ Attached {attached_count} tensors to vLLM's shared memory") # Load state dict into model + # NOTE: We use strict=False because some buffers (like inv_freq) won't be in vLLM + # but we VALIDATE after loading to ensure nothing critical is left on meta model.load_state_dict(hf_state_dict, strict=False, assign=True) # Initialize any remaining meta tensors (buffers like rotary embeddings) @@ -791,39 +867,180 @@ def _attach_to_vllm_shared_tensors( print(f"\n[Setup] Initialized {meta_count} remaining meta tensors") + # ========================================================================= + # CRITICAL VALIDATION: Ensure no parameters/buffers are still on meta device + # This catches mapping bugs that would otherwise cause garbage output + # ========================================================================= + final_meta_params = [] + final_meta_buffers = [] + + for name, param in model.named_parameters(): + if param.device.type == "meta": + final_meta_params.append(name) + + for name, buffer in model.named_buffers(): + if buffer.device.type == "meta": + final_meta_buffers.append(name) + + if final_meta_params or final_meta_buffers: + error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n" + error_msg += "This means they were NOT properly mapped from vLLM or initialized.\n" + error_msg += "The model would produce GARBAGE output.\n\n" + + if final_meta_params: + error_msg += f"Meta parameters ({len(final_meta_params)}):\n" + for name in final_meta_params[:20]: + error_msg += f" - {name}\n" + if len(final_meta_params) > 20: + error_msg += f" ... and {len(final_meta_params) - 20} more\n" + + if final_meta_buffers: + error_msg += f"\nMeta buffers ({len(final_meta_buffers)}):\n" + for name in final_meta_buffers[:20]: + error_msg += f" - {name}\n" + if len(final_meta_buffers) > 20: + error_msg += f" ... and {len(final_meta_buffers) - 20} more\n" + + error_msg += "\nPossible causes:\n" + error_msg += " 1. vLLM parameter names don't match HuggingFace names\n" + error_msg += " 2. QKV/Gate-Up fusion mapping failed\n" + error_msg += " 3. vLLM running with tensor-parallel-size > 1 (not supported)\n" + + raise RuntimeError(error_msg) + + print("[Setup] ✓ All tensors successfully initialized on CUDA") + return model -def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dict: +def _create_vllm_to_hf_mapping( + model: torch.nn.Module, ipc_handles: dict, debug: bool = False +) -> dict: """ Create mapping from HuggingFace parameter names to vLLM tensor names. - vLLM uses slightly different naming conventions than HuggingFace. - This function creates the bidirectional mapping. + vLLM uses different naming conventions and fuses certain layers: + - qkv_proj (vLLM) = q_proj + k_proj + v_proj (HF) + - gate_up_proj (vLLM) = gate_proj + up_proj (HF) + + Returns a dict where: + - Simple mappings: {"hf_name": "vllm_name"} + - Fused mappings: {"hf_name": {"source": "vllm_name", "slice": (start, end), "dim": 0}} """ hf_params = set(model.state_dict().keys()) vllm_params = set(ipc_handles.keys()) + # Get model config for dimension calculations + model_config = model.config + hidden_size = getattr(model_config, "hidden_size", 4096) + num_attention_heads = getattr(model_config, "num_attention_heads", 32) + num_key_value_heads = getattr( + model_config, "num_key_value_heads", num_attention_heads + ) + intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4) + head_dim = hidden_size // num_attention_heads + + # Calculate sizes for QKV split + q_size = hidden_size # num_heads * head_dim + k_size = num_key_value_heads * head_dim + v_size = num_key_value_heads * head_dim + + if debug: + print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, " + f"kv_heads={num_key_value_heads}, intermediate={intermediate_size}") + print(f"[Mapping] QKV sizes: q={q_size}, k={k_size}, v={v_size}") + mapping = {} + def find_vllm_name(hf_name: str) -> Optional[str]: + """Try to find the corresponding vLLM parameter name.""" + # Direct match + if hf_name in vllm_params: + return hf_name + + # Add 'model.' prefix + if not hf_name.startswith("model."): + candidate = f"model.{hf_name}" + if candidate in vllm_params: + return candidate + + # Remove 'model.' prefix + if hf_name.startswith("model."): + candidate = hf_name[6:] + if candidate in vllm_params: + return candidate + + return None + + def find_fused_source(hf_name: str, fused_suffix: str) -> Optional[str]: + """Try to find the fused layer that contains this parameter.""" + # e.g., "model.layers.0.self_attn.q_proj.weight" -> "model.layers.0.self_attn.qkv_proj.weight" + for unfused in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]: + if unfused in hf_name: + fused_name = hf_name.replace(unfused, fused_suffix) + found = find_vllm_name(fused_name) + if found: + return found + return None + for hf_name in hf_params: # Try direct match first - if hf_name in vllm_params: - mapping[hf_name] = hf_name - continue - - # Try common transformations - # vLLM often uses 'model.' prefix - vllm_name = f"model.{hf_name}" if not hf_name.startswith("model.") else hf_name - if vllm_name in vllm_params: + vllm_name = find_vllm_name(hf_name) + if vllm_name: mapping[hf_name] = vllm_name continue - # Remove 'model.' prefix if present - if hf_name.startswith("model."): - vllm_name = hf_name[6:] - if vllm_name in vllm_params: - mapping[hf_name] = vllm_name + # Check for QKV fusion: q_proj, k_proj, v_proj -> qkv_proj + if any(x in hf_name for x in ["q_proj", "k_proj", "v_proj"]): + fused_name = find_fused_source(hf_name, "qkv_proj") + if fused_name: + # Determine which part of the fused tensor this is + if "q_proj" in hf_name: + start, end = 0, q_size + elif "k_proj" in hf_name: + start, end = q_size, q_size + k_size + else: # v_proj + start, end = q_size + k_size, q_size + k_size + v_size + + mapping[hf_name] = { + "source": fused_name, + "slice": (start, end), + "dim": 0, # Split along output dimension + "type": "qkv_fusion", + } + if debug: + print(f"[Mapping] QKV fusion: {hf_name} -> {fused_name}[{start}:{end}]") + continue + + # Check for Gate/Up fusion: gate_proj, up_proj -> gate_up_proj + if any(x in hf_name for x in ["gate_proj", "up_proj"]): + fused_name = find_fused_source(hf_name, "gate_up_proj") + if fused_name: + # Determine which part of the fused tensor this is + if "gate_proj" in hf_name: + start, end = 0, intermediate_size + else: # up_proj + start, end = intermediate_size, intermediate_size * 2 + + mapping[hf_name] = { + "source": fused_name, + "slice": (start, end), + "dim": 0, # Split along output dimension + "type": "gate_up_fusion", + } + if debug: + print(f"[Mapping] Gate/Up fusion: {hf_name} -> {fused_name}[{start}:{end}]") + continue + + # No mapping found - this parameter will need to be handled specially + if debug and "inv_freq" not in hf_name: # inv_freq is expected to be missing + print(f"[Mapping] No mapping for: {hf_name}") + + if debug: + direct = sum(1 for v in mapping.values() if isinstance(v, str)) + fused = sum(1 for v in mapping.values() if isinstance(v, dict)) + print(f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)") + print(f"[Mapping] Unmapped: {len(hf_params) - len(mapping)}") return mapping @@ -1404,7 +1621,7 @@ def train(config: TrainingConfig): # Track data fetch time data_fetch_start = time.time() if len(batches) == 0: - batches = get_data(config.batch_size, config.seq_len) + batches = get_data(config.batch_size, config.seq_len, config.atropos_url) token_batches, label_batches, advantage_batches, temperature_batches = ( batches.pop(0) ) @@ -1634,10 +1851,11 @@ def train_shared_vllm(config: TrainingConfig): os.makedirs(config.save_path, exist_ok=True) # Check Atropos API and register BEFORE training loop - print("\n[Setup] Connecting to Atropos API...") - if not check_atropos_api(timeout=30): + print(f"\n[Setup] Connecting to Atropos API at {config.atropos_url}...") + if not check_atropos_api(url=config.atropos_url, timeout=30): raise RuntimeError( - "Atropos API server not reachable. " "Please start it with: run-api" + f"Atropos API server not reachable at {config.atropos_url}. " + "Please start the environment server (e.g., gsm8k_server.py serve)" ) register_trainer(config) @@ -1657,7 +1875,7 @@ def train_shared_vllm(config: TrainingConfig): # Track data fetch time data_fetch_start = time.time() if len(batches) == 0: - batches = get_data(config.batch_size, config.seq_len) + batches = get_data(config.batch_size, config.seq_len, config.atropos_url) token_batches, label_batches, advantage_batches, temperature_batches = ( batches.pop(0) ) @@ -1870,7 +2088,7 @@ def train_lora(config: TrainingConfig): # Track data fetch time data_fetch_start = time.time() if len(batches) == 0: - batches = get_data(config.batch_size, config.seq_len) + batches = get_data(config.batch_size, config.seq_len, config.atropos_url) token_batches, label_batches, advantage_batches, temperature_batches = ( batches.pop(0) ) @@ -2032,6 +2250,12 @@ def parse_args() -> argparse.Namespace: default=9001, help="Port for the vLLM server", ) + parser.add_argument( + "--atropos-url", + type=str, + default="http://localhost:8000", + help="URL of the Atropos API/environment server (e.g., gsm8k_server)", + ) parser.add_argument( "--vllm-gpu-memory-utilization", type=float, @@ -2197,6 +2421,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: vllm_config_path=getattr(args, "vllm_config_path", None), debug_loading=getattr(args, "debug_loading", False), benchmark=getattr(args, "benchmark", False), + atropos_url=getattr(args, "atropos_url", "http://localhost:8000"), ) diff --git a/example_trainer/scripts/run_concurrent_tests.sh b/example_trainer/scripts/run_concurrent_tests.sh new file mode 100644 index 00000000..3b8b96a4 --- /dev/null +++ b/example_trainer/scripts/run_concurrent_tests.sh @@ -0,0 +1,244 @@ +#!/bin/bash +# ============================================================================= +# Concurrent GSM8k Training Test Script +# ============================================================================= +# +# This script runs BOTH LoRA and Single-Copy modes concurrently on an 8-GPU node: +# - GPUs 0-1: LoRA mode (vLLM on GPU 0, trainer on GPU 1) +# - GPUs 4-5: Single-Copy mode (vLLM+trainer share GPU 4) +# +# Usage: +# ./scripts/run_concurrent_tests.sh [MODEL] [STEPS] +# +# Example: +# ./scripts/run_concurrent_tests.sh Qwen/Qwen2.5-3B-Instruct 100 +# +# ============================================================================= + +set -e + +# Configuration +MODEL="${1:-Qwen/Qwen2.5-3B-Instruct}" +TRAINING_STEPS="${2:-100}" +BATCH_SIZE=4 +LORA_SAVE_INTERVAL=20 + +# Ports (separate for each mode) +LORA_VLLM_PORT=9001 +LORA_GSM8K_PORT=8001 + +SINGLE_COPY_VLLM_PORT=9002 +SINGLE_COPY_GSM8K_PORT=8002 + +# Directories +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TRAINER_DIR="$(dirname "$SCRIPT_DIR")" +REPO_DIR="$(dirname "$TRAINER_DIR")" + +LOG_DIR="${REPO_DIR}/test_logs_$(date +%Y%m%d_%H%M%S)" +mkdir -p "$LOG_DIR" + +LORA_CHECKPOINT_DIR="${LOG_DIR}/lora_checkpoints" +SINGLE_COPY_CHECKPOINT_DIR="${LOG_DIR}/single_copy_checkpoints" +mkdir -p "$LORA_CHECKPOINT_DIR" "$SINGLE_COPY_CHECKPOINT_DIR" + +echo "============================================================" +echo "Concurrent GSM8k Training Test" +echo "============================================================" +echo "Model: $MODEL" +echo "Training Steps: $TRAINING_STEPS" +echo "Batch Size: $BATCH_SIZE" +echo "Log Directory: $LOG_DIR" +echo "" +echo "LoRA Mode: GPUs 0-1, ports ${LORA_VLLM_PORT}/${LORA_GSM8K_PORT}" +echo "Single-Copy Mode: GPU 4, ports ${SINGLE_COPY_VLLM_PORT}/${SINGLE_COPY_GSM8K_PORT}" +echo "============================================================" + +# Cleanup function +cleanup() { + echo "" + echo "Cleaning up processes..." + pkill -u $USER -f "vllm_api_server.*port.${LORA_VLLM_PORT}" 2>/dev/null || true + pkill -u $USER -f "vllm_api_server.*port.${SINGLE_COPY_VLLM_PORT}" 2>/dev/null || true + pkill -u $USER -f "gsm8k_server.*${LORA_GSM8K_PORT}" 2>/dev/null || true + pkill -u $USER -f "gsm8k_server.*${SINGLE_COPY_GSM8K_PORT}" 2>/dev/null || true + pkill -u $USER -f "grpo.py.*lora_only" 2>/dev/null || true + pkill -u $USER -f "grpo.py.*shared_vllm" 2>/dev/null || true + echo "Cleanup complete." +} + +trap cleanup EXIT + +# Kill any existing processes +cleanup + +# Clear Triton cache (for LoRA B200 compatibility) +rm -rf ~/.triton/cache + +cd "$REPO_DIR" + +echo "" +echo "[1/6] Starting LoRA vLLM server (GPUs 0)..." +CUDA_VISIBLE_DEVICES=0 \ +VLLM_ENABLE_SHARED_WEIGHTS=1 \ +python -u example_trainer/vllm_api_server.py \ + --model "$MODEL" \ + --tensor-parallel-size 1 \ + --port $LORA_VLLM_PORT \ + --dtype bfloat16 \ + --gpu-memory-utilization 0.7 \ + --enable-lora \ + --max-loras 2 \ + --max-lora-rank 64 \ + --enforce-eager \ + > "${LOG_DIR}/lora_vllm.log" 2>&1 & +LORA_VLLM_PID=$! +echo " PID: $LORA_VLLM_PID" + +echo "" +echo "[2/6] Starting Single-Copy vLLM server (GPU 4)..." +CUDA_VISIBLE_DEVICES=4 \ +VLLM_ENABLE_SHARED_WEIGHTS=1 \ +LOGDIR="$SINGLE_COPY_CHECKPOINT_DIR" \ +python -u example_trainer/vllm_api_server.py \ + --model "$MODEL" \ + --tensor-parallel-size 1 \ + --port $SINGLE_COPY_VLLM_PORT \ + --dtype bfloat16 \ + --gpu-memory-utilization 0.5 \ + > "${LOG_DIR}/single_copy_vllm.log" 2>&1 & +SINGLE_COPY_VLLM_PID=$! +echo " PID: $SINGLE_COPY_VLLM_PID" + +echo "" +echo "Waiting for vLLM servers to initialize (60s)..." +sleep 60 + +# Verify servers are running +echo "" +echo "Verifying vLLM servers..." + +if curl -s "http://localhost:${LORA_VLLM_PORT}/health" > /dev/null; then + echo " ✓ LoRA vLLM server healthy" +else + echo " ✗ LoRA vLLM server failed to start" + cat "${LOG_DIR}/lora_vllm.log" | tail -50 + exit 1 +fi + +if curl -s "http://localhost:${SINGLE_COPY_VLLM_PORT}/health" > /dev/null; then + echo " ✓ Single-Copy vLLM server healthy" +else + echo " ✗ Single-Copy vLLM server failed to start" + cat "${LOG_DIR}/single_copy_vllm.log" | tail -50 + exit 1 +fi + +echo "" +echo "[3/6] Starting LoRA GSM8k environment..." +python -u environments/gsm8k_server.py serve \ + --env.tokenizer_name "$MODEL" \ + --env.use_wandb=False \ + --openai.model_name "$MODEL" \ + --openai.base_url "http://localhost:${LORA_VLLM_PORT}/v1" \ + --openai.server_type vllm \ + --server.port $LORA_GSM8K_PORT \ + > "${LOG_DIR}/lora_gsm8k.log" 2>&1 & +LORA_GSM8K_PID=$! +echo " PID: $LORA_GSM8K_PID" + +echo "" +echo "[4/6] Starting Single-Copy GSM8k environment..." +python -u environments/gsm8k_server.py serve \ + --env.tokenizer_name "$MODEL" \ + --env.use_wandb=False \ + --openai.model_name "$MODEL" \ + --openai.base_url "http://localhost:${SINGLE_COPY_VLLM_PORT}/v1" \ + --openai.server_type vllm \ + --server.port $SINGLE_COPY_GSM8K_PORT \ + > "${LOG_DIR}/single_copy_gsm8k.log" 2>&1 & +SINGLE_COPY_GSM8K_PID=$! +echo " PID: $SINGLE_COPY_GSM8K_PID" + +echo "" +echo "Waiting for GSM8k environments to initialize (15s)..." +sleep 15 + +echo "" +echo "[5/6] Starting LoRA trainer (GPU 1)..." +CUDA_VISIBLE_DEVICES=1 \ +python -u example_trainer/grpo.py \ + --model-name "$MODEL" \ + --weight-bridge-mode lora_only \ + --vllm-port $LORA_VLLM_PORT \ + --atropos-url "http://localhost:${LORA_GSM8K_PORT}" \ + --batch-size $BATCH_SIZE \ + --training-steps $TRAINING_STEPS \ + --vllm-restart-interval $LORA_SAVE_INTERVAL \ + --save-path "$LORA_CHECKPOINT_DIR" \ + --benchmark \ + > "${LOG_DIR}/lora_trainer.log" 2>&1 & +LORA_TRAINER_PID=$! +echo " PID: $LORA_TRAINER_PID" + +echo "" +echo "[6/6] Starting Single-Copy trainer (GPU 4 - shared with vLLM)..." +CUDA_VISIBLE_DEVICES=4 \ +python -u example_trainer/grpo.py \ + --model-name "$MODEL" \ + --weight-bridge-mode shared_vllm \ + --vllm-port $SINGLE_COPY_VLLM_PORT \ + --atropos-url "http://localhost:${SINGLE_COPY_GSM8K_PORT}" \ + --batch-size $BATCH_SIZE \ + --training-steps $TRAINING_STEPS \ + --save-path "$SINGLE_COPY_CHECKPOINT_DIR" \ + --vllm-config-path "${SINGLE_COPY_CHECKPOINT_DIR}/vllm_bridge_config.json" \ + --benchmark \ + > "${LOG_DIR}/single_copy_trainer.log" 2>&1 & +SINGLE_COPY_TRAINER_PID=$! +echo " PID: $SINGLE_COPY_TRAINER_PID" + +echo "" +echo "============================================================" +echo "Both trainers started!" +echo "" +echo "Monitor logs:" +echo " tail -f ${LOG_DIR}/lora_trainer.log" +echo " tail -f ${LOG_DIR}/single_copy_trainer.log" +echo "" +echo "Or watch both:" +echo " tail -f ${LOG_DIR}/*.log" +echo "" +echo "Waiting for training to complete..." +echo "============================================================" + +# Wait for both trainers to complete +wait $LORA_TRAINER_PID +LORA_EXIT=$? + +wait $SINGLE_COPY_TRAINER_PID +SINGLE_COPY_EXIT=$? + +echo "" +echo "============================================================" +echo "TRAINING COMPLETE" +echo "============================================================" +echo "LoRA Trainer Exit Code: $LORA_EXIT" +echo "Single-Copy Trainer Exit Code: $SINGLE_COPY_EXIT" +echo "" +echo "Results saved to: $LOG_DIR" +echo "" +echo "Checkpoints:" +echo " LoRA: $LORA_CHECKPOINT_DIR" +echo " Single-Copy: $SINGLE_COPY_CHECKPOINT_DIR" +echo "============================================================" + +# Generate summary +echo "" +echo "=== LoRA Training Summary ===" | tee "${LOG_DIR}/summary.txt" +grep -E "Step|Loss|Accuracy" "${LOG_DIR}/lora_trainer.log" | tail -20 | tee -a "${LOG_DIR}/summary.txt" + +echo "" | tee -a "${LOG_DIR}/summary.txt" +echo "=== Single-Copy Training Summary ===" | tee -a "${LOG_DIR}/summary.txt" +grep -E "Step|Loss|Accuracy" "${LOG_DIR}/single_copy_trainer.log" | tail -20 | tee -a "${LOG_DIR}/summary.txt" + diff --git a/example_trainer/scripts/test_lora_mode.sh b/example_trainer/scripts/test_lora_mode.sh new file mode 100644 index 00000000..4e7a93fd --- /dev/null +++ b/example_trainer/scripts/test_lora_mode.sh @@ -0,0 +1,142 @@ +#!/bin/bash +# ============================================================================= +# LoRA Mode GSM8k Training Test +# ============================================================================= +# +# Tests the LoRA training pipeline with GSM8k environment. +# Uses separate GPUs for vLLM and trainer. +# +# Usage: +# CUDA_VISIBLE_DEVICES=0,1 ./scripts/test_lora_mode.sh [MODEL] [STEPS] +# +# ============================================================================= + +set -e + +MODEL="${1:-Qwen/Qwen2.5-3B-Instruct}" +TRAINING_STEPS="${2:-50}" +BATCH_SIZE=4 +SAVE_INTERVAL=10 + +VLLM_PORT=9001 +GSM8K_PORT=8001 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TRAINER_DIR="$(dirname "$SCRIPT_DIR")" +REPO_DIR="$(dirname "$TRAINER_DIR")" + +LOG_DIR="${REPO_DIR}/lora_test_$(date +%Y%m%d_%H%M%S)" +mkdir -p "$LOG_DIR" + +echo "============================================================" +echo "LoRA Mode GSM8k Training Test" +echo "============================================================" +echo "Model: $MODEL" +echo "Steps: $TRAINING_STEPS" +echo "Log Dir: $LOG_DIR" +echo "============================================================" + +cleanup() { + echo "Cleaning up..." + pkill -u $USER -f "vllm_api_server.*port.*${VLLM_PORT}" 2>/dev/null || true + pkill -u $USER -f "gsm8k_server" 2>/dev/null || true + pkill -u $USER -f "grpo.py" 2>/dev/null || true +} +trap cleanup EXIT +cleanup + +# Clear Triton cache for B200 compatibility +rm -rf ~/.triton/cache + +cd "$REPO_DIR" + +echo "" +echo "[1/4] Starting vLLM with LoRA support..." +VLLM_ENABLE_SHARED_WEIGHTS=1 \ +python -u example_trainer/vllm_api_server.py \ + --model "$MODEL" \ + --tensor-parallel-size 1 \ + --port $VLLM_PORT \ + --dtype bfloat16 \ + --gpu-memory-utilization 0.6 \ + --enable-lora \ + --max-loras 2 \ + --max-lora-rank 64 \ + --enforce-eager \ + > "${LOG_DIR}/vllm.log" 2>&1 & + +echo "Waiting for vLLM (45s)..." +sleep 45 + +curl -s "http://localhost:${VLLM_PORT}/health" && echo " ✓ vLLM ready" || { echo " ✗ vLLM failed"; exit 1; } + +echo "" +echo "[2/4] Starting GSM8k environment..." +python -u environments/gsm8k_server.py serve \ + --env.tokenizer_name "$MODEL" \ + --env.use_wandb=False \ + --openai.model_name "$MODEL" \ + --openai.base_url "http://localhost:${VLLM_PORT}/v1" \ + --openai.server_type vllm \ + --server.port $GSM8K_PORT \ + > "${LOG_DIR}/gsm8k.log" 2>&1 & + +echo "Waiting for GSM8k (10s)..." +sleep 10 + +echo "" +echo "[3/4] Baseline test (before training)..." +curl -s -X POST "http://localhost:${VLLM_PORT}/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "'"$MODEL"'", + "messages": [{"role": "user", "content": "What is 123 + 456?"}], + "max_tokens": 100, + "temperature": 0.1 + }' | jq '.choices[0].message.content' | tee "${LOG_DIR}/baseline_response.txt" + +echo "" +echo "[4/4] Starting LoRA trainer..." +python -u example_trainer/grpo.py \ + --model-name "$MODEL" \ + --weight-bridge-mode lora_only \ + --vllm-port $VLLM_PORT \ + --atropos-url "http://localhost:${GSM8K_PORT}" \ + --batch-size $BATCH_SIZE \ + --training-steps $TRAINING_STEPS \ + --vllm-restart-interval $SAVE_INTERVAL \ + --save-path "$LOG_DIR/checkpoints" \ + --benchmark \ + 2>&1 | tee "${LOG_DIR}/trainer.log" + +echo "" +echo "============================================================" +echo "Training Complete!" +echo "Logs: $LOG_DIR" +echo "Checkpoints: $LOG_DIR/checkpoints" +echo "============================================================" + +# Post-training test +if [ -d "$LOG_DIR/checkpoints" ]; then + LATEST_ADAPTER=$(ls -td "$LOG_DIR/checkpoints/adapter_"* 2>/dev/null | head -1) + if [ -n "$LATEST_ADAPTER" ]; then + echo "" + echo "Post-training test with adapter: $LATEST_ADAPTER" + + curl -s -X POST "http://localhost:${VLLM_PORT}/lora/load" \ + -H "Content-Type: application/json" \ + -d '{"adapter_path": "'"$LATEST_ADAPTER"'"}' | jq + + echo "" + echo "Response after training:" + curl -s -X POST "http://localhost:${VLLM_PORT}/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "'"$MODEL"'", + "messages": [{"role": "user", "content": "What is 123 + 456?"}], + "max_tokens": 100, + "temperature": 0.1 + }' | jq '.choices[0].message.content' | tee "${LOG_DIR}/trained_response.txt" + fi +fi + diff --git a/example_trainer/scripts/test_single_copy_mode.sh b/example_trainer/scripts/test_single_copy_mode.sh new file mode 100644 index 00000000..08bbc630 --- /dev/null +++ b/example_trainer/scripts/test_single_copy_mode.sh @@ -0,0 +1,144 @@ +#!/bin/bash +# ============================================================================= +# Single-Copy Mode GSM8k Training Test +# ============================================================================= +# +# Tests the single-copy (shared_vllm) training pipeline with GSM8k environment. +# vLLM and trainer share the SAME GPU memory - true single-copy architecture. +# +# Usage: +# CUDA_VISIBLE_DEVICES=0 ./scripts/test_single_copy_mode.sh [MODEL] [STEPS] +# +# Note: Single-copy mode requires tensor-parallel-size=1 +# +# ============================================================================= + +set -e + +MODEL="${1:-Qwen/Qwen2.5-3B-Instruct}" +TRAINING_STEPS="${2:-50}" +BATCH_SIZE=4 + +VLLM_PORT=9002 +GSM8K_PORT=8002 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TRAINER_DIR="$(dirname "$SCRIPT_DIR")" +REPO_DIR="$(dirname "$TRAINER_DIR")" + +LOG_DIR="${REPO_DIR}/single_copy_test_$(date +%Y%m%d_%H%M%S)" +mkdir -p "$LOG_DIR" + +echo "============================================================" +echo "Single-Copy Mode GSM8k Training Test" +echo "============================================================" +echo "Model: $MODEL" +echo "Steps: $TRAINING_STEPS" +echo "Log Dir: $LOG_DIR" +echo "" +echo "NOTE: vLLM and trainer share the SAME GPU memory!" +echo " Weight updates are INSTANT (no copying)." +echo "============================================================" + +cleanup() { + echo "Cleaning up..." + pkill -u $USER -f "vllm_api_server.*port.*${VLLM_PORT}" 2>/dev/null || true + pkill -u $USER -f "gsm8k_server.*${GSM8K_PORT}" 2>/dev/null || true + pkill -u $USER -f "grpo.py.*shared_vllm" 2>/dev/null || true +} +trap cleanup EXIT +cleanup + +cd "$REPO_DIR" + +echo "" +echo "[1/4] Starting vLLM with shared memory enabled..." +VLLM_ENABLE_SHARED_WEIGHTS=1 \ +LOGDIR="$LOG_DIR" \ +python -u example_trainer/vllm_api_server.py \ + --model "$MODEL" \ + --tensor-parallel-size 1 \ + --port $VLLM_PORT \ + --dtype bfloat16 \ + --gpu-memory-utilization 0.5 \ + > "${LOG_DIR}/vllm.log" 2>&1 & + +echo "Waiting for vLLM (45s)..." +sleep 45 + +curl -s "http://localhost:${VLLM_PORT}/health" && echo " ✓ vLLM ready" || { echo " ✗ vLLM failed"; exit 1; } + +# Verify IPC handles are exported +if [ -f "${LOG_DIR}/vllm_bridge_config.json" ]; then + echo " ✓ vllm_bridge_config.json created" + PARAM_COUNT=$(jq '.ipc_handles | keys | length' "${LOG_DIR}/vllm_bridge_config.json" 2>/dev/null || echo "0") + echo " Exported parameters: $PARAM_COUNT" +else + echo " ✗ vllm_bridge_config.json not found - shared memory may not work" +fi + +echo "" +echo "[2/4] Starting GSM8k environment..." +python -u environments/gsm8k_server.py serve \ + --env.tokenizer_name "$MODEL" \ + --env.use_wandb=False \ + --openai.model_name "$MODEL" \ + --openai.base_url "http://localhost:${VLLM_PORT}/v1" \ + --openai.server_type vllm \ + --server.port $GSM8K_PORT \ + > "${LOG_DIR}/gsm8k.log" 2>&1 & + +echo "Waiting for GSM8k (10s)..." +sleep 10 + +echo "" +echo "[3/4] Baseline test (before training)..." +curl -s -X POST "http://localhost:${VLLM_PORT}/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "'"$MODEL"'", + "messages": [{"role": "user", "content": "What is 123 + 456?"}], + "max_tokens": 100, + "temperature": 0.1 + }' | jq '.choices[0].message.content' | tee "${LOG_DIR}/baseline_response.txt" + +echo "" +echo "[4/4] Starting Single-Copy trainer..." +echo "The trainer will attach to vLLM's GPU memory via CUDA IPC." +echo "" + +python -u example_trainer/grpo.py \ + --model-name "$MODEL" \ + --weight-bridge-mode shared_vllm \ + --vllm-port $VLLM_PORT \ + --atropos-url "http://localhost:${GSM8K_PORT}" \ + --batch-size $BATCH_SIZE \ + --training-steps $TRAINING_STEPS \ + --save-path "$LOG_DIR/checkpoints" \ + --vllm-config-path "${LOG_DIR}/vllm_bridge_config.json" \ + --benchmark \ + --debug-loading \ + 2>&1 | tee "${LOG_DIR}/trainer.log" + +echo "" +echo "============================================================" +echo "Training Complete!" +echo "============================================================" +echo "Logs: $LOG_DIR" +echo "" +echo "Key Metrics:" +grep -E "Attached|fused|Step.*Loss" "${LOG_DIR}/trainer.log" | tail -20 +echo "============================================================" + +# Post-training test +echo "" +echo "Post-training test (weights are already updated in vLLM):" +curl -s -X POST "http://localhost:${VLLM_PORT}/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "'"$MODEL"'", + "messages": [{"role": "user", "content": "What is 123 + 456?"}], + "max_tokens": 100, + "temperature": 0.1 + }' | jq '.choices[0].message.content' | tee "${LOG_DIR}/trained_response.txt" + diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index de34aafb..e6f021bd 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -760,6 +760,26 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse: status_code=404, detail=f"Adapter not found: {request.adapter_path}" ) + # Read adapter config to validate and log details + adapter_config_path = os.path.join(request.adapter_path, "adapter_config.json") + adapter_info = {} + + if os.path.exists(adapter_config_path): + try: + with open(adapter_config_path, "r") as f: + adapter_config = json.load(f) + adapter_info = { + "r": adapter_config.get("r"), + "lora_alpha": adapter_config.get("lora_alpha"), + "target_modules": adapter_config.get("target_modules"), + "base_model": adapter_config.get("base_model_name_or_path"), + } + logger.info(f"LoRA adapter config: {adapter_info}") + except Exception as e: + logger.warning(f"Could not read adapter_config.json: {e}") + else: + logger.warning(f"No adapter_config.json found at {adapter_config_path}") + with bridge_state.lock: bridge_state.active_lora_path = request.adapter_path bridge_state.active_lora_name = ( @@ -770,13 +790,16 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse: ) # vLLM needs unique int ID bridge_state.lora_load_count += 1 - logger.info(f"LoRA adapter loaded: {request.adapter_path}") + logger.info(f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})") return JSONResponse( { "status": "ok", "adapter_path": request.adapter_path, + "adapter_name": bridge_state.active_lora_name, + "adapter_id": bridge_state.active_lora_id, "load_count": bridge_state.lora_load_count, + "adapter_config": adapter_info, } )