[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-04-06 20:55:42 +00:00
parent 48dcd64299
commit 51b4aa4858
6 changed files with 147 additions and 99 deletions

View file

@ -27,6 +27,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import AutoTokenizer
from typing_extensions import TypedDict
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
from atroposlib.frontend.jsonl2html import generate_html
@ -49,7 +50,6 @@ from .server_handling.server_manager import (
ServerManager,
ServerManagerConfig,
)
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ -415,11 +415,11 @@ class BaseEnv(ABC):
if result[0].get("images", None) is not None:
to_postprocess["images"].append(result[0]["images"])
backlog.extend(result[1])
# Apply Raw State Injection if configured
if self.config.state_injection_template:
to_postprocess = self._inject_state(to_postprocess, item)
return to_postprocess, backlog
def _inject_state(self, group: ScoredDataGroup, item: Item) -> ScoredDataGroup:
@ -428,7 +428,7 @@ class BaseEnv(ABC):
"""
state = getattr(item, "state", str(item))
injection = self.config.state_injection_template.format(state=state)
for i in range(len(group["tokens"])):
# Decode, inject, and re-encode (or prepend tokens if possible)
decoded = self.tokenizer.decode(group["tokens"][i])
@ -850,7 +850,9 @@ class BaseEnv(ABC):
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
)
async def _dispatch_scored_data(self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]]):
async def _dispatch_scored_data(
self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]]
):
"""
Dispatches scored data to the configured transport (HTTP or SHM).
"""
@ -870,7 +872,7 @@ class BaseEnv(ABC):
score=group["scores"][i] if i < len(group["scores"]) else 0.0,
instance_id=inst_id,
repetition_id=i,
metadata={"env": self.name, "env_id": env_id}
metadata={"env": self.name, "env_id": env_id},
)
return