mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
Include run name in wandb initialization in BaseEnv
This commit is contained in:
parent
2c8340bece
commit
14c70c0e68
1 changed files with 17 additions and 10 deletions
|
|
@ -394,21 +394,28 @@ class BaseEnv(ABC):
|
|||
# Setup wandb getting the group and project via the server
|
||||
while self.wandb_project is None:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.config.rollout_server_url}/wandb_info"
|
||||
) as resp:
|
||||
async with session.get(f"{self.config.rollout_server_url}/wandb_info") as resp:
|
||||
data = await parse_http_response(resp, logger)
|
||||
self.wandb_group = data["group"]
|
||||
self.wandb_project = data["project"]
|
||||
|
||||
if self.wandb_project is None:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
wandb.init(
|
||||
project=self.wandb_project,
|
||||
group=self.wandb_group,
|
||||
config=self.config.model_dump(),
|
||||
)
|
||||
break
|
||||
continue
|
||||
|
||||
wandb_run_name = None
|
||||
if self.config.wandb_name:
|
||||
random_id = "".join(random.choices(string.ascii_lowercase, k=6))
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
wandb_run_name = f"{self.config.wandb_name}-{current_date}-{random_id}"
|
||||
|
||||
wandb.init(
|
||||
name=wandb_run_name,
|
||||
project=self.wandb_project,
|
||||
group=self.wandb_group,
|
||||
config=self.config.model_dump(),
|
||||
)
|
||||
break
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue