mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
feat: add minimum batch allocation support for environments
- Add min_batch_allocation parameter to ensure environments contribute minimum proportion to each batch - Implement grab_batch_with_minimum_allocations function with proper scaling when allocations exceed 100% - Add mixed-size group buffering to handle variable-sized data submissions - Update server to use minimum allocation logic when any env has min_batch_allocation set - Add comprehensive tests for minimum allocation scenarios - Update documentation in API README and CONFIG.md - Update example environments to demonstrate the feature This feature allows critical environments to guarantee they contribute at least a specified proportion (0.0-1.0) to each training batch, ensuring important data sources are always represented during training. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
4769eeb4a6
commit
08e14cc745
11 changed files with 1670 additions and 91 deletions
|
|
@ -1,47 +1,104 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def find_groups_summing_to_target(buffer: List[Dict], target_size: int) -> List[int]:
|
||||
"""
|
||||
Find indices of groups in buffer that sum exactly to target_size.
|
||||
Prioritizes FIFO order.
|
||||
|
||||
:param buffer: Buffer of groups from same env
|
||||
:param target_size: Target sum of group sizes
|
||||
:return: List of indices that sum to target_size, or empty list if impossible
|
||||
"""
|
||||
if not buffer:
|
||||
return []
|
||||
|
||||
# First try simple FIFO
|
||||
current_sum = 0
|
||||
indices = []
|
||||
|
||||
for i, group in enumerate(buffer):
|
||||
size = len(group["tokens"])
|
||||
if current_sum + size <= target_size:
|
||||
indices.append(i)
|
||||
current_sum += size
|
||||
if current_sum == target_size:
|
||||
return indices
|
||||
|
||||
# If FIFO doesn't work exactly, try dynamic programming
|
||||
# to find any valid combination (still preferring earlier indices)
|
||||
n = len(buffer)
|
||||
sizes = [len(g["tokens"]) for g in buffer]
|
||||
|
||||
# dp[i][j] = can we make sum j using first i groups
|
||||
dp = [[False] * (target_size + 1) for _ in range(n + 1)]
|
||||
dp[0][0] = True
|
||||
|
||||
for i in range(1, n + 1):
|
||||
for j in range(target_size + 1):
|
||||
# Don't take group i-1
|
||||
dp[i][j] = dp[i - 1][j]
|
||||
# Take group i-1 if possible
|
||||
if j >= sizes[i - 1]:
|
||||
dp[i][j] = dp[i][j] or dp[i - 1][j - sizes[i - 1]]
|
||||
|
||||
if not dp[n][target_size]:
|
||||
return []
|
||||
|
||||
# Backtrack to find indices, preferring earlier ones
|
||||
indices = []
|
||||
j = target_size
|
||||
for i in range(n, 0, -1):
|
||||
if j >= sizes[i - 1] and dp[i - 1][j - sizes[i - 1]]:
|
||||
indices.append(i - 1)
|
||||
j -= sizes[i - 1]
|
||||
|
||||
return sorted(indices) # Return in FIFO order
|
||||
|
||||
|
||||
def grab_exact_from_heterogeneous_queue(
|
||||
queue: List[Dict[str, List]], batch_size: int
|
||||
) -> Tuple[Optional[List], List]:
|
||||
"""
|
||||
Grabs a batch of size batchsize from a queue of different sized items
|
||||
Grabs a batch of exactly batch_size sequences from a queue of items with different group sizes.
|
||||
|
||||
Each item in the queue has a 'tokens' field containing a list of sequences.
|
||||
e.g. queue = [{"tokens": [[1, 2, 3],[4, 5, 6, 7, 8]]}, {"tokens": [[9, 10]]}]
|
||||
where the first item has 2 sequences and the second has 1 sequence.
|
||||
|
||||
without going over the batchsize. This function will return a batch of size batchsize, and the new queue.
|
||||
This function returns a batch containing exactly batch_size sequences total, and the remaining queue.
|
||||
|
||||
Because all groups are a common denominator of the batchsize, and all groups are a power of 2,
|
||||
we can simplify a bit by assuming we can grab groups of groups to be equal to the maximum group size.
|
||||
Note that we cannot drop items from groups, so we must grab the entire group if we grab it.
|
||||
Note that we cannot split items, so we must take the entire item with all its sequences.
|
||||
|
||||
There may be a more efficient clearing mechanism by grouping these smaller groups heterogeneously, but
|
||||
forcing them all into powers of two groups is a simple way to ensure we can grab a batch of the correct size.
|
||||
|
||||
:param queue:
|
||||
:param batch_size:
|
||||
:param queue: List of items, each with a 'tokens' field containing sequences
|
||||
:param batch_size: Target number of sequences for the batch
|
||||
:return: batch, new_queue
|
||||
"""
|
||||
|
||||
# Pass 1: precompute group sizes, total tokens and early exit if not enough tokens.
|
||||
# Pass 1: precompute group sizes, total sequences and early exit if not enough sequences.
|
||||
total_groups = len(queue)
|
||||
if total_groups == 0:
|
||||
return None, queue
|
||||
|
||||
group_sizes = []
|
||||
lengths = []
|
||||
total_tokens = 0
|
||||
total_sequences = 0
|
||||
max_group_size = 0
|
||||
|
||||
for item in queue:
|
||||
length = len(item["tokens"])
|
||||
length = len(item["tokens"]) # Number of sequences in this group
|
||||
lengths.append(length)
|
||||
group_sizes.append(length)
|
||||
total_tokens += length
|
||||
total_sequences += length
|
||||
if length > max_group_size:
|
||||
max_group_size = length
|
||||
|
||||
if total_tokens < batch_size:
|
||||
if total_sequences < batch_size:
|
||||
return None, queue
|
||||
|
||||
group_sizes_set = set(group_sizes)
|
||||
|
|
@ -55,30 +112,168 @@ def grab_exact_from_heterogeneous_queue(
|
|||
potential_batch_indices.extend(group_batching_storage[group_size])
|
||||
group_batching_storage[group_size].clear() # much faster than = []
|
||||
|
||||
# Calculate total batch tokens only once (avoid repeated sums)
|
||||
potential_batch_token_total = sum(lengths[i] for i in potential_batch_indices)
|
||||
if potential_batch_token_total < batch_size:
|
||||
# Calculate total sequences in potential batch only once (avoid repeated sums)
|
||||
potential_batch_sequences_total = sum(lengths[i] for i in potential_batch_indices)
|
||||
if potential_batch_sequences_total < batch_size:
|
||||
return None, queue
|
||||
|
||||
# Batch selection
|
||||
batch = []
|
||||
batch_indices = []
|
||||
running_tokens = 0
|
||||
running_seqs = 0
|
||||
for idx in potential_batch_indices:
|
||||
group = queue[idx]
|
||||
batch.append(group)
|
||||
batch_indices.append(idx)
|
||||
running_tokens += lengths[idx]
|
||||
if running_tokens == batch_size:
|
||||
running_seqs += lengths[idx]
|
||||
if running_seqs == batch_size:
|
||||
break
|
||||
elif running_tokens > batch_size:
|
||||
elif running_seqs > batch_size:
|
||||
# Should never happen due to problem constraints, but sanity check
|
||||
return None, queue
|
||||
|
||||
if running_tokens != batch_size:
|
||||
if running_seqs != batch_size:
|
||||
return None, queue
|
||||
|
||||
# Construct new_queue with a single pass, using a set for O(1) lookup
|
||||
batch_indices_set = set(batch_indices)
|
||||
new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set]
|
||||
return batch, new_queue
|
||||
|
||||
|
||||
def grab_batch_with_minimum_allocations(
|
||||
queue: List[Dict[str, any]], batch_size: int, env_configs: List[Dict[str, any]]
|
||||
) -> Tuple[Optional[List], List]:
|
||||
"""
|
||||
Grabs a batch from the queue while respecting minimum allocation requirements for environments.
|
||||
This function works with groups where each group contains multiple sequences.
|
||||
|
||||
:param queue: List of groups with env_id field and sequences (stored in 'tokens' field)
|
||||
:param batch_size: Target batch size in sequences
|
||||
:param env_configs: List of environment configs with min_batch_allocation field
|
||||
:return: batch, new_queue
|
||||
"""
|
||||
if not queue:
|
||||
return None, queue
|
||||
|
||||
# Build env_id to min allocation mapping
|
||||
env_min_allocations = {}
|
||||
for env in env_configs:
|
||||
if env.get("connected", False) and env.get("min_batch_allocation") is not None:
|
||||
env_min_allocations[env["registered_id"]] = env["min_batch_allocation"]
|
||||
|
||||
# If no minimum allocations, fall back to original function
|
||||
if not env_min_allocations:
|
||||
return grab_exact_from_heterogeneous_queue(queue, batch_size)
|
||||
|
||||
# First, find the maximum group size across all items
|
||||
max_group_size = 0
|
||||
for item in queue:
|
||||
group_size = len(item.get("tokens", []))
|
||||
if group_size > max_group_size:
|
||||
max_group_size = group_size
|
||||
|
||||
# Group queue items by env_id and calculate which can form complete packs
|
||||
items_by_env = {}
|
||||
packable_items_by_env = {}
|
||||
|
||||
for i, item in enumerate(queue):
|
||||
env_id = item.get("env_id")
|
||||
group_size = len(item.get("tokens", []))
|
||||
|
||||
if env_id is not None:
|
||||
if env_id not in items_by_env:
|
||||
items_by_env[env_id] = {}
|
||||
packable_items_by_env[env_id] = []
|
||||
|
||||
if group_size not in items_by_env[env_id]:
|
||||
items_by_env[env_id][group_size] = []
|
||||
|
||||
items_by_env[env_id][group_size].append((i, item, group_size))
|
||||
|
||||
# Check if we can form a complete pack
|
||||
items_of_size = items_by_env[env_id][group_size]
|
||||
if len(items_of_size) * group_size == max_group_size:
|
||||
# We have a complete pack!
|
||||
packable_items_by_env[env_id].extend(items_of_size)
|
||||
items_by_env[env_id][group_size] = []
|
||||
|
||||
# Calculate minimum sequences needed per env
|
||||
min_sequences_per_env = {}
|
||||
total_min_sequences = 0
|
||||
for env_id, min_proportion in env_min_allocations.items():
|
||||
min_sequences = int(batch_size * min_proportion)
|
||||
if min_sequences > 0:
|
||||
# Check if this env has any items in the queue at all
|
||||
if env_id not in items_by_env:
|
||||
# This env has a minimum but no items - can't satisfy minimum
|
||||
return None, queue
|
||||
# Check if this env has any packable items
|
||||
if env_id not in packable_items_by_env or not packable_items_by_env[env_id]:
|
||||
# This env has items but no packable items - can't satisfy minimum
|
||||
return None, queue
|
||||
min_sequences_per_env[env_id] = min_sequences
|
||||
total_min_sequences += min_sequences
|
||||
|
||||
# If minimums exceed batch size, scale them down proportionally
|
||||
if total_min_sequences > batch_size:
|
||||
scale_factor = batch_size / total_min_sequences
|
||||
for env_id in min_sequences_per_env:
|
||||
# Ensure at least one pack from each env with minimum
|
||||
if packable_items_by_env.get(env_id):
|
||||
min_group_size = min(g[2] for g in packable_items_by_env[env_id])
|
||||
min_sequences_per_env[env_id] = max(
|
||||
min_group_size,
|
||||
int(min_sequences_per_env[env_id] * scale_factor),
|
||||
)
|
||||
|
||||
# Build batch ensuring minimums are met
|
||||
batch = []
|
||||
batch_indices = []
|
||||
sequences_taken_per_env = {env_id: 0 for env_id in packable_items_by_env}
|
||||
total_sequences = 0
|
||||
|
||||
# First pass: satisfy minimum requirements using packable items
|
||||
for env_id, min_sequences in min_sequences_per_env.items():
|
||||
if env_id in packable_items_by_env:
|
||||
# Take packable items in order (FIFO)
|
||||
for idx, item, group_size in packable_items_by_env[env_id]:
|
||||
if sequences_taken_per_env[env_id] >= min_sequences:
|
||||
break
|
||||
if total_sequences + group_size <= batch_size:
|
||||
batch.append(item)
|
||||
batch_indices.append(idx)
|
||||
sequences_taken_per_env[env_id] += group_size
|
||||
total_sequences += group_size
|
||||
|
||||
# Second pass: fill remaining slots with packable items from any env
|
||||
if total_sequences < batch_size:
|
||||
# Collect all remaining packable items in queue order
|
||||
all_packable = []
|
||||
for i, item in enumerate(queue):
|
||||
if i not in batch_indices:
|
||||
# Check if this item is in any env's packable list
|
||||
env_id = item.get("env_id")
|
||||
if env_id in packable_items_by_env:
|
||||
for idx, packable_item, size in packable_items_by_env[env_id]:
|
||||
if idx == i:
|
||||
all_packable.append((i, item, size))
|
||||
break
|
||||
|
||||
# Take packable items in queue order
|
||||
for idx, item, group_size in all_packable:
|
||||
if total_sequences + group_size <= batch_size:
|
||||
batch.append(item)
|
||||
batch_indices.append(idx)
|
||||
total_sequences += group_size
|
||||
if total_sequences == batch_size:
|
||||
break
|
||||
|
||||
# If we couldn't form a full batch, return None
|
||||
if total_sequences != batch_size:
|
||||
return None, queue
|
||||
|
||||
# Construct new queue
|
||||
batch_indices_set = set(batch_indices)
|
||||
new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set]
|
||||
return batch, new_queue
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue