diff --git a/environments/eval_environments/gsm8k_eval.py b/environments/eval_environments/gsm8k_eval.py index 06f4c28e..680cfe1c 100644 --- a/environments/eval_environments/gsm8k_eval.py +++ b/environments/eval_environments/gsm8k_eval.py @@ -373,7 +373,9 @@ class GSM8KEvalEnv(BaseEnv): # Create evaluation tasks async def eval_task(item): - return await self.rollout_and_score_eval(item, self.server.servers[0].config) + return await self.rollout_and_score_eval( + item, self.server.servers[0].config + ) tasks = [eval_task(item) for item in self.eval_items] diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index bbb116fd..fc332874 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -146,7 +146,7 @@ class MathEnv(BaseEnv): vllm_url = os.environ.get("MATH_ENV_VLLM_URL", "http://localhost:9001/v1") wandb_name = os.environ.get("MATH_ENV_WANDB_NAME", "math-zero-env") max_token_length = int(os.environ.get("MATH_ENV_MAX_TOKENS", "8192")) - + env_config = RSConfig( tokenizer_name=model_name, group_size=8, @@ -299,6 +299,7 @@ class MathEnv(BaseEnv): if not self.config.run_evaluation: return import time + start_time = time.time() eval_tasks = [] @@ -320,9 +321,7 @@ class MathEnv(BaseEnv): metrics[f"{subset}_accuracy"] = accuracy metrics[f"{subset}_total"] = len(scores) metrics[f"{subset}_correct"] = sum(scores) - self.eval_metrics.append( - (f"eval/{subset}_percent_correct", accuracy) - ) + self.eval_metrics.append((f"eval/{subset}_percent_correct", accuracy)) # overall score all_scores = [] @@ -332,9 +331,7 @@ class MathEnv(BaseEnv): metrics["overall_accuracy"] = overall_accuracy metrics["overall_total"] = len(all_scores) metrics["overall_correct"] = sum(all_scores) - self.eval_metrics.append( - ("eval/overall_percent_correct", overall_accuracy) - ) + self.eval_metrics.append(("eval/overall_percent_correct", overall_accuracy)) end_time = time.time() @@ -342,7 +339,9 @@ class MathEnv(BaseEnv): print("\n" + "=" * 60) print("Math Zero Evaluation Results") print("=" * 60) - print(f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})") + print( + f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})" + ) print("\nPer-subset breakdown:") for subset, scores in sorted(task_lists.items()): acc = sum(scores) / len(scores) diff --git a/example_trainer/README.md b/example_trainer/README.md index 227b99b0..b9a360c0 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -18,21 +18,21 @@ example_trainer/ ├── vllm_manager.py # vLLM process management ├── trainers.py # Training mode implementations ├── vllm_api_server.py # Custom vLLM server (streamlined for training) -├── vllm_patching/ # CUDA IPC patches for weight sharing basically overriding standard vllm for this -│ └── patched_gpu_runner.py -└── scripts/ # Helper scripts +├── vllm_patching/ # CUDA IPC patches for weight sharing basically overriding standard vllm for this +│ └── patched_gpu_runner.py +└── scripts/ # Helper scripts ├── test_lora_mode.sh └── test_single_copy_mode.sh ``` -GRPO Training Loop +GRPO Training Loop -1. Generate multiple responses to the same prompt -2. Score each response (reward) -3. Compute ADVANTAGE = reward - mean(rewards) -4. Train: increase probability of above-average responses - decrease probability of below-average responses +1. Generate multiple responses to the same prompt +2. Score each response (reward) +3. Compute ADVANTAGE = reward - mean(rewards) +4. Train: increase probability of above-average responses + decrease probability of below-average responses ``` ### Key Concepts @@ -330,7 +330,7 @@ The trainer creates **views** into vLLM's fused tensors: # Get sizes from model config q_size = num_heads * head_dim # e.g., 4096 -k_size = num_kv_heads * head_dim # e.g., 1024 +k_size = num_kv_heads * head_dim # e.g., 1024 v_size = num_kv_heads * head_dim # e.g., 1024 # Create views (no copy!) @@ -542,4 +542,3 @@ python -m example_trainer.vllm_api_server # NOT direct vllm commands | `vllm_api_server.py` | Streamlined vLLM server for training | | `vllm_manager.py` | vLLM process lifecycle management | | `checkpointing.py` | Save/load checkpoints and adapters | - diff --git a/example_trainer/__init__.py b/example_trainer/__init__.py index 01334052..3920f62a 100644 --- a/example_trainer/__init__.py +++ b/example_trainer/__init__.py @@ -20,9 +20,9 @@ Usage: train_legacy(config) """ +from .cli import config_from_args, parse_args from .config import TrainingConfig -from .trainers import train_legacy, train_shared_vllm, train_lora -from .cli import parse_args, config_from_args +from .trainers import train_legacy, train_lora, train_shared_vllm __all__ = [ "TrainingConfig", diff --git a/example_trainer/api.py b/example_trainer/api.py index cfa4ac48..f9073cc2 100644 --- a/example_trainer/api.py +++ b/example_trainer/api.py @@ -15,7 +15,9 @@ from tenacity import retry, stop_after_attempt, wait_exponential from .config import TrainingConfig -def check_atropos_api(url: str = "http://localhost:8000", timeout: float = 30.0) -> bool: +def check_atropos_api( + url: str = "http://localhost:8000", timeout: float = 30.0 +) -> bool: """ Check if the Atropos API server is reachable. @@ -82,13 +84,13 @@ def register_trainer(config: TrainingConfig): def get_batch(url: str = "http://localhost:8000"): """ Get a batch of training data from the Atropos API. - + Args: url: Base URL of the Atropos API server - + Returns: Batch data dictionary containing tokens, masks, scores, etc. - + Raises: RuntimeError: If trainer is not registered or other API error """ @@ -99,4 +101,3 @@ def get_batch(url: str = "http://localhost:8000"): raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}") return data - diff --git a/example_trainer/checkpointing.py b/example_trainer/checkpointing.py index 14d648c2..b5d60bbe 100644 --- a/example_trainer/checkpointing.py +++ b/example_trainer/checkpointing.py @@ -20,11 +20,11 @@ import torch def _ensure_contiguous_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: """ Create a state dict with contiguous tensors for safe saving. - + This is critical for shared_vllm mode where parameters are views into vLLM's fused tensors. Views may share storage and not be contiguous, which can cause issues when saving. - + Returns: State dict with all tensors made contiguous (copied if necessary) """ @@ -36,14 +36,14 @@ def _ensure_contiguous_state_dict(model: torch.nn.Module) -> Dict[str, torch.Ten state_dict[name] = param.detach().clone().contiguous() else: state_dict[name] = param.detach() - + # Also include buffers for name, buffer in model.named_buffers(): if not buffer.is_contiguous() or buffer.storage_offset() != 0: state_dict[name] = buffer.detach().clone().contiguous() else: state_dict[name] = buffer.detach() - + return state_dict @@ -86,28 +86,32 @@ def save_checkpoint( # For shared_vllm mode: ensure views are properly unfused print(" [Checkpoint] Using safe mode - ensuring contiguous tensors...") state_dict = _ensure_contiguous_state_dict(model) - + # Count how many were non-contiguous (views into fused tensors) view_count = sum( - 1 for name, param in model.named_parameters() + 1 + for name, param in model.named_parameters() if not param.is_contiguous() or param.storage_offset() != 0 ) if view_count > 0: - print(f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)") - + print( + f" [Checkpoint] Unfused {view_count} view tensors (qkv/gate_up fusions)" + ) + # Save state dict manually, then save config separately torch.save(state_dict, os.path.join(checkpoint_path, "pytorch_model.bin")) model.config.save_pretrained(checkpoint_path) - + # CRITICAL: Clean up the copied state_dict to free ~8GB GPU memory! del state_dict import gc + gc.collect() torch.cuda.empty_cache() else: # Standard save (may have issues with view tensors) model.save_pretrained(checkpoint_path) - + tokenizer.save_pretrained(checkpoint_path) print(" Checkpoint saved.") @@ -151,4 +155,3 @@ def save_lora_checkpoint( print(" Adapter saved.") return adapter_path - diff --git a/example_trainer/cli.py b/example_trainer/cli.py index c79b9716..046e3ec3 100644 --- a/example_trainer/cli.py +++ b/example_trainer/cli.py @@ -11,16 +11,17 @@ import torch from .config import TrainingConfig - # ============================================================================= # Argument Group Builders (modular, reusable) # ============================================================================= + def add_model_args(parser: argparse.ArgumentParser) -> None: """Add model-related arguments.""" group = parser.add_argument_group("Model") group.add_argument( - "--model", "--model-name", + "--model", + "--model-name", type=str, required=True, dest="model_name", @@ -67,7 +68,7 @@ def add_training_args(parser: argparse.ArgumentParser) -> None: choices=["adamw", "adamw_8bit", "adamw_cpu", "adafactor"], default="adamw_8bit", help="Optimizer: 'adamw' (full precision), 'adamw_8bit' (8-bit states), " - "'adamw_cpu' (CPU offload), 'adafactor' (no momentum)", + "'adamw_cpu' (CPU offload), 'adafactor' (no momentum)", ) group.add_argument( "--device", @@ -121,7 +122,8 @@ def add_vllm_args(parser: argparse.ArgumentParser) -> None: help="Port for the vLLM server", ) group.add_argument( - "--gpu-memory-utilization", "--vllm-gpu-memory-utilization", + "--gpu-memory-utilization", + "--vllm-gpu-memory-utilization", type=float, default=0.45, dest="gpu_memory_utilization", @@ -203,7 +205,9 @@ def add_lora_args(parser: argparse.ArgumentParser) -> None: """Add LoRA-specific arguments.""" group = parser.add_argument_group("LoRA Configuration") group.add_argument("--lora-r", type=int, default=16, help="LoRA rank") - group.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)") + group.add_argument( + "--lora-alpha", type=int, default=32, help="LoRA alpha (scaling factor)" + ) group.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout") group.add_argument( "--lora-target-modules", @@ -219,8 +223,12 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group("Distributed Training") group.add_argument("--trainer-rank", type=int, default=0, help="Trainer rank") group.add_argument("--world-size", type=int, default=1, help="World size") - group.add_argument("--init-method", type=str, default="env://", help="Distributed init method") - group.add_argument("--num-inference-nodes", type=int, default=0, help="Number of inference nodes") + group.add_argument( + "--init-method", type=str, default="env://", help="Distributed init method" + ) + group.add_argument( + "--num-inference-nodes", type=int, default=0, help="Number of inference nodes" + ) def add_debug_args(parser: argparse.ArgumentParser) -> None: @@ -248,6 +256,7 @@ def add_debug_args(parser: argparse.ArgumentParser) -> None: # Parser Builders # ============================================================================= + def create_base_parser(description: str) -> argparse.ArgumentParser: """Create a base parser with common formatting.""" return argparse.ArgumentParser( @@ -261,7 +270,7 @@ def create_full_parser() -> argparse.ArgumentParser: Create a parser with ALL arguments (for grpo.py multi-mode entry point). """ parser = create_base_parser("GRPO Trainer - Multi-mode training") - + add_model_args(parser) add_training_args(parser) add_grpo_args(parser) @@ -272,7 +281,7 @@ def create_full_parser() -> argparse.ArgumentParser: add_lora_args(parser) add_distributed_args(parser) add_debug_args(parser) - + return parser @@ -283,7 +292,7 @@ def create_unified_parser() -> argparse.ArgumentParser: parser = create_base_parser( "Unified GRPO Trainer - Starts vLLM server and trainer in one command" ) - + add_model_args(parser) add_training_args(parser) add_grpo_args(parser) @@ -291,7 +300,7 @@ def create_unified_parser() -> argparse.ArgumentParser: add_atropos_args(parser) add_wandb_args(parser) add_debug_args(parser) - + return parser @@ -299,10 +308,11 @@ def create_unified_parser() -> argparse.ArgumentParser: # Legacy API (backwards compatibility) # ============================================================================= + def parse_args() -> argparse.Namespace: """ Parse command-line arguments for the GRPO trainer (grpo.py). - + Returns: Parsed arguments namespace """ @@ -313,10 +323,10 @@ def parse_args() -> argparse.Namespace: def config_from_args(args: argparse.Namespace) -> TrainingConfig: """ Build a TrainingConfig from parsed CLI arguments. - + Args: args: Parsed argparse namespace - + Returns: TrainingConfig instance """ diff --git a/example_trainer/config.py b/example_trainer/config.py index 7e74f378..291dad43 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -14,7 +14,7 @@ from pydantic import BaseModel, Field class TrainingConfig(BaseModel): """ Training configuration for GRPO trainer. - + Supports three training modes: - 'none' (legacy): Periodic checkpoint saves + vLLM restarts - 'shared_vllm': Attach to vLLM's shared memory tensors, update in-place @@ -23,7 +23,7 @@ class TrainingConfig(BaseModel): # === Model Configuration === model_name: str = Field(..., description="Name of the base model to train") - + # === Training Hyperparameters === lr: float = Field(1e-5, description="Learning rate for the optimizer") training_steps: int = Field(10, description="Number of training steps") @@ -35,11 +35,11 @@ class TrainingConfig(BaseModel): optimizer: Literal["adamw", "adamw_8bit", "adamw_cpu", "adafactor"] = Field( "adamw_8bit", description="Optimizer to use: 'adamw' (full precision, ~32GB GPU), " - "'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), " - "'adamw_cpu' (CPU offload, ~0GB GPU, slower), " - "'adafactor' (no momentum, ~8GB GPU)" + "'adamw_8bit' (8-bit states, ~8GB GPU, requires bitsandbytes), " + "'adamw_cpu' (CPU offload, ~0GB GPU, slower), " + "'adafactor' (no momentum, ~8GB GPU)", ) - + # === GRPO/PPO Hyperparameters === kl_coef: float = Field( 0.1, @@ -66,15 +66,13 @@ class TrainingConfig(BaseModel): "When False, falls back to REINFORCE-style updates (not recommended)." ), ) - + # === Device & Storage === device: str = Field( - "cuda" if torch.cuda.is_available() else "cpu", - description="Device to train on" + "cuda" if torch.cuda.is_available() else "cpu", description="Device to train on" ) save_path: str = Field( - "trained_model_checkpoints", - description="Base path to save model checkpoints" + "trained_model_checkpoints", description="Base path to save model checkpoints" ) checkpoint_interval: int = Field( 3, @@ -83,7 +81,7 @@ class TrainingConfig(BaseModel): "Set to 0 to only save final checkpoint." ), ) - + # === vLLM Server Configuration === vllm_restart_interval: int = Field( 3, description="Restart vLLM every N training steps (legacy mode)" @@ -116,14 +114,12 @@ class TrainingConfig(BaseModel): "'none': legacy mode, restart vLLM with new checkpoint files." ), ) - + # === Distributed Training Configuration === trainer_rank: int = Field( 0, description="Rank of this trainer in the distributed group" ) - world_size: int = Field( - 1, description="Total processes in the distributed group" - ) + world_size: int = Field(1, description="Total processes in the distributed group") init_method: str = Field( "env://", description=( @@ -189,7 +185,7 @@ class TrainingConfig(BaseModel): "data fetch time, and GPU memory usage per step." ), ) - + # === Atropos API Configuration === atropos_url: str = Field( "http://localhost:8000", @@ -198,4 +194,3 @@ class TrainingConfig(BaseModel): "Default is http://localhost:8000. Change for concurrent tests." ), ) - diff --git a/example_trainer/data.py b/example_trainer/data.py index 1b11db75..1d17ffdd 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -21,7 +21,7 @@ from .api import get_batch def pad_data_to_good_offset( - data: dict, + data: dict, batch_size: int, extract_inference_logprobs: bool = True, ) -> Tuple[ @@ -33,22 +33,22 @@ def pad_data_to_good_offset( ]: """ Pad and batch data from the Atropos API. - + Processes raw batch data into properly padded tensors suitable for training: - Pads token sequences to nearest multiple of 64 - Normalizes advantage scores - Extracts temperature values - Extracts and pads inference logprobs for proper GRPO loss computation - + Args: data: Raw batch data from Atropos API batch_size: Size of each training batch extract_inference_logprobs: Whether to extract inference logprobs - + Returns: Tuple of (token_batches, label_batches, advantage_batches, temperature_batches, inference_logprob_batches) inference_logprob_batches is None if extract_inference_logprobs=False or no logprobs in data - + Note: inference_logprob_batches are padded with 0.0 at positions where labels == -100. This allows token-by-token alignment during GRPO loss computation. @@ -56,7 +56,7 @@ def pad_data_to_good_offset( max_token_len = max( [max([len(x) for x in item["tokens"]]) for item in data["batch"]] ) - + # Pad to nearest multiple of 64 for GPU efficiency good_multiple = 64 if (max_token_len - 1) % (good_multiple) != 0: @@ -65,7 +65,7 @@ def pad_data_to_good_offset( else: token_setup_len = max_token_len max_token_len = max_token_len - 1 # -1 for causal shift - + # Process all items input_ids = [] labels = [] @@ -74,7 +74,7 @@ def pad_data_to_good_offset( temperatures = [] inference_logprobs_padded: List[np.ndarray] = [] # Padded to match labels shape has_any_logprobs = False - + for item in data["batch"]: # Normalize advantage scores scores = np.array(item["scores"]) @@ -82,43 +82,45 @@ def pad_data_to_good_offset( scores = scores - scores.mean() scores = scores / max(scores.std(), 1e-8) item["scores"] = scores - + # Handle score overrides if item["overrides"] is not None: for i in range(len(item["overrides"])): if item["overrides"][i].get("set_advantage_to_zero", False): item["scores"][i] = 0 - + # Process each sample in the item for i in range(len(item["tokens"])): seq_len = len(item["tokens"][i]) - lengths.append( - math.ceil((seq_len - 1) / good_multiple) * good_multiple - ) - + lengths.append(math.ceil((seq_len - 1) / good_multiple) * good_multiple) + # Create labels with padding (-100 for masked positions) - label_item = np.concatenate([ - np.array(item["masks"][i]), - np.full( - max(0, token_setup_len - seq_len), - -100, - dtype=np.int32, - ), - ]) - + label_item = np.concatenate( + [ + np.array(item["masks"][i]), + np.full( + max(0, token_setup_len - seq_len), + -100, + dtype=np.int32, + ), + ] + ) + # Pad tokens - item["tokens"][i] = np.concatenate([ - np.array(item["tokens"][i]), - np.zeros( - max(0, token_setup_len - seq_len), - dtype=np.int32, - ), - ]) - + item["tokens"][i] = np.concatenate( + [ + np.array(item["tokens"][i]), + np.zeros( + max(0, token_setup_len - seq_len), + dtype=np.int32, + ), + ] + ) + input_ids.append(item["tokens"][i][:-1]) # Remove last for causal labels.append(label_item[1:]) # Shift by 1 for causal advantages.append(item["scores"][i]) - + # Extract and pad inference logprobs to match labels shape # IMPORTANT: inference_logprobs is ALREADY ALIGNED with tokens/masks: # - 1.0 for prompt tokens (masked positions) @@ -126,26 +128,32 @@ def pad_data_to_good_offset( # We just need to pad to match the sequence length if extract_inference_logprobs and "inference_logprobs" in item: if i < len(item["inference_logprobs"]): - raw_logprobs = np.array(item["inference_logprobs"][i], dtype=np.float32) + raw_logprobs = np.array( + item["inference_logprobs"][i], dtype=np.float32 + ) has_any_logprobs = True - + # Create padded logprobs array matching token_setup_len # Fill with 1.0 (the masked token placeholder value) for padding padded_logprobs = np.full(token_setup_len, 1.0, dtype=np.float32) - + # Copy raw_logprobs directly - they're already aligned with tokens n_to_copy = min(len(raw_logprobs), token_setup_len) padded_logprobs[:n_to_copy] = raw_logprobs[:n_to_copy] - + # Shift by 1 to match causal label shift inference_logprobs_padded.append(padded_logprobs[1:]) else: - # No logprobs for this sample, use 1.0 - inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32)) + # No logprobs for this sample, use 1.0 + inference_logprobs_padded.append( + np.full(token_setup_len - 1, 1.0, dtype=np.float32) + ) elif extract_inference_logprobs: # No inference_logprobs in item, use 1.0 - inference_logprobs_padded.append(np.full(token_setup_len - 1, 1.0, dtype=np.float32)) - + inference_logprobs_padded.append( + np.full(token_setup_len - 1, 1.0, dtype=np.float32) + ) + # Extract temperature (priority: override > generation_params > group_overrides > 1.0) t = 1.0 if ( @@ -155,48 +163,58 @@ def pad_data_to_good_offset( and ("temperature" in item["overrides"][i]) ): t = float(item["overrides"][i]["temperature"]) - elif item.get("generation_params") and ("temperature" in item["generation_params"]): + elif item.get("generation_params") and ( + "temperature" in item["generation_params"] + ): t = float(item["generation_params"]["temperature"]) - elif item.get("group_overrides") and ("temperature" in item["group_overrides"]): + elif item.get("group_overrides") and ( + "temperature" in item["group_overrides"] + ): t = float(item["group_overrides"]["temperature"]) temperatures.append(t) - + # Batch the data token_batches = [] label_batches = [] advantage_batches = [] temperature_batches = [] inference_logprob_batches = [] - + for i in range(len(input_ids) // batch_size): start = i * batch_size end = (i + 1) * batch_size - - token_batches.append( - torch.tensor(np.stack(input_ids[start:end], axis=0)) - ) - label_batches.append( - torch.tensor(np.stack(labels[start:end], axis=0)) - ) + + token_batches.append(torch.tensor(np.stack(input_ids[start:end], axis=0))) + label_batches.append(torch.tensor(np.stack(labels[start:end], axis=0))) advantage_batches.append( torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1) ) temperature_batches.append( - torch.tensor( - np.array(temperatures[start:end], dtype=np.float32) - ).view(-1, 1, 1) + torch.tensor(np.array(temperatures[start:end], dtype=np.float32)).view( + -1, 1, 1 + ) ) - + # Batch inference logprobs (same shape as labels) if extract_inference_logprobs and inference_logprobs_padded: inference_logprob_batches.append( torch.tensor(np.stack(inference_logprobs_padded[start:end], axis=0)) ) - + # Return inference logprob batches if we have any real logprobs - final_logprob_batches = inference_logprob_batches if (has_any_logprobs and inference_logprob_batches) else None - - return token_batches, label_batches, advantage_batches, temperature_batches, final_logprob_batches + final_logprob_batches = ( + inference_logprob_batches + if (has_any_logprobs and inference_logprob_batches) + else None + ) + + return ( + token_batches, + label_batches, + advantage_batches, + temperature_batches, + final_logprob_batches, + ) def get_data( @@ -205,27 +223,29 @@ def get_data( atropos_url: str = "http://localhost:8000", extract_inference_logprobs: bool = True, ) -> Tuple[ - List[Tuple[ - List[torch.Tensor], # token_batches - List[torch.Tensor], # label_batches - List[torch.Tensor], # advantage_batches - List[torch.Tensor], # temperature_batches - Optional[List[torch.Tensor]], # inference_logprob_batches - ]], + List[ + Tuple[ + List[torch.Tensor], # token_batches + List[torch.Tensor], # label_batches + List[torch.Tensor], # advantage_batches + List[torch.Tensor], # temperature_batches + Optional[List[torch.Tensor]], # inference_logprob_batches + ] + ], None, # Legacy return (no longer used) ]: """ Fetch and process training data from the Atropos API. - + Continuously polls the API until data is available, then processes all available batches. - + Args: batch_size: Size of each training batch seq_len: Maximum sequence length (for reference, not used directly) atropos_url: URL of the Atropos API server extract_inference_logprobs: Whether to extract inference logprobs for GRPO loss - + Returns: Tuple of (batches, None) - batches: List of processed batch tuples, each containing: @@ -234,42 +254,73 @@ def get_data( """ batches = [] _logged_logprob_warning = False - + while True: data = get_batch(url=atropos_url) - + if data["batch"] is not None: # DEBUG: Check if inference_logprobs exists in the data if not _logged_logprob_warning: - has_logprobs = any("inference_logprobs" in item for item in data["batch"]) + has_logprobs = any( + "inference_logprobs" in item for item in data["batch"] + ) if has_logprobs: # Check if they're non-empty - sample_item = next((item for item in data["batch"] if "inference_logprobs" in item), None) + sample_item = next( + ( + item + for item in data["batch"] + if "inference_logprobs" in item + ), + None, + ) if sample_item and sample_item.get("inference_logprobs"): - sample_lp = sample_item["inference_logprobs"][0] if sample_item["inference_logprobs"] else [] - print(f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})") + sample_lp = ( + sample_item["inference_logprobs"][0] + if sample_item["inference_logprobs"] + else [] + ) + print( + f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})" + ) else: - print(" [Data] ⚠ inference_logprobs key exists but is empty!") + print( + " [Data] ⚠ inference_logprobs key exists but is empty!" + ) else: print(" [Data] ⚠ NO inference_logprobs in batch data!") - print(f" [Data] Keys in first item: {list(data['batch'][0].keys())}") + print( + f" [Data] Keys in first item: {list(data['batch'][0].keys())}" + ) _logged_logprob_warning = True - + # Save batch for debugging with open("temp.json", "w", encoding="utf-8") as f: json.dump(data, f) - + # Process and accumulate batches (now includes batched inference logprobs) - token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches = \ - pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) - + ( + token_batches, + label_batches, + adv_batches, + temp_batches, + inf_logprob_batches, + ) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) + # Include inference logprob batches in the tuple - batches.append((token_batches, label_batches, adv_batches, temp_batches, inf_logprob_batches)) - + batches.append( + ( + token_batches, + label_batches, + adv_batches, + temp_batches, + inf_logprob_batches, + ) + ) + elif len(batches) > 0: # Return accumulated batches when no more data return batches, None else: # Wait for data time.sleep(1) - diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index e9293937..bca45cb8 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -9,7 +9,7 @@ Supports three training modes: Usage: # Legacy mode (manages vLLM internally) python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct - + # Shared vLLM mode (requires external vLLM with VLLM_ENABLE_SHARED_WEIGHTS=1) python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct \\ --weight-bridge-mode shared_vllm @@ -19,8 +19,8 @@ Usage: --weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32 """ -from .cli import parse_args, config_from_args -from .trainers import train_legacy, train_shared_vllm, train_lora +from .cli import config_from_args, parse_args +from .trainers import train_legacy, train_lora, train_shared_vllm def main(): @@ -28,9 +28,9 @@ def main(): args = parse_args() config = config_from_args(args) - print("\n" + "="*60) + print("\n" + "=" * 60) print("GRPO TRAINER") - print("="*60) + print("=" * 60) print(f"Model: {config.model_name}") print(f"Mode: {config.weight_bridge_mode}") print(f"Training steps: {config.training_steps}") diff --git a/example_trainer/model.py b/example_trainer/model.py index b67a4313..e6152874 100644 --- a/example_trainer/model.py +++ b/example_trainer/model.py @@ -20,6 +20,7 @@ from .config import TrainingConfig # Import PEFT for LoRA training try: from peft import LoraConfig, TaskType, get_peft_model + PEFT_AVAILABLE = True except ImportError: PEFT_AVAILABLE = False @@ -31,12 +32,13 @@ def _get_attention_implementation() -> str: Priority: 1. Flash Attention 2 (if flash_attn library is available and works) 2. SDPA (PyTorch's scaled dot-product attention) - + Returns: Tuple of (attn_implementation string, human-readable name) """ try: import flash_attn # noqa: F401 + return "flash_attention_2" except ImportError: return "sdpa" @@ -49,26 +51,33 @@ def _load_model_with_attention( ) -> torch.nn.Module: """ Load a model with the best available attention implementation. - + Args: model_name_or_config: Either a model name (str) or a model config object torch_dtype: Data type for model weights from_config: If True, use from_config (for meta device loading - no weights) If False, use from_pretrained (downloads and loads weights) - + Returns: Loaded model with appropriate attention implementation """ # Select the loader function based on mode # from_config: creates empty shell (meta device), from_pretrained: loads weights - loader = AutoModelForCausalLM.from_config if from_config else AutoModelForCausalLM.from_pretrained - + loader = ( + AutoModelForCausalLM.from_config + if from_config + else AutoModelForCausalLM.from_pretrained + ) + # Try attention implementations in order of preference for attn_impl in ["flash_attention_2", "sdpa"]: # Skip flash_attention_2 if not available - if attn_impl == "flash_attention_2" and _get_attention_implementation() != "flash_attention_2": + if ( + attn_impl == "flash_attention_2" + and _get_attention_implementation() != "flash_attention_2" + ): continue - + try: model = loader( model_name_or_config, @@ -82,10 +91,11 @@ def _load_model_with_attention( print(f"[Setup] Flash Attention 2 failed ({e}), trying SDPA...") continue raise - + # Should never reach here, but just in case raise RuntimeError("Failed to load model with any attention implementation") + def load_model_and_tokenizer( config: TrainingConfig, single_copy: bool = False, @@ -106,7 +116,7 @@ def load_model_and_tokenizer( if single_copy or config.weight_bridge_mode == "shared_vllm": config_path = _find_vllm_config(config) model = _attach_to_vllm_shared_tensors(config, config_path) - + if model is not None: print("[Setup] ✓ Single-copy mode active - using vLLM's tensors directly!") # Enable gradient checkpointing to save memory (was missing before!) @@ -178,7 +188,7 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module: Returns: PEFT model with LoRA adapters applied """ - if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface + if not PEFT_AVAILABLE: # Yeah no PEFT is needed no matter what bless huggingface raise RuntimeError("PEFT library not available. Install with: pip install peft") print("[Setup] Loading base model for LoRA mode...") @@ -208,7 +218,9 @@ def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module: return model -def _setup_gradient_checkpointing(model: torch.nn.Module, config: TrainingConfig) -> None: +def _setup_gradient_checkpointing( + model: torch.nn.Module, config: TrainingConfig +) -> None: """Configure gradient checkpointing for the model.""" # Disable KV cache - incompatible with gradient checkpointing model.config.use_cache = False @@ -272,8 +284,8 @@ def _attach_to_vllm_shared_tensors( print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!") # Load model config (not weights) to get architecture - # doesn't store the buffers just basically the schematics. This is the - # the blueprint for the house not the actual house + # doesn't store the buffers just basically the schematics. This is the + # the blueprint for the house not the actual house model_config = AutoConfig.from_pretrained(config.model_name) # Create empty model on meta device (no memory allocation) @@ -297,7 +309,9 @@ def _attach_to_vllm_shared_tensors( ipc_handles, vllm_to_hf_mapping, config ) - print(f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)") + print( + f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)" + ) if attached_count == 0: print("[Setup] Could not attach any tensors, falling back to regular loading") @@ -322,6 +336,7 @@ def _attach_to_vllm_shared_tensors( def _deserialize_ipc_handles(handles_raw: dict) -> dict: """Deserialize base64-encoded bytes in IPC handles.""" + def deserialize(handles): result = {} for k, v in handles.items(): @@ -333,6 +348,7 @@ def _deserialize_ipc_handles(handles_raw: dict) -> dict: else: result[k] = v return result + return deserialize(handles_raw) @@ -387,8 +403,14 @@ def _reconstruct_shared_tensors( event_sync_required = ipc_info["event_sync_required"] share_tuple = ( - device_index, ipc_handle, storage_size, storage_offset_orig, - ref_counter_handle, ref_counter_offset, event_handle, event_sync_required, + device_index, + ipc_handle, + storage_size, + storage_offset_orig, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, ) storage = torch.UntypedStorage._new_shared_cuda(*share_tuple) @@ -424,7 +446,9 @@ def _reconstruct_shared_tensors( if slice_dim == 0: tensor = full_tensor[slice_start:slice_end] else: - tensor = full_tensor.narrow(slice_dim, slice_start, slice_end - slice_start) + tensor = full_tensor.narrow( + slice_dim, slice_start, slice_end - slice_start + ) tensor.requires_grad_(True) hf_state_dict[hf_name] = tensor @@ -459,12 +483,16 @@ def _validate_mapping_coverage( # Note: attached_count may be > param_count because state_dict includes buffers # while named_parameters only counts trainable params - print(f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters " - f"(>100% is OK - includes buffers)") + print( + f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters " + f"(>100% is OK - includes buffers)" + ) if mapping_coverage < 0.90: unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys()) - warning_msg = f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n" + warning_msg = ( + f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n" + ) warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n" for name in list(unmapped_params)[:20]: warning_msg += f" - {name}\n" @@ -484,11 +512,17 @@ def _initialize_meta_tensors( config: TrainingConfig, ) -> None: """Initialize any remaining meta tensors after loading.""" - meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"] - meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"] + meta_params = [ + name for name, p in model.named_parameters() if p.device.type == "meta" + ] + meta_buffers = [ + name for name, b in model.named_buffers() if b.device.type == "meta" + ] if config.debug_loading: - print(f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}") + print( + f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}" + ) def get_parent_and_name(model, full_name): parts = full_name.split(".") @@ -526,11 +560,15 @@ def _initialize_meta_tensors( dim = buffer.shape[0] * 2 # Get rope_theta from model config (default 10000.0 for LLaMA, but Qwen3 uses 5000000!) rope_theta = getattr(model.config, "rope_theta", 10000.0) - inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) new_buffer = inv_freq.to(dtype=buffer.dtype, device=device) print(f"[Setup] Initialized {name} with rope_theta={rope_theta}") else: - new_buffer = torch.zeros(buffer.shape, dtype=buffer.dtype, device=device) + new_buffer = torch.zeros( + buffer.shape, dtype=buffer.dtype, device=device + ) parent, attr_name = get_parent_and_name(model, name) parent.register_buffer(attr_name, new_buffer) @@ -544,8 +582,12 @@ def _initialize_meta_tensors( def _validate_no_meta_tensors(model: torch.nn.Module) -> None: """Ensure no parameters or buffers are still on meta device.""" - final_meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"] - final_meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"] + final_meta_params = [ + name for name, p in model.named_parameters() if p.device.type == "meta" + ] + final_meta_buffers = [ + name for name, b in model.named_buffers() if b.device.type == "meta" + ] if final_meta_params or final_meta_buffers: error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n" @@ -575,7 +617,7 @@ def _create_vllm_to_hf_mapping( Handles fused layers: - qkv_proj (vLLM) = q_proj + k_proj + v_proj (HF) - gate_up_proj (vLLM) = gate_proj + up_proj (HF) - + Uses actual tensor shapes from HF model to determine slice sizes, rather than calculating from config (which can be wrong for some models). """ @@ -588,20 +630,22 @@ def _create_vllm_to_hf_mapping( model_config = model.config hidden_size = getattr(model_config, "hidden_size", 4096) num_attention_heads = getattr(model_config, "num_attention_heads", 32) - num_key_value_heads = getattr(model_config, "num_key_value_heads", num_attention_heads) + num_key_value_heads = getattr( + model_config, "num_key_value_heads", num_attention_heads + ) intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4) - + # Try to get head_dim from config (some models like Qwen3 have this) head_dim = getattr(model_config, "head_dim", None) if head_dim is None: head_dim = hidden_size // num_attention_heads - # Determine QKV sizes from ACTUAL HF model tensor shapes + # Determine QKV sizes from ACTUAL HF model tensor shapes # Look for a q_proj weight in the model to get the actual size q_size = None k_size = None v_size = None - + for name, param in hf_state_dict.items(): if "q_proj.weight" in name and q_size is None: q_size = param.shape[0] # Output dimension @@ -611,7 +655,7 @@ def _create_vllm_to_hf_mapping( v_size = param.shape[0] if q_size and k_size and v_size: break - + # Fallback to calculated values if not found if q_size is None: q_size = num_attention_heads * head_dim @@ -623,7 +667,7 @@ def _create_vllm_to_hf_mapping( # Also get gate/up sizes from actual HF model gate_size = None up_size = None - + for name, param in hf_state_dict.items(): if "gate_proj.weight" in name and gate_size is None: gate_size = param.shape[0] @@ -631,7 +675,7 @@ def _create_vllm_to_hf_mapping( up_size = param.shape[0] if gate_size and up_size: break - + # Fallback if gate_size is None: gate_size = intermediate_size @@ -639,8 +683,10 @@ def _create_vllm_to_hf_mapping( up_size = intermediate_size # Always print sizes for debugging weight sharing issues - print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, " - f"kv_heads={num_key_value_heads}, head_dim={head_dim}") + print( + f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, " + f"kv_heads={num_key_value_heads}, head_dim={head_dim}" + ) print(f"[Mapping] QKV sizes from HF model: q={q_size}, k={k_size}, v={v_size}") print(f"[Mapping] Gate/Up sizes from HF model: gate={gate_size}, up={up_size}") @@ -714,7 +760,8 @@ def _create_vllm_to_hf_mapping( if debug: direct = sum(1 for v in mapping.values() if isinstance(v, str)) fused = sum(1 for v in mapping.values() if isinstance(v, dict)) - print(f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)") + print( + f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)" + ) return mapping - diff --git a/example_trainer/run.py b/example_trainer/run.py index e54637ad..ff60d0ba 100644 --- a/example_trainer/run.py +++ b/example_trainer/run.py @@ -33,7 +33,7 @@ def wait_for_vllm(port: int, timeout: int = 300) -> bool: """Wait for vLLM server to be ready.""" print(f"[Run] Waiting for vLLM server on port {port}...") start = time.time() - + while time.time() - start < timeout: try: response = requests.get(f"http://localhost:{port}/health", timeout=5) @@ -44,9 +44,9 @@ def wait_for_vllm(port: int, timeout: int = 300) -> bool: pass except Exception as e: print(f"[Run] Health check error: {e}") - + time.sleep(2) - + print(f"[Run] ✗ vLLM server failed to start within {timeout}s") return False @@ -55,20 +55,23 @@ def wait_for_bridge_config(config_path: str, timeout: int = 60) -> bool: """Wait for vLLM bridge config to be created.""" print(f"[Run] Waiting for bridge config at {config_path}...") start = time.time() - + while time.time() - start < timeout: if os.path.exists(config_path): try: import json - with open(config_path, 'r') as f: + + with open(config_path, "r") as f: config = json.load(f) - if config.get('ipc_handles') and len(config['ipc_handles']) > 0: - print(f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles") + if config.get("ipc_handles") and len(config["ipc_handles"]) > 0: + print( + f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles" + ) return True except Exception: pass time.sleep(1) - + print(f"[Run] ✗ Bridge config not created within {timeout}s") return False @@ -77,44 +80,44 @@ def main(): # Parse args using shared CLI module parser = create_unified_parser() args = parser.parse_args() - + # Create log directory - log_dir = getattr(args, 'log_dir', './logs') + log_dir = getattr(args, "log_dir", "./logs") os.makedirs(log_dir, exist_ok=True) - + # Bridge config path bridge_config_path = "./vllm_bridge_config.json" - + # Clean up old bridge config if os.path.exists(bridge_config_path): os.remove(bridge_config_path) print("[Run] Removed old bridge config") - + # === Print Configuration === - print("\n" + "="*60) + print("\n" + "=" * 60) print("STARTING UNIFIED GRPO TRAINER (shared_vllm mode)") - print("="*60) + print("=" * 60) print(f"Model: {args.model_name}") print(f"vLLM port: {args.vllm_port}") print(f"GPU memory utilization: {args.gpu_memory_utilization}") print(f"Training steps: {args.training_steps}") print(f"Optimizer: {args.optimizer}") print(f"GRPO: kl_coef={args.kl_coef}, clip_eps={args.clip_eps}") - print("="*60 + "\n") - + print("=" * 60 + "\n") + # Get the path to vllm_api_server.py script_dir = Path(__file__).parent vllm_server_script = script_dir / "vllm_api_server.py" - + if not vllm_server_script.exists(): print(f"[Run] ✗ vLLM server script not found at {vllm_server_script}") sys.exit(1) - + # Extract device index from args.device device_index = "0" if ":" in args.device: device_index = args.device.split(":")[1] - + # Build vLLM environment vllm_env = os.environ.copy() vllm_env["VLLM_ENABLE_SHARED_WEIGHTS"] = "1" @@ -123,21 +126,28 @@ def main(): vllm_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" vllm_env["VLLM_USE_V1"] = "0" # v0 engine required for shared weights patches vllm_env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" # Required for CUDA - + # Build vLLM command vllm_cmd = [ - sys.executable, "-u", str(vllm_server_script), - "--model", args.model_name, - "--port", str(args.vllm_port), - "--dtype", args.dtype, - "--gpu-memory-utilization", str(args.gpu_memory_utilization), - "--max-model-len", str(args.max_model_len), + sys.executable, + "-u", + str(vllm_server_script), + "--model", + args.model_name, + "--port", + str(args.vllm_port), + "--dtype", + args.dtype, + "--gpu-memory-utilization", + str(args.gpu_memory_utilization), + "--max-model-len", + str(args.max_model_len), "--enforce-eager", # Required for shared weights ] - + vllm_log_path = os.path.join(log_dir, "vllm.log") print(f"[Run] Starting vLLM server (log: {vllm_log_path})...") - + vllm_log = open(vllm_log_path, "w") vllm_process = subprocess.Popen( vllm_cmd, @@ -145,7 +155,7 @@ def main(): stdout=vllm_log, stderr=subprocess.STDOUT, ) - + # Register cleanup def cleanup(): print("\n[Run] Cleaning up...") @@ -158,24 +168,24 @@ def main(): vllm_process.kill() vllm_log.close() print("[Run] Cleanup complete.") - + atexit.register(cleanup) signal.signal(signal.SIGINT, lambda s, f: sys.exit(0)) signal.signal(signal.SIGTERM, lambda s, f: sys.exit(0)) - + # Wait for vLLM to be ready if not wait_for_vllm(args.vllm_port, timeout=500): print("[Run] ✗ vLLM server failed to start. Check logs at:", vllm_log_path) sys.exit(1) - + # Wait for bridge config if not wait_for_bridge_config(bridge_config_path, timeout=60): print("[Run] ✗ Bridge config not created. Check vLLM logs.") sys.exit(1) - + # === Start Trainer === print("\n[Run] Starting GRPO trainer...") - + # Build config - override some fields for shared_vllm mode config = TrainingConfig( model_name=args.model_name, @@ -205,13 +215,14 @@ def main(): benchmark=True, # Always show timing info for run.py debug_loading=getattr(args, "debug_loading", False), ) - + try: train_shared_vllm(config) print("\n[Run] ✓ Training completed successfully!") except Exception as e: print(f"\n[Run] ✗ Training failed: {e}") import traceback + traceback.print_exc() sys.exit(1) diff --git a/example_trainer/scripts/test_lora_mode.sh b/example_trainer/scripts/test_lora_mode.sh index 078e78e0..be672e3c 100644 --- a/example_trainer/scripts/test_lora_mode.sh +++ b/example_trainer/scripts/test_lora_mode.sh @@ -122,11 +122,11 @@ if [ -d "$LOG_DIR/checkpoints" ]; then if [ -n "$LATEST_ADAPTER" ]; then echo "" echo "Post-training test with adapter: $LATEST_ADAPTER" - + curl -s -X POST "http://localhost:${VLLM_PORT}/lora/load" \ -H "Content-Type: application/json" \ -d '{"adapter_path": "'"$LATEST_ADAPTER"'"}' | jq - + echo "" echo "Response after training:" curl -s -X POST "http://localhost:${VLLM_PORT}/generate" \ @@ -138,4 +138,3 @@ if [ -d "$LOG_DIR/checkpoints" ]; then }' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt" fi fi - diff --git a/example_trainer/scripts/test_single_copy_mode.sh b/example_trainer/scripts/test_single_copy_mode.sh index 1022ea72..68faaf1a 100644 --- a/example_trainer/scripts/test_single_copy_mode.sh +++ b/example_trainer/scripts/test_single_copy_mode.sh @@ -143,4 +143,3 @@ curl -s -X POST "http://localhost:${VLLM_PORT}/generate" \ "max_tokens": 100, "temperature": 0.1 }' | jq '.text[0]' | tee "${LOG_DIR}/trained_response.txt" - diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index b0e976fb..77e7ec2f 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -21,75 +21,78 @@ from .api import check_atropos_api, register_trainer class CPUOffloadAdamW(torch.optim.Optimizer): """ AdamW with optimizer states offloaded to CPU. - + Full precision (no quantization), but states stay on CPU RAM instead of GPU. Trade-off: Slower (~2x) but uses ~0GB GPU memory for optimizer states. """ - def __init__(self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01): + + def __init__( + self, params, lr=1e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01 + ): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults) - + def _init_state(self, p): """Lazily initialize state on CPU.""" state = self.state[p] if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Store on CPU in FP32 - state['exp_avg'] = torch.zeros_like(p, device='cpu', dtype=torch.float32) - state['exp_avg_sq'] = torch.zeros_like(p, device='cpu', dtype=torch.float32) + state["exp_avg"] = torch.zeros_like(p, device="cpu", dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p, device="cpu", dtype=torch.float32) return state - + @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() - + for group in self.param_groups: - beta1, beta2 = group['betas'] - - for p in group['params']: + beta1, beta2 = group["betas"] + + for p in group["params"]: if p.grad is None: continue - + grad = p.grad state = self._init_state(p) - - state['step'] += 1 - + + state["step"] += 1 + # Move states to GPU for computation - exp_avg = state['exp_avg'].to(p.device) - exp_avg_sq = state['exp_avg_sq'].to(p.device) - + exp_avg = state["exp_avg"].to(p.device) + exp_avg_sq = state["exp_avg_sq"].to(p.device) + # AdamW update exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - + # Bias correction - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] - step_size = group['lr'] / bias_correction1 - + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = group["lr"] / bias_correction1 + # Update weights - denom = (exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)).add_(group['eps']) + denom = (exp_avg_sq.sqrt() / (bias_correction2**0.5)).add_(group["eps"]) p.addcdiv_(exp_avg, denom, value=-step_size) - + # Weight decay - if group['weight_decay'] != 0: - p.add_(p, alpha=-group['lr'] * group['weight_decay']) - + if group["weight_decay"] != 0: + p.add_(p, alpha=-group["lr"] * group["weight_decay"]) + # Move states back to CPU (non-blocking for better perf) - state['exp_avg'].copy_(exp_avg.cpu()) - state['exp_avg_sq'].copy_(exp_avg_sq.cpu()) - + state["exp_avg"].copy_(exp_avg.cpu()) + state["exp_avg_sq"].copy_(exp_avg_sq.cpu()) + return loss def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: """ Create optimizer based on config.optimizer setting. - + Options: - 'adamw': Standard AdamW (full precision, ~32GB GPU for 4B model) - 'adamw_8bit': 8-bit AdamW from bitsandbytes (~8GB GPU, requires bitsandbytes) @@ -99,22 +102,28 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: if config.optimizer == "adamw_8bit": try: import bitsandbytes as bnb + optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.lr) print("[Setup] Using 8-bit AdamW (saves ~24GB optimizer memory)") return optimizer except ImportError: print("[Setup] WARNING: bitsandbytes not installed, falling back to AdamW") print("[Setup] Install with: pip install bitsandbytes") - + if config.optimizer == "adamw_cpu": optimizer = CPUOffloadAdamW(model.parameters(), lr=config.lr) - print("[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)") - print("[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization") + print( + "[Setup] Using AdamW with CPU offload (full precision, ~0GB GPU for states)" + ) + print( + "[Setup] NOTE: ~2x slower due to CPU<->GPU transfers, but no quantization" + ) return optimizer - + if config.optimizer == "adafactor": try: from transformers.optimization import Adafactor + optimizer = Adafactor( model.parameters(), lr=config.lr, @@ -125,7 +134,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: return optimizer except ImportError: print("[Setup] WARNING: transformers Adafactor not available, using AdamW") - + # Default: standard AdamW optimizer = AdamW(model.parameters(), lr=config.lr) print("[Setup] Using standard AdamW (requires ~32GB for optimizer states)") @@ -135,7 +144,7 @@ def create_optimizer(model: torch.nn.Module, config) -> torch.optim.Optimizer: from .checkpointing import save_checkpoint, save_lora_checkpoint # noqa: E402 from .config import TrainingConfig # noqa: E402 from .data import get_data # noqa: E402 -from .model import load_model_and_tokenizer, PEFT_AVAILABLE # noqa: E402 +from .model import PEFT_AVAILABLE, load_model_and_tokenizer # noqa: E402 from .training import ( # noqa: E402 finalize_training, log_metrics, @@ -146,8 +155,8 @@ from .vllm_manager import ( # noqa: E402 check_vllm_health, check_vllm_process_health, launch_vllm_server, - terminate_vllm_process, set_vllm_process, + terminate_vllm_process, ) @@ -171,13 +180,13 @@ def train_legacy(config: TrainingConfig): model, tokenizer = load_model_and_tokenizer(config) optimizer = create_optimizer(model, config) - print("\n" + "="*60) + print("\n" + "=" * 60) print("LEGACY MODE (checkpoint + vLLM restart)") - print("="*60) + print("=" * 60) print(f"Training for {config.training_steps} steps on {config.device}") print(f"vLLM restart interval: every {config.vllm_restart_interval} steps") print(f"Save path: {config.save_path}") - print("="*60 + "\n") + print("=" * 60 + "\n") os.makedirs(config.save_path, exist_ok=True) @@ -206,24 +215,36 @@ def train_legacy(config: TrainingConfig): # Fetch data (with inference logprobs for proper GRPO) data_fetch_start = time.time() if len(batches) == 0: - batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url, - extract_inference_logprobs=True) + batches, _ = get_data( + config.batch_size, + config.seq_len, + config.atropos_url, + extract_inference_logprobs=True, + ) batch_data = batches.pop(0) - token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4] + token_batches, label_batches, advantage_batches, temperature_batches = ( + batch_data[:4] + ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) # Check if we should sync (save checkpoint + restart vLLM) - should_sync = (step + 1) % config.vllm_restart_interval == 0 or step == config.training_steps - 1 + should_sync = ( + step + 1 + ) % config.vllm_restart_interval == 0 or step == config.training_steps - 1 if should_sync: terminate_vllm_process() # Training step (with proper GRPO using inference logprobs) step_start = time.time() metrics = run_training_step( - model, optimizer, - token_batches, label_batches, advantage_batches, temperature_batches, + model, + optimizer, + token_batches, + label_batches, + advantage_batches, + temperature_batches, config, inference_logprob_batches=inference_logprob_batches, ) @@ -231,15 +252,21 @@ def train_legacy(config: TrainingConfig): benchmark_stats["step_times"].append(step_time) # GPU memory tracking - gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 - gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + gpu_mem_gb = ( + torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + gpu_mem_reserved_gb = ( + torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + ) benchmark_stats["gpu_memories"].append(gpu_mem_gb) # Sync (checkpoint + restart) sync_time = 0 if should_sync: sync_start = time.time() - checkpoint_path = save_checkpoint(model, tokenizer, config.save_path, step + 1) + checkpoint_path = save_checkpoint( + model, tokenizer, config.save_path, step + 1 + ) torch.cuda.empty_cache() vllm_proc = launch_vllm_server(config, checkpoint_path) set_vllm_process(vllm_proc) @@ -247,20 +274,31 @@ def train_legacy(config: TrainingConfig): benchmark_stats["sync_times"].append(sync_time) # Update metrics - metrics.update({ - "step_time": step_time, - "sync_time": sync_time, - "data_fetch_time": data_fetch_time, - "gpu_memory_gb": gpu_mem_gb, - "gpu_memory_reserved_gb": gpu_mem_reserved_gb, - }) + metrics.update( + { + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + } + ) log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) check_vllm_process_health() # === Cleanup === - save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True) - finalize_training(use_wandb, training_start_time, "legacy", config.training_steps, benchmark_stats, config.benchmark) + save_checkpoint( + model, tokenizer, config.save_path, config.training_steps, is_final=True + ) + finalize_training( + use_wandb, + training_start_time, + "legacy", + config.training_steps, + benchmark_stats, + config.benchmark, + ) def train_shared_vllm(config: TrainingConfig): @@ -281,13 +319,13 @@ def train_shared_vllm(config: TrainingConfig): # === Setup === use_wandb = setup_wandb(config) - print("\n" + "="*60) + print("\n" + "=" * 60) print("SINGLE-COPY MODE (CUDA IPC)") print(">>> Trainer uses vLLM's tensors directly!") - print("="*60) + print("=" * 60) print(f"Model: {config.model_name}") print(f"Save path: {config.save_path}") - print("="*60 + "\n") + print("=" * 60 + "\n") # Attach to vLLM's shared tensors print("[1/2] Attaching to vLLM's shared tensors...") @@ -331,11 +369,15 @@ def train_shared_vllm(config: TrainingConfig): data_fetch_start = time.time() if len(batches) == 0: batches, _ = get_data( - config.batch_size, config.seq_len, config.atropos_url, + config.batch_size, + config.seq_len, + config.atropos_url, extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs ) batch_data = batches.pop(0) - token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4] + token_batches, label_batches, advantage_batches, temperature_batches = ( + batch_data[:4] + ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -343,8 +385,12 @@ def train_shared_vllm(config: TrainingConfig): # Training step with proper GRPO (importance sampling + KL penalty) step_start = time.time() metrics = run_training_step( - model, optimizer, - token_batches, label_batches, advantage_batches, temperature_batches, + model, + optimizer, + token_batches, + label_batches, + advantage_batches, + temperature_batches, config, inference_logprob_batches=inference_logprob_batches, # Pass for GRPO ratio computation ) @@ -352,8 +398,12 @@ def train_shared_vllm(config: TrainingConfig): benchmark_stats["step_times"].append(step_time) # GPU memory tracking - gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 - gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + gpu_mem_gb = ( + torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + gpu_mem_reserved_gb = ( + torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + ) benchmark_stats["gpu_memories"].append(gpu_mem_gb) # In single-copy mode, weights are updated in-place (no sync needed!) @@ -362,23 +412,37 @@ def train_shared_vllm(config: TrainingConfig): benchmark_stats["sync_times"].append(sync_time) # Update metrics - metrics.update({ - "step_time": step_time, - "sync_time": sync_time, - "data_fetch_time": data_fetch_time, - "gpu_memory_gb": gpu_mem_gb, - "gpu_memory_reserved_gb": gpu_mem_reserved_gb, - }) + metrics.update( + { + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + } + ) log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) # Periodic checkpoint (for recovery, not for vLLM sync) - if config.checkpoint_interval > 0 and (step + 1) % config.checkpoint_interval == 0: + if ( + config.checkpoint_interval > 0 + and (step + 1) % config.checkpoint_interval == 0 + ): save_checkpoint(model, tokenizer, config.save_path, step + 1) # === Cleanup === - save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True) - finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps, benchmark_stats, config.benchmark) + save_checkpoint( + model, tokenizer, config.save_path, config.training_steps, is_final=True + ) + finalize_training( + use_wandb, + training_start_time, + "shared_vllm", + config.training_steps, + benchmark_stats, + config.benchmark, + ) def train_lora(config: TrainingConfig): @@ -399,29 +463,33 @@ def train_lora(config: TrainingConfig): - External vLLM server running with --enable-lora """ if not PEFT_AVAILABLE: - raise RuntimeError("PEFT library required for LoRA mode. Install with: pip install peft") + raise RuntimeError( + "PEFT library required for LoRA mode. Install with: pip install peft" + ) training_start_time = time.time() # === Setup === use_wandb = setup_wandb(config) - print("\n" + "="*60) + print("\n" + "=" * 60) print("LORA MODE (adapter-only training)") - print("="*60) + print("=" * 60) print(f"Base model: {config.model_name}") print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}") print(f"Save path: {config.save_path}") print(f"vLLM port: {config.vllm_port}") - print("="*60 + "\n") + print("=" * 60 + "\n") # Check external vLLM server print("[1/3] Checking external vLLM server...") if not check_vllm_health(config.vllm_port): print(f"\nERROR: vLLM server not running on port {config.vllm_port}") print("\nLoRA mode requires an external vLLM server. Start it first:") - print(f" python example_trainer/vllm_api_server.py --model {config.model_name} " - f"--port {config.vllm_port} --enable-lora --enforce-eager") + print( + f" python example_trainer/vllm_api_server.py --model {config.model_name} " + f"--port {config.vllm_port} --enable-lora --enforce-eager" + ) raise RuntimeError(f"External vLLM server required on port {config.vllm_port}") print(f"vLLM server healthy on port {config.vllm_port}") @@ -459,10 +527,16 @@ def train_lora(config: TrainingConfig): # Fetch data (with inference logprobs for proper GRPO) data_fetch_start = time.time() if len(batches) == 0: - batches, _ = get_data(config.batch_size, config.seq_len, config.atropos_url, - extract_inference_logprobs=True) + batches, _ = get_data( + config.batch_size, + config.seq_len, + config.atropos_url, + extract_inference_logprobs=True, + ) batch_data = batches.pop(0) - token_batches, label_batches, advantage_batches, temperature_batches = batch_data[:4] + token_batches, label_batches, advantage_batches, temperature_batches = ( + batch_data[:4] + ) inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time) @@ -470,8 +544,12 @@ def train_lora(config: TrainingConfig): # Training step with proper GRPO step_start = time.time() metrics = run_training_step( - model, optimizer, - token_batches, label_batches, advantage_batches, temperature_batches, + model, + optimizer, + token_batches, + label_batches, + advantage_batches, + temperature_batches, config, inference_logprob_batches=inference_logprob_batches, ) @@ -479,8 +557,12 @@ def train_lora(config: TrainingConfig): benchmark_stats["step_times"].append(step_time) # GPU memory tracking - gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 - gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + gpu_mem_gb = ( + torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + gpu_mem_reserved_gb = ( + torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + ) benchmark_stats["gpu_memories"].append(gpu_mem_gb) # Periodic adapter save + hot-swap @@ -494,24 +576,35 @@ def train_lora(config: TrainingConfig): benchmark_stats["sync_times"].append(sync_time) # Update metrics - metrics.update({ - "step_time": step_time, - "sync_time": sync_time, - "data_fetch_time": data_fetch_time, - "gpu_memory_gb": gpu_mem_gb, - "gpu_memory_reserved_gb": gpu_mem_reserved_gb, - }) + metrics.update( + { + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + } + ) log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) # === Cleanup === final_sync_start = time.time() - final_adapter_path = save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True) + final_adapter_path = save_lora_checkpoint( + model, config.save_path, config.training_steps, is_final=True + ) _hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final") final_sync_time = time.time() - final_sync_start benchmark_stats["sync_times"].append(final_sync_time) - finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps, benchmark_stats, config.benchmark) + finalize_training( + use_wandb, + training_start_time, + "lora_only", + config.training_steps, + benchmark_stats, + config.benchmark, + ) # Save tokenizer tokenizer_path = os.path.join(config.save_path, "tokenizer") @@ -563,4 +656,3 @@ def _hotswap_lora_adapter( except Exception as e: print(f" [LORA] ✗ Hot-swap request failed: {e}") return False - diff --git a/example_trainer/training.py b/example_trainer/training.py index 6dfe7f6a..f44cfe84 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -18,10 +18,10 @@ import wandb from .config import TrainingConfig - # Global storage for logprob alignment stats _logprob_alignment_stats: Dict[str, float] = {} + def setup_wandb(config: TrainingConfig) -> bool: """ Initialize Weights & Biases logging if enabled. @@ -80,12 +80,12 @@ def compute_grpo_loss( - Importance sampling ratio: policy(a|s) / policy_old(a|s) - PPO-style clipping to prevent large updates - KL penalty to prevent reward hacking/policy collapse - + The loss encourages the model to: - Increase probability for tokens with positive advantages - Decrease probability for tokens with negative advantages - Stay close to the reference policy (inference-time policy) - + Args: model: The model to compute loss for tokens: Input token IDs [batch, seq_len] @@ -105,7 +105,7 @@ def compute_grpo_loss( outputs = model(tokens) logits = outputs.logits - # Temperature scaling for training otherwise likely ratio is off + # Temperature scaling for training otherwise likely ratio is off t = temperatures.to(logits.device, logits.dtype) t = torch.where(t <= 0, torch.ones_like(t), t) scaled_logits = logits / t @@ -130,48 +130,56 @@ def compute_grpo_loss( logprob_diff_mean = 0.0 logprob_diff_abs_mean = 0.0 logprob_diff_max = 0.0 - + # === GRPO/PPO Loss Computation === if use_reference_logprobs and inference_logprobs is not None: # Move inference logprobs to correct device/dtype - ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype) + ref_logprobs = inference_logprobs.to( + logp_per_token.device, logp_per_token.dtype + ) # NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated with torch.no_grad(): # Only look at generated positions (where mask == 1) ref_at_generated = (ref_logprobs * mask).sum() / mask.sum() train_at_generated = (logp_per_token * mask).sum() / mask.sum() - + # Extract logprobs at generated positions for alignment tracking inference_logprobs_flat = ref_logprobs[mask.bool()].detach() training_at_mask = logp_per_token[mask.bool()].detach() - + # Token-level difference: THE key metric for alignment verification # If weights are truly shared, this should be ~0 at step start token_diff = training_at_mask - inference_logprobs_flat logprob_diff_mean = token_diff.mean().item() logprob_diff_abs_mean = token_diff.abs().mean().item() logprob_diff_max = token_diff.abs().max().item() - + # Check if ref logprobs are negative (as they should be for generated tokens) # If ref_at_generated is close to 1.0, that means the 1.0 placeholder is being used if ref_at_generated > 0.5: - print(f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)") - print(" [WARNING] This suggests inference_logprobs alignment is wrong") + print( + f" [WARNING] ref_logprobs avg {ref_at_generated:.3f} (should be negative!)" + ) + print( + " [WARNING] This suggests inference_logprobs alignment is wrong" + ) elif abs(ref_at_generated - train_at_generated) > 2.0: - print(f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}") - + print( + f" [DEBUG] Logprob gap: ref={ref_at_generated:.3f}, train={train_at_generated:.3f}" + ) + # Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old) log_ratio = logp_per_token - ref_logprobs ratio = torch.exp(log_ratio) - + # PPO-style clipping clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) - + # Surrogate objectives surr1 = ratio * adv_expanded surr2 = clipped_ratio * adv_expanded - + # Pessimistic bound: min for positive advantages, max for negative # This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0 # -max(ratio * A, clipped_ratio * A) when A < 0 @@ -180,10 +188,10 @@ def compute_grpo_loss( torch.min(surr1, surr2), torch.max(surr1, surr2), ) - + # Average over tokens, then over batch policy_loss = ((policy_loss_per_token * mask).sum(dim=-1) / mask_sum).mean() - + # KL penalty: encourage staying close to reference policy # Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4): # This estimator is guaranteed to be non-negative (unlike squared log-ratio). @@ -192,23 +200,27 @@ def compute_grpo_loss( # = exp(-log_ratio) + log_ratio - 1 kl_per_token = torch.exp(-log_ratio) + log_ratio - 1.0 kl_penalty = ((kl_per_token * mask).sum(dim=-1) / mask_sum).mean() - total_loss = (policy_loss + kl_coef * kl_penalty) / gradient_accumulation_steps + total_loss = ( + policy_loss + kl_coef * kl_penalty + ) / gradient_accumulation_steps else: kl_penalty = torch.tensor(0.0, device=logp_per_token.device) total_loss = policy_loss / gradient_accumulation_steps - + # Compute metrics for logging with torch.no_grad(): # Fraction of tokens where ratio was clipped - clipped_fraction = ((ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps)).float() + clipped_fraction = ( + (ratio < 1.0 - clip_eps) | (ratio > 1.0 + clip_eps) + ).float() clipped_fraction = (clipped_fraction * mask).sum() / mask.sum() - + # Mean ratio and KL for monitoring (using Schulman's estimator) mean_ratio = (ratio * mask).sum() / mask.sum() # Schulman KL: exp(-log_ratio) + log_ratio - 1 schulman_kl = torch.exp(-log_ratio) + log_ratio - 1.0 mean_kl = (schulman_kl * mask).sum() / mask.sum() - + # For backward compatibility: collect training logprobs raw_logp_per_token = -F.cross_entropy( outputs.logits.view(-1, outputs.logits.size(-1)), @@ -239,10 +251,10 @@ def compute_grpo_loss( avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum pos_logp = (logp_per_token * pos).mean().item() neg_logp = (logp_per_token * neg).mean().item() - + # Interpretable metric: advantage-weighted average logprob interpretable_loss = (avg_logp * advantages.squeeze()).mean().item() - + metrics = { "pos_logp": pos_logp, "neg_logp": neg_logp, @@ -256,7 +268,11 @@ def compute_grpo_loss( "kl_penalty": kl_penalty.item() if torch.is_tensor(kl_penalty) else kl_penalty, "mean_ratio": mean_ratio.item() if torch.is_tensor(mean_ratio) else mean_ratio, "mean_kl": mean_kl.item() if torch.is_tensor(mean_kl) else mean_kl, - "clipped_fraction": clipped_fraction.item() if torch.is_tensor(clipped_fraction) else clipped_fraction, + "clipped_fraction": ( + clipped_fraction.item() + if torch.is_tensor(clipped_fraction) + else clipped_fraction + ), # Token-level alignment metrics (key for verifying weight sharing) "logprob_diff_mean": logprob_diff_mean, "logprob_diff_abs_mean": logprob_diff_abs_mean, @@ -284,7 +300,7 @@ def run_training_step( 2. Backward pass with gradient accumulation 3. Gradient clipping 4. Optimizer step - + Args: model: The model to train optimizer: The optimizer @@ -315,23 +331,25 @@ def run_training_step( all_inference_logprobs: List[torch.Tensor] = [] # Get GRPO hyperparameters from config - kl_coef = getattr(config, 'kl_coef', 0.1) - clip_eps = getattr(config, 'clip_eps', 0.2) - use_reference_logprobs = getattr(config, 'use_reference_logprobs', True) + kl_coef = getattr(config, "kl_coef", 0.1) + clip_eps = getattr(config, "clip_eps", 0.2) + use_reference_logprobs = getattr(config, "use_reference_logprobs", True) # Accumulate gradients over micro-batches num_batches = len(token_batches) if token_batches else 1 - - for batch_idx, (tokens, labels, advantages, temperatures) in enumerate(zip( - token_batches, label_batches, advantage_batches, temperature_batches - )): + + for batch_idx, (tokens, labels, advantages, temperatures) in enumerate( + zip(token_batches, label_batches, advantage_batches, temperature_batches) + ): tokens = tokens.to(config.device) labels = labels.to(config.device) advantages = advantages.to(config.device) - + # Get corresponding inference logprobs batch if available inf_logprobs = None - if inference_logprob_batches is not None and batch_idx < len(inference_logprob_batches): + if inference_logprob_batches is not None and batch_idx < len( + inference_logprob_batches + ): inf_logprobs = inference_logprob_batches[batch_idx] loss, metrics = compute_grpo_loss( @@ -353,29 +371,34 @@ def run_training_step( total_neg_logp += metrics["neg_logp"] total_pos += metrics["pos_count"] total_neg += metrics["neg_count"] - + # Accumulate GRPO-specific metrics total_kl_penalty += metrics.get("kl_penalty", 0.0) total_mean_ratio += metrics.get("mean_ratio", 1.0) total_mean_kl += metrics.get("mean_kl", 0.0) total_clipped_fraction += metrics.get("clipped_fraction", 0.0) - + # Accumulate token-level alignment metrics total_logprob_diff_mean += metrics.get("logprob_diff_mean", 0.0) total_logprob_diff_abs_mean += metrics.get("logprob_diff_abs_mean", 0.0) - total_logprob_diff_max = max(total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0)) - + total_logprob_diff_max = max( + total_logprob_diff_max, metrics.get("logprob_diff_max", 0.0) + ) + # Collect logprobs for alignment monitoring if "training_logprobs" in metrics and metrics["training_logprobs"] is not None: all_training_logprobs.append(metrics["training_logprobs"]) - if "inference_logprobs" in metrics and metrics["inference_logprobs"] is not None: + if ( + "inference_logprobs" in metrics + and metrics["inference_logprobs"] is not None + ): all_inference_logprobs.append(metrics["inference_logprobs"]) # Gradient clipping and optimizer step grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() optimizer.zero_grad() - + # Help prevent memory fragmentation torch.cuda.empty_cache() @@ -387,7 +410,7 @@ def run_training_step( result = { "loss": total_loss, - "grad_norm": grad_norm.item() if hasattr(grad_norm, 'item') else grad_norm, + "grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm, "pos_logp": total_pos_logp, "neg_logp": total_neg_logp, "pos_count": total_pos, @@ -398,27 +421,33 @@ def run_training_step( "mean_kl": total_mean_kl / num_batches, "clipped_fraction": total_clipped_fraction / num_batches, } - + # Compute logprob alignment stats for monitoring # This proves weight sharing is working: inference & training logprobs should converge if all_training_logprobs: train_flat = torch.cat(all_training_logprobs) if train_flat.numel() > 0: - _logprob_alignment_stats["logprobs/training_mean"] = train_flat.mean().item() + _logprob_alignment_stats["logprobs/training_mean"] = ( + train_flat.mean().item() + ) _logprob_alignment_stats["logprobs/training_std"] = train_flat.std().item() - + if all_inference_logprobs: inf_flat = torch.cat(all_inference_logprobs) if inf_flat.numel() > 0: _logprob_alignment_stats["logprobs/inference_mean"] = inf_flat.mean().item() _logprob_alignment_stats["logprobs/inference_std"] = inf_flat.std().item() - + # Token-level alignment metrics - THE key metric for verifying weight sharing # diff_abs_mean close to 0 = weights are truly shared - _logprob_alignment_stats["alignment/diff_mean"] = total_logprob_diff_mean / num_batches - _logprob_alignment_stats["alignment/diff_abs_mean"] = total_logprob_diff_abs_mean / num_batches + _logprob_alignment_stats["alignment/diff_mean"] = ( + total_logprob_diff_mean / num_batches + ) + _logprob_alignment_stats["alignment/diff_abs_mean"] = ( + total_logprob_diff_abs_mean / num_batches + ) _logprob_alignment_stats["alignment/diff_max"] = total_logprob_diff_max - + return result @@ -464,7 +493,7 @@ def log_metrics( mean_ratio = metrics.get("mean_ratio", 1.0) mean_kl = metrics.get("mean_kl", 0) clipped_frac = metrics.get("clipped_fraction", 0) - + if kl_penalty > 0 or mean_kl > 0: print( f" GRPO: KL={mean_kl:.4f}, ratio={mean_ratio:.3f}, " @@ -495,15 +524,20 @@ def log_metrics( "grpo/clipped_fraction": clipped_frac, } # Add timing metrics if present - for key in ["step_time", "sync_time", "data_fetch_time", - "gpu_memory_gb", "gpu_memory_reserved_gb"]: + for key in [ + "step_time", + "sync_time", + "data_fetch_time", + "gpu_memory_gb", + "gpu_memory_reserved_gb", + ]: if key in metrics: log_dict[f"train/{key}"] = metrics[key] - - # Add logprob alignment stats + + # Add logprob alignment stats if _logprob_alignment_stats: log_dict.update(_logprob_alignment_stats) - + if extra_metrics: log_dict.update(extra_metrics) wandb.log(log_dict, step=step) @@ -549,7 +583,9 @@ def finalize_training( total_step_time = sum(step_times) avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0 total_sync_time = sum(sync_times) - avg_data_fetch = sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0 + avg_data_fetch = ( + sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0 + ) total_data_fetch = sum(data_fetch_times) avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0 @@ -557,13 +593,17 @@ def finalize_training( print(f"\n{'='*70}") print(f"BENCHMARK SUMMARY ({mode})") print(f"{'='*70}") - print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)") + print( + f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)" + ) print(f" Total steps: {total_steps}") print(" ") print(" TIMING BREAKDOWN:") print(f" Avg step time: {avg_step_time:.2f}s") print(f" Total step time: {total_step_time:.2f}s") - print(f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)") + print( + f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)" + ) print(f" Total sync time: {total_sync_time:.2f}s") print(f" Avg data fetch time: {avg_data_fetch:.2f}s") print(f" Total data fetch time: {total_data_fetch:.2f}s") @@ -584,4 +624,3 @@ def finalize_training( wandb.finish() elif use_wandb: wandb.finish() - diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index d7f6e2e6..2846f14f 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -53,7 +53,7 @@ os.environ.setdefault("VLLM_USE_V1", "0") # Set spawn method for multiprocessing (required for CUDA) os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") try: - multiprocessing.set_start_method('spawn', force=True) + multiprocessing.set_start_method("spawn", force=True) except RuntimeError: pass # Already set @@ -86,6 +86,7 @@ def _apply_patches_early() -> bool: try: import sys from pathlib import Path + # Add parent directory to path so we can import vllm_patching script_dir = Path(__file__).parent if str(script_dir) not in sys.path: @@ -106,6 +107,7 @@ def _apply_patches_early() -> bool: except Exception as e: print(f"[vLLM Server] Error applying patches: {e}") import traceback + traceback.print_exc() return False @@ -139,23 +141,26 @@ except ImportError: # Create a compatible ArgumentParser that handles 'deprecated' kwarg # (Python 3.10 doesn't support 'deprecated' in BooleanOptionalAction) import argparse - + class FlexibleArgumentParser(argparse.ArgumentParser): """ArgumentParser that strips unsupported kwargs for Python < 3.13.""" - + def add_argument(self, *args, **kwargs): # Remove 'deprecated' kwarg if present (not supported before Python 3.13) - kwargs.pop('deprecated', None) + kwargs.pop("deprecated", None) return super().add_argument(*args, **kwargs) + # set_ulimit might not exist in all vLLM versions try: from vllm.utils import set_ulimit except ImportError: + def set_ulimit() -> None: """No-op fallback for set_ulimit.""" pass + from vllm.outputs import RequestOutput # noqa: F401, E402 from vllm.version import __version__ as VLLM_VERSION # noqa: E402 @@ -602,7 +607,9 @@ async def lora_load(request: LoraLoadRequest) -> JSONResponse: ) # vLLM needs unique int ID bridge_state.lora_load_count += 1 - logger.info(f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})") + logger.info( + f"LoRA adapter loaded: {request.adapter_path} (id={bridge_state.active_lora_id})" + ) return JSONResponse( { diff --git a/example_trainer/vllm_manager.py b/example_trainer/vllm_manager.py index 3713660a..4088994c 100644 --- a/example_trainer/vllm_manager.py +++ b/example_trainer/vllm_manager.py @@ -17,7 +17,6 @@ import requests from .config import TrainingConfig - # Global variable to keep track of the vLLM process _vllm_process: Optional[subprocess.Popen] = None @@ -25,37 +24,34 @@ _vllm_process: Optional[subprocess.Popen] = None def is_port_in_use(port: int) -> bool: """Check if a port is already in use.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 + return s.connect_ex(("localhost", port)) == 0 def kill_process_on_port(port: int, timeout: float = 5.0) -> bool: """ Kill any process using the specified port. - + Returns True if no process was running or if it was successfully killed. """ if not is_port_in_use(port): return True - + print(f" Port {port} is in use, attempting to kill existing process...") - + try: # Try to find and kill the process using lsof (Linux/Mac) result = subprocess.run( - ["lsof", "-t", "-i", f":{port}"], - capture_output=True, - text=True, - timeout=5 + ["lsof", "-t", "-i", f":{port}"], capture_output=True, text=True, timeout=5 ) if result.stdout.strip(): - pids = result.stdout.strip().split('\n') + pids = result.stdout.strip().split("\n") for pid in pids: try: os.kill(int(pid), signal.SIGTERM) print(f" Sent SIGTERM to PID {pid}") except (ProcessLookupError, ValueError): pass - + # Wait for port to be free start = time.time() while time.time() - start < timeout: @@ -63,7 +59,7 @@ def kill_process_on_port(port: int, timeout: float = 5.0) -> bool: print(f" Port {port} is now free") return True time.sleep(0.5) - + # Force kill if still running for pid in pids: try: @@ -71,7 +67,7 @@ def kill_process_on_port(port: int, timeout: float = 5.0) -> bool: print(f" Sent SIGKILL to PID {pid}") except (ProcessLookupError, ValueError): pass - + time.sleep(1) return not is_port_in_use(port) except FileNotFoundError: @@ -84,7 +80,7 @@ def kill_process_on_port(port: int, timeout: float = 5.0) -> bool: pass except subprocess.TimeoutExpired: pass - + print(f" WARNING: Could not kill process on port {port}") return False @@ -135,7 +131,9 @@ def launch_vllm_server( if is_port_in_use(config.vllm_port): print(f" WARNING: Port {config.vllm_port} is already in use!") if not kill_process_on_port(config.vllm_port): - print(f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process.") + print( + f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process." + ) print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN") print(f" Or: pkill -f 'vllm.*{config.vllm_port}'") return None @@ -155,7 +153,7 @@ def launch_vllm_server( "--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization), ] - + # Add served-model-name if using checkpoint path if model_path != config.model_name: vllm_command.extend(["--served-model-name", config.model_name]) @@ -209,7 +207,9 @@ def check_vllm_process_health() -> None: global _vllm_process if _vllm_process is not None and _vllm_process.poll() is not None: - print(f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})") + print( + f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})" + ) _vllm_process = None @@ -299,7 +299,9 @@ def hotswap_lora_adapter( print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})") return True else: - print(f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}") + print( + f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}" + ) return False except requests.exceptions.ConnectionError: @@ -308,4 +310,3 @@ def hotswap_lora_adapter( except Exception as e: print(f" [LORA] ✗ Error during hot-swap: {e}") return False - diff --git a/example_trainer/vllm_patching/patched_gpu_runner.py b/example_trainer/vllm_patching/patched_gpu_runner.py index dd448393..cd3768da 100644 --- a/example_trainer/vllm_patching/patched_gpu_runner.py +++ b/example_trainer/vllm_patching/patched_gpu_runner.py @@ -29,61 +29,62 @@ _PATCHED_RUNNER_CLASS = None def _patch_lora_triton_for_blackwell() -> bool: """ Patch vLLM's LoRA Triton kernels to disable GDC (Grid Dependency Control). - + GDC is a Blackwell-specific feature that causes Triton compilation to fail on B200 GPUs. This patches the kernel_utils.py to disable GDC. - + Returns True if patch was applied successfully. """ try: import vllm + vllm_path = vllm.__path__[0] kernel_utils_path = f"{vllm_path}/lora/ops/triton_ops/kernel_utils.py" - + # Check if file exists if not os.path.exists(kernel_utils_path): print("[vLLM Patch] LoRA kernel_utils.py not found, skipping GDC patch") return False - - with open(kernel_utils_path, 'r') as f: + + with open(kernel_utils_path, "r") as f: content = f.read() - + # Check if already patched - if 'PATCHED FOR B200' in content: + if "PATCHED FOR B200" in content: print("[vLLM Patch] LoRA GDC already patched for B200") return True - + modified = False - + # Patch USE_GDC = True -> False - if 'USE_GDC = True' in content: + if "USE_GDC = True" in content: content = content.replace( - 'USE_GDC = True', - 'USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure' + "USE_GDC = True", + "USE_GDC = False # PATCHED FOR B200 - GDC causes Triton compilation failure", ) modified = True - + # Patch USE_GDC: tl.constexpr = True -> False - if 'USE_GDC: tl.constexpr = True' in content: + if "USE_GDC: tl.constexpr = True" in content: content = content.replace( - 'USE_GDC: tl.constexpr = True', - 'USE_GDC: tl.constexpr = False # PATCHED FOR B200' + "USE_GDC: tl.constexpr = True", + "USE_GDC: tl.constexpr = False # PATCHED FOR B200", ) modified = True - + # Patch the gdc_wait call itself - if 'tl.extra.cuda.gdc_wait()' in content: + if "tl.extra.cuda.gdc_wait()" in content: content = content.replace( - 'tl.extra.cuda.gdc_wait()', - 'pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled' + "tl.extra.cuda.gdc_wait()", + "pass # tl.extra.cuda.gdc_wait() PATCHED FOR B200 - disabled", ) modified = True - + if modified: - with open(kernel_utils_path, 'w') as f: + with open(kernel_utils_path, "w") as f: f.write(content) print(f"[vLLM Patch] ✓ Patched LoRA Triton GDC in {kernel_utils_path}") - + # Clear Triton cache to force recompilation triton_cache = os.path.expanduser("~/.triton/cache") if os.path.exists(triton_cache): @@ -92,12 +93,12 @@ def _patch_lora_triton_for_blackwell() -> bool: print("[vLLM Patch] ✓ Cleared Triton cache") except Exception as e: print(f"[vLLM Patch] Warning: Could not clear Triton cache: {e}") - + return True else: print("[vLLM Patch] No GDC patterns found to patch") return False - + except Exception as e: print(f"[vLLM Patch] Warning: Could not patch LoRA GDC: {e}") return False @@ -109,7 +110,7 @@ def apply_patches() -> bool: This must be called BEFORE importing vLLM's engine classes. Safe to call multiple times (idempotent). - + Also patches LoRA Triton kernels to disable GDC for B200 compatibility. Returns True if patches were applied successfully. @@ -129,7 +130,7 @@ def apply_patches() -> bool: if _PATCHES_APPLIED: return True - + # First, patch LoRA Triton for B200 compatibility _patch_lora_triton_for_blackwell()