feat(skyrl-shm): Universal SHM transport and Raw State Injection

This commit is contained in:
RUFFY-369 2026-04-06 12:23:32 +05:30
parent 210883f0da
commit 43a5cdcdfc
3 changed files with 170 additions and 25 deletions

View file

@ -25,10 +25,19 @@ def mock_env_worker(worker_id: int, shm_name: str, barrier: mp.Barrier, stop_eve
count = 0
while not stop_event.is_set() and count < (TOTAL_TRAJECTORIES // NUM_ENV_WORKERS):
# Simulate REAL Reasoning model trace (4k tokens)
tokens = [100 + i for i in range(4096)]
tokens = [100 + i for i in range(ENTRY_SIZE)]
score = 0.8 + (worker_id * 0.05)
instance_id = f"task_{count}"
repetition_id = worker_id
metadata = {"worker": worker_id, "timestamp": time.time()}
success = shm.write_trajectory(tokens=tokens, score=score)
success = shm.write_trajectory(
tokens=tokens,
score=score,
instance_id=instance_id,
repetition_id=repetition_id,
metadata=metadata
)
if success:
count += 1
else:
@ -43,7 +52,7 @@ def run_e2e_benchmark():
1. Setup SHM
2. Launch Concurrency Workers
3. Measure Reader Throughput
4. Compare with HTTP Baseline Simulation
4. Verify Data Integrity (IDs and Metadata)
"""
shm_name = f"test_e2e_shm_{uuid.uuid4().hex[:8]}"
shm = ZeroCopySHMBuffer(name=shm_name, size=BATCH_SIZE * 2, entry_size=ENTRY_SIZE, create=True)
@ -62,12 +71,19 @@ def run_e2e_benchmark():
barrier.wait() # Start the race
# Throughput Benchmark (SHM)
print("📈 Measuring SHM Throughput...")
print("📈 Measuring SHM Throughput & Integrity...")
start_shm = time.perf_counter()
received = 0
verification_passed = True
while received < TOTAL_TRAJECTORIES:
data = shm.read_next()
if data:
# Verify Metadata Integrity for a sample
if received % 100 == 0:
if not (data["instance_id"].startswith("task_") and "worker" in data["metadata"]):
print(f"❌ Integrity Check Failed at index {received}!")
verification_passed = False
received += 1
else:
if all(not p.is_alive() for p in workers) and received < TOTAL_TRAJECTORIES:
@ -77,14 +93,21 @@ def run_e2e_benchmark():
shm_time = end_shm - start_shm
shm_tps = TOTAL_TRAJECTORIES / shm_time
print(f" [SHM] Received {received} trajectories in {shm_time:.4f}s ({shm_tps:.2f} traj/s)")
print(f" [SHM] Integrity Verification: {'✅ PASSED' if verification_passed else '❌ FAILED'}")
# HTTP Baseline Simulation
print("📉 Measuring HTTP Baseline Simulation (JSON Tax)...")
start_http = time.perf_counter()
for _ in range(TOTAL_TRAJECTORIES):
# Simulate JSON Serialization + Dummy HTTP Request
tokens = [100 + i for i in range(10)]
payload = json.dumps({"tokens": tokens, "score": 0.8})
tokens = [100 + i for i in range(ENTRY_SIZE)]
payload = json.dumps({
"tokens": tokens,
"score": 0.8,
"instance_id": "task_x",
"repetition_id": 0,
"metadata": {"foo": "bar"}
})
_ = json.loads(payload) # Deserialization
end_http = time.perf_counter()
@ -98,11 +121,13 @@ def run_e2e_benchmark():
print("="*40)
print(f"SHM Throughput Gain: {shm_tps / http_tps:.2f}x")
print(f"Concurrency Load: {NUM_ENV_WORKERS} workers handled without corruption.")
print(f"Data Integrity: {'Verified' if verification_passed else 'CORRUPT'}")
print("="*40)
stop_event.set()
for p in workers: p.join()
shm.close(unlink=True)
if __name__ == "__main__":
run_e2e_benchmark()