mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
feat(shm): bugfix and code cleanup
This commit is contained in:
parent
43a5cdcdfc
commit
48dcd64299
5 changed files with 38 additions and 83 deletions
|
|
@ -211,10 +211,8 @@ def _process_scored_data(scored_data: ScoredData) -> Dict[str, Any]:
|
|||
actual_group_size = len(scored_data.tokens)
|
||||
|
||||
if actual_group_size != expected_group_size:
|
||||
# Buffer mixed-size groups if necessary (TBD)
|
||||
buffer = app.state.buffer.setdefault(env_id, [])
|
||||
buffer.append(data_dict)
|
||||
pass
|
||||
|
||||
if hasattr(app.state, "shm_buffer") and app.state.shm_buffer:
|
||||
for i in range(len(scored_data.tokens)):
|
||||
|
|
@ -269,18 +267,17 @@ async def register(registration: Registration):
|
|||
|
||||
app.state.requesters.append(uuid.uuid4().int)
|
||||
|
||||
# Initialize Pinhole SHM Buffer
|
||||
# Pin-hole SHM initialization
|
||||
shm_name = f"atropos_shm_{app.state.group}"
|
||||
try:
|
||||
app.state.shm_buffer = ZeroCopySHMBuffer(
|
||||
name=shm_name,
|
||||
size=app.state.batchsize * 10, # Keep 10 batches in flight
|
||||
size=app.state.batchsize * 10,
|
||||
entry_size=app.state.max_token_len,
|
||||
create=True
|
||||
)
|
||||
logger.info(f"Initialized Zero-Copy SHM Pinhole: {shm_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize SHM Pinhole: {e}")
|
||||
logger.error(f"SHM Buffer Init Failed: {e}")
|
||||
app.state.shm_buffer = None
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -27,8 +27,7 @@ class SHMBufferConfig:
|
|||
class ZeroCopySHMBuffer:
|
||||
"""
|
||||
High-performance circular buffer using multiprocessing.shared_memory.
|
||||
Eliminates JSON serialization and HTTP overhead for trajectory transport.
|
||||
Now expanded with TrajectoryID and metadata slots for universal Atropos use.
|
||||
Eliminates serialization and HTTP overhead for trajectory transport.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -96,9 +95,11 @@ class ZeroCopySHMBuffer:
|
|||
raise ValueError("Invalid SHM Magic")
|
||||
return read_idx, write_idx, max_size, entry_size
|
||||
|
||||
def _set_indices(self, read_idx: int, write_idx: int):
|
||||
# We only update these two fields (Offsets: ReadIdx=8, WriteIdx=12)
|
||||
struct.pack_into("II", self.buf, 8, read_idx, write_idx)
|
||||
def _set_read_idx(self, idx: int):
|
||||
struct.pack_into("I", self.buf, 8, idx)
|
||||
|
||||
def _set_write_idx(self, idx: int):
|
||||
struct.pack_into("I", self.buf, 12, idx)
|
||||
|
||||
def write_trajectory(
|
||||
self,
|
||||
|
|
@ -122,36 +123,29 @@ class ZeroCopySHMBuffer:
|
|||
# Calculate offset in data segment
|
||||
offset = SHMBufferConfig.SIZE + (write_idx * self.slot_size)
|
||||
|
||||
# write Score (8)
|
||||
# Pack Metadata and Rich attributes
|
||||
struct.pack_into("d", self.buf, offset, float(score))
|
||||
|
||||
# write Token Length (4)
|
||||
token_len = min(len(tokens), entry_size)
|
||||
struct.pack_into("i", self.buf, offset + 8, token_len)
|
||||
|
||||
# write Instance ID (fixed len)
|
||||
id_bytes = instance_id.encode('utf-8')[:self.instance_id_len]
|
||||
struct.pack_into(f"{self.instance_id_len}s", self.buf, offset + 12, id_bytes)
|
||||
|
||||
# write Repetition ID (4)
|
||||
struct.pack_into("i", self.buf, offset + 12 + self.instance_id_len, int(repetition_id))
|
||||
|
||||
# write Metadata (fixed len JSON)
|
||||
meta_json = json.dumps(metadata or {}).encode('utf-8')[:self.metadata_len]
|
||||
struct.pack_into(f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4, meta_json)
|
||||
|
||||
# write Tokens (Numpy View)
|
||||
# Copy tokens via Numpy View directly into SHM slot
|
||||
token_offset = offset + 12 + self.instance_id_len + 4 + self.metadata_len
|
||||
token_arr = np.array(tokens, dtype=np.int32)
|
||||
|
||||
# View the SHM as a numpy array for the specific token slot
|
||||
shm_slot = np.ndarray((entry_size,), dtype=np.int32, buffer=self.buf, offset=token_offset)
|
||||
shm_slot[:token_len] = token_arr[:token_len]
|
||||
if token_len < entry_size:
|
||||
shm_slot[token_len:] = 0 # Padding
|
||||
shm_slot[token_len:] = 0
|
||||
|
||||
# Update write index
|
||||
self._set_indices(read_idx, next_write)
|
||||
self._set_write_idx(next_write)
|
||||
return True
|
||||
|
||||
def read_next(self) -> Optional[Dict[str, Any]]:
|
||||
|
|
@ -165,31 +159,25 @@ class ZeroCopySHMBuffer:
|
|||
|
||||
offset = SHMBufferConfig.SIZE + (read_idx * self.slot_size)
|
||||
|
||||
# Read Score and Token Length
|
||||
# Unpack Metadata and Rich attributes
|
||||
score = struct.unpack_from("d", self.buf, offset)[0]
|
||||
token_len = struct.unpack_from("i", self.buf, offset + 8)[0]
|
||||
token_len = min(token_len, entry_size)
|
||||
token_len = min(struct.unpack_from("i", self.buf, offset + 8)[0], entry_size)
|
||||
|
||||
# Read Instance ID
|
||||
id_bytes = struct.unpack_from(f"{self.instance_id_len}s", self.buf, offset + 12)[0]
|
||||
instance_id = id_bytes.decode('utf-8', errors='ignore').strip('\x00')
|
||||
|
||||
# Read Repetition ID
|
||||
repetition_id = struct.unpack_from("i", self.buf, offset + 12 + self.instance_id_len)[0]
|
||||
|
||||
# Read Metadata
|
||||
meta_bytes = struct.unpack_from(f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4)[0]
|
||||
try:
|
||||
metadata = json.loads(meta_bytes.decode('utf-8', errors='ignore').strip('\x00'))
|
||||
except:
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
metadata = {}
|
||||
|
||||
# Read Tokens (Numpy View)
|
||||
token_offset = offset + 12 + self.instance_id_len + 4 + self.metadata_len
|
||||
tokens_view = np.ndarray((token_len,), dtype=np.int32, buffer=self.buf, offset=token_offset)
|
||||
|
||||
# Advance read index
|
||||
self._set_indices((read_idx + 1) % max_size, write_idx)
|
||||
self._set_read_idx((read_idx + 1) % max_size)
|
||||
|
||||
return {
|
||||
"tokens": tokens_view.tolist(),
|
||||
|
|
|
|||
|
|
@ -99,10 +99,6 @@ class EvalHandlingEnum(Enum):
|
|||
|
||||
|
||||
class TransportType(Enum):
|
||||
"""
|
||||
Enum for trajectory transport types.
|
||||
"""
|
||||
|
||||
HTTP = "HTTP"
|
||||
SHM = "SHM"
|
||||
|
||||
|
|
@ -866,17 +862,15 @@ class BaseEnv(ABC):
|
|||
|
||||
if self.config.transport == TransportType.SHM and self.shm_buffer:
|
||||
for group in data_list:
|
||||
# Use the provided instance_id (Task ID) if available, fallback to env_id
|
||||
inst_id = str(group.get("instance_id") or env_id or "unknown")
|
||||
for i in range(len(group["tokens"])):
|
||||
# Write each rollout in the group to the circular buffer
|
||||
self.shm_buffer.write_trajectory(
|
||||
tokens=group["tokens"][i],
|
||||
score=group["scores"][i],
|
||||
instance_id=str(env_id or "unknown"),
|
||||
metadata={
|
||||
"env": self.name,
|
||||
"step": self.curr_step,
|
||||
"group_idx": i
|
||||
}
|
||||
score=group["scores"][i] if i < len(group["scores"]) else 0.0,
|
||||
instance_id=inst_id,
|
||||
repetition_id=i,
|
||||
metadata={"env": self.name, "env_id": env_id}
|
||||
)
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -31,10 +31,8 @@ class SkyRLConfig(BaseEnvConfig):
|
|||
|
||||
class SkyRLAdapter(BaseEnv):
|
||||
"""
|
||||
Atropos Adapter for Berkeley's SkyRL (NovaSky-AI) environments.
|
||||
|
||||
This adapter bridges the SkyRL-gym trajectory format (Thinking Traces + PRM)
|
||||
into the Atropos orchestration layer.
|
||||
Atropos Adapter for SkyRL (NovaSky-AI) environments.
|
||||
Bridges reasoning traces and step-wise rewards into the Atropos layer.
|
||||
"""
|
||||
name = "skyrl"
|
||||
env_config_cls = SkyRLConfig
|
||||
|
|
|
|||
|
|
@ -15,62 +15,46 @@ NUM_ENV_WORKERS = 4
|
|||
TOTAL_TRAJECTORIES = 500
|
||||
|
||||
def mock_env_worker(worker_id: int, shm_name: str, barrier: mp.Barrier, stop_event: mp.Event):
|
||||
"""
|
||||
Simulates a SkyRL Environment process pushing trajectories to SHM.
|
||||
"""
|
||||
"""Simulates a SkyRL Environment process pushing trajectories to SHM."""
|
||||
try:
|
||||
shm = ZeroCopySHMBuffer(name=shm_name, create=False)
|
||||
barrier.wait() # Synced start
|
||||
barrier.wait()
|
||||
|
||||
count = 0
|
||||
while not stop_event.is_set() and count < (TOTAL_TRAJECTORIES // NUM_ENV_WORKERS):
|
||||
# Simulate REAL Reasoning model trace (4k tokens)
|
||||
tokens = [100 + i for i in range(ENTRY_SIZE)]
|
||||
score = 0.8 + (worker_id * 0.05)
|
||||
instance_id = f"task_{count}"
|
||||
repetition_id = worker_id
|
||||
metadata = {"worker": worker_id, "timestamp": time.time()}
|
||||
|
||||
success = shm.write_trajectory(
|
||||
tokens=tokens,
|
||||
score=score,
|
||||
instance_id=instance_id,
|
||||
repetition_id=repetition_id,
|
||||
metadata=metadata
|
||||
instance_id=f"task_{count}",
|
||||
repetition_id=worker_id,
|
||||
metadata={"worker": worker_id}
|
||||
)
|
||||
if success:
|
||||
count += 1
|
||||
else:
|
||||
time.sleep(0.001) # Buffer full, backoff
|
||||
time.sleep(0.001)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Worker {worker_id} Error: {e}")
|
||||
|
||||
def run_e2e_benchmark():
|
||||
"""
|
||||
Main E2E logic:
|
||||
1. Setup SHM
|
||||
2. Launch Concurrency Workers
|
||||
3. Measure Reader Throughput
|
||||
4. Verify Data Integrity (IDs and Metadata)
|
||||
"""
|
||||
shm_name = f"test_e2e_shm_{uuid.uuid4().hex[:8]}"
|
||||
shm = ZeroCopySHMBuffer(name=shm_name, size=BATCH_SIZE * 2, entry_size=ENTRY_SIZE, create=True)
|
||||
|
||||
barrier = mp.Barrier(NUM_ENV_WORKERS + 1)
|
||||
stop_event = mp.Event()
|
||||
|
||||
# Concurrency Test
|
||||
print(f"🚀 Starting {NUM_ENV_WORKERS} Environment Workers (Concurrency Test)...")
|
||||
workers = []
|
||||
for i in range(NUM_ENV_WORKERS):
|
||||
p = mp.Process(target=mock_env_worker, args=(i, shm_name, barrier, stop_event))
|
||||
p.start()
|
||||
workers.append(p)
|
||||
barrier.wait()
|
||||
|
||||
barrier.wait() # Start the race
|
||||
|
||||
# Throughput Benchmark (SHM)
|
||||
print("📈 Measuring SHM Throughput & Integrity...")
|
||||
start_shm = time.perf_counter()
|
||||
received = 0
|
||||
|
|
@ -79,7 +63,6 @@ def run_e2e_benchmark():
|
|||
while received < TOTAL_TRAJECTORIES:
|
||||
data = shm.read_next()
|
||||
if data:
|
||||
# Verify Metadata Integrity for a sample
|
||||
if received % 100 == 0:
|
||||
if not (data["instance_id"].startswith("task_") and "worker" in data["metadata"]):
|
||||
print(f"❌ Integrity Check Failed at index {received}!")
|
||||
|
|
@ -87,19 +70,16 @@ def run_e2e_benchmark():
|
|||
received += 1
|
||||
else:
|
||||
if all(not p.is_alive() for p in workers) and received < TOTAL_TRAJECTORIES:
|
||||
break # All workers died
|
||||
break
|
||||
|
||||
end_shm = time.perf_counter()
|
||||
shm_time = end_shm - start_shm
|
||||
shm_tps = TOTAL_TRAJECTORIES / shm_time
|
||||
print(f" [SHM] Received {received} trajectories in {shm_time:.4f}s ({shm_tps:.2f} traj/s)")
|
||||
shm_tps = TOTAL_TRAJECTORIES / (time.perf_counter() - start_shm)
|
||||
print(f" [SHM] Received {received} trajectories ({shm_tps:.2f} traj/s)")
|
||||
print(f" [SHM] Integrity Verification: {'✅ PASSED' if verification_passed else '❌ FAILED'}")
|
||||
|
||||
# HTTP Baseline Simulation
|
||||
print("📉 Measuring HTTP Baseline Simulation (JSON Tax)...")
|
||||
start_http = time.perf_counter()
|
||||
for _ in range(TOTAL_TRAJECTORIES):
|
||||
# Simulate JSON Serialization + Dummy HTTP Request
|
||||
tokens = [100 + i for i in range(ENTRY_SIZE)]
|
||||
payload = json.dumps({
|
||||
"tokens": tokens,
|
||||
|
|
@ -108,12 +88,10 @@ def run_e2e_benchmark():
|
|||
"repetition_id": 0,
|
||||
"metadata": {"foo": "bar"}
|
||||
})
|
||||
_ = json.loads(payload) # Deserialization
|
||||
_ = json.loads(payload)
|
||||
|
||||
end_http = time.perf_counter()
|
||||
http_time = end_http - start_http
|
||||
http_tps = TOTAL_TRAJECTORIES / http_time
|
||||
print(f" [HTTP] Processed {TOTAL_TRAJECTORIES} trajectories in {http_time:.4f}s ({http_tps:.2f} traj/s)")
|
||||
http_tps = TOTAL_TRAJECTORIES / (time.perf_counter() - start_http)
|
||||
print(f" [HTTP] Processed {TOTAL_TRAJECTORIES} trajectories ({http_tps:.2f} traj/s)")
|
||||
|
||||
# --- RESULTS ---
|
||||
print("\n" + "="*40)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue