replace await resp.json() with await parse_http_response(resp)

This commit is contained in:
hjc-puro 2025-05-03 06:36:05 -04:00
parent a8b59ccc9b
commit e06469f8c2
5 changed files with 52 additions and 6 deletions

View file

@ -38,6 +38,7 @@ from atroposlib.utils.cli import (
get_prefixed_pydantic_model,
merge_dicts,
)
from atroposlib.utils.io import parse_http_response
from atroposlib.utils.metrics import get_std_min_max_avg
from ..type_definitions import Item, Message
@ -376,7 +377,7 @@ class BaseEnv(ABC):
"weight": self.config.inference_weight,
},
) as resp:
data = await resp.json()
data = await parse_http_response(resp, logger)
self.env_id = data["env_id"]
self.wandb_prepend = data["wandb_name"]
self.curr_step = data["starting_step"]
@ -401,7 +402,7 @@ class BaseEnv(ABC):
"""
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.config.rollout_server_url}/info") as resp:
data = await resp.json()
data = await parse_http_response(resp, logger)
if data["batch_size"] != -1:
# update the batch size
self.config.batch_size = data["batch_size"]
@ -684,7 +685,7 @@ class BaseEnv(ABC):
f"{self.config.rollout_server_url}/status-env",
json={"env_id": self.env_id},
) as resp:
self.status_dict = await resp.json()
self.status_dict = await parse_http_response(resp, logger)
new_weight = self.status_dict["env_weight"]
max_num_workers = self.config.max_num_workers
if max_num_workers == -1: