[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-04-06 20:55:42 +00:00
parent 48dcd64299
commit 51b4aa4858
6 changed files with 147 additions and 99 deletions

View file

@ -213,13 +213,13 @@ def _process_scored_data(scored_data: ScoredData) -> Dict[str, Any]:
if actual_group_size != expected_group_size:
buffer = app.state.buffer.setdefault(env_id, [])
buffer.append(data_dict)
if hasattr(app.state, "shm_buffer") and app.state.shm_buffer:
for i in range(len(scored_data.tokens)):
app.state.shm_buffer.write_trajectory(
tokens=scored_data.tokens[i],
score=scored_data.scores[i],
metadata={"env_id": env_id}
metadata={"env_id": env_id},
)
app.state.queue.append(data_dict)
@ -266,7 +266,7 @@ async def register(registration: Registration):
app.state.requesters = []
app.state.requesters.append(uuid.uuid4().int)
# Pin-hole SHM initialization
shm_name = f"atropos_shm_{app.state.group}"
try:
@ -274,7 +274,7 @@ async def register(registration: Registration):
name=shm_name,
size=app.state.batchsize * 10,
entry_size=app.state.max_token_len,
create=True
create=True,
)
except Exception as e:
logger.error(f"SHM Buffer Init Failed: {e}")
@ -282,7 +282,7 @@ async def register(registration: Registration):
return {
"uuid": app.state.requesters[-1],
"shm_handle": shm_name if app.state.shm_buffer else None
"shm_handle": shm_name if app.state.shm_buffer else None,
}

View file

@ -17,6 +17,7 @@ class SHMBufferConfig:
Control block for Shared Memory Buffer.
Stored at the beginning of the SHM segment.
"""
# [Magic (4B) | Version (4B) | ReadIdx (4B) | WriteIdx (4B) | MaxSize (4B) | EntrySize (4B)]
FORMAT = "4sIIIII"
SIZE = struct.calcsize(FORMAT)
@ -44,15 +45,13 @@ class ZeroCopySHMBuffer:
self.entry_size = entry_size
self.instance_id_len = instance_id_len
self.metadata_len = metadata_len
# Schema: [Score (8) | Len (4) | InstanceID (id_len) | RepID (4) | Meta (meta_len) | Tokens (Size*4)]
self.slot_size = (
8 + 4 + instance_id_len + 4 + metadata_len + (entry_size * 4)
)
self.slot_size = 8 + 4 + instance_id_len + 4 + metadata_len + (entry_size * 4)
# Total size = Control Block + Data Segment
self.total_size = SHMBufferConfig.SIZE + (size * self.slot_size)
try:
if create:
# Remove existing if any (OS-level cleanup)
@ -61,11 +60,15 @@ class ZeroCopySHMBuffer:
shm.unlink()
except FileNotFoundError:
pass
self.shm = shared_memory.SharedMemory(name=name, create=True, size=self.total_size)
self.shm = shared_memory.SharedMemory(
name=name, create=True, size=self.total_size
)
self.buf = self.shm.buf
self._init_control_block()
logger.info(f"Created SHM buffer '{name}' with size {self.total_size} bytes")
logger.info(
f"Created SHM buffer '{name}' with size {self.total_size} bytes"
)
else:
self.shm = shared_memory.SharedMemory(name=name)
self.buf = self.shm.buf
@ -102,18 +105,18 @@ class ZeroCopySHMBuffer:
struct.pack_into("I", self.buf, 12, idx)
def write_trajectory(
self,
tokens: List[int],
score: float,
instance_id: str = "",
repetition_id: int = 0,
metadata: Dict[str, Any] = None
self,
tokens: List[int],
score: float,
instance_id: str = "",
repetition_id: int = 0,
metadata: Dict[str, Any] = None,
):
"""
Writes a trajectory and its rich metadata to the buffer.
"""
read_idx, write_idx, max_size, entry_size = self._get_control()
# Check for overflow
next_write = (write_idx + 1) % max_size
if next_write == read_idx:
@ -122,29 +125,38 @@ class ZeroCopySHMBuffer:
# Calculate offset in data segment
offset = SHMBufferConfig.SIZE + (write_idx * self.slot_size)
# Pack Metadata and Rich attributes
struct.pack_into("d", self.buf, offset, float(score))
token_len = min(len(tokens), entry_size)
struct.pack_into("i", self.buf, offset + 8, token_len)
id_bytes = instance_id.encode('utf-8')[:self.instance_id_len]
id_bytes = instance_id.encode("utf-8")[: self.instance_id_len]
struct.pack_into(f"{self.instance_id_len}s", self.buf, offset + 12, id_bytes)
struct.pack_into("i", self.buf, offset + 12 + self.instance_id_len, int(repetition_id))
meta_json = json.dumps(metadata or {}).encode('utf-8')[:self.metadata_len]
struct.pack_into(f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4, meta_json)
struct.pack_into(
"i", self.buf, offset + 12 + self.instance_id_len, int(repetition_id)
)
meta_json = json.dumps(metadata or {}).encode("utf-8")[: self.metadata_len]
struct.pack_into(
f"{self.metadata_len}s",
self.buf,
offset + 12 + self.instance_id_len + 4,
meta_json,
)
# Copy tokens via Numpy View directly into SHM slot
token_offset = offset + 12 + self.instance_id_len + 4 + self.metadata_len
token_arr = np.array(tokens, dtype=np.int32)
shm_slot = np.ndarray((entry_size,), dtype=np.int32, buffer=self.buf, offset=token_offset)
shm_slot = np.ndarray(
(entry_size,), dtype=np.int32, buffer=self.buf, offset=token_offset
)
shm_slot[:token_len] = token_arr[:token_len]
if token_len < entry_size:
shm_slot[token_len:] = 0
self._set_write_idx(next_write)
return True
@ -153,42 +165,51 @@ class ZeroCopySHMBuffer:
Reads the next available trajectory with its score and metadata.
"""
read_idx, write_idx, max_size, entry_size = self._get_control()
if read_idx == write_idx:
return None # Buffer empty
return None # Buffer empty
offset = SHMBufferConfig.SIZE + (read_idx * self.slot_size)
# Unpack Metadata and Rich attributes
score = struct.unpack_from("d", self.buf, offset)[0]
token_len = min(struct.unpack_from("i", self.buf, offset + 8)[0], entry_size)
id_bytes = struct.unpack_from(f"{self.instance_id_len}s", self.buf, offset + 12)[0]
instance_id = id_bytes.decode('utf-8', errors='ignore').strip('\x00')
repetition_id = struct.unpack_from("i", self.buf, offset + 12 + self.instance_id_len)[0]
meta_bytes = struct.unpack_from(f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4)[0]
id_bytes = struct.unpack_from(
f"{self.instance_id_len}s", self.buf, offset + 12
)[0]
instance_id = id_bytes.decode("utf-8", errors="ignore").strip("\x00")
repetition_id = struct.unpack_from(
"i", self.buf, offset + 12 + self.instance_id_len
)[0]
meta_bytes = struct.unpack_from(
f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4
)[0]
try:
metadata = json.loads(meta_bytes.decode('utf-8', errors='ignore').strip('\x00'))
metadata = json.loads(
meta_bytes.decode("utf-8", errors="ignore").strip("\x00")
)
except (json.JSONDecodeError, UnicodeDecodeError):
metadata = {}
token_offset = offset + 12 + self.instance_id_len + 4 + self.metadata_len
tokens_view = np.ndarray((token_len,), dtype=np.int32, buffer=self.buf, offset=token_offset)
tokens_view = np.ndarray(
(token_len,), dtype=np.int32, buffer=self.buf, offset=token_offset
)
self._set_read_idx((read_idx + 1) % max_size)
return {
"tokens": tokens_view.tolist(),
"score": score,
"instance_id": instance_id,
"repetition_id": repetition_id,
"metadata": metadata
"metadata": metadata,
}
def close(self, unlink: bool = False):
self.shm.close()
if unlink:
self.shm.unlink()

View file

