mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
feat(shm): bugfix and code cleanup
This commit is contained in:
parent
43a5cdcdfc
commit
48dcd64299
5 changed files with 38 additions and 83 deletions
|
|
@ -99,10 +99,6 @@ class EvalHandlingEnum(Enum):
|
|||
|
||||
|
||||
class TransportType(Enum):
|
||||
"""
|
||||
Enum for trajectory transport types.
|
||||
"""
|
||||
|
||||
HTTP = "HTTP"
|
||||
SHM = "SHM"
|
||||
|
||||
|
|
@ -866,17 +862,15 @@ class BaseEnv(ABC):
|
|||
|
||||
if self.config.transport == TransportType.SHM and self.shm_buffer:
|
||||
for group in data_list:
|
||||
# Use the provided instance_id (Task ID) if available, fallback to env_id
|
||||
inst_id = str(group.get("instance_id") or env_id or "unknown")
|
||||
for i in range(len(group["tokens"])):
|
||||
# Write each rollout in the group to the circular buffer
|
||||
self.shm_buffer.write_trajectory(
|
||||
tokens=group["tokens"][i],
|
||||
score=group["scores"][i],
|
||||
instance_id=str(env_id or "unknown"),
|
||||
metadata={
|
||||
"env": self.name,
|
||||
"step": self.curr_step,
|
||||
"group_idx": i
|
||||
}
|
||||
score=group["scores"][i] if i < len(group["scores"]) else 0.0,
|
||||
instance_id=inst_id,
|
||||
repetition_id=i,
|
||||
metadata={"env": self.name, "env_id": env_id}
|
||||
)
|
||||
return
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue