Merge pull request #204 from NousResearch/multienv-enforce-mins

Multienv with enforced minimum samples in a batch
This commit is contained in:
dmahan93 2025-07-07 08:53:43 -05:00 committed by GitHub
commit 58446dbcb1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1670 additions and 91 deletions

View file

@ -1,6 +1,7 @@
import asyncio
import json
import logging
import math
import os
import random
import string
@ -162,6 +163,14 @@ class BaseEnvConfig(BaseModel):
default=False,
description="Whether to include messages in the output transmitted to the trainer",
)
min_batch_allocation: Optional[float] = Field(
default=None,
description="Minimum proportion of a batch this environment should be allocated (0.0-1.0)",
)
worker_timeout: float = Field(
default=600,
description="Timeout for a a task, in seconds, if -1, no timeout",
)
class BaseEnv(ABC):
@ -237,6 +246,26 @@ class BaseEnv(ABC):
else:
self.jsonl_writer = None
@property
def derived_batch_size(self):
"""Calculate the effective batch size for this environment based on minimum allocations."""
# If batch_size is not set or no status yet, return the config batch_size
if not hasattr(self, "status_dict") or self.config.batch_size == -1:
return self.config.batch_size
# Get unallocated fraction from status
unallocated_fraction = self.status_dict.get("unallocated_fraction", 1.0)
# If this env has a minimum allocation, add it to the unallocated portion
if self.config.min_batch_allocation is not None:
effective_fraction = unallocated_fraction + self.config.min_batch_allocation
else:
# This env competes for the unallocated portion based on its weight
effective_fraction = unallocated_fraction
# Calculate derived batch size
return int(self.config.batch_size * effective_fraction)
@classmethod
def config_init(
cls,
@ -434,6 +463,8 @@ class BaseEnv(ABC):
"max_token_length": self.config.max_token_length,
"desired_name": self.config.wandb_name,
"weight": self.config.inference_weight,
"min_batch_allocation": self.config.min_batch_allocation,
"group_size": self.config.group_size,
},
) as resp:
data = await parse_http_response(resp, logger)
@ -614,6 +645,13 @@ class BaseEnv(ABC):
"""
Send scored data to the API with retry logic for timeouts and server errors.
"""
# Add env_id to the data
if isinstance(scored_data, list):
for item in scored_data:
item["env_id"] = getattr(self, "env_id", None)
else:
scored_data["env_id"] = getattr(self, "env_id", None)
url = (
f"{self.config.rollout_server_url}/scored_data_list"
if isinstance(scored_data, list)
@ -736,7 +774,7 @@ class BaseEnv(ABC):
"""
Handle the rollout of an item
"""
item = self.running_items.get(item_uuid)
item = self.running_items.get(item_uuid)["item"]
if item is None:
print(f"item {item_uuid} not found... returning")
return None
@ -813,7 +851,9 @@ class BaseEnv(ABC):
self.eval_runner = eval_task
if self.config.eval_handling == EvalHandlingEnum.STOP_TRAIN:
# Stop training if eval is running
self.backlog.extend(self.running_items.values())
self.backlog.extend(
[x["item"] for x in self.running_items.values()]
)
for worker in self.workers:
worker.cancel()
self.workers = set()
@ -852,16 +892,72 @@ class BaseEnv(ABC):
max_num_workers,
(
self.config.max_batches_offpolicy
* self.config.batch_size
* self.derived_batch_size
// self.config.group_size
)
- (self.status_dict["queue_size"]),
)
# Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of
# overruns by other environments
if self.config.min_batch_allocation is not None:
min_workers_to_fill_self_queue = max(
0,
math.ceil(
(
(
(
math.ceil(
self.config.min_batch_allocation
* self.config.batch_size
* self.config.max_batches_offpolicy
/ self.status_dict["max_group_size"]
)
+ (
self.status_dict["max_group_size"]
// self.config.group_size
)
)
* self.status_dict["max_group_size"]
)
- (
(
self.status_dict["max_group_size"]
* self.status_dict["self_queue_size"]
// (
self.status_dict["max_group_size"]
/ self.config.group_size
)
)
)
)
/ self.config.group_size
),
)
max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue)
print(
f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, "
f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}",
flush=True,
)
if (self.curr_step == 0) and (len(self.workers) == 0):
# We are starting up, so we should just skip the append to the list
pass
else:
self.workers_added_list.append(max_num_workers - len(self.workers))
if len(self.workers) > max_num_workers:
print(
f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, "
"sending workers to backlog",
flush=True,
)
num_to_reduce = len(self.workers) - max_num_workers
running_items_to_remove = list(self.running_items.keys())[:num_to_reduce]
for item_uuid in running_items_to_remove:
self.backlog.append(self.running_items[item_uuid]["item"])
self.running_items[item_uuid]["worker"].cancel()
self.workers.discard(self.running_items[item_uuid]["worker"])
self.running_items.pop(item_uuid)
while len(self.workers) < max_num_workers:
# Generate a UUID for tracking this item
item_uuid = str(uuid.uuid4())
@ -871,8 +967,12 @@ class BaseEnv(ABC):
item = await self.get_next_item()
if item is None:
break
self.running_items[item_uuid] = item
worker = asyncio.create_task(self.handle_env(item_uuid))
self.running_items[item_uuid] = {
"item": item,
"worker": worker,
"start_time": time.time(),
}
self.workers.add(worker)
worker.add_done_callback(
lambda fut, i=item: (
@ -926,9 +1026,32 @@ class BaseEnv(ABC):
>= self.config.max_batches_offpolicy * self.config.batch_size
)
and (self.config.max_batches_offpolicy > 0)
) or (self.config.batch_size == -1):
and (
(self.config.min_batch_allocation is None)
or (
(
(
(
math.ceil(
self.config.min_batch_allocation
* self.config.batch_size
* self.config.max_batches_offpolicy
/ self.status_dict["max_group_size"]
)
* (
self.status_dict["max_group_size"]
// self.config.group_size
)
)
)
- (self.status_dict["self_queue_size"])
)
<= 0
)
)
) or (self.derived_batch_size == -1):
# We have too many, lets cleanup the tasks and wait a bit
self.backlog.extend(self.running_items.values())
self.backlog.extend([x["item"] for x in self.running_items.values()])
for worker in self.workers:
worker.cancel()
self.running_items = dict()
@ -937,6 +1060,18 @@ class BaseEnv(ABC):
pass
else:
await self.add_train_workers()
# cleanup workers that have timed out
if self.config.worker_timeout > 0:
for item_uuid, item in list(self.running_items.items()):
if time.time() - item["start_time"] > self.config.worker_timeout:
logger.warning(
f"Worker {item_uuid} has timed out after {time.time() - item['start_time']} seconds"
)
item["worker"].cancel()
self.workers.discard(item["worker"])
self.running_items.pop(item_uuid)
# Do we want to retry? probably not...
# self.backlog.append(item["item"])
await asyncio.sleep(0.1)
async def process_manager(self):