diff --git a/example_trainer/test_multi_model.py b/example_trainer/test_multi_model.py index a5b98ade..ee02d0d3 100644 --- a/example_trainer/test_multi_model.py +++ b/example_trainer/test_multi_model.py @@ -244,9 +244,9 @@ def run_model_test( result["start_time"] = datetime.now().isoformat() start_time = time.time() - vllm_process = None env_process = None run_api_process = None + trainer_process = None # Get atropos root directory (used for vLLM and gsm8k scripts) script_dir = Path(__file__).parent @@ -283,86 +283,35 @@ def run_model_test( time.sleep(10) # Give it time to initialize and connect print(f"[{model_name}] ✓ gsm8k environment started") - # === Start vLLM Server === - vllm_server_script = script_dir / "vllm_api_server.py" + # === Start Unified vLLM + Trainer (run.py) === + # Using run.py ensures vLLM is a CHILD of the trainer process, + # which is required for CUDA IPC with ptrace_scope=1 + run_script = script_dir / "run.py" - vllm_env = os.environ.copy() - vllm_env["VLLM_ENABLE_SHARED_WEIGHTS"] = "1" - vllm_env["VLLM_BRIDGE_CONFIG_PATH"] = str(bridge_config_path) - vllm_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - vllm_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + run_env = os.environ.copy() + run_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + run_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - vllm_cmd = [ - sys.executable, "-u", str(vllm_server_script), + run_cmd = [ + sys.executable, "-u", str(run_script), "--model", model_config.model_id, - "--port", str(vllm_port), - "--dtype", model_config.dtype, + "--vllm-port", str(vllm_port), "--gpu-memory-utilization", str(model_config.gpu_memory_utilization), "--max-model-len", str(model_config.max_model_len), - "--enforce-eager", - ] - - print(f"[{model_name}] Starting vLLM server...") - with open(vllm_log, "w") as vlog: - vllm_process = subprocess.Popen( - vllm_cmd, - env=vllm_env, - stdout=vlog, - stderr=subprocess.STDOUT, - ) - - # Wait for vLLM to be ready - import requests - vllm_ready = False - for i in range(120): # 2 minute timeout - try: - resp = requests.get(f"http://localhost:{vllm_port}/health", timeout=5) - if resp.status_code == 200: - vllm_ready = True - print(f"[{model_name}] ✓ vLLM server ready") - break - except: - pass - time.sleep(2) - - if not vllm_ready: - raise RuntimeError("vLLM server failed to start") - - # Wait for bridge config - for i in range(30): - if bridge_config_path.exists(): - with open(bridge_config_path) as f: - cfg = json.load(f) - if cfg.get("ipc_handles"): - print(f"[{model_name}] ✓ Bridge config ready") - break - time.sleep(1) - else: - raise RuntimeError("Bridge config not created") - - # === Run Trainer === - trainer_env = os.environ.copy() - trainer_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - trainer_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - - trainer_cmd = [ - sys.executable, "-u", "-m", "example_trainer.grpo", - "--model-name", model_config.model_id, - "--weight-bridge-mode", "shared_vllm", - "--vllm-port", str(vllm_port), - "--vllm-config-path", str(bridge_config_path), + "--dtype", model_config.dtype, "--atropos-url", atropos_url, "--training-steps", str(training_steps), "--optimizer", "adamw_8bit", "--save-path", str(checkpoint_dir), "--checkpoint-interval", "5", + "--log-dir", str(log_dir), ] - print(f"[{model_name}] Starting trainer for {training_steps} steps...") + print(f"[{model_name}] Starting unified trainer (vLLM + GRPO) for {training_steps} steps...") with open(trainer_log, "w") as tlog: trainer_process = subprocess.Popen( - trainer_cmd, - env=trainer_env, + run_cmd, + env=run_env, stdout=tlog, stderr=subprocess.STDOUT, cwd=str(atropos_root), # Run from atropos root @@ -370,7 +319,7 @@ def run_model_test( trainer_process.wait() if trainer_process.returncode != 0: - raise RuntimeError(f"Trainer exited with code {trainer_process.returncode}") + raise RuntimeError(f"Unified trainer exited with code {trainer_process.returncode}") result["status"] = "success" print(f"[{model_name}] ✓ Training completed successfully!") @@ -403,14 +352,7 @@ def run_model_test( traceback.print_exc() finally: - # Cleanup vLLM - if vllm_process and vllm_process.poll() is None: - print(f"[{model_name}] Terminating vLLM server...") - vllm_process.terminate() - try: - vllm_process.wait(timeout=10) - except subprocess.TimeoutExpired: - vllm_process.kill() + # Note: vLLM is managed by run.py and cleaned up automatically # Cleanup gsm8k environment if env_process and env_process.poll() is None: