feat(skyrl-shm): Universal SHM transport and Raw State Injection

This commit is contained in:
RUFFY-369 2026-04-06 12:23:32 +05:30
parent 210883f0da
commit 43a5cdcdfc
3 changed files with 170 additions and 25 deletions

View file

@ -49,6 +49,7 @@ from .server_handling.server_manager import (
ServerManager,
ServerManagerConfig,
)
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ -97,6 +98,15 @@ class EvalHandlingEnum(Enum):
NONE = "NONE"
class TransportType(Enum):
"""
Enum for trajectory transport types.
"""
HTTP = "HTTP"
SHM = "SHM"
class BaseEnvConfig(BaseModel):
"""
Basic env configuration.
@ -211,6 +221,22 @@ class BaseEnvConfig(BaseModel):
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
"eval_helpers for the standard Hermes reasoning prompt.",
)
transport: TransportType = Field(
default=TransportType.HTTP,
description="Transport protocol for trajectories (HTTP or SHM).",
)
shm_name: str = Field(
default="atropos_shm",
description="Name of the Shared Memory segment (if transport is SHM).",
)
shm_size: int = Field(
default=1000,
description="Number of slots in the SHM buffer.",
)
state_injection_template: Optional[str] = Field(
default=None,
description="Template for state injection (e.g. 'Terminal output: {state}').",
)
class BaseEnv(ABC):
@ -296,6 +322,17 @@ class BaseEnv(ABC):
else:
self.jsonl_writer = None
# Initialize SHM buffer if configured
self.shm_buffer = None
if self.config.transport == TransportType.SHM:
self.shm_buffer = ZeroCopySHMBuffer(
name=self.config.shm_name,
size=self.config.shm_size,
entry_size=self.config.max_token_length,
create=True, # Env manager usually acts as the creator
)
logger.info("Universal SHM transport initialized: %s", self.config.shm_name)
@property
def derived_batch_size(self):
"""Calculate the effective batch size for this environment based on minimum allocations."""
@ -382,8 +419,30 @@ class BaseEnv(ABC):
if result[0].get("images", None) is not None:
to_postprocess["images"].append(result[0]["images"])
backlog.extend(result[1])
# Apply Raw State Injection if configured
if self.config.state_injection_template:
to_postprocess = self._inject_state(to_postprocess, item)
return to_postprocess, backlog
def _inject_state(self, group: ScoredDataGroup, item: Item) -> ScoredDataGroup:
"""
Injects raw environment/terminal state into the prompt as per Teknium's feedback.
"""
state = getattr(item, "state", str(item))
injection = self.config.state_injection_template.format(state=state)
for i in range(len(group["tokens"])):
# Decode, inject, and re-encode (or prepend tokens if possible)
decoded = self.tokenizer.decode(group["tokens"][i])
if injection not in decoded:
new_text = f"{injection}\n\n{decoded}"
group["tokens"][i] = self.tokenizer.encode(new_text)
# Adjust masks if necessary (heuristic: keep mask same length for now)
# In a real scenario, we might need to properly re-mask.
return group
async def postprocess_histories(
self,
trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
@ -795,17 +854,33 @@ class BaseEnv(ABC):
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
)
async def _send_scored_data_to_api(self, scored_data):
async def _dispatch_scored_data(self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]]):
"""
Send scored data to the API with retry logic for timeouts and server errors.
Dispatches scored data to the configured transport (HTTP or SHM).
"""
# Add env_id to the data
if isinstance(scored_data, list):
for item in scored_data:
item["env_id"] = getattr(self, "env_id", None)
else:
scored_data["env_id"] = getattr(self, "env_id", None)
env_id = getattr(self, "env_id", None)
data_list = scored_data if isinstance(scored_data, list) else [scored_data]
for item in data_list:
item["env_id"] = env_id
if self.config.transport == TransportType.SHM and self.shm_buffer:
for group in data_list:
for i in range(len(group["tokens"])):
# Write each rollout in the group to the circular buffer
self.shm_buffer.write_trajectory(
tokens=group["tokens"][i],
score=group["scores"][i],
instance_id=str(env_id or "unknown"),
metadata={
"env": self.name,
"step": self.curr_step,
"group_idx": i
}
)
return
# Fallback to HTTP
url = (
f"{self.config.rollout_server_url}/scored_data_list"
if isinstance(scored_data, list)
@ -943,7 +1018,7 @@ class BaseEnv(ABC):
try:
self.items_sent_this_step += len(valid_groups)
await self._send_scored_data_to_api(data_to_send_to_api)
await self._dispatch_scored_data(data_to_send_to_api)
except (Exception, TimeoutError) as e:
data_type_str = (
"single ScoredDataGroup"
@ -1006,7 +1081,9 @@ class BaseEnv(ABC):
"""
Optional: Cleanup the environment
"""
pass
if self.shm_buffer:
logger.info("Cleaning up Universal SHM transport: %s", self.config.shm_name)
self.shm_buffer.close(unlink=True)
@retry(
stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10)