[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-02-06 06:46:14 +00:00 committed by Jai Suphavadeeprasit
parent d07ab3e3ce
commit 5cfd1929f1
19 changed files with 708 additions and 452 deletions

View file

@ -33,7 +33,7 @@ def wait_for_vllm(port: int, timeout: int = 300) -> bool:
"""Wait for vLLM server to be ready."""
print(f"[Run] Waiting for vLLM server on port {port}...")
start = time.time()
while time.time() - start < timeout:
try:
response = requests.get(f"http://localhost:{port}/health", timeout=5)
@ -44,9 +44,9 @@ def wait_for_vllm(port: int, timeout: int = 300) -> bool:
pass
except Exception as e:
print(f"[Run] Health check error: {e}")
time.sleep(2)
print(f"[Run] ✗ vLLM server failed to start within {timeout}s")
return False
@ -55,20 +55,23 @@ def wait_for_bridge_config(config_path: str, timeout: int = 60) -> bool:
"""Wait for vLLM bridge config to be created."""
print(f"[Run] Waiting for bridge config at {config_path}...")
start = time.time()
while time.time() - start < timeout:
if os.path.exists(config_path):
try:
import json
with open(config_path, 'r') as f:
with open(config_path, "r") as f:
config = json.load(f)
if config.get('ipc_handles') and len(config['ipc_handles']) > 0:
print(f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles")
if config.get("ipc_handles") and len(config["ipc_handles"]) > 0:
print(
f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles"
)
return True
except Exception:
pass
time.sleep(1)
print(f"[Run] ✗ Bridge config not created within {timeout}s")
return False
@ -77,44 +80,44 @@ def main():
# Parse args using shared CLI module
parser = create_unified_parser()
args = parser.parse_args()
# Create log directory
log_dir = getattr(args, 'log_dir', './logs')
log_dir = getattr(args, "log_dir", "./logs")
os.makedirs(log_dir, exist_ok=True)
# Bridge config path
bridge_config_path = "./vllm_bridge_config.json"
# Clean up old bridge config
if os.path.exists(bridge_config_path):
os.remove(bridge_config_path)
print("[Run] Removed old bridge config")
# === Print Configuration ===
print("\n" + "="*60)
print("\n" + "=" * 60)
print("STARTING UNIFIED GRPO TRAINER (shared_vllm mode)")
print("="*60)
print("=" * 60)
print(f"Model: {args.model_name}")
print(f"vLLM port: {args.vllm_port}")
print(f"GPU memory utilization: {args.gpu_memory_utilization}")
print(f"Training steps: {args.training_steps}")
print(f"Optimizer: {args.optimizer}")
print(f"GRPO: kl_coef={args.kl_coef}, clip_eps={args.clip_eps}")
print("="*60 + "\n")
print("=" * 60 + "\n")
# Get the path to vllm_api_server.py
script_dir = Path(__file__).parent
vllm_server_script = script_dir / "vllm_api_server.py"
if not vllm_server_script.exists():
print(f"[Run] ✗ vLLM server script not found at {vllm_server_script}")
sys.exit(1)
# Extract device index from args.device
device_index = "0"
if ":" in args.device:
device_index = args.device.split(":")[1]
# Build vLLM environment
vllm_env = os.environ.copy()
vllm_env["VLLM_ENABLE_SHARED_WEIGHTS"] = "1"
@ -123,21 +126,28 @@ def main():
vllm_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
vllm_env["VLLM_USE_V1"] = "0" # v0 engine required for shared weights patches
vllm_env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" # Required for CUDA
# Build vLLM command
vllm_cmd = [
sys.executable, "-u", str(vllm_server_script),
"--model", args.model_name,
"--port", str(args.vllm_port),
"--dtype", args.dtype,
"--gpu-memory-utilization", str(args.gpu_memory_utilization),
"--max-model-len", str(args.max_model_len),
sys.executable,
"-u",
str(vllm_server_script),
"--model",
args.model_name,
"--port",
str(args.vllm_port),
"--dtype",
args.dtype,
"--gpu-memory-utilization",
str(args.gpu_memory_utilization),
"--max-model-len",
str(args.max_model_len),
"--enforce-eager", # Required for shared weights
]
vllm_log_path = os.path.join(log_dir, "vllm.log")
print(f"[Run] Starting vLLM server (log: {vllm_log_path})...")
vllm_log = open(vllm_log_path, "w")
vllm_process = subprocess.Popen(
vllm_cmd,
@ -145,7 +155,7 @@ def main():
stdout=vllm_log,
stderr=subprocess.STDOUT,
)
# Register cleanup
def cleanup():
print("\n[Run] Cleaning up...")
@ -158,24 +168,24 @@ def main():
vllm_process.kill()
vllm_log.close()
print("[Run] Cleanup complete.")
atexit.register(cleanup)
signal.signal(signal.SIGINT, lambda s, f: sys.exit(0))
signal.signal(signal.SIGTERM, lambda s, f: sys.exit(0))
# Wait for vLLM to be ready
if not wait_for_vllm(args.vllm_port, timeout=500):
print("[Run] ✗ vLLM server failed to start. Check logs at:", vllm_log_path)
sys.exit(1)
# Wait for bridge config
if not wait_for_bridge_config(bridge_config_path, timeout=60):
print("[Run] ✗ Bridge config not created. Check vLLM logs.")
sys.exit(1)
# === Start Trainer ===
print("\n[Run] Starting GRPO trainer...")
# Build config - override some fields for shared_vllm mode
config = TrainingConfig(
model_name=args.model_name,
@ -205,13 +215,14 @@ def main():
benchmark=True, # Always show timing info for run.py
debug_loading=getattr(args, "debug_loading", False),
)
try:
train_shared_vllm(config)
print("\n[Run] ✓ Training completed successfully!")
except Exception as e:
print(f"\n[Run] ✗ Training failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)