mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
d07ab3e3ce
commit
5cfd1929f1
19 changed files with 708 additions and 452 deletions
|
|
@ -33,7 +33,7 @@ def wait_for_vllm(port: int, timeout: int = 300) -> bool:
|
|||
"""Wait for vLLM server to be ready."""
|
||||
print(f"[Run] Waiting for vLLM server on port {port}...")
|
||||
start = time.time()
|
||||
|
||||
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
response = requests.get(f"http://localhost:{port}/health", timeout=5)
|
||||
|
|
@ -44,9 +44,9 @@ def wait_for_vllm(port: int, timeout: int = 300) -> bool:
|
|||
pass
|
||||
except Exception as e:
|
||||
print(f"[Run] Health check error: {e}")
|
||||
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
print(f"[Run] ✗ vLLM server failed to start within {timeout}s")
|
||||
return False
|
||||
|
||||
|
|
@ -55,20 +55,23 @@ def wait_for_bridge_config(config_path: str, timeout: int = 60) -> bool:
|
|||
"""Wait for vLLM bridge config to be created."""
|
||||
print(f"[Run] Waiting for bridge config at {config_path}...")
|
||||
start = time.time()
|
||||
|
||||
|
||||
while time.time() - start < timeout:
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
import json
|
||||
with open(config_path, 'r') as f:
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
if config.get('ipc_handles') and len(config['ipc_handles']) > 0:
|
||||
print(f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles")
|
||||
if config.get("ipc_handles") and len(config["ipc_handles"]) > 0:
|
||||
print(
|
||||
f"[Run] ✓ Bridge config ready with {len(config['ipc_handles'])} IPC handles"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
print(f"[Run] ✗ Bridge config not created within {timeout}s")
|
||||
return False
|
||||
|
||||
|
|
@ -77,44 +80,44 @@ def main():
|
|||
# Parse args using shared CLI module
|
||||
parser = create_unified_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Create log directory
|
||||
log_dir = getattr(args, 'log_dir', './logs')
|
||||
log_dir = getattr(args, "log_dir", "./logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
|
||||
# Bridge config path
|
||||
bridge_config_path = "./vllm_bridge_config.json"
|
||||
|
||||
|
||||
# Clean up old bridge config
|
||||
if os.path.exists(bridge_config_path):
|
||||
os.remove(bridge_config_path)
|
||||
print("[Run] Removed old bridge config")
|
||||
|
||||
|
||||
# === Print Configuration ===
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("STARTING UNIFIED GRPO TRAINER (shared_vllm mode)")
|
||||
print("="*60)
|
||||
print("=" * 60)
|
||||
print(f"Model: {args.model_name}")
|
||||
print(f"vLLM port: {args.vllm_port}")
|
||||
print(f"GPU memory utilization: {args.gpu_memory_utilization}")
|
||||
print(f"Training steps: {args.training_steps}")
|
||||
print(f"Optimizer: {args.optimizer}")
|
||||
print(f"GRPO: kl_coef={args.kl_coef}, clip_eps={args.clip_eps}")
|
||||
print("="*60 + "\n")
|
||||
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Get the path to vllm_api_server.py
|
||||
script_dir = Path(__file__).parent
|
||||
vllm_server_script = script_dir / "vllm_api_server.py"
|
||||
|
||||
|
||||
if not vllm_server_script.exists():
|
||||
print(f"[Run] ✗ vLLM server script not found at {vllm_server_script}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Extract device index from args.device
|
||||
device_index = "0"
|
||||
if ":" in args.device:
|
||||
device_index = args.device.split(":")[1]
|
||||
|
||||
|
||||
# Build vLLM environment
|
||||
vllm_env = os.environ.copy()
|
||||
vllm_env["VLLM_ENABLE_SHARED_WEIGHTS"] = "1"
|
||||
|
|
@ -123,21 +126,28 @@ def main():
|
|||
vllm_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
vllm_env["VLLM_USE_V1"] = "0" # v0 engine required for shared weights patches
|
||||
vllm_env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" # Required for CUDA
|
||||
|
||||
|
||||
# Build vLLM command
|
||||
vllm_cmd = [
|
||||
sys.executable, "-u", str(vllm_server_script),
|
||||
"--model", args.model_name,
|
||||
"--port", str(args.vllm_port),
|
||||
"--dtype", args.dtype,
|
||||
"--gpu-memory-utilization", str(args.gpu_memory_utilization),
|
||||
"--max-model-len", str(args.max_model_len),
|
||||
sys.executable,
|
||||
"-u",
|
||||
str(vllm_server_script),
|
||||
"--model",
|
||||
args.model_name,
|
||||
"--port",
|
||||
str(args.vllm_port),
|
||||
"--dtype",
|
||||
args.dtype,
|
||||
"--gpu-memory-utilization",
|
||||
str(args.gpu_memory_utilization),
|
||||
"--max-model-len",
|
||||
str(args.max_model_len),
|
||||
"--enforce-eager", # Required for shared weights
|
||||
]
|
||||
|
||||
|
||||
vllm_log_path = os.path.join(log_dir, "vllm.log")
|
||||
print(f"[Run] Starting vLLM server (log: {vllm_log_path})...")
|
||||
|
||||
|
||||
vllm_log = open(vllm_log_path, "w")
|
||||
vllm_process = subprocess.Popen(
|
||||
vllm_cmd,
|
||||
|
|
@ -145,7 +155,7 @@ def main():
|
|||
stdout=vllm_log,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
|
||||
# Register cleanup
|
||||
def cleanup():
|
||||
print("\n[Run] Cleaning up...")
|
||||
|
|
@ -158,24 +168,24 @@ def main():
|
|||
vllm_process.kill()
|
||||
vllm_log.close()
|
||||
print("[Run] Cleanup complete.")
|
||||
|
||||
|
||||
atexit.register(cleanup)
|
||||
signal.signal(signal.SIGINT, lambda s, f: sys.exit(0))
|
||||
signal.signal(signal.SIGTERM, lambda s, f: sys.exit(0))
|
||||
|
||||
|
||||
# Wait for vLLM to be ready
|
||||
if not wait_for_vllm(args.vllm_port, timeout=500):
|
||||
print("[Run] ✗ vLLM server failed to start. Check logs at:", vllm_log_path)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Wait for bridge config
|
||||
if not wait_for_bridge_config(bridge_config_path, timeout=60):
|
||||
print("[Run] ✗ Bridge config not created. Check vLLM logs.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# === Start Trainer ===
|
||||
print("\n[Run] Starting GRPO trainer...")
|
||||
|
||||
|
||||
# Build config - override some fields for shared_vllm mode
|
||||
config = TrainingConfig(
|
||||
model_name=args.model_name,
|
||||
|
|
@ -205,13 +215,14 @@ def main():
|
|||
benchmark=True, # Always show timing info for run.py
|
||||
debug_loading=getattr(args, "debug_loading", False),
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
train_shared_vllm(config)
|
||||
print("\n[Run] ✓ Training completed successfully!")
|
||||
except Exception as e:
|
||||
print(f"\n[Run] ✗ Training failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue