diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index b9b9def7..31bb733b 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -365,35 +365,52 @@ class BaseEnv(ABC): ) break + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=1, max=10), + ) + async def _register_env(self): + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.config.rollout_server_url}/register-env", + json={ + "max_token_length": self.config.max_token_length, + "desired_name": self.config.wandb_name, + "weight": self.config.inference_weight, + }, + ) as resp: + data = await resp.json() + return data + except Exception as e: + logger.error(f"Error registering env: {e}") + raise e + async def register_env(self): # Now register the env... - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.config.rollout_server_url}/register-env", - json={ - "max_token_length": self.config.max_token_length, - "desired_name": self.config.wandb_name, - "weight": self.config.inference_weight, - }, - ) as resp: - data = await resp.json() - self.env_id = data["env_id"] - self.wandb_prepend = data["wandb_name"] - self.curr_step = data["starting_step"] - self.checkpoint_dir = data["checkpoint_dir"] - self.checkpoint_interval = data["checkpoint_interval"] + while True: + data = await self._register_env() + if data['status'] != "success": + logging.warning(f"Waiting to register the env due to status {data['status']}") + await asyncio.sleep(1) + continue + self.env_id = data["env_id"] + self.wandb_prepend = data["wandb_name"] + self.curr_step = data["starting_step"] + self.checkpoint_dir = data["checkpoint_dir"] + self.checkpoint_interval = data["checkpoint_interval"] + if self.config.total_steps == -1: + self.config.total_steps = data["num_steps"] if self.config.total_steps == -1: - self.config.total_steps = data["num_steps"] - if self.config.total_steps == -1: - raise ValueError("Total steps not set in config or server!") - print( - f"Initialized env with id {self.env_id}: " - f"curr_step: {self.curr_step}, " - f"checkpoint_dir: {self.checkpoint_dir}, " - f"checkpoint_interval: {self.checkpoint_interval}" - ) - if self.curr_step > 0: - self.load_checkpoint() + raise ValueError("Total steps not set in config or server!") + print( + f"Initialized env with id {self.env_id}: " + f"curr_step: {self.curr_step}, " + f"checkpoint_dir: {self.checkpoint_dir}, " + f"checkpoint_interval: {self.checkpoint_interval}" + ) + if self.curr_step > 0: + self.load_checkpoint() async def get_server_info(self): """