mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
feat(skyrl-shm): Universal SHM transport and Raw State Injection
This commit is contained in:
parent
210883f0da
commit
43a5cdcdfc
3 changed files with 170 additions and 25 deletions
|
|
@ -28,6 +28,7 @@ class ZeroCopySHMBuffer:
|
|||
"""
|
||||
High-performance circular buffer using multiprocessing.shared_memory.
|
||||
Eliminates JSON serialization and HTTP overhead for trajectory transport.
|
||||
Now expanded with TrajectoryID and metadata slots for universal Atropos use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -35,13 +36,20 @@ class ZeroCopySHMBuffer:
|
|||
name: str,
|
||||
size: int = 1000,
|
||||
entry_size: int = 4096, # Max tokens per trajectory
|
||||
instance_id_len: int = 64,
|
||||
metadata_len: int = 256,
|
||||
create: bool = False,
|
||||
):
|
||||
self.name = name
|
||||
self.max_size = size
|
||||
self.entry_size = entry_size
|
||||
self.instance_id_len = instance_id_len
|
||||
self.metadata_len = metadata_len
|
||||
|
||||
self.slot_size = 8 + 4 + (entry_size * 4) # Score (8) + Len (4) + Tokens (Size*4)
|
||||
# Schema: [Score (8) | Len (4) | InstanceID (id_len) | RepID (4) | Meta (meta_len) | Tokens (Size*4)]
|
||||
self.slot_size = (
|
||||
8 + 4 + instance_id_len + 4 + metadata_len + (entry_size * 4)
|
||||
)
|
||||
|
||||
# Total size = Control Block + Data Segment
|
||||
self.total_size = SHMBufferConfig.SIZE + (size * self.slot_size)
|
||||
|
|
@ -92,10 +100,16 @@ class ZeroCopySHMBuffer:
|
|||
# We only update these two fields (Offsets: ReadIdx=8, WriteIdx=12)
|
||||
struct.pack_into("II", self.buf, 8, read_idx, write_idx)
|
||||
|
||||
def write_trajectory(self, tokens: List[int], score: float, metadata: Dict[str, Any] = None):
|
||||
def write_trajectory(
|
||||
self,
|
||||
tokens: List[int],
|
||||
score: float,
|
||||
instance_id: str = "",
|
||||
repetition_id: int = 0,
|
||||
metadata: Dict[str, Any] = None
|
||||
):
|
||||
"""
|
||||
Writes a trajectory, its score, and metadata to the buffer.
|
||||
Schema: [Score (8 bytes) | TokenLen (4 bytes) | Tokens (EntrySize * 4 bytes)]
|
||||
Writes a trajectory and its rich metadata to the buffer.
|
||||
"""
|
||||
read_idx, write_idx, max_size, entry_size = self._get_control()
|
||||
|
||||
|
|
@ -108,15 +122,26 @@ class ZeroCopySHMBuffer:
|
|||
# Calculate offset in data segment
|
||||
offset = SHMBufferConfig.SIZE + (write_idx * self.slot_size)
|
||||
|
||||
# Write Score
|
||||
# write Score (8)
|
||||
struct.pack_into("d", self.buf, offset, float(score))
|
||||
|
||||
# Write Token Length
|
||||
# write Token Length (4)
|
||||
token_len = min(len(tokens), entry_size)
|
||||
struct.pack_into("i", self.buf, offset + 8, token_len)
|
||||
|
||||
# Write Tokens (Numpy View)
|
||||
token_offset = offset + 8 + 4
|
||||
# write Instance ID (fixed len)
|
||||
id_bytes = instance_id.encode('utf-8')[:self.instance_id_len]
|
||||
struct.pack_into(f"{self.instance_id_len}s", self.buf, offset + 12, id_bytes)
|
||||
|
||||
# write Repetition ID (4)
|
||||
struct.pack_into("i", self.buf, offset + 12 + self.instance_id_len, int(repetition_id))
|
||||
|
||||
# write Metadata (fixed len JSON)
|
||||
meta_json = json.dumps(metadata or {}).encode('utf-8')[:self.metadata_len]
|
||||
struct.pack_into(f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4, meta_json)
|
||||
|
||||
# write Tokens (Numpy View)
|
||||
token_offset = offset + 12 + self.instance_id_len + 4 + self.metadata_len
|
||||
token_arr = np.array(tokens, dtype=np.int32)
|
||||
|
||||
# View the SHM as a numpy array for the specific token slot
|
||||
|
|
@ -145,8 +170,22 @@ class ZeroCopySHMBuffer:
|
|||
token_len = struct.unpack_from("i", self.buf, offset + 8)[0]
|
||||
token_len = min(token_len, entry_size)
|
||||
|
||||
# Read Instance ID
|
||||
id_bytes = struct.unpack_from(f"{self.instance_id_len}s", self.buf, offset + 12)[0]
|
||||
instance_id = id_bytes.decode('utf-8', errors='ignore').strip('\x00')
|
||||
|
||||
# Read Repetition ID
|
||||
repetition_id = struct.unpack_from("i", self.buf, offset + 12 + self.instance_id_len)[0]
|
||||
|
||||
# Read Metadata
|
||||
meta_bytes = struct.unpack_from(f"{self.metadata_len}s", self.buf, offset + 12 + self.instance_id_len + 4)[0]
|
||||
try:
|
||||
metadata = json.loads(meta_bytes.decode('utf-8', errors='ignore').strip('\x00'))
|
||||
except:
|
||||
metadata = {}
|
||||
|
||||
# Read Tokens (Numpy View)
|
||||
token_offset = offset + 8 + 4
|
||||
token_offset = offset + 12 + self.instance_id_len + 4 + self.metadata_len
|
||||
tokens_view = np.ndarray((token_len,), dtype=np.int32, buffer=self.buf, offset=token_offset)
|
||||
|
||||
# Advance read index
|
||||
|
|
@ -154,10 +193,14 @@ class ZeroCopySHMBuffer:
|
|||
|
||||
return {
|
||||
"tokens": tokens_view.tolist(),
|
||||
"score": score
|
||||
"score": score,
|
||||
"instance_id": instance_id,
|
||||
"repetition_id": repetition_id,
|
||||
"metadata": metadata
|
||||
}
|
||||
|
||||
def close(self, unlink: bool = False):
|
||||
self.shm.close()
|
||||
if unlink:
|
||||
self.shm.unlink()
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from .server_handling.server_manager import (
|
|||
ServerManager,
|
||||
ServerManagerConfig,
|
||||
)
|
||||
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
|
@ -97,6 +98,15 @@ class EvalHandlingEnum(Enum):
|
|||
NONE = "NONE"
|
||||
|
||||
|
||||
class TransportType(Enum):
|
||||
"""
|
||||
Enum for trajectory transport types.
|
||||
"""
|
||||
|
||||
HTTP = "HTTP"
|
||||
SHM = "SHM"
|
||||
|
||||
|
||||
class BaseEnvConfig(BaseModel):
|
||||
"""
|
||||
Basic env configuration.
|
||||
|
|
@ -211,6 +221,22 @@ class BaseEnvConfig(BaseModel):
|
|||
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
|
||||
"eval_helpers for the standard Hermes reasoning prompt.",
|
||||
)
|
||||
transport: TransportType = Field(
|
||||
default=TransportType.HTTP,
|
||||
description="Transport protocol for trajectories (HTTP or SHM).",
|
||||
)
|
||||
shm_name: str = Field(
|
||||
default="atropos_shm",
|
||||
description="Name of the Shared Memory segment (if transport is SHM).",
|
||||
)
|
||||
shm_size: int = Field(
|
||||
default=1000,
|
||||
description="Number of slots in the SHM buffer.",
|
||||
)
|
||||
state_injection_template: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Template for state injection (e.g. 'Terminal output: {state}').",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
|
@ -296,6 +322,17 @@ class BaseEnv(ABC):
|
|||
else:
|
||||
self.jsonl_writer = None
|
||||
|
||||
# Initialize SHM buffer if configured
|
||||
self.shm_buffer = None
|
||||
if self.config.transport == TransportType.SHM:
|
||||
self.shm_buffer = ZeroCopySHMBuffer(
|
||||
name=self.config.shm_name,
|
||||
size=self.config.shm_size,
|
||||
entry_size=self.config.max_token_length,
|
||||
create=True, # Env manager usually acts as the creator
|
||||
)
|
||||
logger.info("Universal SHM transport initialized: %s", self.config.shm_name)
|
||||
|
||||
@property
|
||||
def derived_batch_size(self):
|
||||
"""Calculate the effective batch size for this environment based on minimum allocations."""
|
||||
|
|
@ -382,8 +419,30 @@ class BaseEnv(ABC):
|
|||
if result[0].get("images", None) is not None:
|
||||
to_postprocess["images"].append(result[0]["images"])
|
||||
backlog.extend(result[1])
|
||||
|
||||
# Apply Raw State Injection if configured
|
||||
if self.config.state_injection_template:
|
||||
to_postprocess = self._inject_state(to_postprocess, item)
|
||||
|
||||
return to_postprocess, backlog
|
||||
|
||||
def _inject_state(self, group: ScoredDataGroup, item: Item) -> ScoredDataGroup:
|
||||
"""
|
||||
Injects raw environment/terminal state into the prompt as per Teknium's feedback.
|
||||
"""
|
||||
state = getattr(item, "state", str(item))
|
||||
injection = self.config.state_injection_template.format(state=state)
|
||||
|
||||
for i in range(len(group["tokens"])):
|
||||
# Decode, inject, and re-encode (or prepend tokens if possible)
|
||||
decoded = self.tokenizer.decode(group["tokens"][i])
|
||||
if injection not in decoded:
|
||||
new_text = f"{injection}\n\n{decoded}"
|
||||
group["tokens"][i] = self.tokenizer.encode(new_text)
|
||||
# Adjust masks if necessary (heuristic: keep mask same length for now)
|
||||
# In a real scenario, we might need to properly re-mask.
|
||||
return group
|
||||
|
||||
async def postprocess_histories(
|
||||
self,
|
||||
trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
|
||||
|
|
@ -795,17 +854,33 @@ class BaseEnv(ABC):
|
|||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(multiplier=1, max=10),
|
||||
)
|
||||
async def _send_scored_data_to_api(self, scored_data):
|
||||
async def _dispatch_scored_data(self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]]):
|
||||
"""
|
||||
Send scored data to the API with retry logic for timeouts and server errors.
|
||||
Dispatches scored data to the configured transport (HTTP or SHM).
|
||||
"""
|
||||
# Add env_id to the data
|
||||
if isinstance(scored_data, list):
|
||||
for item in scored_data:
|
||||
item["env_id"] = getattr(self, "env_id", None)
|
||||
else:
|
||||
scored_data["env_id"] = getattr(self, "env_id", None)
|
||||
env_id = getattr(self, "env_id", None)
|
||||
data_list = scored_data if isinstance(scored_data, list) else [scored_data]
|
||||
for item in data_list:
|
||||
item["env_id"] = env_id
|
||||
|
||||
if self.config.transport == TransportType.SHM and self.shm_buffer:
|
||||
for group in data_list:
|
||||
for i in range(len(group["tokens"])):
|
||||
# Write each rollout in the group to the circular buffer
|
||||
self.shm_buffer.write_trajectory(
|
||||
tokens=group["tokens"][i],
|
||||
score=group["scores"][i],
|
||||
instance_id=str(env_id or "unknown"),
|
||||
metadata={
|
||||
"env": self.name,
|
||||
"step": self.curr_step,
|
||||
"group_idx": i
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Fallback to HTTP
|
||||
url = (
|
||||
f"{self.config.rollout_server_url}/scored_data_list"
|
||||
if isinstance(scored_data, list)
|
||||
|
|
@ -943,7 +1018,7 @@ class BaseEnv(ABC):
|
|||
|
||||
try:
|
||||
self.items_sent_this_step += len(valid_groups)
|
||||
await self._send_scored_data_to_api(data_to_send_to_api)
|
||||
await self._dispatch_scored_data(data_to_send_to_api)
|
||||
except (Exception, TimeoutError) as e:
|
||||
data_type_str = (
|
||||
"single ScoredDataGroup"
|
||||
|
|
@ -1006,7 +1081,9 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
Optional: Cleanup the environment
|
||||
"""
|
||||
pass
|
||||
if self.shm_buffer:
|
||||
logger.info("Cleaning up Universal SHM transport: %s", self.config.shm_name)
|
||||
self.shm_buffer.close(unlink=True)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)
|
||||
|
|
|
|||
|
|
@ -25,10 +25,19 @@ def mock_env_worker(worker_id: int, shm_name: str, barrier: mp.Barrier, stop_eve
|
|||
count = 0
|
||||
while not stop_event.is_set() and count < (TOTAL_TRAJECTORIES // NUM_ENV_WORKERS):
|
||||
# Simulate REAL Reasoning model trace (4k tokens)
|
||||
tokens = [100 + i for i in range(4096)]
|
||||
tokens = [100 + i for i in range(ENTRY_SIZE)]
|
||||
score = 0.8 + (worker_id * 0.05)
|
||||
instance_id = f"task_{count}"
|
||||
repetition_id = worker_id
|
||||
metadata = {"worker": worker_id, "timestamp": time.time()}
|
||||
|
||||
success = shm.write_trajectory(tokens=tokens, score=score)
|
||||
success = shm.write_trajectory(
|
||||
tokens=tokens,
|
||||
score=score,
|
||||
instance_id=instance_id,
|
||||
repetition_id=repetition_id,
|
||||
metadata=metadata
|
||||
)
|
||||
if success:
|
||||
count += 1
|
||||
else:
|
||||
|
|
@ -43,7 +52,7 @@ def run_e2e_benchmark():
|
|||
1. Setup SHM
|
||||
2. Launch Concurrency Workers
|
||||
3. Measure Reader Throughput
|
||||
4. Compare with HTTP Baseline Simulation
|
||||
4. Verify Data Integrity (IDs and Metadata)
|
||||
"""
|
||||
shm_name = f"test_e2e_shm_{uuid.uuid4().hex[:8]}"
|
||||
shm = ZeroCopySHMBuffer(name=shm_name, size=BATCH_SIZE * 2, entry_size=ENTRY_SIZE, create=True)
|
||||
|
|
@ -62,12 +71,19 @@ def run_e2e_benchmark():
|
|||
barrier.wait() # Start the race
|
||||
|
||||
# Throughput Benchmark (SHM)
|
||||
print("📈 Measuring SHM Throughput...")
|
||||
print("📈 Measuring SHM Throughput & Integrity...")
|
||||
start_shm = time.perf_counter()
|
||||
received = 0
|
||||
verification_passed = True
|
||||
|
||||
while received < TOTAL_TRAJECTORIES:
|
||||
data = shm.read_next()
|
||||
if data:
|
||||
# Verify Metadata Integrity for a sample
|
||||
if received % 100 == 0:
|
||||
if not (data["instance_id"].startswith("task_") and "worker" in data["metadata"]):
|
||||
print(f"❌ Integrity Check Failed at index {received}!")
|
||||
verification_passed = False
|
||||
received += 1
|
||||
else:
|
||||
if all(not p.is_alive() for p in workers) and received < TOTAL_TRAJECTORIES:
|
||||
|
|
@ -77,14 +93,21 @@ def run_e2e_benchmark():
|
|||
shm_time = end_shm - start_shm
|
||||
shm_tps = TOTAL_TRAJECTORIES / shm_time
|
||||
print(f" [SHM] Received {received} trajectories in {shm_time:.4f}s ({shm_tps:.2f} traj/s)")
|
||||
print(f" [SHM] Integrity Verification: {'✅ PASSED' if verification_passed else '❌ FAILED'}")
|
||||
|
||||
# HTTP Baseline Simulation
|
||||
print("📉 Measuring HTTP Baseline Simulation (JSON Tax)...")
|
||||
start_http = time.perf_counter()
|
||||
for _ in range(TOTAL_TRAJECTORIES):
|
||||
# Simulate JSON Serialization + Dummy HTTP Request
|
||||
tokens = [100 + i for i in range(10)]
|
||||
payload = json.dumps({"tokens": tokens, "score": 0.8})
|
||||
tokens = [100 + i for i in range(ENTRY_SIZE)]
|
||||
payload = json.dumps({
|
||||
"tokens": tokens,
|
||||
"score": 0.8,
|
||||
"instance_id": "task_x",
|
||||
"repetition_id": 0,
|
||||
"metadata": {"foo": "bar"}
|
||||
})
|
||||
_ = json.loads(payload) # Deserialization
|
||||
|
||||
end_http = time.perf_counter()
|
||||
|
|
@ -98,11 +121,13 @@ def run_e2e_benchmark():
|
|||
print("="*40)
|
||||
print(f"SHM Throughput Gain: {shm_tps / http_tps:.2f}x")
|
||||
print(f"Concurrency Load: {NUM_ENV_WORKERS} workers handled without corruption.")
|
||||
print(f"Data Integrity: {'Verified' if verification_passed else 'CORRUPT'}")
|
||||
print("="*40)
|
||||
|
||||
stop_event.set()
|
||||
for p in workers: p.join()
|
||||
shm.close(unlink=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_e2e_benchmark()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue