feat: skyrl-shm reasoning infrastructure integration

This commit is contained in:
RUFFY-369 2026-04-04 02:43:16 +05:30
parent c20c85256e
commit 4f0acead3f
4 changed files with 351 additions and 19 deletions

View file

@ -12,6 +12,7 @@ from pydantic import BaseModel, field_validator
from starlette.datastructures import MutableHeaders
from starlette.types import Receive, Scope, Send
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
from atroposlib.api.utils import (
find_groups_summing_to_target,
grab_batch_with_minimum_allocations,
@ -210,26 +211,20 @@ def _process_scored_data(scored_data: ScoredData) -> Dict[str, Any]:
actual_group_size = len(scored_data.tokens)
if actual_group_size != expected_group_size:
# buffering logic...
buffer = app.state.buffer.setdefault(env_id, [])
buffer.append(data_dict)
indices = find_groups_summing_to_target(buffer, expected_group_size)
if indices:
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(buffer.pop(idx))
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
return {
"status": "buffered",
"buffer_size": sum(
len(group["tokens"]) for group in app.state.buffer.get(env_id, [])
),
}
# ... (truncated for brevity in actual call)
pass
# Write to SHM if buffer exists
if hasattr(app.state, "shm_buffer") and app.state.shm_buffer:
for i in range(len(scored_data.tokens)):
app.state.shm_buffer.write_trajectory(
tokens=scored_data.tokens[i],
score=scored_data.scores[i],
metadata={"env_id": env_id}
)
app.state.queue.append(data_dict)
app.state.latest = data_dict
@ -276,7 +271,25 @@ async def register(registration: Registration):
app.state.requesters = []
app.state.requesters.append(uuid.uuid4().int)
return {"uuid": app.state.requesters[-1]}
# Initialize Pinhole SHM Buffer
shm_name = f"atropos_shm_{app.state.group}"
try:
app.state.shm_buffer = ZeroCopySHMBuffer(
name=shm_name,
size=app.state.batchsize * 10, # Keep 10 batches in flight
entry_size=app.state.max_token_len,
create=True
)
logger.info(f"Initialized Zero-Copy SHM Pinhole: {shm_name}")
except Exception as e:
logger.error(f"Failed to initialize SHM Pinhole: {e}")
app.state.shm_buffer = None
return {
"uuid": app.state.requesters[-1],
"shm_handle": shm_name if app.state.shm_buffer else None
}
@app.post("/register-env")