diff --git a/atroposlib/cli/dpo.py b/atroposlib/cli/dpo.py index bd409c18..15452f24 100644 --- a/atroposlib/cli/dpo.py +++ b/atroposlib/cli/dpo.py @@ -8,6 +8,8 @@ import jsonlines from tqdm.asyncio import tqdm # Import tqdm for async from transformers import AutoTokenizer +from atroposlib.utils.io import parse_http_response + def find_common_prefix(strings): """ @@ -80,7 +82,7 @@ async def check_for_batch(api_url): while True: async with aiohttp.ClientSession() as session: async with session.get(f"{api_url}/batch") as response: - data = await response.json() + data = await parse_http_response(response) if data["batch"] is not None: return data["batch"] await asyncio.sleep(1) # Wait before polling again diff --git a/atroposlib/cli/sft.py b/atroposlib/cli/sft.py index b58badb1..c5781b32 100644 --- a/atroposlib/cli/sft.py +++ b/atroposlib/cli/sft.py @@ -7,6 +7,8 @@ import jsonlines from tqdm.asyncio import tqdm # Import tqdm for async from transformers import AutoTokenizer +from atroposlib.utils.io import parse_http_response + def find_common_prefix(strings): """ @@ -79,7 +81,7 @@ async def check_for_batch(api_url): while True: async with aiohttp.ClientSession() as session: async with session.get(f"{api_url}/batch") as response: - data = await response.json() + data = await parse_http_response(response) if data["batch"] is not None: return data["batch"] await asyncio.sleep(1) # Wait before polling again diff --git a/atroposlib/cli/view_run.py b/atroposlib/cli/view_run.py index 42462355..e22bcbc9 100644 --- a/atroposlib/cli/view_run.py +++ b/atroposlib/cli/view_run.py @@ -5,6 +5,8 @@ import aiohttp import gradio as gr from transformers import AutoTokenizer +from atroposlib.utils.io import parse_http_response + def find_common_prefix(strings): if not strings: @@ -46,7 +48,7 @@ async def check_for_batch(): while True: async with aiohttp.ClientSession() as session: async with session.get("http://localhost:8000/batch") as response: - data = await response.json() + data = await parse_http_response(response) print(data) if data["batch"] is not None: return data["batch"] diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 0577363a..87a640ab 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -35,6 +35,7 @@ from atroposlib.utils.cli import ( get_prefixed_pydantic_model, merge_dicts, ) +from atroposlib.utils.io import parse_http_response from atroposlib.utils.metrics import get_std_min_max_avg from ..type_definitions import Item, Message @@ -349,7 +350,7 @@ class BaseEnv(ABC): async with session.get( f"{self.config.rollout_server_url}/wandb_info" ) as resp: - data = await resp.json() + data = await parse_http_response(resp, logger) self.wandb_group = data["group"] self.wandb_project = data["project"] if self.wandb_project is None: @@ -377,7 +378,7 @@ class BaseEnv(ABC): "weight": self.config.inference_weight, }, ) as resp: - data = await resp.json() + data = await parse_http_response(resp, logger) return data except Exception as e: logger.error(f"Error registering env: {e}") @@ -418,7 +419,7 @@ class BaseEnv(ABC): """ async with aiohttp.ClientSession() as session: async with session.get(f"{self.config.rollout_server_url}/info") as resp: - data = await resp.json() + data = await parse_http_response(resp, logger) if data["batch_size"] != -1: # update the batch size self.config.batch_size = data["batch_size"] @@ -727,7 +728,7 @@ class BaseEnv(ABC): f"{self.config.rollout_server_url}/status-env", json={"env_id": self.env_id}, ) as resp: - self.status_dict = await resp.json() + self.status_dict = await parse_http_response(resp, logger) new_weight = self.status_dict["env_weight"] max_num_workers = self.config.max_num_workers if max_num_workers == -1: diff --git a/atroposlib/utils/io.py b/atroposlib/utils/io.py new file mode 100644 index 00000000..d8ae55bd --- /dev/null +++ b/atroposlib/utils/io.py @@ -0,0 +1,39 @@ +import logging +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +async def parse_http_response( + resp: Any, logger: Optional[logging.Logger] = None +) -> Any: + """ + Parse an HTTP response with proper error handling and logging. + + Args: + resp: The HTTP response object (must have raise_for_status() and json() methods) + logger: Optional logger instance. If not provided, uses the default module logger. + + Returns: + The parsed JSON response + + Raises: + Exception: Re-raises any exceptions that occur during parsing + """ + if logger is None: + logger = logging.getLogger(__name__) + + try: + # Raise an exception for bad status codes (4xx or 5xx) + resp.raise_for_status() + # Attempt to parse the response as JSON + return await resp.json() + except Exception as e: + # Handle HTTP errors (raised by raise_for_status) + error_text = await resp.text() # Read the response text for logging + logger.error( + f"Error fetching from server. Status: {getattr(e, 'status', 'unknown')}, " + f"Message: {getattr(e, 'message', str(e))}. Response: {error_text}" + ) + # Re-raise the exception to allow retry decorators to handle it + raise