@ -27,6 +27,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import AutoTokenizer
from typing_extensions import TypedDict
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
from atroposlib.frontend.jsonl2html import generate_html
@ -49,7 +50,6 @@ from .server_handling.server_manager import (
ServerManager,
ServerManagerConfig,
)
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ -415,11 +415,11 @@ class BaseEnv(ABC):
if result[0].get("images", None) is not None:
to_postprocess["images"].append(result[0]["images"])
backlog.extend(result[1])
# Apply Raw State Injection if configured
if self.config.state_injection_template:
to_postprocess = self._inject_state(to_postprocess, item)
return to_postprocess, backlog
def _inject_state(self, group: ScoredDataGroup, item: Item) -> ScoredDataGroup:
@ -428,7 +428,7 @@ class BaseEnv(ABC):
"""
state = getattr(item, "state", str(item))
injection = self.config.state_injection_template.format(state=state)
for i in range(len(group["tokens"])):
# Decode, inject, and re-encode (or prepend tokens if possible)
decoded = self.tokenizer.decode(group["tokens"][i])
@ -850,7 +850,9 @@ class BaseEnv(ABC):
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
)
async def _dispatch_scored_data(self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]]):
async def _dispatch_scored_data(
self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]]
):
"""
Dispatches scored data to the configured transport (HTTP or SHM).
"""
@ -870,7 +872,7 @@ class BaseEnv(ABC):
score=group["scores"][i] if i < len(group["scores"]) else 0.0,
instance_id=inst_id,
repetition_id=i,
metadata={"env": self.name, "env_id": env_id}
metadata={"env": self.name, "env_id": env_id},
)
return

View file

@ -8,10 +8,12 @@ from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup
logger = logging.getLogger(__name__)
class SkyRLConfig(BaseEnvConfig):
"""
Configuration for the Berkeley SkyRL adapter.
"""
skyrl_repo_id: str = Field(
default="NovaSky-AI/Sky-AIME-5K",
description="The SkyRL-gym repository ID or local path to the reasoning environment.",
@ -29,11 +31,13 @@ class SkyRLConfig(BaseEnvConfig):
description="The closing tag for reasoning/thinking traces.",
)
class SkyRLAdapter(BaseEnv):
"""
Atropos Adapter for SkyRL (NovaSky-AI) environments.
Bridges reasoning traces and step-wise rewards into the Atropos layer.
"""
name = "skyrl"
env_config_cls = SkyRLConfig
@ -67,7 +71,7 @@ class SkyRLAdapter(BaseEnv):
self.config.thought_start_tag
)
end_idx = content.find(self.config.thought_end_tag)
if end_idx != -1:
thinking_trace = content[start_idx:end_idx].strip()
if "reasoning_traces" not in group["env_metrics"]:
@ -92,7 +96,7 @@ class SkyRLAdapter(BaseEnv):
advantages=None,
ref_logprobs=None,
messages=None,
meta={"source": "skyrl_dummy"}
meta={"source": "skyrl_dummy"},
)
async def evaluate(self, *args, **kwargs) -> Dict[str, float]:

View file

@ -1,11 +1,13 @@
import multiprocessing as mp
import time
import json
import uuid
import requests
import numpy as np
import multiprocessing as mp
import struct
from typing import List, Dict, Any
import time
import uuid
from typing import Any, Dict, List
import numpy as np
import requests
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
# Configuration for Mocks
@ -14,96 +16,114 @@ ENTRY_SIZE = 4096
NUM_ENV_WORKERS = 4
TOTAL_TRAJECTORIES = 500
def mock_env_worker(worker_id: int, shm_name: str, barrier: mp.Barrier, stop_event: mp.Event):
def mock_env_worker(
worker_id: int, shm_name: str, barrier: mp.Barrier, stop_event: mp.Event
):
"""Simulates a SkyRL Environment process pushing trajectories to SHM."""
try:
shm = ZeroCopySHMBuffer(name=shm_name, create=False)
barrier.wait()
count = 0
while not stop_event.is_set() and count < (TOTAL_TRAJECTORIES // NUM_ENV_WORKERS):
tokens = [100 + i for i in range(ENTRY_SIZE)]
while not stop_event.is_set() and count < (
TOTAL_TRAJECTORIES // NUM_ENV_WORKERS
):
tokens = [100 + i for i in range(ENTRY_SIZE)]
score = 0.8 + (worker_id * 0.05)
success = shm.write_trajectory(
tokens=tokens,
score=score,
tokens=tokens,
score=score,
instance_id=f"task_{count}",
repetition_id=worker_id,
metadata={"worker": worker_id}
metadata={"worker": worker_id},
)
if success:
count += 1
else:
time.sleep(0.001)
except Exception as e:
print(f"Worker {worker_id} Error: {e}")
def run_e2e_benchmark():
shm_name = f"test_e2e_shm_{uuid.uuid4().hex[:8]}"
shm = ZeroCopySHMBuffer(name=shm_name, size=BATCH_SIZE * 2, entry_size=ENTRY_SIZE, create=True)
shm = ZeroCopySHMBuffer(
name=shm_name, size=BATCH_SIZE * 2, entry_size=ENTRY_SIZE, create=True
)
barrier = mp.Barrier(NUM_ENV_WORKERS + 1)
stop_event = mp.Event()
print(f"🚀 Starting {NUM_ENV_WORKERS} Environment Workers (Concurrency Test)...")
workers = []
for i in range(NUM_ENV_WORKERS):
p = mp.Process(target=mock_env_worker, args=(i, shm_name, barrier, stop_event))
p.start()
workers.append(p)
barrier.wait()
barrier.wait()
print("📈 Measuring SHM Throughput & Integrity...")
start_shm = time.perf_counter()
received = 0
verification_passed = True
while received < TOTAL_TRAJECTORIES:
data = shm.read_next()
if data:
if received % 100 == 0:
if not (data["instance_id"].startswith("task_") and "worker" in data["metadata"]):
if not (
data["instance_id"].startswith("task_")
and "worker" in data["metadata"]
):
print(f"❌ Integrity Check Failed at index {received}!")
verification_passed = False
received += 1
else:
if all(not p.is_alive() for p in workers) and received < TOTAL_TRAJECTORIES:
break
shm_tps = TOTAL_TRAJECTORIES / (time.perf_counter() - start_shm)
print(f" [SHM] Received {received} trajectories ({shm_tps:.2f} traj/s)")
print(f" [SHM] Integrity Verification: {'✅ PASSED' if verification_passed else '❌ FAILED'}")
print(
f" [SHM] Integrity Verification: {'✅ PASSED' if verification_passed else '❌ FAILED'}"
)
# HTTP Baseline Simulation
print("📉 Measuring HTTP Baseline Simulation (JSON Tax)...")
start_http = time.perf_counter()
for _ in range(TOTAL_TRAJECTORIES):
tokens = [100 + i for i in range(ENTRY_SIZE)]
payload = json.dumps({
"tokens": tokens,
"score": 0.8,
"instance_id": "task_x",
"repetition_id": 0,
"metadata": {"foo": "bar"}
})
_ = json.loads(payload)
payload = json.dumps(
{
"tokens": tokens,
"score": 0.8,
"instance_id": "task_x",
"repetition_id": 0,
"metadata": {"foo": "bar"},
}
)
_ = json.loads(payload)
http_tps = TOTAL_TRAJECTORIES / (time.perf_counter() - start_http)
print(f" [HTTP] Processed {TOTAL_TRAJECTORIES} trajectories ({http_tps:.2f} traj/s)")
print(
f" [HTTP] Processed {TOTAL_TRAJECTORIES} trajectories ({http_tps:.2f} traj/s)"
)
# --- RESULTS ---
print("\n" + "="*40)
print("\n" + "=" * 40)
print("🏆 E2E TEST RESULTS")
print("="*40)
print("=" * 40)
print(f"SHM Throughput Gain: {shm_tps / http_tps:.2f}x")
print(f"Concurrency Load: {NUM_ENV_WORKERS} workers handled without corruption.")
print(f"Data Integrity: {'Verified' if verification_passed else 'CORRUPT'}")
print("="*40)
print("=" * 40)
stop_event.set()
for p in workers: p.join()
for p in workers:
p.join()
shm.close(unlink=True)

View file

@ -13,8 +13,8 @@ Usage:
import logging
from typing import Any, Dict, List, Optional, Tuple
from atroposlib.envs.skyrl_adapter import SkyRLAdapter, SkyRLConfig
from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.skyrl_adapter import SkyRLAdapter, SkyRLConfig
logger = logging.getLogger(__name__)
@ -23,7 +23,7 @@ class SkyRLServerEnv(SkyRLAdapter):
"""
User-facing environment for SkyRL reasoning tasks.
"""
@classmethod
def config_init(cls) -> Tuple[SkyRLConfig, List[APIServerConfig]]:
env_config = SkyRLConfig(
@ -54,5 +54,6 @@ class SkyRLServerEnv(SkyRLAdapter):
await super().setup()
logger.info("SkyRL environment setup complete.")
if __name__ == "__main__":
SkyRLServerEnv.cli()