testing 5

This commit is contained in:
Jai Suphavadeeprasit 2026-02-02 13:24:24 -05:00
parent 50fcef7041
commit c00e7461aa

View file

@ -244,9 +244,9 @@ def run_model_test(
result["start_time"] = datetime.now().isoformat()
start_time = time.time()
vllm_process = None
env_process = None
run_api_process = None
trainer_process = None
# Get atropos root directory (used for vLLM and gsm8k scripts)
script_dir = Path(__file__).parent
@ -283,86 +283,35 @@ def run_model_test(
time.sleep(10) # Give it time to initialize and connect
print(f"[{model_name}] ✓ gsm8k environment started")
# === Start vLLM Server ===
vllm_server_script = script_dir / "vllm_api_server.py"
# === Start Unified vLLM + Trainer (run.py) ===
# Using run.py ensures vLLM is a CHILD of the trainer process,
# which is required for CUDA IPC with ptrace_scope=1
run_script = script_dir / "run.py"
vllm_env = os.environ.copy()
vllm_env["VLLM_ENABLE_SHARED_WEIGHTS"] = "1"
vllm_env["VLLM_BRIDGE_CONFIG_PATH"] = str(bridge_config_path)
vllm_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
vllm_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
run_env = os.environ.copy()
run_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
run_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
vllm_cmd = [
sys.executable, "-u", str(vllm_server_script),
run_cmd = [
sys.executable, "-u", str(run_script),
"--model", model_config.model_id,
"--port", str(vllm_port),
"--dtype", model_config.dtype,
"--vllm-port", str(vllm_port),
"--gpu-memory-utilization", str(model_config.gpu_memory_utilization),
"--max-model-len", str(model_config.max_model_len),
"--enforce-eager",
]
print(f"[{model_name}] Starting vLLM server...")
with open(vllm_log, "w") as vlog:
vllm_process = subprocess.Popen(
vllm_cmd,
env=vllm_env,
stdout=vlog,
stderr=subprocess.STDOUT,
)
# Wait for vLLM to be ready
import requests
vllm_ready = False
for i in range(120): # 2 minute timeout
try:
resp = requests.get(f"http://localhost:{vllm_port}/health", timeout=5)
if resp.status_code == 200:
vllm_ready = True
print(f"[{model_name}] ✓ vLLM server ready")
break
except:
pass
time.sleep(2)
if not vllm_ready:
raise RuntimeError("vLLM server failed to start")
# Wait for bridge config
for i in range(30):
if bridge_config_path.exists():
with open(bridge_config_path) as f:
cfg = json.load(f)
if cfg.get("ipc_handles"):
print(f"[{model_name}] ✓ Bridge config ready")
break
time.sleep(1)
else:
raise RuntimeError("Bridge config not created")
# === Run Trainer ===
trainer_env = os.environ.copy()
trainer_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
trainer_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
trainer_cmd = [
sys.executable, "-u", "-m", "example_trainer.grpo",
"--model-name", model_config.model_id,
"--weight-bridge-mode", "shared_vllm",
"--vllm-port", str(vllm_port),
"--vllm-config-path", str(bridge_config_path),
"--dtype", model_config.dtype,
"--atropos-url", atropos_url,
"--training-steps", str(training_steps),
"--optimizer", "adamw_8bit",
"--save-path", str(checkpoint_dir),
"--checkpoint-interval", "5",
"--log-dir", str(log_dir),
]
print(f"[{model_name}] Starting trainer for {training_steps} steps...")
print(f"[{model_name}] Starting unified trainer (vLLM + GRPO) for {training_steps} steps...")
with open(trainer_log, "w") as tlog:
trainer_process = subprocess.Popen(
trainer_cmd,
env=trainer_env,
run_cmd,
env=run_env,
stdout=tlog,
stderr=subprocess.STDOUT,
cwd=str(atropos_root), # Run from atropos root
@ -370,7 +319,7 @@ def run_model_test(
trainer_process.wait()
if trainer_process.returncode != 0:
raise RuntimeError(f"Trainer exited with code {trainer_process.returncode}")
raise RuntimeError(f"Unified trainer exited with code {trainer_process.returncode}")
result["status"] = "success"
print(f"[{model_name}] ✓ Training completed successfully!")
@ -403,14 +352,7 @@ def run_model_test(
traceback.print_exc()
finally:
# Cleanup vLLM
if vllm_process and vllm_process.poll() is None:
print(f"[{model_name}] Terminating vLLM server...")
vllm_process.terminate()
try:
vllm_process.wait(timeout=10)
except subprocess.TimeoutExpired:
vllm_process.kill()
# Note: vLLM is managed by run.py and cleaned up automatically
# Cleanup gsm8k environment
if env_process and env_process.poll() is None: