import multiprocessing as mp import time import json import uuid import requests import numpy as np import struct from typing import List, Dict, Any from atroposlib.api.shm_buffer import ZeroCopySHMBuffer # Configuration for Mocks BATCH_SIZE = 128 ENTRY_SIZE = 4096 NUM_ENV_WORKERS = 4 TOTAL_TRAJECTORIES = 500 def mock_env_worker(worker_id: int, shm_name: str, barrier: mp.Barrier, stop_event: mp.Event): """ Simulates a SkyRL Environment process pushing trajectories to SHM. """ try: shm = ZeroCopySHMBuffer(name=shm_name, create=False) barrier.wait() # Synced start count = 0 while not stop_event.is_set() and count < (TOTAL_TRAJECTORIES // NUM_ENV_WORKERS): # Simulate REAL Reasoning model trace (4k tokens) tokens = [100 + i for i in range(ENTRY_SIZE)] score = 0.8 + (worker_id * 0.05) instance_id = f"task_{count}" repetition_id = worker_id metadata = {"worker": worker_id, "timestamp": time.time()} success = shm.write_trajectory( tokens=tokens, score=score, instance_id=instance_id, repetition_id=repetition_id, metadata=metadata ) if success: count += 1 else: time.sleep(0.001) # Buffer full, backoff except Exception as e: print(f"Worker {worker_id} Error: {e}") def run_e2e_benchmark(): """ Main E2E logic: 1. Setup SHM 2. Launch Concurrency Workers 3. Measure Reader Throughput 4. Verify Data Integrity (IDs and Metadata) """ shm_name = f"test_e2e_shm_{uuid.uuid4().hex[:8]}" shm = ZeroCopySHMBuffer(name=shm_name, size=BATCH_SIZE * 2, entry_size=ENTRY_SIZE, create=True) barrier = mp.Barrier(NUM_ENV_WORKERS + 1) stop_event = mp.Event() # Concurrency Test print(f"🚀 Starting {NUM_ENV_WORKERS} Environment Workers (Concurrency Test)...") workers = [] for i in range(NUM_ENV_WORKERS): p = mp.Process(target=mock_env_worker, args=(i, shm_name, barrier, stop_event)) p.start() workers.append(p) barrier.wait() # Start the race # Throughput Benchmark (SHM) print("📈 Measuring SHM Throughput & Integrity...") start_shm = time.perf_counter() received = 0 verification_passed = True while received < TOTAL_TRAJECTORIES: data = shm.read_next() if data: # Verify Metadata Integrity for a sample if received % 100 == 0: if not (data["instance_id"].startswith("task_") and "worker" in data["metadata"]): print(f"❌ Integrity Check Failed at index {received}!") verification_passed = False received += 1 else: if all(not p.is_alive() for p in workers) and received < TOTAL_TRAJECTORIES: break # All workers died end_shm = time.perf_counter() shm_time = end_shm - start_shm shm_tps = TOTAL_TRAJECTORIES / shm_time print(f" [SHM] Received {received} trajectories in {shm_time:.4f}s ({shm_tps:.2f} traj/s)") print(f" [SHM] Integrity Verification: {'✅ PASSED' if verification_passed else '❌ FAILED'}") # HTTP Baseline Simulation print("📉 Measuring HTTP Baseline Simulation (JSON Tax)...") start_http = time.perf_counter() for _ in range(TOTAL_TRAJECTORIES): # Simulate JSON Serialization + Dummy HTTP Request tokens = [100 + i for i in range(ENTRY_SIZE)] payload = json.dumps({ "tokens": tokens, "score": 0.8, "instance_id": "task_x", "repetition_id": 0, "metadata": {"foo": "bar"} }) _ = json.loads(payload) # Deserialization end_http = time.perf_counter() http_time = end_http - start_http http_tps = TOTAL_TRAJECTORIES / http_time print(f" [HTTP] Processed {TOTAL_TRAJECTORIES} trajectories in {http_time:.4f}s ({http_tps:.2f} traj/s)") # --- RESULTS --- print("\n" + "="*40) print("🏆 E2E TEST RESULTS") print("="*40) print(f"SHM Throughput Gain: {shm_tps / http_tps:.2f}x") print(f"Concurrency Load: {NUM_ENV_WORKERS} workers handled without corruption.") print(f"Data Integrity: {'Verified' if verification_passed else 'CORRUPT'}") print("="*40) stop_event.set() for p in workers: p.join() shm.close(unlink=True) if __name__ == "__main__": run_e2e_benchmark()