mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
48dcd64299
commit
51b4aa4858
6 changed files with 147 additions and 99 deletions
|
|
@ -17,6 +17,7 @@ class SHMBufferConfig:
|
|||
Control block for Shared Memory Buffer.
|
||||
Stored at the beginning of the SHM segment.
|
||||
"""
|
||||
|
||||
# [Magic (4B) | Version (4B) | ReadIdx (4B) | WriteIdx (4B) | MaxSize (4B) | EntrySize (4B)]
|
||||
FORMAT = "4sIIIII"
|
||||
SIZE = struct.calcsize(FORMAT)
|
||||
|
|
@ -44,15 +45,13 @@ class ZeroCopySHMBuffer:
|
|||
self.entry_size = entry_size
|
||||
self.instance_id_len = instance_id_len
|
||||
self.metadata_len = metadata_len
|
||||
|
||||
|
||||
# Schema: [Score (8) | Len (4) | InstanceID (id_len) | RepID (4) | Meta (meta_len) | Tokens (Size*4)]
|
||||
self.slot_size = (
|
||||
8 + 4 + instance_id_len + 4 + metadata_len + (entry_size * 4)
|
||||
)
|
||||
|
||||
self.slot_size = 8 + 4 + instance_id_len + 4 + metadata_len + (entry_size * 4)
|
||||
|
||||
# Total size = Control Block + Data Segment
|
||||
self.total_size = SHMBufferConfig.SIZE + (size * self.slot_size)
|
||||
|
||||
|
||||
try:
|
||||
if create:
|
||||
# Remove existing if any (OS-level cleanup)
|
||||
|
|
@ -61,11 +60,15 @@ class ZeroCopySHMBuffer:
|
|||
shm.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
self.shm = shared_memory.SharedMemory(name=name, create=True, size=self.total_size)
|
||||
|
||||
self.shm = shared_memory.SharedMemory(
|
||||
name=name, create=True, size=self.total_size
|
||||
)
|
||||
self.buf = self.shm.buf
|
||||
self._init_control_block()
|
||||
logger.info(f"Created SHM buffer '{name}' with size {self.total_size} bytes")
|
||||
logger.info(
|
||||
f"Created SHM buffer '{name}' with size {self.total_size} bytes"
|
||||
)
|
||||
else:
|
||||
self.shm = shared_memory.SharedMemory(name=name)
|
||||
self.buf = self.shm.buf
|
||||
|
|
@ -102,18 +105,18 @@ class ZeroCopySHMBuffer:
|
|||
struct.pack_into("I", self.buf, 12, idx)
|
||||
|
||||
def write_trajectory(
|
||||
self,
|
||||
tokens: List[int],
|
||||
score: float,
|
||||
instance_id: str = "",
|
||||
repetition_id: int = 0,
|
||||
metadata: Dict[str, Any] = None
|
||||
self,
|
||||
tokens: List[int],
|
||||
score: float,
|
||||
instance_id: str = "",
|
||||
repetition_id: int = 0,
|
||||
metadata: Dict[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
Writes a trajectory and its rich metadata to the buffer.
|
||||
"""
|
||||
read_idx, write_idx, max_size, entry_size = self._get_control()
|
||||
|
||||
|
||||
# Check for overflow
|
||||
next_write = (write_idx + 1) % max_size
|
||||
if next_write == read_idx:
|
||||
|
|
@ -122,29 +125,38 @@ class ZeroCopySHMBuffer:
|
|||
|
||||
# Calculate offset in data segment
|
||||
offset = SHMBufferConfig.SIZE + (write_idx * self.slot_size)
|
||||
|
||||
|
||||
# Pack Metadata and Rich attributes
|
||||
struct.pack_into("d", self.buf, offset, float(score))
|
||||
|
||||
|
||||
token_len = min(len(tokens), entry_size)
|
||||
struct.pack_into("i", self.buf, offset + 8, token_len)
|
||||
|
||||
id_bytes = instance_id.encode('utf-8')[:self.instance_id_len]
|
||||
|
||||
id_bytes = instance_id.encode("utf-8")[: self.instance_id_len]
|
||||
struct.pack_into(f"{self.instance_id_len}s", self.buf, offset + 12, id_bytes)
|
||||
|
||||
struct.pack_into("i", self.buf, offset + 12 + self.instance_id_len, int(repetition_id))
|
||||
|
||||
meta_json = json.dumps(metadata or {}).encode('utf-8')[:self.metadata_len]
|
||||
struct.pack_into(f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4, meta_json)
|
||||
|
||||
|
||||
struct.pack_into(
|
||||
"i", self.buf, offset + 12 + self.instance_id_len, int(repetition_id)
|
||||
)
|
||||
|
||||
meta_json = json.dumps(metadata or {}).encode("utf-8")[: self.metadata_len]
|
||||
struct.pack_into(
|
||||
f"{self.metadata_len}s",
|
||||
self.buf,
|
||||
offset + 12 + self.instance_id_len + 4,
|
||||
meta_json,
|
||||
)
|
||||
|
||||
# Copy tokens via Numpy View directly into SHM slot
|
||||
token_offset = offset + 12 + self.instance_id_len + 4 + self.metadata_len
|
||||
token_arr = np.array(tokens, dtype=np.int32)
|
||||
shm_slot = np.ndarray((entry_size,), dtype=np.int32, buffer=self.buf, offset=token_offset)
|
||||
shm_slot = np.ndarray(
|
||||
(entry_size,), dtype=np.int32, buffer=self.buf, offset=token_offset
|
||||
)
|
||||
shm_slot[:token_len] = token_arr[:token_len]
|
||||
if token_len < entry_size:
|
||||
shm_slot[token_len:] = 0
|
||||
|
||||
|
||||
self._set_write_idx(next_write)
|
||||
return True
|
||||
|
||||
|
|
@ -153,42 +165,51 @@ class ZeroCopySHMBuffer:
|
|||
Reads the next available trajectory with its score and metadata.
|
||||
"""
|
||||
read_idx, write_idx, max_size, entry_size = self._get_control()
|
||||
|
||||
|
||||
if read_idx == write_idx:
|
||||
return None # Buffer empty
|
||||
|
||||
return None # Buffer empty
|
||||
|
||||
offset = SHMBufferConfig.SIZE + (read_idx * self.slot_size)
|
||||
|
||||
|
||||
# Unpack Metadata and Rich attributes
|
||||
score = struct.unpack_from("d", self.buf, offset)[0]
|
||||
token_len = min(struct.unpack_from("i", self.buf, offset + 8)[0], entry_size)
|
||||
|
||||
id_bytes = struct.unpack_from(f"{self.instance_id_len}s", self.buf, offset + 12)[0]
|
||||
instance_id = id_bytes.decode('utf-8', errors='ignore').strip('\x00')
|
||||
|
||||
repetition_id = struct.unpack_from("i", self.buf, offset + 12 + self.instance_id_len)[0]
|
||||
|
||||
meta_bytes = struct.unpack_from(f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4)[0]
|
||||
|
||||
id_bytes = struct.unpack_from(
|
||||
f"{self.instance_id_len}s", self.buf, offset + 12
|
||||
)[0]
|
||||
instance_id = id_bytes.decode("utf-8", errors="ignore").strip("\x00")
|
||||
|
||||
repetition_id = struct.unpack_from(
|
||||
"i", self.buf, offset + 12 + self.instance_id_len
|
||||
)[0]
|
||||
|
||||
meta_bytes = struct.unpack_from(
|
||||
f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4
|
||||
)[0]
|
||||
try:
|
||||
metadata = json.loads(meta_bytes.decode('utf-8', errors='ignore').strip('\x00'))
|
||||
metadata = json.loads(
|
||||
meta_bytes.decode("utf-8", errors="ignore").strip("\x00")
|
||||
)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
metadata = {}
|
||||
|
||||
|
||||
token_offset = offset + 12 + self.instance_id_len + 4 + self.metadata_len
|
||||
tokens_view = np.ndarray((token_len,), dtype=np.int32, buffer=self.buf, offset=token_offset)
|
||||
|
||||
tokens_view = np.ndarray(
|
||||
(token_len,), dtype=np.int32, buffer=self.buf, offset=token_offset
|
||||
)
|
||||
|
||||
self._set_read_idx((read_idx + 1) % max_size)
|
||||
|
||||
|
||||
return {
|
||||
"tokens": tokens_view.tolist(),
|
||||
"score": score,
|
||||
"instance_id": instance_id,
|
||||
"repetition_id": repetition_id,
|
||||
"metadata": metadata
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
def close(self, unlink: bool = False):
|
||||
self.shm.close()
|
||||
if unlink:
|
||||
self.shm.unlink()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue