feat(skyrl-shm): Universal SHM transport and Raw State Injection

This commit is contained in:
RUFFY-369 2026-04-06 12:23:32 +05:30
parent 210883f0da
commit 43a5cdcdfc
3 changed files with 170 additions and 25 deletions

View file

@ -28,6 +28,7 @@ class ZeroCopySHMBuffer:
"""
High-performance circular buffer using multiprocessing.shared_memory.
Eliminates JSON serialization and HTTP overhead for trajectory transport.
Now expanded with TrajectoryID and metadata slots for universal Atropos use.
"""
def __init__(
@ -35,13 +36,20 @@ class ZeroCopySHMBuffer:
name: str,
size: int = 1000,
entry_size: int = 4096, # Max tokens per trajectory
instance_id_len: int = 64,
metadata_len: int = 256,
create: bool = False,
):
self.name = name
self.max_size = size
self.entry_size = entry_size
self.instance_id_len = instance_id_len
self.metadata_len = metadata_len
self.slot_size = 8 + 4 + (entry_size * 4) # Score (8) + Len (4) + Tokens (Size*4)
# Schema: [Score (8) | Len (4) | InstanceID (id_len) | RepID (4) | Meta (meta_len) | Tokens (Size*4)]
self.slot_size = (
8 + 4 + instance_id_len + 4 + metadata_len + (entry_size * 4)
)
# Total size = Control Block + Data Segment
self.total_size = SHMBufferConfig.SIZE + (size * self.slot_size)
@ -92,10 +100,16 @@ class ZeroCopySHMBuffer:
# We only update these two fields (Offsets: ReadIdx=8, WriteIdx=12)
struct.pack_into("II", self.buf, 8, read_idx, write_idx)
def write_trajectory(self, tokens: List[int], score: float, metadata: Dict[str, Any] = None):
def write_trajectory(
self,
tokens: List[int],
score: float,
instance_id: str = "",
repetition_id: int = 0,
metadata: Dict[str, Any] = None
):
"""
Writes a trajectory, its score, and metadata to the buffer.
Schema: [Score (8 bytes) | TokenLen (4 bytes) | Tokens (EntrySize * 4 bytes)]
Writes a trajectory and its rich metadata to the buffer.
"""
read_idx, write_idx, max_size, entry_size = self._get_control()
@ -108,15 +122,26 @@ class ZeroCopySHMBuffer:
# Calculate offset in data segment
offset = SHMBufferConfig.SIZE + (write_idx * self.slot_size)
# Write Score
# write Score (8)
struct.pack_into("d", self.buf, offset, float(score))
# Write Token Length
# write Token Length (4)
token_len = min(len(tokens), entry_size)
struct.pack_into("i", self.buf, offset + 8, token_len)
# Write Tokens (Numpy View)
token_offset = offset + 8 + 4
# write Instance ID (fixed len)
id_bytes = instance_id.encode('utf-8')[:self.instance_id_len]
struct.pack_into(f"{self.instance_id_len}s", self.buf, offset + 12, id_bytes)
# write Repetition ID (4)
struct.pack_into("i", self.buf, offset + 12 + self.instance_id_len, int(repetition_id))
# write Metadata (fixed len JSON)
meta_json = json.dumps(metadata or {}).encode('utf-8')[:self.metadata_len]
struct.pack_into(f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4, meta_json)
# write Tokens (Numpy View)
token_offset = offset + 12 + self.instance_id_len + 4 + self.metadata_len
token_arr = np.array(tokens, dtype=np.int32)
# View the SHM as a numpy array for the specific token slot
@ -145,8 +170,22 @@ class ZeroCopySHMBuffer:
token_len = struct.unpack_from("i", self.buf, offset + 8)[0]
token_len = min(token_len, entry_size)
# Read Instance ID
id_bytes = struct.unpack_from(f"{self.instance_id_len}s", self.buf, offset + 12)[0]
instance_id = id_bytes.decode('utf-8', errors='ignore').strip('\x00')
# Read Repetition ID
repetition_id = struct.unpack_from("i", self.buf, offset + 12 + self.instance_id_len)[0]
# Read Metadata
meta_bytes = struct.unpack_from(f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4)[0]
try:
metadata = json.loads(meta_bytes.decode('utf-8', errors='ignore').strip('\x00'))
except:
metadata = {}
# Read Tokens (Numpy View)
token_offset = offset + 8 + 4
token_offset = offset + 12 + self.instance_id_len + 4 + self.metadata_len
tokens_view = np.ndarray((token_len,), dtype=np.int32, buffer=self.buf, offset=token_offset)
# Advance read index
@ -154,10 +193,14 @@ class ZeroCopySHMBuffer:
return {
"tokens": tokens_view.tolist(),
"score": score
"score": score,
"instance_id": instance_id,
"repetition_id": repetition_id,
"metadata": metadata
}
def close(self, unlink: bool = False):
self.shm.close()
if unlink:
self.shm.unlink()