mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
48dcd64299
commit
51b4aa4858
6 changed files with 147 additions and 99 deletions
|
|
@ -27,6 +27,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from transformers import AutoTokenizer
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
|
||||
from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
|
||||
from atroposlib.frontend.jsonl2html import generate_html
|
||||
|
|
@ -49,7 +50,6 @@ from .server_handling.server_manager import (
|
|||
ServerManager,
|
||||
ServerManagerConfig,
|
||||
)
|
||||
from atroposlib.api.shm_buffer import ZeroCopySHMBuffer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
|
@ -415,11 +415,11 @@ class BaseEnv(ABC):
|
|||
if result[0].get("images", None) is not None:
|
||||
to_postprocess["images"].append(result[0]["images"])
|
||||
backlog.extend(result[1])
|
||||
|
||||
|
||||
# Apply Raw State Injection if configured
|
||||
if self.config.state_injection_template:
|
||||
to_postprocess = self._inject_state(to_postprocess, item)
|
||||
|
||||
|
||||
return to_postprocess, backlog
|
||||
|
||||
def _inject_state(self, group: ScoredDataGroup, item: Item) -> ScoredDataGroup:
|
||||
|
|
@ -428,7 +428,7 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
state = getattr(item, "state", str(item))
|
||||
injection = self.config.state_injection_template.format(state=state)
|
||||
|
||||
|
||||
for i in range(len(group["tokens"])):
|
||||
# Decode, inject, and re-encode (or prepend tokens if possible)
|
||||
decoded = self.tokenizer.decode(group["tokens"][i])
|
||||
|
|
@ -850,7 +850,9 @@ class BaseEnv(ABC):
|
|||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(multiplier=1, max=10),
|
||||
)
|
||||
async def _dispatch_scored_data(self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]]):
|
||||
async def _dispatch_scored_data(
|
||||
self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]]
|
||||
):
|
||||
"""
|
||||
Dispatches scored data to the configured transport (HTTP or SHM).
|
||||
"""
|
||||
|
|
@ -870,7 +872,7 @@ class BaseEnv(ABC):
|
|||
score=group["scores"][i] if i < len(group["scores"]) else 0.0,
|
||||
instance_id=inst_id,
|
||||
repetition_id=i,
|
||||
metadata={"env": self.name, "env_id": env_id}
|
||||
metadata={"env": self.name, "env_id": env_id},
|
||||
)
|
||||
return
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue