atropos/atroposlib/tests/test_skyrl_shm_e2e.py

133 lines
4.5 KiB
Python

import multiprocessing as mp
import time
import json
import uuid
import requests
import numpy as np
import struct
from typing import List, Dict, Any
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
# Configuration for Mocks
BATCH_SIZE = 128
ENTRY_SIZE = 4096
NUM_ENV_WORKERS = 4
TOTAL_TRAJECTORIES = 500
def mock_env_worker(worker_id: int, shm_name: str, barrier: mp.Barrier, stop_event: mp.Event):
"""
Simulates a SkyRL Environment process pushing trajectories to SHM.
"""
try:
shm = ZeroCopySHMBuffer(name=shm_name, create=False)
barrier.wait() # Synced start
count = 0
while not stop_event.is_set() and count < (TOTAL_TRAJECTORIES // NUM_ENV_WORKERS):
# Simulate REAL Reasoning model trace (4k tokens)
tokens = [100 + i for i in range(ENTRY_SIZE)]
score = 0.8 + (worker_id * 0.05)
instance_id = f"task_{count}"
repetition_id = worker_id
metadata = {"worker": worker_id, "timestamp": time.time()}
success = shm.write_trajectory(
tokens=tokens,
score=score,
instance_id=instance_id,
repetition_id=repetition_id,
metadata=metadata
)
if success:
count += 1
else:
time.sleep(0.001) # Buffer full, backoff
except Exception as e:
print(f"Worker {worker_id} Error: {e}")
def run_e2e_benchmark():
"""
Main E2E logic:
1. Setup SHM
2. Launch Concurrency Workers
3. Measure Reader Throughput
4. Verify Data Integrity (IDs and Metadata)
"""
shm_name = f"test_e2e_shm_{uuid.uuid4().hex[:8]}"
shm = ZeroCopySHMBuffer(name=shm_name, size=BATCH_SIZE * 2, entry_size=ENTRY_SIZE, create=True)
barrier = mp.Barrier(NUM_ENV_WORKERS + 1)
stop_event = mp.Event()
# Concurrency Test
print(f"🚀 Starting {NUM_ENV_WORKERS} Environment Workers (Concurrency Test)...")
workers = []
for i in range(NUM_ENV_WORKERS):
p = mp.Process(target=mock_env_worker, args=(i, shm_name, barrier, stop_event))
p.start()
workers.append(p)
barrier.wait() # Start the race
# Throughput Benchmark (SHM)
print("📈 Measuring SHM Throughput & Integrity...")
start_shm = time.perf_counter()
received = 0
verification_passed = True
while received < TOTAL_TRAJECTORIES:
data = shm.read_next()
if data:
# Verify Metadata Integrity for a sample
if received % 100 == 0:
if not (data["instance_id"].startswith("task_") and "worker" in data["metadata"]):
print(f"❌ Integrity Check Failed at index {received}!")
verification_passed = False
received += 1
else:
if all(not p.is_alive() for p in workers) and received < TOTAL_TRAJECTORIES:
break # All workers died
end_shm = time.perf_counter()
shm_time = end_shm - start_shm
shm_tps = TOTAL_TRAJECTORIES / shm_time
print(f" [SHM] Received {received} trajectories in {shm_time:.4f}s ({shm_tps:.2f} traj/s)")
print(f" [SHM] Integrity Verification: {'✅ PASSED' if verification_passed else '❌ FAILED'}")
# HTTP Baseline Simulation
print("📉 Measuring HTTP Baseline Simulation (JSON Tax)...")
start_http = time.perf_counter()
for _ in range(TOTAL_TRAJECTORIES):
# Simulate JSON Serialization + Dummy HTTP Request
tokens = [100 + i for i in range(ENTRY_SIZE)]
payload = json.dumps({
"tokens": tokens,
"score": 0.8,
"instance_id": "task_x",
"repetition_id": 0,
"metadata": {"foo": "bar"}
})
_ = json.loads(payload) # Deserialization
end_http = time.perf_counter()
http_time = end_http - start_http
http_tps = TOTAL_TRAJECTORIES / http_time
print(f" [HTTP] Processed {TOTAL_TRAJECTORIES} trajectories in {http_time:.4f}s ({http_tps:.2f} traj/s)")
# --- RESULTS ---
print("\n" + "="*40)
print("🏆 E2E TEST RESULTS")
print("="*40)
print(f"SHM Throughput Gain: {shm_tps / http_tps:.2f}x")
print(f"Concurrency Load: {NUM_ENV_WORKERS} workers handled without corruption.")
print(f"Data Integrity: {'Verified' if verification_passed else 'CORRUPT'}")
print("="*40)
stop_event.set()
for p in workers: p.join()
shm.close(unlink=True)
if __name__ == "__main__":
run_e2e_benchmark()