mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Update base.py
This commit is contained in:
parent
301cc03b9d
commit
1848c7d453
1 changed files with 43 additions and 26 deletions
|
|
@ -365,35 +365,52 @@ class BaseEnv(ABC):
|
|||
)
|
||||
break
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(multiplier=1, max=10),
|
||||
)
|
||||
async def _register_env(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.config.rollout_server_url}/register-env",
|
||||
json={
|
||||
"max_token_length": self.config.max_token_length,
|
||||
"desired_name": self.config.wandb_name,
|
||||
"weight": self.config.inference_weight,
|
||||
},
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering env: {e}")
|
||||
raise e
|
||||
|
||||
async def register_env(self):
|
||||
# Now register the env...
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.config.rollout_server_url}/register-env",
|
||||
json={
|
||||
"max_token_length": self.config.max_token_length,
|
||||
"desired_name": self.config.wandb_name,
|
||||
"weight": self.config.inference_weight,
|
||||
},
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
self.env_id = data["env_id"]
|
||||
self.wandb_prepend = data["wandb_name"]
|
||||
self.curr_step = data["starting_step"]
|
||||
self.checkpoint_dir = data["checkpoint_dir"]
|
||||
self.checkpoint_interval = data["checkpoint_interval"]
|
||||
while True:
|
||||
data = await self._register_env()
|
||||
if data['status'] != "success":
|
||||
logging.warning(f"Waiting to register the env due to status {data['status']}")
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
self.env_id = data["env_id"]
|
||||
self.wandb_prepend = data["wandb_name"]
|
||||
self.curr_step = data["starting_step"]
|
||||
self.checkpoint_dir = data["checkpoint_dir"]
|
||||
self.checkpoint_interval = data["checkpoint_interval"]
|
||||
if self.config.total_steps == -1:
|
||||
self.config.total_steps = data["num_steps"]
|
||||
if self.config.total_steps == -1:
|
||||
self.config.total_steps = data["num_steps"]
|
||||
if self.config.total_steps == -1:
|
||||
raise ValueError("Total steps not set in config or server!")
|
||||
print(
|
||||
f"Initialized env with id {self.env_id}: "
|
||||
f"curr_step: {self.curr_step}, "
|
||||
f"checkpoint_dir: {self.checkpoint_dir}, "
|
||||
f"checkpoint_interval: {self.checkpoint_interval}"
|
||||
)
|
||||
if self.curr_step > 0:
|
||||
self.load_checkpoint()
|
||||
raise ValueError("Total steps not set in config or server!")
|
||||
print(
|
||||
f"Initialized env with id {self.env_id}: "
|
||||
f"curr_step: {self.curr_step}, "
|
||||
f"checkpoint_dir: {self.checkpoint_dir}, "
|
||||
f"checkpoint_interval: {self.checkpoint_interval}"
|
||||
)
|
||||
if self.curr_step > 0:
|
||||
self.load_checkpoint()
|
||||
|
||||
async def get_server_info(self):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue