mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge pull request #204 from NousResearch/multienv-enforce-mins
Multienv with enforced minimum samples in a batch
This commit is contained in:
commit
58446dbcb1
11 changed files with 1670 additions and 91 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
|
|
@ -162,6 +163,14 @@ class BaseEnvConfig(BaseModel):
|
|||
default=False,
|
||||
description="Whether to include messages in the output transmitted to the trainer",
|
||||
)
|
||||
min_batch_allocation: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Minimum proportion of a batch this environment should be allocated (0.0-1.0)",
|
||||
)
|
||||
worker_timeout: float = Field(
|
||||
default=600,
|
||||
description="Timeout for a a task, in seconds, if -1, no timeout",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
|
@ -237,6 +246,26 @@ class BaseEnv(ABC):
|
|||
else:
|
||||
self.jsonl_writer = None
|
||||
|
||||
@property
|
||||
def derived_batch_size(self):
|
||||
"""Calculate the effective batch size for this environment based on minimum allocations."""
|
||||
# If batch_size is not set or no status yet, return the config batch_size
|
||||
if not hasattr(self, "status_dict") or self.config.batch_size == -1:
|
||||
return self.config.batch_size
|
||||
|
||||
# Get unallocated fraction from status
|
||||
unallocated_fraction = self.status_dict.get("unallocated_fraction", 1.0)
|
||||
|
||||
# If this env has a minimum allocation, add it to the unallocated portion
|
||||
if self.config.min_batch_allocation is not None:
|
||||
effective_fraction = unallocated_fraction + self.config.min_batch_allocation
|
||||
else:
|
||||
# This env competes for the unallocated portion based on its weight
|
||||
effective_fraction = unallocated_fraction
|
||||
|
||||
# Calculate derived batch size
|
||||
return int(self.config.batch_size * effective_fraction)
|
||||
|
||||
@classmethod
|
||||
def config_init(
|
||||
cls,
|
||||
|
|
@ -434,6 +463,8 @@ class BaseEnv(ABC):
|
|||
"max_token_length": self.config.max_token_length,
|
||||
"desired_name": self.config.wandb_name,
|
||||
"weight": self.config.inference_weight,
|
||||
"min_batch_allocation": self.config.min_batch_allocation,
|
||||
"group_size": self.config.group_size,
|
||||
},
|
||||
) as resp:
|
||||
data = await parse_http_response(resp, logger)
|
||||
|
|
@ -614,6 +645,13 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
Send scored data to the API with retry logic for timeouts and server errors.
|
||||
"""
|
||||
# Add env_id to the data
|
||||
if isinstance(scored_data, list):
|
||||
for item in scored_data:
|
||||
item["env_id"] = getattr(self, "env_id", None)
|
||||
else:
|
||||
scored_data["env_id"] = getattr(self, "env_id", None)
|
||||
|
||||
url = (
|
||||
f"{self.config.rollout_server_url}/scored_data_list"
|
||||
if isinstance(scored_data, list)
|
||||
|
|
@ -736,7 +774,7 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
Handle the rollout of an item
|
||||
"""
|
||||
item = self.running_items.get(item_uuid)
|
||||
item = self.running_items.get(item_uuid)["item"]
|
||||
if item is None:
|
||||
print(f"item {item_uuid} not found... returning")
|
||||
return None
|
||||
|
|
@ -813,7 +851,9 @@ class BaseEnv(ABC):
|
|||
self.eval_runner = eval_task
|
||||
if self.config.eval_handling == EvalHandlingEnum.STOP_TRAIN:
|
||||
# Stop training if eval is running
|
||||
self.backlog.extend(self.running_items.values())
|
||||
self.backlog.extend(
|
||||
[x["item"] for x in self.running_items.values()]
|
||||
)
|
||||
for worker in self.workers:
|
||||
worker.cancel()
|
||||
self.workers = set()
|
||||
|
|
@ -852,16 +892,72 @@ class BaseEnv(ABC):
|
|||
max_num_workers,
|
||||
(
|
||||
self.config.max_batches_offpolicy
|
||||
* self.config.batch_size
|
||||
* self.derived_batch_size
|
||||
// self.config.group_size
|
||||
)
|
||||
- (self.status_dict["queue_size"]),
|
||||
)
|
||||
# Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of
|
||||
# overruns by other environments
|
||||
if self.config.min_batch_allocation is not None:
|
||||
min_workers_to_fill_self_queue = max(
|
||||
0,
|
||||
math.ceil(
|
||||
(
|
||||
(
|
||||
(
|
||||
math.ceil(
|
||||
self.config.min_batch_allocation
|
||||
* self.config.batch_size
|
||||
* self.config.max_batches_offpolicy
|
||||
/ self.status_dict["max_group_size"]
|
||||
)
|
||||
+ (
|
||||
self.status_dict["max_group_size"]
|
||||
// self.config.group_size
|
||||
)
|
||||
)
|
||||
* self.status_dict["max_group_size"]
|
||||
)
|
||||
- (
|
||||
(
|
||||
self.status_dict["max_group_size"]
|
||||
* self.status_dict["self_queue_size"]
|
||||
// (
|
||||
self.status_dict["max_group_size"]
|
||||
/ self.config.group_size
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
/ self.config.group_size
|
||||
),
|
||||
)
|
||||
max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue)
|
||||
print(
|
||||
f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, "
|
||||
f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}",
|
||||
flush=True,
|
||||
)
|
||||
if (self.curr_step == 0) and (len(self.workers) == 0):
|
||||
# We are starting up, so we should just skip the append to the list
|
||||
pass
|
||||
else:
|
||||
self.workers_added_list.append(max_num_workers - len(self.workers))
|
||||
if len(self.workers) > max_num_workers:
|
||||
print(
|
||||
f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, "
|
||||
"sending workers to backlog",
|
||||
flush=True,
|
||||
)
|
||||
num_to_reduce = len(self.workers) - max_num_workers
|
||||
running_items_to_remove = list(self.running_items.keys())[:num_to_reduce]
|
||||
for item_uuid in running_items_to_remove:
|
||||
self.backlog.append(self.running_items[item_uuid]["item"])
|
||||
self.running_items[item_uuid]["worker"].cancel()
|
||||
self.workers.discard(self.running_items[item_uuid]["worker"])
|
||||
self.running_items.pop(item_uuid)
|
||||
|
||||
while len(self.workers) < max_num_workers:
|
||||
# Generate a UUID for tracking this item
|
||||
item_uuid = str(uuid.uuid4())
|
||||
|
|
@ -871,8 +967,12 @@ class BaseEnv(ABC):
|
|||
item = await self.get_next_item()
|
||||
if item is None:
|
||||
break
|
||||
self.running_items[item_uuid] = item
|
||||
worker = asyncio.create_task(self.handle_env(item_uuid))
|
||||
self.running_items[item_uuid] = {
|
||||
"item": item,
|
||||
"worker": worker,
|
||||
"start_time": time.time(),
|
||||
}
|
||||
self.workers.add(worker)
|
||||
worker.add_done_callback(
|
||||
lambda fut, i=item: (
|
||||
|
|
@ -926,9 +1026,32 @@ class BaseEnv(ABC):
|
|||
>= self.config.max_batches_offpolicy * self.config.batch_size
|
||||
)
|
||||
and (self.config.max_batches_offpolicy > 0)
|
||||
) or (self.config.batch_size == -1):
|
||||
and (
|
||||
(self.config.min_batch_allocation is None)
|
||||
or (
|
||||
(
|
||||
(
|
||||
(
|
||||
math.ceil(
|
||||
self.config.min_batch_allocation
|
||||
* self.config.batch_size
|
||||
* self.config.max_batches_offpolicy
|
||||
/ self.status_dict["max_group_size"]
|
||||
)
|
||||
* (
|
||||
self.status_dict["max_group_size"]
|
||||
// self.config.group_size
|
||||
)
|
||||
)
|
||||
)
|
||||
- (self.status_dict["self_queue_size"])
|
||||
)
|
||||
<= 0
|
||||
)
|
||||
)
|
||||
) or (self.derived_batch_size == -1):
|
||||
# We have too many, lets cleanup the tasks and wait a bit
|
||||
self.backlog.extend(self.running_items.values())
|
||||
self.backlog.extend([x["item"] for x in self.running_items.values()])
|
||||
for worker in self.workers:
|
||||
worker.cancel()
|
||||
self.running_items = dict()
|
||||
|
|
@ -937,6 +1060,18 @@ class BaseEnv(ABC):
|
|||
pass
|
||||
else:
|
||||
await self.add_train_workers()
|
||||
# cleanup workers that have timed out
|
||||
if self.config.worker_timeout > 0:
|
||||
for item_uuid, item in list(self.running_items.items()):
|
||||
if time.time() - item["start_time"] > self.config.worker_timeout:
|
||||
logger.warning(
|
||||
f"Worker {item_uuid} has timed out after {time.time() - item['start_time']} seconds"
|
||||
)
|
||||
item["worker"].cancel()
|
||||
self.workers.discard(item["worker"])
|
||||
self.running_items.pop(item_uuid)
|
||||
# Do we want to retry? probably not...
|
||||
# self.backlog.append(item["item"])
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def process_manager(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue