Update base.py

This commit is contained in:
dmahan93 2025-05-08 11:29:29 -05:00 committed by GitHub
parent 301cc03b9d
commit 1848c7d453
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -365,35 +365,52 @@ class BaseEnv(ABC):
)
break
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
)
async def _register_env(self):
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.rollout_server_url}/register-env",
json={
"max_token_length": self.config.max_token_length,
"desired_name": self.config.wandb_name,
"weight": self.config.inference_weight,
},
) as resp:
data = await resp.json()
return data
except Exception as e:
logger.error(f"Error registering env: {e}")
raise e
async def register_env(self):
# Now register the env...
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.rollout_server_url}/register-env",
json={
"max_token_length": self.config.max_token_length,
"desired_name": self.config.wandb_name,
"weight": self.config.inference_weight,
},
) as resp:
data = await resp.json()
self.env_id = data["env_id"]
self.wandb_prepend = data["wandb_name"]
self.curr_step = data["starting_step"]
self.checkpoint_dir = data["checkpoint_dir"]
self.checkpoint_interval = data["checkpoint_interval"]
while True:
data = await self._register_env()
if data['status'] != "success":
logging.warning(f"Waiting to register the env due to status {data['status']}")
await asyncio.sleep(1)
continue
self.env_id = data["env_id"]
self.wandb_prepend = data["wandb_name"]
self.curr_step = data["starting_step"]
self.checkpoint_dir = data["checkpoint_dir"]
self.checkpoint_interval = data["checkpoint_interval"]
if self.config.total_steps == -1:
self.config.total_steps = data["num_steps"]
if self.config.total_steps == -1:
self.config.total_steps = data["num_steps"]
if self.config.total_steps == -1:
raise ValueError("Total steps not set in config or server!")
print(
f"Initialized env with id {self.env_id}: "
f"curr_step: {self.curr_step}, "
f"checkpoint_dir: {self.checkpoint_dir}, "
f"checkpoint_interval: {self.checkpoint_interval}"
)
if self.curr_step > 0:
self.load_checkpoint()
raise ValueError("Total steps not set in config or server!")
print(
f"Initialized env with id {self.env_id}: "
f"curr_step: {self.curr_step}, "
f"checkpoint_dir: {self.checkpoint_dir}, "
f"checkpoint_interval: {self.checkpoint_interval}"
)
if self.curr_step > 0:
self.load_checkpoint()
async def get_server_info(self):
"""