Merge pull request #204 from NousResearch/multienv-enforce-mins

Multienv with enforced minimum samples in a batch
This commit is contained in:
dmahan93 2025-07-07 08:53:43 -05:00 committed by GitHub
commit 58446dbcb1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1670 additions and 91 deletions

View file

@ -18,6 +18,7 @@ This service specifically handles the **experience data pathway**:
* Registration endpoints for Trainers and Rollout Handlers.
* Serves batches of aggregated experience data to Trainers.
* Supports heterogeneous environments with weighting (via `/register-env` weight and internal batching).
* Minimum batch allocation support to guarantee certain environments contribute a minimum proportion to each batch.
* Provides status endpoints for monitoring queue size and training step count.
* Basic integration with Weights & Biases (W&B) project/group info.
* Endpoints for Rollout Handlers to disconnect gracefully.
@ -97,6 +98,8 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
max_token_length: int # Max length this env produces
desired_name: str # Base name for identification/logging
weight: float # Weight for sampling/batching (e.g., 1.0)
group_size: int # Expected number of sequences per data submission
min_batch_allocation: Optional[float] = None # Minimum proportion of batch (0.0-1.0)
```
* **Response:** Provides assigned ID, unique W&B name, checkpoint info.
```json
@ -135,14 +138,17 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
overrides: Optional[List[dict]] = None # Per-item logging overrides
group_overrides: Optional[dict] = None # Group logging overrides
images: Optional[Any] = None # Image data (if applicable)
env_id: Optional[int] = None # ID of the environment that generated this data
```
* **Response:** `{"status": "received"}`
* **Response:**
* Normal submission: `{"status": "received"}`
* Mixed-size group buffered: `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
* `POST /scored_data_list`
* **Description:** Endpoint for Rollout Handlers to push a list of `ScoredData` chunks.
* **Request Body:** `List[ScoredData]`
* **Response:** `{"status": "received", "groups_processed": <count>}`
* `GET /batch`
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic (`grab_exact_from_heterogeneous_queue`) to form a batch of the configured size from the available data in the queue, potentially respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic to form a batch of the configured size from the available data in the queue. If any environments have minimum batch allocations specified, it uses `grab_batch_with_minimum_allocations` to ensure each environment gets at least its minimum proportion of the batch. Otherwise, it uses `grab_exact_from_heterogeneous_queue` to form batches respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
* **Response:**
* Success: `{"batch": [<data_item_1>, ..., <data_item_N>]}` where each `data_item` matches the structure pushed via `/scored_data`.
* Not enough data: `{"batch": null}`
@ -178,6 +184,57 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
* mermaid diagram of how a rollout handler interacts with the api is located [here](env_interaction.md).
6. **Shutdown:** Handlers may call `POST /disconnect-env`.
## Minimum Batch Allocation Feature
The API supports ensuring minimum batch allocations for specific environments. This feature is useful when you want to guarantee that certain environments contribute at least a minimum proportion of sequences to each training batch.
### How It Works
1. **Environment Registration**: When registering an environment via `/register-env`, you can specify:
- `min_batch_allocation` (Optional[float]): A value between 0.0 and 1.0 representing the minimum proportion of the batch this environment should contribute
- `group_size` (int): The expected number of sequences per data submission from this environment
2. **Batch Formation**: When the trainer requests a batch via `/batch`:
- If any environment has a `min_batch_allocation` specified, the system uses special logic to ensure minimums are met
- The system attempts to allocate at least `min_batch_allocation * batch_size` sequences from each environment with a minimum
- If the sum of all minimum allocations exceeds 1.0, they are proportionally scaled down
- If an environment with a minimum allocation has no data available, the batch formation fails (returns null)
3. **Mixed-Size Group Handling**: When an environment submits data with a different number of sequences than its declared `group_size`:
- The data is buffered separately for that environment
- The system attempts to combine buffered groups to match the expected `group_size`
- Once combined, the data is added to the main queue
- Response includes `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
### Example Configuration
```python
# Environment 1: Requires at least 30% of each batch
{
"max_token_length": 512,
"desired_name": "critical_env",
"weight": 1.0,
"group_size": 4,
"min_batch_allocation": 0.3 # 30% minimum
}
# Environment 2: No minimum requirement
{
"max_token_length": 512,
"desired_name": "standard_env",
"weight": 1.0,
"group_size": 2,
"min_batch_allocation": None # No minimum
}
```
### Important Notes
- Minimum allocations are enforced per batch, not globally
- If minimum allocations cannot be satisfied (e.g., not enough data from a required environment), batch formation fails
- Environments without `min_batch_allocation` fill the remaining batch space after minimums are satisfied
- The feature respects heterogeneous packing constraints when forming batches
## Limitations & TODOs
* **In-Memory State:** The primary limitation is that all queues, configurations, and states are stored in the FastAPI application's memory (`app.state`).

View file

@ -7,7 +7,11 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, field_validator
from atroposlib.api.utils import grab_exact_from_heterogeneous_queue
from atroposlib.api.utils import (
find_groups_summing_to_target,
grab_batch_with_minimum_allocations,
grab_exact_from_heterogeneous_queue,
)
# Message import removed - using Dict[str, Any] for more flexible validation
@ -42,6 +46,10 @@ class RegisterEnv(BaseModel):
max_token_length: int
desired_name: str
weight: float
group_size: int
min_batch_allocation: Optional[float] = (
None # Minimum proportion of a batch this env should be allocated (0.0-1.0)
)
class EnvIdentifier(BaseModel):
@ -60,6 +68,7 @@ class ScoredData(BaseModel):
overrides: Optional[List[dict]] = None
group_overrides: Optional[dict] = None
images: Optional[Any] = None
env_id: Optional[int] = None # ID of the environment that generated this data
@field_validator("messages", mode="before")
@classmethod
@ -115,6 +124,7 @@ async def register(registration: Registration):
app.state.curr_batch = []
app.state.started = False
app.state.envs = []
app.state.buffer = {} # Buffer for mixed-size groups per environment
try:
app.state.requesters.append(uuid.uuid4().int)
except AttributeError:
@ -157,6 +167,8 @@ async def register_env_url(register_env: RegisterEnv):
"registered_id": registered_id,
"last_update": time.time(),
"connected": True,
"min_batch_allocation": register_env.min_batch_allocation,
"group_size": register_env.group_size,
}
)
return {
@ -207,14 +219,31 @@ async def get_batch():
return {"batch": app.state.curr_batch.pop()}
else:
new_batches = []
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
# Check if any envs have minimum allocations
has_min_allocations = any(
env.get("min_batch_allocation") is not None
for env in getattr(app.state, "envs", [])
)
while batch is not None:
new_batches.append(batch)
if has_min_allocations:
batch, app.state.queue = grab_batch_with_minimum_allocations(
app.state.queue, app.state.batchsize, app.state.envs
)
else:
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
while batch is not None:
new_batches.append(batch)
if has_min_allocations:
batch, app.state.queue = grab_batch_with_minimum_allocations(
app.state.queue, app.state.batchsize, app.state.envs
)
else:
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
steps_to_take = len(new_batches)
if steps_to_take == 0:
return {"batch": None}
@ -224,7 +253,7 @@ async def get_batch():
app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop()
# check length before sending
print(f"Sending batch of length {sum(len(x['tokens']) for x in curr_batch)}")
print(f"Sending batch of {sum(len(x['tokens']) for x in curr_batch)} sequences")
return {"batch": curr_batch}
@ -246,20 +275,58 @@ async def get_latest_example():
@app.post("/scored_data")
async def scored_data(scored_data: ScoredData):
app.state.queue.append(
{
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
}
)
app.state.latest = app.state.queue[-1]
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
"env_id": scored_data.env_id,
}
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []
app.state.buffer[env_id].append(data_dict)
# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)
if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))
# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
return {
"status": "buffered",
"buffer_size": sum(
len(g["tokens"]) for g in app.state.buffer.get(env_id, [])
),
}
# Normal path - correct size or no env info
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received"}
@ -267,24 +334,57 @@ async def scored_data(scored_data: ScoredData):
async def scored_data_list(scored_data_list: List[ScoredData]):
"""Handle a list of ScoredData objects for step-based learning"""
for idx, scored_data in enumerate(scored_data_list):
# Process each scored data item
for scored_data in scored_data_list:
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"images": scored_data.images,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"env_id": scored_data.env_id,
}
app.state.queue.append(
{
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"images": scored_data.images,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
}
)
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if scored_data_list:
app.state.latest = app.state.queue[-1]
if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []
app.state.buffer[env_id].append(data_dict)
# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)
if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))
# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
else:
# Normal size - add directly to queue
app.state.queue.append(data_dict)
app.state.latest = data_dict
else:
# No env info or normal path - add directly to queue
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received", "groups_processed": len(scored_data_list)}
@ -309,6 +409,7 @@ async def get_status_env(env: EnvIdentifier):
if x["connected"]
]
)
env_group_size = app.state.envs[env.env_id]["group_size"]
env_weight = (
app.state.envs[env.env_id]["max_context_len"]
* app.state.envs[env.env_id]["weight"]
@ -318,13 +419,95 @@ async def get_status_env(env: EnvIdentifier):
0.01, env_weight
) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this
# Calculate total minimum allocations
total_min_allocation = 0.0
for env_config in app.state.envs:
if (
env_config.get("connected", False)
and env_config.get("min_batch_allocation") is not None
):
total_min_allocation += env_config["min_batch_allocation"]
# Calculate unallocated fraction
unallocated_fraction = 1.0 - min(total_min_allocation, 1.0)
# Find the maximum group size across all items in queue
queue = getattr(app.state, "queue", [])
max_group_size = 1
num_self_sequences_in_queue = 0
for item in queue:
group_size = len(item.get("tokens", []))
if group_size > max_group_size:
max_group_size = group_size
if item.get("env_id") == env.env_id:
# update the group size for the requesting env, handle cases where the group size may be dynamic with max
env_group_size = max(env_group_size, group_size)
num_self_sequences_in_queue += group_size
# update the group size for the requesting env
app.state.envs[env.env_id]["group_size"] = env_group_size
# Calculate minimum sequences allocated to each environment
batch_size = getattr(app.state, "batchsize", 0)
min_sequences_by_env = {}
for env_config in app.state.envs:
if (
env_config.get("connected", False)
and env_config.get("min_batch_allocation") is not None
):
env_id = env_config["registered_id"]
min_sequences = int(batch_size * env_config["min_batch_allocation"])
min_sequences_by_env[env_id] = min_sequences
# Count sequences and calculate packed groups for each environment
import math
sequences_by_env = {}
packed_groups_by_env = {}
curr_env_total_sequences = 0
for item in queue:
env_id = item.get("env_id")
seq_count = len(item.get("tokens", []))
# Special handling for the requesting environment
if env_id == env.env_id:
curr_env_total_sequences += seq_count
else:
if env_id not in sequences_by_env:
sequences_by_env[env_id] = 0
sequences_by_env[env_id] += seq_count
# Calculate packed groups for each environment (excluding the requesting env)
if max_group_size > 1:
for env_id, seq_count in sequences_by_env.items():
packed_groups_by_env[env_id] = math.ceil(seq_count / max_group_size)
# Calculate adjusted queue size
# (curr_env_total_sequences + sum of available sequences from other envs after their minimums)
available_from_others = 0
for env_id in packed_groups_by_env:
packed_sequences = packed_groups_by_env[env_id] * max_group_size
min_sequences = min_sequences_by_env.get(env_id, 0)
available_from_others += max(0, packed_sequences - min_sequences)
env_queue_size = curr_env_total_sequences + available_from_others
try:
ret_dict = {
"current_step": app.state.status_dict["step"],
"queue_size": len(app.state.queue),
"queue_size": env_queue_size // env_group_size,
"unallocated_fraction": unallocated_fraction,
"self_queue_size": num_self_sequences_in_queue // env_group_size,
"max_group_size": max_group_size,
}
except AttributeError:
ret_dict = {"current_step": 0, "queue_size": 0}
ret_dict = {
"current_step": 0,
"queue_size": 0,
"unallocated_fraction": 1.0,
"num_self_sequences_in_queue": 0,
}
ret_dict["env_weight"] = env_weight
return ret_dict
@ -342,6 +525,7 @@ async def reset_data():
app.state.started = False
app.state.requesters = []
app.state.envs = []
app.state.buffer = {}
except KeyError:
pass
return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)

View file

@ -1,47 +1,104 @@
from typing import Dict, List, Optional, Tuple
def find_groups_summing_to_target(buffer: List[Dict], target_size: int) -> List[int]:
"""
Find indices of groups in buffer that sum exactly to target_size.
Prioritizes FIFO order.
:param buffer: Buffer of groups from same env
:param target_size: Target sum of group sizes
:return: List of indices that sum to target_size, or empty list if impossible
"""
if not buffer:
return []
# First try simple FIFO
current_sum = 0
indices = []
for i, group in enumerate(buffer):
size = len(group["tokens"])
if current_sum + size <= target_size:
indices.append(i)
current_sum += size
if current_sum == target_size:
return indices
# If FIFO doesn't work exactly, try dynamic programming
# to find any valid combination (still preferring earlier indices)
n = len(buffer)
sizes = [len(g["tokens"]) for g in buffer]
# dp[i][j] = can we make sum j using first i groups
dp = [[False] * (target_size + 1) for _ in range(n + 1)]
dp[0][0] = True
for i in range(1, n + 1):
for j in range(target_size + 1):
# Don't take group i-1
dp[i][j] = dp[i - 1][j]
# Take group i-1 if possible
if j >= sizes[i - 1]:
dp[i][j] = dp[i][j] or dp[i - 1][j - sizes[i - 1]]
if not dp[n][target_size]:
return []
# Backtrack to find indices, preferring earlier ones
indices = []
j = target_size
for i in range(n, 0, -1):
if j >= sizes[i - 1] and dp[i - 1][j - sizes[i - 1]]:
indices.append(i - 1)
j -= sizes[i - 1]
return sorted(indices) # Return in FIFO order
def grab_exact_from_heterogeneous_queue(
queue: List[Dict[str, List]], batch_size: int
) -> Tuple[Optional[List], List]:
"""
Grabs a batch of size batchsize from a queue of different sized items
Grabs a batch of exactly batch_size sequences from a queue of items with different group sizes.
Each item in the queue has a 'tokens' field containing a list of sequences.
e.g. queue = [{"tokens": [[1, 2, 3],[4, 5, 6, 7, 8]]}, {"tokens": [[9, 10]]}]
where the first item has 2 sequences and the second has 1 sequence.
without going over the batchsize. This function will return a batch of size batchsize, and the new queue.
This function returns a batch containing exactly batch_size sequences total, and the remaining queue.
Because all groups are a common denominator of the batchsize, and all groups are a power of 2,
we can simplify a bit by assuming we can grab groups of groups to be equal to the maximum group size.
Note that we cannot drop items from groups, so we must grab the entire group if we grab it.
Note that we cannot split items, so we must take the entire item with all its sequences.
There may be a more efficient clearing mechanism by grouping these smaller groups heterogeneously, but
forcing them all into powers of two groups is a simple way to ensure we can grab a batch of the correct size.
:param queue:
:param batch_size:
:param queue: List of items, each with a 'tokens' field containing sequences
:param batch_size: Target number of sequences for the batch
:return: batch, new_queue
"""
# Pass 1: precompute group sizes, total tokens and early exit if not enough tokens.
# Pass 1: precompute group sizes, total sequences and early exit if not enough sequences.
total_groups = len(queue)
if total_groups == 0:
return None, queue
group_sizes = []
lengths = []
total_tokens = 0
total_sequences = 0
max_group_size = 0
for item in queue:
length = len(item["tokens"])
length = len(item["tokens"]) # Number of sequences in this group
lengths.append(length)
group_sizes.append(length)
total_tokens += length
total_sequences += length
if length > max_group_size:
max_group_size = length
if total_tokens < batch_size:
if total_sequences < batch_size:
return None, queue
group_sizes_set = set(group_sizes)
@ -55,30 +112,168 @@ def grab_exact_from_heterogeneous_queue(
potential_batch_indices.extend(group_batching_storage[group_size])
group_batching_storage[group_size].clear() # much faster than = []
# Calculate total batch tokens only once (avoid repeated sums)
potential_batch_token_total = sum(lengths[i] for i in potential_batch_indices)
if potential_batch_token_total < batch_size:
# Calculate total sequences in potential batch only once (avoid repeated sums)
potential_batch_sequences_total = sum(lengths[i] for i in potential_batch_indices)
if potential_batch_sequences_total < batch_size:
return None, queue
# Batch selection
batch = []
batch_indices = []
running_tokens = 0
running_seqs = 0
for idx in potential_batch_indices:
group = queue[idx]
batch.append(group)
batch_indices.append(idx)
running_tokens += lengths[idx]
if running_tokens == batch_size:
running_seqs += lengths[idx]
if running_seqs == batch_size:
break
elif running_tokens > batch_size:
elif running_seqs > batch_size:
# Should never happen due to problem constraints, but sanity check
return None, queue
if running_tokens != batch_size:
if running_seqs != batch_size:
return None, queue
# Construct new_queue with a single pass, using a set for O(1) lookup
batch_indices_set = set(batch_indices)
new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set]
return batch, new_queue
def grab_batch_with_minimum_allocations(
queue: List[Dict[str, any]], batch_size: int, env_configs: List[Dict[str, any]]
) -> Tuple[Optional[List], List]:
"""
Grabs a batch from the queue while respecting minimum allocation requirements for environments.
This function works with groups where each group contains multiple sequences.
:param queue: List of groups with env_id field and sequences (stored in 'tokens' field)
:param batch_size: Target batch size in sequences
:param env_configs: List of environment configs with min_batch_allocation field
:return: batch, new_queue
"""
if not queue:
return None, queue
# Build env_id to min allocation mapping
env_min_allocations = {}
for env in env_configs:
if env.get("connected", False) and env.get("min_batch_allocation") is not None:
env_min_allocations[env["registered_id"]] = env["min_batch_allocation"]
# If no minimum allocations, fall back to original function
if not env_min_allocations:
return grab_exact_from_heterogeneous_queue(queue, batch_size)
# First, find the maximum group size across all items
max_group_size = 0
for item in queue:
group_size = len(item.get("tokens", []))
if group_size > max_group_size:
max_group_size = group_size
# Group queue items by env_id and calculate which can form complete packs
items_by_env = {}
packable_items_by_env = {}
for i, item in enumerate(queue):
env_id = item.get("env_id")
group_size = len(item.get("tokens", []))
if env_id is not None:
if env_id not in items_by_env:
items_by_env[env_id] = {}
packable_items_by_env[env_id] = []
if group_size not in items_by_env[env_id]:
items_by_env[env_id][group_size] = []
items_by_env[env_id][group_size].append((i, item, group_size))
# Check if we can form a complete pack
items_of_size = items_by_env[env_id][group_size]
if len(items_of_size) * group_size == max_group_size:
# We have a complete pack!
packable_items_by_env[env_id].extend(items_of_size)
items_by_env[env_id][group_size] = []
# Calculate minimum sequences needed per env
min_sequences_per_env = {}
total_min_sequences = 0
for env_id, min_proportion in env_min_allocations.items():
min_sequences = int(batch_size * min_proportion)
if min_sequences > 0:
# Check if this env has any items in the queue at all
if env_id not in items_by_env:
# This env has a minimum but no items - can't satisfy minimum
return None, queue
# Check if this env has any packable items
if env_id not in packable_items_by_env or not packable_items_by_env[env_id]:
# This env has items but no packable items - can't satisfy minimum
return None, queue
min_sequences_per_env[env_id] = min_sequences
total_min_sequences += min_sequences
# If minimums exceed batch size, scale them down proportionally
if total_min_sequences > batch_size:
scale_factor = batch_size / total_min_sequences
for env_id in min_sequences_per_env:
# Ensure at least one pack from each env with minimum
if packable_items_by_env.get(env_id):
min_group_size = min(g[2] for g in packable_items_by_env[env_id])
min_sequences_per_env[env_id] = max(
min_group_size,
int(min_sequences_per_env[env_id] * scale_factor),
)
# Build batch ensuring minimums are met
batch = []
batch_indices = []
sequences_taken_per_env = {env_id: 0 for env_id in packable_items_by_env}
total_sequences = 0
# First pass: satisfy minimum requirements using packable items
for env_id, min_sequences in min_sequences_per_env.items():
if env_id in packable_items_by_env:
# Take packable items in order (FIFO)
for idx, item, group_size in packable_items_by_env[env_id]:
if sequences_taken_per_env[env_id] >= min_sequences:
break
if total_sequences + group_size <= batch_size:
batch.append(item)
batch_indices.append(idx)
sequences_taken_per_env[env_id] += group_size
total_sequences += group_size
# Second pass: fill remaining slots with packable items from any env
if total_sequences < batch_size:
# Collect all remaining packable items in queue order
all_packable = []
for i, item in enumerate(queue):
if i not in batch_indices:
# Check if this item is in any env's packable list
env_id = item.get("env_id")
if env_id in packable_items_by_env:
for idx, packable_item, size in packable_items_by_env[env_id]:
if idx == i:
all_packable.append((i, item, size))
break
# Take packable items in queue order
for idx, item, group_size in all_packable:
if total_sequences + group_size <= batch_size:
batch.append(item)
batch_indices.append(idx)
total_sequences += group_size
if total_sequences == batch_size:
break
# If we couldn't form a full batch, return None
if total_sequences != batch_size:
return None, queue
# Construct new queue
batch_indices_set = set(batch_indices)
new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set]
return batch, new_queue

View file

@ -1,6 +1,7 @@
import asyncio
import json
import logging
import math
import os
import random
import string
@ -162,6 +163,14 @@ class BaseEnvConfig(BaseModel):
default=False,
description="Whether to include messages in the output transmitted to the trainer",
)
min_batch_allocation: Optional[float] = Field(
default=None,
description="Minimum proportion of a batch this environment should be allocated (0.0-1.0)",
)
worker_timeout: float = Field(
default=600,
description="Timeout for a a task, in seconds, if -1, no timeout",
)
class BaseEnv(ABC):
@ -237,6 +246,26 @@ class BaseEnv(ABC):
else:
self.jsonl_writer = None
@property
def derived_batch_size(self):
"""Calculate the effective batch size for this environment based on minimum allocations."""
# If batch_size is not set or no status yet, return the config batch_size
if not hasattr(self, "status_dict") or self.config.batch_size == -1:
return self.config.batch_size
# Get unallocated fraction from status
unallocated_fraction = self.status_dict.get("unallocated_fraction", 1.0)
# If this env has a minimum allocation, add it to the unallocated portion
if self.config.min_batch_allocation is not None:
effective_fraction = unallocated_fraction + self.config.min_batch_allocation
else:
# This env competes for the unallocated portion based on its weight
effective_fraction = unallocated_fraction
# Calculate derived batch size
return int(self.config.batch_size * effective_fraction)
@classmethod
def config_init(
cls,
@ -434,6 +463,8 @@ class BaseEnv(ABC):
"max_token_length": self.config.max_token_length,
"desired_name": self.config.wandb_name,
"weight": self.config.inference_weight,
"min_batch_allocation": self.config.min_batch_allocation,
"group_size": self.config.group_size,
},
) as resp:
data = await parse_http_response(resp, logger)
@ -614,6 +645,13 @@ class BaseEnv(ABC):
"""
Send scored data to the API with retry logic for timeouts and server errors.
"""
# Add env_id to the data
if isinstance(scored_data, list):
for item in scored_data:
item["env_id"] = getattr(self, "env_id", None)
else:
scored_data["env_id"] = getattr(self, "env_id", None)
url = (
f"{self.config.rollout_server_url}/scored_data_list"
if isinstance(scored_data, list)
@ -736,7 +774,7 @@ class BaseEnv(ABC):
"""
Handle the rollout of an item
"""
item = self.running_items.get(item_uuid)
item = self.running_items.get(item_uuid)["item"]
if item is None:
print(f"item {item_uuid} not found... returning")
return None
@ -813,7 +851,9 @@ class BaseEnv(ABC):
self.eval_runner = eval_task
if self.config.eval_handling == EvalHandlingEnum.STOP_TRAIN:
# Stop training if eval is running
self.backlog.extend(self.running_items.values())
self.backlog.extend(
[x["item"] for x in self.running_items.values()]
)
for worker in self.workers:
worker.cancel()
self.workers = set()
@ -852,16 +892,72 @@ class BaseEnv(ABC):
max_num_workers,
(
self.config.max_batches_offpolicy
* self.config.batch_size
* self.derived_batch_size
// self.config.group_size
)
- (self.status_dict["queue_size"]),
)
# Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of
# overruns by other environments
if self.config.min_batch_allocation is not None:
min_workers_to_fill_self_queue = max(
0,
math.ceil(
(
(
(
math.ceil(
self.config.min_batch_allocation
* self.config.batch_size
* self.config.max_batches_offpolicy
/ self.status_dict["max_group_size"]
)
+ (
self.status_dict["max_group_size"]
// self.config.group_size
)
)
* self.status_dict["max_group_size"]
)
- (
(
self.status_dict["max_group_size"]
* self.status_dict["self_queue_size"]
// (
self.status_dict["max_group_size"]
/ self.config.group_size
)
)
)
)
/ self.config.group_size
),
)
max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue)
print(
f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, "
f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}",
flush=True,
)
if (self.curr_step == 0) and (len(self.workers) == 0):
# We are starting up, so we should just skip the append to the list
pass
else:
self.workers_added_list.append(max_num_workers - len(self.workers))
if len(self.workers) > max_num_workers:
print(
f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, "
"sending workers to backlog",
flush=True,
)
num_to_reduce = len(self.workers) - max_num_workers
running_items_to_remove = list(self.running_items.keys())[:num_to_reduce]
for item_uuid in running_items_to_remove:
self.backlog.append(self.running_items[item_uuid]["item"])
self.running_items[item_uuid]["worker"].cancel()
self.workers.discard(self.running_items[item_uuid]["worker"])
self.running_items.pop(item_uuid)
while len(self.workers) < max_num_workers:
# Generate a UUID for tracking this item
item_uuid = str(uuid.uuid4())
@ -871,8 +967,12 @@ class BaseEnv(ABC):
item = await self.get_next_item()
if item is None:
break
self.running_items[item_uuid] = item
worker = asyncio.create_task(self.handle_env(item_uuid))
self.running_items[item_uuid] = {
"item": item,
"worker": worker,
"start_time": time.time(),
}
self.workers.add(worker)
worker.add_done_callback(
lambda fut, i=item: (
@ -926,9 +1026,32 @@ class BaseEnv(ABC):
>= self.config.max_batches_offpolicy * self.config.batch_size
)
and (self.config.max_batches_offpolicy > 0)
) or (self.config.batch_size == -1):
and (
(self.config.min_batch_allocation is None)
or (
(
(
(
math.ceil(
self.config.min_batch_allocation
* self.config.batch_size
* self.config.max_batches_offpolicy
/ self.status_dict["max_group_size"]
)
* (
self.status_dict["max_group_size"]
// self.config.group_size
)
)
)
- (self.status_dict["self_queue_size"])
)
<= 0
)
)
) or (self.derived_batch_size == -1):
# We have too many, lets cleanup the tasks and wait a bit
self.backlog.extend(self.running_items.values())
self.backlog.extend([x["item"] for x in self.running_items.values()])
for worker in self.workers:
worker.cancel()
self.running_items = dict()
@ -937,6 +1060,18 @@ class BaseEnv(ABC):
pass
else:
await self.add_train_workers()
# cleanup workers that have timed out
if self.config.worker_timeout > 0:
for item_uuid, item in list(self.running_items.items()):
if time.time() - item["start_time"] > self.config.worker_timeout:
logger.warning(
f"Worker {item_uuid} has timed out after {time.time() - item['start_time']} seconds"
)
item["worker"].cancel()
self.workers.discard(item["worker"])
self.running_items.pop(item_uuid)
# Do we want to retry? probably not...
# self.backlog.append(item["item"])
await asyncio.sleep(0.1)
async def process_manager(self):

View file

@ -0,0 +1,154 @@
"""Tests for heterogeneous group packing utility."""
import pytest
from atroposlib.api.utils import find_groups_summing_to_target
class TestHeterogeneousPacking:
"""Test cases for finding groups that sum to target size."""
def test_simple_fifo_exact_match(self):
"""Test when FIFO order gives exact match."""
buffer = [
{"tokens": [[1, 2]], "scores": [0.5]}, # size 1
{"tokens": [[3, 4], [5, 6]], "scores": [0.6, 0.7]}, # size 2
{"tokens": [[7, 8]], "scores": [0.8]}, # size 1
]
indices = find_groups_summing_to_target(buffer, 4)
assert indices == [0, 1, 2]
def test_fifo_partial_match(self):
"""Test when FIFO can match with subset."""
buffer = [
{"tokens": [[1, 2], [3, 4]], "scores": [0.5, 0.6]}, # size 2
{"tokens": [[5, 6], [7, 8]], "scores": [0.7, 0.8]}, # size 2
{
"tokens": [[9, 10], [11, 12], [13, 14], [15, 16]],
"scores": [0.9, 1.0, 1.1, 1.2],
}, # size 4
]
indices = find_groups_summing_to_target(buffer, 4)
assert indices == [0, 1] # First two groups sum to 4
def test_need_dynamic_programming(self):
"""Test when FIFO doesn't work but other combinations do."""
buffer = [
{"tokens": [[1, 2], [3, 4], [5, 6]], "scores": [0.5, 0.6, 0.7]}, # size 3
{"tokens": [[7, 8]], "scores": [0.8]}, # size 1
{
"tokens": [[9, 10], [11, 12], [13, 14], [15, 16]],
"scores": [0.9, 1.0, 1.1, 1.2],
}, # size 4
]
indices = find_groups_summing_to_target(buffer, 5)
assert indices == [1, 2] # Groups at index 1 (size 1) and 2 (size 4)
def test_impossible_target(self):
"""Test when no combination can reach target."""
buffer = [
{"tokens": [[1, 2], [3, 4]], "scores": [0.5, 0.6]}, # size 2
{
"tokens": [[5, 6], [7, 8], [9, 10], [11, 12]],
"scores": [0.7, 0.8, 0.9, 1.0],
}, # size 4
]
indices = find_groups_summing_to_target(buffer, 3)
assert indices == [] # Can't make 3 from groups of size 2 and 4
def test_empty_buffer(self):
"""Test with empty buffer."""
indices = find_groups_summing_to_target([], 4)
assert indices == []
def test_single_group_exact(self):
"""Test when single group matches exactly."""
buffer = [
{
"tokens": [[1, 2], [3, 4], [5, 6], [7, 8]],
"scores": [0.5, 0.6, 0.7, 0.8],
}, # size 4
]
indices = find_groups_summing_to_target(buffer, 4)
assert indices == [0]
def test_bradley_terry_pairs(self):
"""Test RLAIF use case with Bradley-Terry pairs."""
buffer = [
{"tokens": [[1, 2], [3, 4]], "scores": [0.7, 0.3]}, # size 2 (BT pair)
{"tokens": [[5, 6], [7, 8]], "scores": [0.6, 0.4]}, # size 2 (BT pair)
{"tokens": [[9, 10], [11, 12]], "scores": [0.8, 0.2]}, # size 2 (BT pair)
{"tokens": [[13, 14], [15, 16]], "scores": [0.5, 0.5]}, # size 2 (BT pair)
]
indices = find_groups_summing_to_target(buffer, 8)
assert indices == [0, 1, 2, 3] # All 4 pairs
def test_mixed_sizes_complex(self):
"""Test with various power-of-2 sizes."""
buffer = [
{"tokens": [[1]], "scores": [0.5]}, # size 1
{"tokens": [[2], [3]], "scores": [0.6, 0.7]}, # size 2
{"tokens": [[4]], "scores": [0.8]}, # size 1
{"tokens": [[5], [6], [7], [8]], "scores": [0.9, 1.0, 1.1, 1.2]}, # size 4
{"tokens": [[9], [10]], "scores": [1.3, 1.4]}, # size 2
]
# Target 8: should find combination that sums to 8
indices = find_groups_summing_to_target(buffer, 8)
assert len(indices) > 0
assert sum(len(buffer[i]["tokens"]) for i in indices) == 8
def test_large_groups(self):
"""Test with larger group sizes."""
buffer = [
{"tokens": [[i] for i in range(16)], "scores": [0.5] * 16}, # size 16
{"tokens": [[i] for i in range(8)], "scores": [0.6] * 8}, # size 8
{"tokens": [[i] for i in range(8)], "scores": [0.7] * 8}, # size 8
]
indices = find_groups_summing_to_target(buffer, 32)
assert indices == [0, 1, 2] # All groups needed
def test_prefer_earlier_indices(self):
"""Test that algorithm prefers earlier indices when multiple solutions exist."""
buffer = [
{"tokens": [[1], [2]], "scores": [0.5, 0.6]}, # size 2
{"tokens": [[3], [4]], "scores": [0.7, 0.8]}, # size 2
{"tokens": [[5], [6], [7], [8]], "scores": [0.9, 1.0, 1.1, 1.2]}, # size 4
{"tokens": [[9], [10]], "scores": [1.3, 1.4]}, # size 2
{"tokens": [[11], [12]], "scores": [1.5, 1.6]}, # size 2
]
indices = find_groups_summing_to_target(buffer, 4)
assert indices == [0, 1] # Should prefer first two groups over later ones
def test_exact_fit_with_remainder(self):
"""Test when we can form exact target but have leftover groups."""
buffer = [
{"tokens": [[1], [2]], "scores": [0.5, 0.6]}, # size 2
{"tokens": [[3], [4], [5], [6]], "scores": [0.7, 0.8, 0.9, 1.0]}, # size 4
{"tokens": [[7], [8]], "scores": [1.1, 1.2]}, # size 2
{"tokens": [[9]], "scores": [1.3]}, # size 1
]
indices = find_groups_summing_to_target(buffer, 6)
assert sorted(indices) == [0, 1] # First two groups sum to 6
def test_stress_many_small_groups(self):
"""Test with many small groups."""
# Create 16 groups of size 1
buffer = [{"tokens": [[i]], "scores": [i * 0.1]} for i in range(16)]
indices = find_groups_summing_to_target(buffer, 8)
assert len(indices) == 8
assert indices == list(range(8)) # Should take first 8
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,806 @@
"""Tests for minimum batch allocation functionality."""
import random
from atroposlib.api.utils import grab_batch_with_minimum_allocations
class TestMinBatchAllocation:
"""Test cases for minimum batch allocation feature."""
def test_basic_minimum_allocation(self):
"""Test that basic minimum allocations are respected."""
# Each item represents a group with multiple token sequences
queue = [
{
"tokens": [[1, 2], [3, 4]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
}, # 2 groups
{
"tokens": [[5, 6], [7, 8]],
"env_id": 1,
"masks": [[1, 1], [1, 1]],
"scores": [0.7, 0.8],
}, # 2 groups
{
"tokens": [[9, 10], [11, 12]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
}, # 2 groups
{
"tokens": [[13, 14], [15, 16]],
"env_id": 1,
"masks": [[1, 1], [1, 1]],
"scores": [0.7, 0.8],
}, # 2 groups
]
env_configs = [
{
"registered_id": 0,
"connected": True,
"min_batch_allocation": 0.25,
}, # 25% = 2 groups min
{
"registered_id": 1,
"connected": True,
"min_batch_allocation": 0.5,
}, # 50% = 4 groups min
]
batch_size = 8 # 8 token groups total
batch, new_queue = grab_batch_with_minimum_allocations(
queue, batch_size, env_configs
)
assert batch is not None
# Count groups (not items) per environment
env_groups = {}
total_groups = 0
for item in batch:
env_id = item["env_id"]
groups = len(item["tokens"])
env_groups[env_id] = env_groups.get(env_id, 0) + groups
total_groups += groups
assert total_groups == batch_size
# Env 1 should have at least 50% (4 groups)
assert env_groups.get(1, 0) >= 4
# Env 0 should have at least 25% (2 groups)
assert env_groups.get(0, 0) >= 2
def test_no_minimum_allocation_fallback(self):
"""Test fallback to original function when no minimums specified."""
queue = [
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
{"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]},
{"tokens": [[5, 6]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
{"tokens": [[7, 8]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]},
]
env_configs = [
{"registered_id": 0, "connected": True},
{"registered_id": 1, "connected": True},
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
# Should still form a batch using original logic
assert batch is not None
assert len(new_queue) < len(queue)
def test_conflicting_minimums_scale_down(self):
"""Test that conflicting minimums > 100% are scaled down."""
queue = [
{
"tokens": [[1, 2], [3, 4]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
}, # 2 groups
{
"tokens": [[5, 6], [7, 8]],
"env_id": 1,
"masks": [[1, 1], [1, 1]],
"scores": [0.7, 0.8],
}, # 2 groups
]
env_configs = [
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.7}, # 70%
{
"registered_id": 1,
"connected": True,
"min_batch_allocation": 0.6,
}, # 60% = 130% total
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
# Should still form a batch with scaled allocations
assert batch is not None
assert len(batch) == 2 # Both items needed to form batch of 4 groups
def test_missing_env_in_queue(self):
"""Test handling when an env has minimum but no items in queue."""
queue = [
{
"tokens": [[1, 2], [3, 4]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
}, # 2 groups
{
"tokens": [[5, 6], [7, 8]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
}, # 2 groups
]
env_configs = [
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.3},
{
"registered_id": 1,
"connected": True,
"min_batch_allocation": 0.5,
}, # No items!
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
# Should return None because env 1 has minimum allocation but no items
assert batch is None
def test_disconnected_env_ignored(self):
"""Test that disconnected environments are ignored."""
queue = [
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
{"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]},
]
env_configs = [
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.25},
{
"registered_id": 1,
"connected": False,
"min_batch_allocation": 0.75,
}, # Disconnected!
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs)
# Should only consider connected env
assert batch is not None
# May include env 1 items but won't enforce its minimum
def test_mixed_group_sizes(self):
"""Test handling of different group sizes."""
queue = [
{"tokens": [[1]], "env_id": 0, "masks": [[1]], "scores": [0.5]}, # size 1
{
"tokens": [[2, 3, 4, 5]],
"env_id": 0,
"masks": [[1, 1, 1, 1]],
"scores": [0.6, 0.7, 0.8, 0.9],
}, # size 4
{
"tokens": [[6, 7]],
"env_id": 1,
"masks": [[1, 1]],
"scores": [0.5, 0.6],
}, # size 2
]
env_configs = [
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
{"registered_id": 1, "connected": True, "min_batch_allocation": 0.25},
]
# Try to form batch of size 7 (which would need all items)
batch, new_queue = grab_batch_with_minimum_allocations(queue, 7, env_configs)
if batch is not None:
total_tokens = sum(len(item["tokens"]) for item in batch)
assert total_tokens == 7
def test_empty_queue(self):
"""Test handling of empty queue."""
queue = []
env_configs = [
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
assert batch is None
assert new_queue == []
def test_insufficient_items_for_batch(self):
"""Test when there aren't enough items to form a full batch."""
queue = [
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
]
env_configs = [
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
]
# Request batch size 4 but only have 2 tokens
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
assert batch is None
assert len(new_queue) == 1 # Original queue unchanged
def test_heterogeneous_envs(self):
"""Test envs with individual group sizes."""
# Env 0: all groups have size 2
# Env 1: all groups have size 4
# Env 2: all groups have size 8
queue = []
# Add items for env 0 (group size 2)
for i in range(1):
queue.append(
{
"tokens": [[i * 2, i * 2 + 1] for _ in range(2)],
"env_id": 0,
"masks": [[1, 1] for _ in range(2)],
"scores": [0.5, 0.6],
}
)
# for i in range(1):
# queue.append(
# {
# "tokens": [[i * 2, i * 2 + 1] for _ in range(2)],
# "env_id": 1,
# "masks": [[1, 1] for _ in range(2)],
# "scores": [0.5, 0.6],
# }
# )
# Add 3 items of group size 2 to show why greedy packing doesn't work
for i in range(3):
queue.append(
{
"tokens": [[i * 2, i * 2 + 1] for _ in range(2)],
"env_id": 6,
"masks": [[1, 1] for _ in range(2)],
"scores": [0.5, 0.6],
}
)
# Add items for env 1 (group size 4)
for i in range(5):
queue.append(
{
"tokens": [
[i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3] for _ in range(4)
],
"env_id": 2,
"masks": [[1, 1, 1, 1] for _ in range(4)],
"scores": [0.7, 0.8, 0.9, 1.0],
}
)
# Add items for env 2 (group size 8)
for i in range(3):
queue.append(
{
"tokens": [[i * 8 + j] for j in range(8)],
"env_id": 3,
"masks": [[1] for _ in range(8)],
"scores": [0.5] * 8,
}
)
env_configs = [
{
"registered_id": 0,
"connected": True,
"min_batch_allocation": 0.1,
}, # min 2 sequences
# {
# "registered_id": 1,
# "connected": True,
# "min_batch_allocation": 0.1,
# }, # min 2 sequences
# {
# "registered_id": 2,
# "connected": True,
# "min_batch_allocation": 0.25,
# }, # min 4 sequences
{
"registered_id": 3,
"connected": True,
"min_batch_allocation": 0.5,
}, # min 8 sequences
]
batch_size = 16
batch, new_queue = grab_batch_with_minimum_allocations(
queue, batch_size, env_configs
)
# Since env 0 has min allocation of 10% but can't form any complete packs
# (has 1 item of size 2, needs 4 to make pack of 8), the function should
# return None as it cannot satisfy the minimum allocation requirement
assert batch is None
# Queue should be unchanged
assert len(new_queue) == len(queue)
def test_packing_constraint_enforcement(self):
"""Test that packing to max group size is properly enforced."""
# Create queue with items that can't form complete packs
queue = [
{
"tokens": [[1, 2]],
"env_id": 0,
"masks": [[1, 1]],
"scores": [0.5],
}, # size 1
{
"tokens": [[3, 4]],
"env_id": 0,
"masks": [[1, 1]],
"scores": [0.5],
}, # size 1
{
"tokens": [[5, 6]],
"env_id": 0,
"masks": [[1, 1]],
"scores": [0.5],
}, # size 1
# Need 4 items of size 1 to make a pack of 4, only have 3
{
"tokens": [[7, 8], [9, 10], [11, 12], [13, 14]],
"env_id": 1,
"masks": [[1, 1]] * 4,
"scores": [0.7] * 4,
}, # size 4
]
env_configs = [
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.25},
{"registered_id": 1, "connected": True, "min_batch_allocation": None},
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
# Should return None because env 0 can't form complete packs
assert batch is None
def test_fifo_order_preservation(self):
"""Test that FIFO order is preserved when forming batches."""
queue = []
# Add items with sequential scores to track order
for i in range(8):
queue.append(
{
"tokens": [[i, i + 1]],
"env_id": 0,
"masks": [[1, 1]],
"scores": [float(i)], # Use score to track original order
}
)
env_configs = [
{"registered_id": 0, "connected": True, "min_batch_allocation": None},
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
if batch is not None:
# Check that we got the first 4 items (scores 0-3)
batch_scores = [item["scores"][0] for item in batch]
assert sorted(batch_scores) == [0.0, 1.0, 2.0, 3.0]
def test_exact_minimum_boundary(self):
"""Test behavior at exact minimum allocation boundaries."""
queue = [
{
"tokens": [[1, 2], [3, 4]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
},
{
"tokens": [[5, 6], [7, 8]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
},
{
"tokens": [[9, 10], [11, 12]],
"env_id": 1,
"masks": [[1, 1], [1, 1]],
"scores": [0.7, 0.8],
},
{
"tokens": [[13, 14], [15, 16]],
"env_id": 1,
"masks": [[1, 1], [1, 1]],
"scores": [0.7, 0.8],
},
]
env_configs = [
{
"registered_id": 0,
"connected": True,
"min_batch_allocation": 0.5,
}, # Exactly 50%
{
"registered_id": 1,
"connected": True,
"min_batch_allocation": 0.5,
}, # Exactly 50%
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 8, env_configs)
assert batch is not None
env_counts = {}
for item in batch:
env_id = item["env_id"]
count = len(item["tokens"])
env_counts[env_id] = env_counts.get(env_id, 0) + count
# Both envs should get exactly 4 sequences (50%)
assert env_counts[0] == 4
assert env_counts[1] == 4
def test_zero_minimum_allocation(self):
"""Test that zero minimum allocation is handled correctly."""
queue = [
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]},
{"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7]},
]
env_configs = [
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.0}, # 0%
{"registered_id": 1, "connected": True, "min_batch_allocation": 0.5}, # 50%
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs)
# Should work fine - env 0 has no minimum requirement
assert batch is not None
def test_multiple_complete_packs(self):
"""Test forming multiple complete packs from same environment."""
queue = []
# Add 16 items of size 1 from env 0 (can form 4 complete packs of 4)
for i in range(16):
queue.append(
{
"tokens": [[i]],
"env_id": 0,
"masks": [[1]],
"scores": [0.5],
}
)
# Add 2 items of size 4 from env 1
for i in range(2):
queue.append(
{
"tokens": [[100 + i * 4, 101 + i * 4, 102 + i * 4, 103 + i * 4]],
"env_id": 1,
"masks": [[1, 1, 1, 1]],
"scores": [0.7],
}
)
env_configs = [
{
"registered_id": 0,
"connected": True,
"min_batch_allocation": 0.75,
}, # 12 sequences
{"registered_id": 1, "connected": True, "min_batch_allocation": None},
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 16, env_configs)
assert batch is not None
env_counts = {}
for item in batch:
env_id = item["env_id"]
count = len(item["tokens"])
env_counts[env_id] = env_counts.get(env_id, 0) + count
# Env 0 should get at least 12 sequences
assert env_counts.get(0, 0) >= 12
assert sum(env_counts.values()) == 16
def test_no_packable_items(self):
"""Test when no items can form complete packs."""
queue = [
{
"tokens": [[1, 2], [3, 4]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
}, # size 2
{
"tokens": [[5, 6], [7, 8]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
}, # size 2
{
"tokens": [[9, 10], [11, 12]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
}, # size 2
# Only 3 items of size 2, need 4 to make complete pack of 8
{
"tokens": [
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24],
[25, 26, 27, 28],
[29, 30, 31, 32],
[33, 34, 35, 36],
[37, 38, 39, 40],
[41, 42, 43, 44],
],
"env_id": 1,
"masks": [[1, 1, 1, 1]] * 8,
"scores": [0.7] * 8,
}, # size 8
]
env_configs = [
{
"registered_id": 0,
"connected": True,
"min_batch_allocation": 0.25,
}, # Can't form complete packs
{"registered_id": 1, "connected": True, "min_batch_allocation": None},
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 8, env_configs)
# Env 0 can't form complete packs (has 3 items, needs 4)
assert batch is None
def test_env_without_items(self):
"""Test env config without any items in queue."""
queue = [
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]},
{"tokens": [[3, 4]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]},
]
env_configs = [
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
{
"registered_id": 1,
"connected": True,
"min_batch_allocation": None,
}, # No items
{
"registered_id": 2,
"connected": True,
"min_batch_allocation": 0.3,
}, # No items
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs)
# Should work - env 2 has no items so its minimum is ignored
assert batch is not None
def test_scaling_with_single_env(self):
"""Test scaling behavior with only one env having minimum."""
queue = []
for i in range(8):
queue.append(
{
"tokens": [[i, i + 1]],
"env_id": 0,
"masks": [[1, 1]],
"scores": [0.5],
}
)
env_configs = [
{
"registered_id": 0,
"connected": True,
"min_batch_allocation": 1.5,
}, # 150% - impossible
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
# Should scale down to 100% and work
assert batch is not None
assert len(batch) == 4
def test_mixed_null_and_set_minimums(self):
"""Test mix of environments with and without minimum allocations."""
queue = []
# Env 0: 4 items of size 2
for i in range(4):
queue.append(
{
"tokens": [[i * 2, i * 2 + 1], [i * 2 + 10, i * 2 + 11]],
"env_id": 0,
"masks": [[1, 1], [1, 1]],
"scores": [0.5, 0.6],
}
)
# Env 1: 2 items of size 2
for i in range(2):
queue.append(
{
"tokens": [[i * 2 + 20, i * 2 + 21], [i * 2 + 30, i * 2 + 31]],
"env_id": 1,
"masks": [[1, 1], [1, 1]],
"scores": [0.7, 0.8],
}
)
# Env 2: 2 items of size 2 (no minimum)
for i in range(2):
queue.append(
{
"tokens": [[i * 2 + 40, i * 2 + 41], [i * 2 + 50, i * 2 + 51]],
"env_id": 2,
"masks": [[1, 1], [1, 1]],
"scores": [0.9, 1.0],
}
)
env_configs = [
{
"registered_id": 0,
"connected": True,
"min_batch_allocation": 0.4,
}, # 40% = 6.4 ≈ 6
{
"registered_id": 1,
"connected": True,
"min_batch_allocation": 0.2,
}, # 20% = 3.2 ≈ 3
{
"registered_id": 2,
"connected": True,
"min_batch_allocation": None,
}, # No minimum
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 16, env_configs)
assert batch is not None
env_counts = {}
for item in batch:
env_id = item["env_id"]
count = len(item["tokens"])
env_counts[env_id] = env_counts.get(env_id, 0) + count
# Check minimums are satisfied
assert env_counts.get(0, 0) >= 6 # At least 40% of 16
assert env_counts.get(1, 0) >= 3 # At least 20% of 16
assert sum(env_counts.values()) == 16
def test_random_consistent_group_sizes(self):
"""Random test where each env has a consistent power-of-2 group size."""
for _ in range(100):
batch_size = 64 * random.randint(1, 4)
num_envs = random.randint(2, 4)
# Assign each env a consistent group size
env_group_sizes = {}
for env_id in range(num_envs):
env_group_sizes[env_id] = 2 ** random.randint(0, 3) # 1, 2, 4, or 8
# Create queue
queue = []
for env_id in range(num_envs):
group_size = env_group_sizes[env_id]
num_items = random.randint(5, 20)
for i in range(num_items):
queue.append(
{
"tokens": [
[env_id * 1000 + i * 10 + j] for j in range(group_size)
],
"env_id": env_id,
"masks": [[1] for _ in range(group_size)],
"scores": [0.5 + env_id * 0.1] * group_size,
}
)
# Random minimum allocations that sum to less than 1.0
env_configs = []
remaining = 0.9
for env_id in range(num_envs):
if env_id == num_envs - 1:
min_alloc = remaining
else:
min_alloc = random.uniform(0.1, min(0.4, remaining))
remaining -= min_alloc
env_configs.append(
{
"registered_id": env_id,
"connected": True,
"min_batch_allocation": min_alloc,
}
)
batch, new_queue = grab_batch_with_minimum_allocations(
queue, batch_size, env_configs
)
if batch is not None:
# Verify batch size
total_sequences = sum(len(item["tokens"]) for item in batch)
assert total_sequences == batch_size
# Verify all items from same env have same group size
env_group_sizes_seen = {}
for item in batch:
env_id = item["env_id"]
group_size = len(item["tokens"])
if env_id in env_group_sizes_seen:
assert group_size == env_group_sizes_seen[env_id]
else:
env_group_sizes_seen[env_id] = group_size
def test_queue_dominated_by_one_env(self):
"""Test minimum allocation when one env dominates the queue."""
queue = []
# Only env 1 items in queue
for i in range(100):
queue.append(
{
"tokens": [[1000 + i, 1001 + i]],
"env_id": 1,
"masks": [[1, 1]],
"scores": [0.7],
}
)
env_configs = [
{
"registered_id": 0,
"connected": True,
"min_batch_allocation": 0.5,
}, # 50% but no items!
{"registered_id": 1, "connected": True, "min_batch_allocation": 0.3}, # 30%
]
batch, new_queue = grab_batch_with_minimum_allocations(queue, 10, env_configs)
# Should return None because env 0 has minimum allocation but no items
assert batch is None
# Test with env 0 having no minimum - should work
env_configs[0]["min_batch_allocation"] = None
batch, new_queue = grab_batch_with_minimum_allocations(queue, 10, env_configs)
assert batch is not None
env_counts = {}
for item in batch:
env_id = item["env_id"]
count = len(item["tokens"])
env_counts[env_id] = env_counts.get(env_id, 0) + count
# Should all be from env 1
assert env_counts.get(1, 0) == 10
assert sum(env_counts.values()) == 10
if __name__ == "__main__":
test = TestMinBatchAllocation()
test.test_queue_dominated_by_one_env()