diff --git a/README.md b/README.md index ddb2bac7..b842d5ab 100644 --- a/README.md +++ b/README.md @@ -294,10 +294,10 @@ Distillation is configured in `BaseEnvConfig` and available via CLI under `--env Both setups are supported: -- **Self-distillation (same model family for teacher and student)** +- **Self-distillation (same model family for teacher and student)** Point `teacher_base_url` to a server running the same model (or equivalent checkpoint family) as the student. This is the most stable setup for token-level alignment. -- **Cross-model distillation (different teacher and student models)** +- **Cross-model distillation (different teacher and student models)** Also supported, but tokenization compatibility becomes more important. If token vocabularies/template behavior differ significantly, alignment quality may degrade. In practice, self-distillation is usually easiest to bring up first, then cross-model can be layered in once your pipeline is stable. diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 551eb5b0..96800d66 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -360,47 +360,51 @@ class BaseEnv(ABC): ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: """ Fetch top-K logprobs from teacher model for given sequences. - + Supports any OpenAI-compatible API (vLLM, OpenAI, Together, etc.). - + Args: token_sequences: List of token ID sequences to get logprobs for messages_list: Optional list of message histories (for chat APIs). If provided, uses chat/completions with logprobs. top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k) - + Returns: Tuple of (distill_token_ids, distill_logprobs), both shaped as: [batch][position][top_k]. Returns ([], []) if teacher_base_url is not configured. """ - logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences") + logger.info( + f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences" + ) logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}") - + if not self.config.teacher_base_url: logger.warning("[TEACHER] No teacher_base_url configured, returning empty") return [], [] - + if top_k is None: top_k = self.config.teacher_top_k - + # Get API key from config or environment api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "") model_name = self.config.teacher_model_name or "default" - + logger.info(f"[TEACHER] Using model={model_name}, top_k={top_k}") - + headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" - + token_id_results: List[List[List[int]]] = [] logprob_results: List[List[List[float]]] = [] - + try: async with aiohttp.ClientSession() as session: for i, tokens in enumerate(token_sequences): - logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens") + logger.info( + f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens" + ) # Decode original sequence and optionally prepend teacher steering text. base_text = self.tokenizer.decode(tokens, skip_special_tokens=False) steering_prefix = "" @@ -413,11 +417,15 @@ class BaseEnv(ABC): steering_prefix += self.config.teacher_prefix_text full_text = steering_prefix + base_text prefix_token_len = ( - len(self.tokenizer.encode(steering_prefix, add_special_tokens=False)) + len( + self.tokenizer.encode( + steering_prefix, add_special_tokens=False + ) + ) if steering_prefix else 0 ) - + # Try vLLM-style completions first (supports prompt_logprobs) # This is most efficient as it doesn't generate new tokens request_data = { @@ -428,7 +436,7 @@ class BaseEnv(ABC): "logprobs": top_k, "echo": True, # Include prompt in response with logprobs } - + try: async with session.post( f"{self.config.teacher_base_url}/completions", @@ -438,22 +446,24 @@ class BaseEnv(ABC): ) as response: if response.status == 200: data = await response.json() - seq_token_ids, seq_logprobs = self._parse_completion_logprobs( - data, top_k + seq_token_ids, seq_logprobs = ( + self._parse_completion_logprobs(data, top_k) ) if seq_token_ids and seq_logprobs: - aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( - seq_token_ids, - seq_logprobs, - target_token_len=len(tokens), - prefix_token_len=prefix_token_len, + aligned_ids, aligned_lps = ( + self._align_teacher_topk_to_tokens( + seq_token_ids, + seq_logprobs, + target_token_len=len(tokens), + prefix_token_len=prefix_token_len, + ) ) token_id_results.append(aligned_ids) logprob_results.append(aligned_lps) continue except Exception: pass # Fall through to chat completions - + # Fallback: Use chat/completions with logprobs (OpenAI style) # This requires messages format if messages_list and i < len(messages_list): @@ -476,7 +486,7 @@ class BaseEnv(ABC): } ) messages.append({"role": "user", "content": full_text}) - + chat_request = { "model": model_name, "messages": messages, @@ -485,7 +495,7 @@ class BaseEnv(ABC): "logprobs": True, "top_logprobs": top_k, } - + try: async with session.post( f"{self.config.teacher_base_url}/chat/completions", @@ -501,11 +511,13 @@ class BaseEnv(ABC): # Chat fallback logprobs are for generated tokens, not prompt tokens. # To keep alignment correct for distillation, return empty per-position rows. if seq_token_ids and len(seq_token_ids) >= len(tokens): - aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( - seq_token_ids, - seq_logprobs, - target_token_len=len(tokens), - prefix_token_len=0, + aligned_ids, aligned_lps = ( + self._align_teacher_topk_to_tokens( + seq_token_ids, + seq_logprobs, + target_token_len=len(tokens), + prefix_token_len=0, + ) ) else: aligned_ids = [[] for _ in range(len(tokens))] @@ -513,16 +525,20 @@ class BaseEnv(ABC): token_id_results.append(aligned_ids) logprob_results.append(aligned_lps) else: - logger.warning(f"Teacher API returned {response.status}") - token_id_results.append([[] for _ in range(len(tokens))]) + logger.warning( + f"Teacher API returned {response.status}" + ) + token_id_results.append( + [[] for _ in range(len(tokens))] + ) logprob_results.append([[] for _ in range(len(tokens))]) except Exception as e: logger.warning(f"Teacher chat request failed: {e}") token_id_results.append([[] for _ in range(len(tokens))]) logprob_results.append([[] for _ in range(len(tokens))]) - + return token_id_results, logprob_results - + except Exception as e: logger.error(f"Error fetching teacher logprobs: {e}") return [], [] @@ -556,7 +572,7 @@ class BaseEnv(ABC): aligned_lps.extend([[] for _ in range(pad_count)]) return aligned_ids, aligned_lps - + def _parse_completion_logprobs( self, data: Dict, top_k: int ) -> Tuple[List[List[int]], List[List[float]]]: @@ -564,10 +580,10 @@ class BaseEnv(ABC): try: choice = data.get("choices", [{}])[0] logprobs_data = choice.get("logprobs", {}) - + # vLLM returns top_logprobs as list of dicts top_logprobs = logprobs_data.get("top_logprobs", []) - + if not top_logprobs: return [], [] @@ -580,15 +596,15 @@ class BaseEnv(ABC): elif isinstance(pos_logprobs, dict): # Format: {token_str: logprob, ...} sorted_items = sorted( - pos_logprobs.items(), - key=lambda x: x[1], - reverse=True + pos_logprobs.items(), key=lambda x: x[1], reverse=True )[:top_k] pos_ids: List[int] = [] pos_lps: List[float] = [] for token_str, logprob in sorted_items: # Convert token string to ID - token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) + token_ids = self.tokenizer.encode( + token_str, add_special_tokens=False + ) if token_ids: pos_ids.append(int(token_ids[0])) pos_lps.append(float(logprob)) @@ -602,7 +618,7 @@ class BaseEnv(ABC): except Exception as e: logger.warning(f"Error parsing completion logprobs: {e}") return [], [] - + def _parse_chat_logprobs( self, data: Dict, top_k: int ) -> Tuple[List[List[int]], List[List[float]]]: @@ -610,14 +626,14 @@ class BaseEnv(ABC): try: choice = data.get("choices", [{}])[0] logprobs_data = choice.get("logprobs", {}) - + if not logprobs_data: return [], [] - + content = logprobs_data.get("content", []) seq_token_ids: List[List[int]] = [] seq_logprobs: List[List[float]] = [] - + for token_data in content: top_logprobs = token_data.get("top_logprobs", []) pos_ids: List[int] = [] @@ -626,7 +642,9 @@ class BaseEnv(ABC): token_str = item.get("token", "") logprob = item.get("logprob", 0.0) # Convert token string to ID - token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) + token_ids = self.tokenizer.encode( + token_str, add_special_tokens=False + ) if token_ids: pos_ids.append(int(token_ids[0])) pos_lps.append(float(logprob)) @@ -1251,7 +1269,9 @@ class BaseEnv(ABC): if valid_groups and do_send_to_api: # On-policy distillation: fetch teacher logprobs if enabled if self.config.distillation_enabled and self.config.teacher_base_url: - logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups") + logger.info( + f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups" + ) for group in valid_groups: has_new_format = ( group.get("distill_token_ids") is not None @@ -1259,9 +1279,11 @@ class BaseEnv(ABC): ) if not has_new_format: try: - teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs( - token_sequences=group["tokens"], - messages_list=group.get("messages"), + teacher_token_ids, teacher_logprobs = ( + await self.get_teacher_logprobs( + token_sequences=group["tokens"], + messages_list=group.get("messages"), + ) ) if teacher_token_ids and teacher_logprobs: group["distill_token_ids"] = teacher_token_ids @@ -1270,10 +1292,15 @@ class BaseEnv(ABC): f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences" ) else: - logger.warning("[DISTILL] get_teacher_logprobs returned empty") + logger.warning( + "[DISTILL] get_teacher_logprobs returned empty" + ) except Exception as e: - logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}") + logger.error( + f"[DISTILL] Failed to fetch teacher logprobs: {e}" + ) import traceback + logger.error(traceback.format_exc()) else: logger.debug( @@ -1788,13 +1815,13 @@ class BaseEnv(ABC): cli_passed_flags, openai_full_prefix ) # CLI args yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) - + # Debug logging for CLI args print(f"[CLI DEBUG] cli_passed_flags = {cli_passed_flags}") print(f"[CLI DEBUG] openai_full_prefix = {openai_full_prefix}") print(f"[CLI DEBUG] oai_cli_passed_args = {oai_cli_passed_args}") print(f"[CLI DEBUG] yaml_oai_config = {yaml_oai_config}") - + # Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided # This allows any environment to use --openai.* CLI args without modifying config_init # Use a new variable to avoid UnboundLocalError from closure scoping @@ -1808,7 +1835,7 @@ class BaseEnv(ABC): logger.info( "Auto-converted ServerBaseline to APIServerConfig for CLI/YAML overrides" ) - + if ( isinstance(effective_server_configs, list) and len(effective_server_configs) == 1 @@ -1822,13 +1849,17 @@ class BaseEnv(ABC): if isinstance(default_openai_config_, APIServerConfig) and isinstance( yaml_oai_config, dict ): - print(f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}") + print( + f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}" + ) openai_config_dict = merge_dicts( default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) yaml_oai_config, oai_cli_passed_args, ) - print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}") + print( + f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}" + ) else: print( "[CLI DEBUG] Not merging: default_openai_config_ " diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index f84558c4..4f8b90f2 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -165,7 +165,9 @@ def resolve_openai_configs( """ from atroposlib.envs.server_handling.server_manager import ServerBaseline - print(f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}") + print( + f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}" + ) print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}") openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" @@ -216,7 +218,9 @@ def resolve_openai_configs( elif isinstance(default_server_configs, APIServerConfig): # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline print("[RESOLVE DEBUG] Taking APIServerConfig merged path") - logger.info("Using single OpenAI server configuration based on merged settings (default/YAML/CLI).") + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index d3e7d2ea..b06e4e2f 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -193,7 +193,9 @@ class VLLMServer(APIServer): debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1" if debug_requests: base = self.config.base_url.replace("/v1", "") - prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n") + prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace( + "\n", "\\n" + ) print( f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate " f"prompt_token_len={len(prompt_tokens)}", @@ -211,7 +213,7 @@ class VLLMServer(APIServer): ) print( f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate " - '-H "Content-Type: application/json" -d \'\'', + "-H \"Content-Type: application/json\" -d ''", flush=True, ) diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 0fa53431..615a3a36 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -289,7 +289,9 @@ class GSM8kEnv(BaseEnv): } ) to_postprocess = await self.score(to_score) - accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", [])) + accepted = ( + 0 if to_postprocess is None else len(to_postprocess.get("tokens", [])) + ) print( f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}", flush=True, diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index ee1346ab..48e995b9 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -12,7 +12,6 @@ from typing import Dict, List, Optional, Tuple, Union import wandb from datasets import load_dataset - from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify from math_verify.errors import TimeoutException @@ -146,7 +145,7 @@ class MathEnv(BaseEnv): wandb_name = os.environ.get("MATH_ENV_WANDB_NAME", "math-zero-env") max_token_length = int(os.environ.get("MATH_ENV_MAX_TOKENS", "32000")) worker_timeout = float(os.environ.get("MATH_ENV_WORKER_TIMEOUT", "1500")) - + env_config = RSConfig( tokenizer_name=model_name, group_size=8, @@ -524,7 +523,7 @@ class MathEnv(BaseEnv): and (not scores["overrides"][i].get("set_advantage_to_zero", False)) ] ) - + return scores async def get_next_item(self): diff --git a/example_trainer/README.md b/example_trainer/README.md index 9be1cf66..a04f0cfd 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -194,7 +194,7 @@ python -m example_trainer.grpo --weight-bridge-mode lora_only ... --- -## Shared vLLM Mode +## Shared vLLM Mode Single-copy mode shares GPU memory between vLLM and the trainer - zero model duplication! diff --git a/example_trainer/config.py b/example_trainer/config.py index 4a020643..b30473c2 100644 --- a/example_trainer/config.py +++ b/example_trainer/config.py @@ -136,15 +136,17 @@ class TrainingConfig(BaseModel): wandb_group: Optional[str] = Field(None, description="Wandb group name") # === Training Mode Configuration === - weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = Field( - "none", - description=( - "How to synchronize weights with inference server. " - "'shared_vllm': attach to vLLM's shared memory tensors and update in-place. " - "'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). " - "'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). " - "'none': legacy mode, restart vLLM with new checkpoint files." - ), + weight_bridge_mode: Literal["shared_vllm", "lora_only", "lora_restart", "none"] = ( + Field( + "none", + description=( + "How to synchronize weights with inference server. " + "'shared_vllm': attach to vLLM's shared memory tensors and update in-place. " + "'lora_only': keep base model frozen, train/swap LoRA adapters via HTTP (slow, needs --enforce-eager). " + "'lora_restart': LoRA training with vLLM restarts (fast, CUDA graphs enabled). " + "'none': legacy mode, restart vLLM with new checkpoint files." + ), + ) ) # === Distributed Training Configuration === diff --git a/example_trainer/data.py b/example_trainer/data.py index 2834552e..129a4d2b 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -258,10 +258,14 @@ def pad_data_to_good_offset( else None ) final_distill_token_id_batches = ( - distill_token_id_batches if (has_any_distill and distill_token_id_batches) else None + distill_token_id_batches + if (has_any_distill and distill_token_id_batches) + else None ) final_distill_logprob_batches = ( - distill_logprob_batches if (has_any_distill and distill_logprob_batches) else None + distill_logprob_batches + if (has_any_distill and distill_logprob_batches) + else None ) return ( diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index 0641ee16..dcf96634 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -735,19 +735,23 @@ def train_lora_restart(config: TrainingConfig): # Periodic adapter save + vLLM restart sync_time = 0 should_sync = (step + 1) % config.vllm_restart_interval == 0 - if should_sync and (step + 1) < config.training_steps: # Don't restart on last step + if ( + should_sync and (step + 1) < config.training_steps + ): # Don't restart on last step sync_start = time.time() - + # Save new adapter - current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1) - + current_adapter_path = save_lora_checkpoint( + model, config.save_path, step + 1 + ) + # Restart vLLM with new adapter print(" [RESTART] Restarting vLLM with new adapter...") _terminate_vllm(vllm_proc, config.vllm_port) vllm_proc = _launch_vllm_with_lora(config, current_adapter_path) if vllm_proc is None: raise RuntimeError("Failed to restart vLLM") - + sync_time = time.time() - sync_start benchmark_stats["sync_times"].append(sync_time) benchmark_stats["restart_times"].append(sync_time) @@ -798,45 +802,53 @@ def train_lora_restart(config: TrainingConfig): _vllm_restart_counter = 0 -def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]: +def _launch_vllm_with_lora( + config: TrainingConfig, adapter_path: str +) -> Optional[subprocess.Popen]: """ Launch vLLM with a LoRA adapter (no --enforce-eager for faster inference). - + Unlike lora_only mode, this does NOT use --enforce-eager, so we get ~108 TPS instead of ~13 TPS (8x faster). """ global _vllm_restart_counter from .vllm_manager import kill_process_on_port, wait_for_vllm_ready - + # Kill any existing process on the port print(f" Cleaning up port {config.vllm_port}...") kill_process_on_port(config.vllm_port) - + # Clear CUDA cache before starting new vLLM if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() - + # Wait for port and GPU memory to be fully released time.sleep(5) - + # Find the vllm_api_server.py script script_dir = os.path.dirname(os.path.abspath(__file__)) server_script = os.path.join(script_dir, "vllm_api_server.py") - + # Build command - NO --enforce-eager for faster inference (~108 TPS vs ~13 TPS) cmd = [ - sys.executable, server_script, - "--model", config.model_name, - "--port", str(config.vllm_port), - "--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization), - "--max-model-len", str(config.max_model_len), + sys.executable, + server_script, + "--model", + config.model_name, + "--port", + str(config.vllm_port), + "--gpu-memory-utilization", + str(config.vllm_gpu_memory_utilization), + "--max-model-len", + str(config.max_model_len), "--enable-lora", - "--max-lora-rank", str(max(config.lora_r * 2, 32)), + "--max-lora-rank", + str(max(config.lora_r * 2, 32)), # Note: NOT adding --enforce-eager - this gives us ~8x faster inference! # Without --enforce-eager, vLLM can use more optimizations. ] - + # Set environment for GPU selection env = os.environ.copy() if config.vllm_gpu is not None: @@ -844,32 +856,39 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona print(f" GPU: {config.vllm_gpu} (via CUDA_VISIBLE_DEVICES)") else: print(" GPU: Same as trainer (inherited CUDA_VISIBLE_DEVICES)") - + print(f" Launching: {' '.join(cmd)}") print(f" Adapter: {adapter_path}") - + # Log vLLM output to file for debugging (unique file per restart) - vllm_log_path = os.path.join(config.save_path, f"vllm_restart_{_vllm_restart_counter}.log") + vllm_log_path = os.path.join( + config.save_path, f"vllm_restart_{_vllm_restart_counter}.log" + ) _vllm_restart_counter += 1 print(f" vLLM log: {vllm_log_path}") - + try: vllm_log_file = open(vllm_log_path, "w") # Start in new session so we can kill entire process group later proc = subprocess.Popen( - cmd, env=env, stdout=vllm_log_file, stderr=subprocess.STDOUT, - start_new_session=True # Creates new process group for easy cleanup + cmd, + env=env, + stdout=vllm_log_file, + stderr=subprocess.STDOUT, + start_new_session=True, # Creates new process group for easy cleanup ) print(f" vLLM PID: {proc.pid} (process group: {os.getpgid(proc.pid)})") - print(" NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)...") - + print( + " NOTE: vLLM without --enforce-eager compiles CUDA graphs on startup (takes 1-3 min)..." + ) + # Wait for server to be ready (longer timeout for CUDA graph compilation) if not wait_for_vllm_ready(config.vllm_port, timeout=300): print(" ERROR: vLLM failed to start after 300s") print(f" Check log: {vllm_log_path}") # Print last 30 lines of the log try: - with open(vllm_log_path, 'r') as f: + with open(vllm_log_path, "r") as f: lines = f.readlines() print(" Last 30 lines of vLLM log:") for line in lines[-30:]: @@ -878,7 +897,7 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona print(f" Could not read log: {e}") proc.terminate() return None - + # Load the LoRA adapter print(" Loading LoRA adapter...") try: @@ -890,13 +909,15 @@ def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optiona if resp.status_code == 200: print(" ✓ Adapter loaded successfully") else: - print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}") + print( + f" WARNING: Adapter load returned {resp.status_code}: {resp.text}" + ) except Exception as e: print(f" WARNING: Could not load adapter: {e}") # Continue anyway - base model inference still works - + return proc - + except Exception as e: print(f" ERROR: {e}") return None @@ -906,12 +927,12 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: """Terminate a vLLM process and release GPU resources.""" import signal import subprocess as sp - + print(f" Terminating vLLM on port {port}...") - + # Get current GPU device gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0] - + # Phase 1: Kill the process group if we have a handle (kills all children too) main_pid = None if proc is not None: @@ -927,12 +948,13 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: proc.wait(timeout=5) except Exception as e: print(f" Warning: {e}") - + # Phase 2: Kill by port (catches anything still running) from .vllm_manager import kill_process_on_port + kill_process_on_port(port) time.sleep(2) - + # Phase 3: Aggressively kill ALL vLLM-related processes print(" Killing all vLLM-related processes...") kill_commands = [ @@ -948,19 +970,22 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: sp.run(cmd, shell=True, capture_output=True, timeout=5) except Exception: pass - + # Phase 4: Use nvidia-smi to find and kill GPU processes (nuclear option) print(f" Checking for zombie GPU processes on GPU {gpu_id}...") try: result = sp.run( f"nvidia-smi --query-compute-apps=pid,used_memory --format=csv,noheader,nounits -i {gpu_id}", - shell=True, capture_output=True, text=True, timeout=10 + shell=True, + capture_output=True, + text=True, + timeout=10, ) if result.stdout.strip(): print(f" Found GPU processes:\n{result.stdout}") - for line in result.stdout.strip().split('\n'): + for line in result.stdout.strip().split("\n"): if line.strip(): - parts = line.split(',') + parts = line.split(",") if len(parts) >= 1: pid = parts[0].strip() # Don't kill the current Python process (trainer) @@ -972,7 +997,7 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: pass except Exception as e: print(f" Warning: nvidia-smi check failed: {e}") - + # Phase 5: Wait for GPU memory release - CRITICAL # The CUDA driver needs time to actually free memory after process death print(" Waiting for GPU memory release...") @@ -982,24 +1007,26 @@ def _terminate_vllm(proc: Optional[subprocess.Popen], port: int = 9001) -> None: torch.cuda.empty_cache() free_mem = torch.cuda.mem_get_info()[0] / 1e9 total_mem = torch.cuda.mem_get_info()[1] / 1e9 - print(f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)") + print( + f" [{(i+1)*5}s] GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)" + ) # If we have enough memory (>50% free), break early if free_mem > total_mem * 0.5: print(f" ✓ Sufficient memory available ({free_mem:.1f} GB)") break - + # Final cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() free_mem = torch.cuda.mem_get_info()[0] / 1e9 total_mem = torch.cuda.mem_get_info()[1] / 1e9 - print(f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)") - + print( + f" Final GPU memory: {free_mem:.1f}/{total_mem:.1f} GB free ({100*free_mem/total_mem:.0f}%)" + ) + if free_mem < total_mem * 0.3: print(" WARNING: Low GPU memory! May fail to restart vLLM.") print(" Consider reducing --vllm-gpu-memory-utilization") - + print(" vLLM terminated") - - diff --git a/example_trainer/training.py b/example_trainer/training.py index 098ba173..0e1fd1f0 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -63,9 +63,10 @@ def compute_distillation_loss( continue ids_tensor = torch.tensor(pos_ids, device=logits.device, dtype=torch.long) - teacher_lps = torch.tensor( - pos_lps, device=logits.device, dtype=logits.dtype - ) / temperature + teacher_lps = ( + torch.tensor(pos_lps, device=logits.device, dtype=logits.dtype) + / temperature + ) student_log_probs = F.log_softmax(logits[b, t] / temperature, dim=-1) student_subset = student_log_probs[ids_tensor] @@ -75,7 +76,9 @@ def compute_distillation_loss( token_loss = -(teacher_probs * student_subset).sum() else: teacher_log_probs = F.log_softmax(teacher_lps, dim=-1) - token_loss = (teacher_probs * (teacher_log_probs - student_subset)).sum() + token_loss = ( + teacher_probs * (teacher_log_probs - student_subset) + ).sum() total = total + token_loss count = count + 1.0 @@ -323,7 +326,11 @@ def compute_grpo_loss( interpretable_loss = (avg_logp * advantages.squeeze()).mean().item() distill_loss_val = 0.0 - if distillation_enabled and distill_token_ids is not None and distill_logprobs is not None: + if ( + distillation_enabled + and distill_token_ids is not None + and distill_logprobs is not None + ): distill_loss = compute_distillation_loss( logits=scaled_logits, mask=mask, @@ -332,7 +339,10 @@ def compute_grpo_loss( temperature=distillation_temperature, loss_type=distillation_loss_type, ) - total_loss = total_loss + (distillation_coef * distill_loss) / gradient_accumulation_steps + total_loss = ( + total_loss + + (distillation_coef * distill_loss) / gradient_accumulation_steps + ) distill_loss_val = distill_loss.item() metrics = { @@ -437,9 +447,13 @@ def run_training_step( inf_logprobs = inference_logprob_batches[batch_idx] distill_ids = None distill_lps = None - if distill_token_id_batches is not None and batch_idx < len(distill_token_id_batches): + if distill_token_id_batches is not None and batch_idx < len( + distill_token_id_batches + ): distill_ids = distill_token_id_batches[batch_idx] - if distill_logprob_batches is not None and batch_idx < len(distill_logprob_batches): + if distill_logprob_batches is not None and batch_idx < len( + distill_logprob_batches + ): distill_lps = distill_logprob_batches[batch_idx] loss, metrics = compute_grpo_loss(