diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 8cda4abd..0aa1c70f 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -100,6 +100,15 @@ async def register(registration: Registration): @app.post("/register-env") async def register_env_url(register_env: RegisterEnv): + try: + if not app.state.started: + return { + "status": "wait for trainer to start", + } + except AttributeError: + return { + "status": "wait for trainer to start", + } try: isinstance(app.state.envs, list) except AttributeError: diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index b9b9def7..ba8b8ee4 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -365,35 +365,53 @@ class BaseEnv(ABC): ) break + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=1, max=10), + ) + async def _register_env(self): + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.config.rollout_server_url}/register-env", + json={ + "max_token_length": self.config.max_token_length, + "desired_name": self.config.wandb_name, + "weight": self.config.inference_weight, + }, + ) as resp: + data = await resp.json() + return data + except Exception as e: + logger.error(f"Error registering env: {e}") + raise e + async def register_env(self): # Now register the env... - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.config.rollout_server_url}/register-env", - json={ - "max_token_length": self.config.max_token_length, - "desired_name": self.config.wandb_name, - "weight": self.config.inference_weight, - }, - ) as resp: - data = await resp.json() - self.env_id = data["env_id"] - self.wandb_prepend = data["wandb_name"] - self.curr_step = data["starting_step"] - self.checkpoint_dir = data["checkpoint_dir"] - self.checkpoint_interval = data["checkpoint_interval"] + while True: + data = await self._register_env() + if data['status'] != "success": + logging.warning(f"Waiting to register the env due to status {data['status']}") + await asyncio.sleep(1) + continue + self.env_id = data["env_id"] + self.wandb_prepend = data["wandb_name"] + self.curr_step = data["starting_step"] + self.checkpoint_dir = data["checkpoint_dir"] + self.checkpoint_interval = data["checkpoint_interval"] + if self.config.total_steps == -1: + self.config.total_steps = data["num_steps"] if self.config.total_steps == -1: - self.config.total_steps = data["num_steps"] - if self.config.total_steps == -1: - raise ValueError("Total steps not set in config or server!") - print( - f"Initialized env with id {self.env_id}: " - f"curr_step: {self.curr_step}, " - f"checkpoint_dir: {self.checkpoint_dir}, " - f"checkpoint_interval: {self.checkpoint_interval}" - ) - if self.curr_step > 0: - self.load_checkpoint() + raise ValueError("Total steps not set in config or server!") + print( + f"Initialized env with id {self.env_id}: " + f"curr_step: {self.curr_step}, " + f"checkpoint_dir: {self.checkpoint_dir}, " + f"checkpoint_interval: {self.checkpoint_interval}" + ) + if self.curr_step > 0: + self.load_checkpoint() + break async def get_server_info(self): """