feat: add minimum batch allocation support for environments

- Add min_batch_allocation parameter to ensure environments contribute minimum proportion to each batch
- Implement grab_batch_with_minimum_allocations function with proper scaling when allocations exceed 100%
- Add mixed-size group buffering to handle variable-sized data submissions
- Update server to use minimum allocation logic when any env has min_batch_allocation set
- Add comprehensive tests for minimum allocation scenarios
- Update documentation in API README and CONFIG.md
- Update example environments to demonstrate the feature

This feature allows critical environments to guarantee they contribute at least a specified proportion (0.0-1.0) to each training batch, ensuring important data sources are always represented during training.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Dakota 2025-07-07 08:50:28 -05:00
parent 4769eeb4a6
commit 08e14cc745
11 changed files with 1670 additions and 91 deletions

View file

@ -18,6 +18,7 @@ This service specifically handles the **experience data pathway**:
* Registration endpoints for Trainers and Rollout Handlers.
* Serves batches of aggregated experience data to Trainers.
* Supports heterogeneous environments with weighting (via `/register-env` weight and internal batching).
* Minimum batch allocation support to guarantee certain environments contribute a minimum proportion to each batch.
* Provides status endpoints for monitoring queue size and training step count.
* Basic integration with Weights & Biases (W&B) project/group info.
* Endpoints for Rollout Handlers to disconnect gracefully.
@ -97,6 +98,8 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
max_token_length: int # Max length this env produces
desired_name: str # Base name for identification/logging
weight: float # Weight for sampling/batching (e.g., 1.0)
group_size: int # Expected number of sequences per data submission
min_batch_allocation: Optional[float] = None # Minimum proportion of batch (0.0-1.0)
```
* **Response:** Provides assigned ID, unique W&B name, checkpoint info.
```json
@ -135,14 +138,17 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
overrides: Optional[List[dict]] = None # Per-item logging overrides
group_overrides: Optional[dict] = None # Group logging overrides
images: Optional[Any] = None # Image data (if applicable)
env_id: Optional[int] = None # ID of the environment that generated this data
```
* **Response:** `{"status": "received"}`
* **Response:**
* Normal submission: `{"status": "received"}`
* Mixed-size group buffered: `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
* `POST /scored_data_list`
* **Description:** Endpoint for Rollout Handlers to push a list of `ScoredData` chunks.
* **Request Body:** `List[ScoredData]`
* **Response:** `{"status": "received", "groups_processed": <count>}`
* `GET /batch`
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic (`grab_exact_from_heterogeneous_queue`) to form a batch of the configured size from the available data in the queue, potentially respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic to form a batch of the configured size from the available data in the queue. If any environments have minimum batch allocations specified, it uses `grab_batch_with_minimum_allocations` to ensure each environment gets at least its minimum proportion of the batch. Otherwise, it uses `grab_exact_from_heterogeneous_queue` to form batches respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
* **Response:**
* Success: `{"batch": [<data_item_1>, ..., <data_item_N>]}` where each `data_item` matches the structure pushed via `/scored_data`.
* Not enough data: `{"batch": null}`
@ -178,6 +184,57 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
* mermaid diagram of how a rollout handler interacts with the api is located [here](env_interaction.md).
6. **Shutdown:** Handlers may call `POST /disconnect-env`.
## Minimum Batch Allocation Feature
The API supports ensuring minimum batch allocations for specific environments. This feature is useful when you want to guarantee that certain environments contribute at least a minimum proportion of sequences to each training batch.
### How It Works
1. **Environment Registration**: When registering an environment via `/register-env`, you can specify:
- `min_batch_allocation` (Optional[float]): A value between 0.0 and 1.0 representing the minimum proportion of the batch this environment should contribute
- `group_size` (int): The expected number of sequences per data submission from this environment
2. **Batch Formation**: When the trainer requests a batch via `/batch`:
- If any environment has a `min_batch_allocation` specified, the system uses special logic to ensure minimums are met
- The system attempts to allocate at least `min_batch_allocation * batch_size` sequences from each environment with a minimum
- If the sum of all minimum allocations exceeds 1.0, they are proportionally scaled down
- If an environment with a minimum allocation has no data available, the batch formation fails (returns null)
3. **Mixed-Size Group Handling**: When an environment submits data with a different number of sequences than its declared `group_size`:
- The data is buffered separately for that environment
- The system attempts to combine buffered groups to match the expected `group_size`
- Once combined, the data is added to the main queue
- Response includes `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
### Example Configuration
```python
# Environment 1: Requires at least 30% of each batch
{
"max_token_length": 512,
"desired_name": "critical_env",
"weight": 1.0,
"group_size": 4,
"min_batch_allocation": 0.3 # 30% minimum
}
# Environment 2: No minimum requirement
{
"max_token_length": 512,
"desired_name": "standard_env",
"weight": 1.0,
"group_size": 2,
"min_batch_allocation": None # No minimum
}
```
### Important Notes
- Minimum allocations are enforced per batch, not globally
- If minimum allocations cannot be satisfied (e.g., not enough data from a required environment), batch formation fails
- Environments without `min_batch_allocation` fill the remaining batch space after minimums are satisfied
- The feature respects heterogeneous packing constraints when forming batches
## Limitations & TODOs
* **In-Memory State:** The primary limitation is that all queues, configurations, and states are stored in the FastAPI application's memory (`app.state`).

View file

@ -7,7 +7,11 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, field_validator
from atroposlib.api.utils import grab_exact_from_heterogeneous_queue
from atroposlib.api.utils import (
find_groups_summing_to_target,
grab_batch_with_minimum_allocations,
grab_exact_from_heterogeneous_queue,
)
# Message import removed - using Dict[str, Any] for more flexible validation
@ -42,6 +46,10 @@ class RegisterEnv(BaseModel):
max_token_length: int
desired_name: str
weight: float
group_size: int
min_batch_allocation: Optional[float] = (
None # Minimum proportion of a batch this env should be allocated (0.0-1.0)
)
class EnvIdentifier(BaseModel):
@ -60,6 +68,7 @@ class ScoredData(BaseModel):
overrides: Optional[List[dict]] = None
group_overrides: Optional[dict] = None
images: Optional[Any] = None
env_id: Optional[int] = None # ID of the environment that generated this data
@field_validator("messages", mode="before")
@classmethod
@ -115,6 +124,7 @@ async def register(registration: Registration):
app.state.curr_batch = []
app.state.started = False
app.state.envs = []
app.state.buffer = {} # Buffer for mixed-size groups per environment
try:
app.state.requesters.append(uuid.uuid4().int)
except AttributeError:
@ -157,6 +167,8 @@ async def register_env_url(register_env: RegisterEnv):
"registered_id": registered_id,
"last_update": time.time(),
"connected": True,
"min_batch_allocation": register_env.min_batch_allocation,
"group_size": register_env.group_size,
}
)
return {
@ -207,14 +219,31 @@ async def get_batch():
return {"batch": app.state.curr_batch.pop()}
else:
new_batches = []
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
# Check if any envs have minimum allocations
has_min_allocations = any(
env.get("min_batch_allocation") is not None
for env in getattr(app.state, "envs", [])
)
while batch is not None:
new_batches.append(batch)
if has_min_allocations:
batch, app.state.queue = grab_batch_with_minimum_allocations(
app.state.queue, app.state.batchsize, app.state.envs
)
else:
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
while batch is not None:
new_batches.append(batch)
if has_min_allocations:
batch, app.state.queue = grab_batch_with_minimum_allocations(
app.state.queue, app.state.batchsize, app.state.envs
)
else:
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
app.state.queue, app.state.batchsize
)
steps_to_take = len(new_batches)
if steps_to_take == 0:
return {"batch": None}
@ -224,7 +253,7 @@ async def get_batch():
app.state.curr_batch.append(batch)
curr_batch = app.state.curr_batch.pop()
# check length before sending
print(f"Sending batch of length {sum(len(x['tokens']) for x in curr_batch)}")
print(f"Sending batch of {sum(len(x['tokens']) for x in curr_batch)} sequences")
return {"batch": curr_batch}
@ -246,20 +275,58 @@ async def get_latest_example():
@app.post("/scored_data")
async def scored_data(scored_data: ScoredData):
app.state.queue.append(
{
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
}
)
app.state.latest = app.state.queue[-1]
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
"env_id": scored_data.env_id,
}
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []
app.state.buffer[env_id].append(data_dict)
# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)
if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))
# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
return {
"status": "buffered",
"buffer_size": sum(
len(g["tokens"]) for g in app.state.buffer.get(env_id, [])
),
}
# Normal path - correct size or no env info
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received"}
@ -267,24 +334,57 @@ async def scored_data(scored_data: ScoredData):
async def scored_data_list(scored_data_list: List[ScoredData]):
"""Handle a list of ScoredData objects for step-based learning"""
for idx, scored_data in enumerate(scored_data_list):
# Process each scored data item
for scored_data in scored_data_list:
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"images": scored_data.images,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"env_id": scored_data.env_id,
}
app.state.queue.append(
{
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"images": scored_data.images,
"messages": scored_data.messages,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
}
)
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if scored_data_list:
app.state.latest = app.state.queue[-1]
if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []
app.state.buffer[env_id].append(data_dict)
# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)
if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))
# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
else:
# Normal size - add directly to queue
app.state.queue.append(data_dict)
app.state.latest = data_dict
else:
# No env info or normal path - add directly to queue
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received", "groups_processed": len(scored_data_list)}
@ -309,6 +409,7 @@ async def get_status_env(env: EnvIdentifier):
if x["connected"]
]
)
env_group_size = app.state.envs[env.env_id]["group_size"]
env_weight = (
app.state.envs[env.env_id]["max_context_len"]
* app.state.envs[env.env_id]["weight"]
@ -318,13 +419,95 @@ async def get_status_env(env: EnvIdentifier):
0.01, env_weight
) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this
# Calculate total minimum allocations
total_min_allocation = 0.0
for env_config in app.state.envs:
if (
env_config.get("connected", False)
and env_config.get("min_batch_allocation") is not None
):
total_min_allocation += env_config["min_batch_allocation"]
# Calculate unallocated fraction
unallocated_fraction = 1.0 - min(total_min_allocation, 1.0)
# Find the maximum group size across all items in queue
queue = getattr(app.state, "queue", [])
max_group_size = 1
num_self_sequences_in_queue = 0
for item in queue:
group_size = len(item.get("tokens", []))
if group_size > max_group_size:
max_group_size = group_size
if item.get("env_id") == env.env_id:
# update the group size for the requesting env, handle cases where the group size may be dynamic with max
env_group_size = max(env_group_size, group_size)
num_self_sequences_in_queue += group_size
# update the group size for the requesting env
app.state.envs[env.env_id]["group_size"] = env_group_size
# Calculate minimum sequences allocated to each environment
batch_size = getattr(app.state, "batchsize", 0)
min_sequences_by_env = {}
for env_config in app.state.envs:
if (
env_config.get("connected", False)
and env_config.get("min_batch_allocation") is not None
):
env_id = env_config["registered_id"]
min_sequences = int(batch_size * env_config["min_batch_allocation"])
min_sequences_by_env[env_id] = min_sequences
# Count sequences and calculate packed groups for each environment
import math
sequences_by_env = {}
packed_groups_by_env = {}
curr_env_total_sequences = 0
for item in queue:
env_id = item.get("env_id")
seq_count = len(item.get("tokens", []))
# Special handling for the requesting environment
if env_id == env.env_id:
curr_env_total_sequences += seq_count
else:
if env_id not in sequences_by_env:
sequences_by_env[env_id] = 0
sequences_by_env[env_id] += seq_count
# Calculate packed groups for each environment (excluding the requesting env)
if max_group_size > 1:
for env_id, seq_count in sequences_by_env.items():
packed_groups_by_env[env_id] = math.ceil(seq_count / max_group_size)
# Calculate adjusted queue size
# (curr_env_total_sequences + sum of available sequences from other envs after their minimums)
available_from_others = 0
for env_id in packed_groups_by_env:
packed_sequences = packed_groups_by_env[env_id] * max_group_size
min_sequences = min_sequences_by_env.get(env_id, 0)
available_from_others += max(0, packed_sequences - min_sequences)
env_queue_size = curr_env_total_sequences + available_from_others
try:
ret_dict = {
"current_step": app.state.status_dict["step"],
"queue_size": len(app.state.queue),
"queue_size": env_queue_size // env_group_size,
"unallocated_fraction": unallocated_fraction,
"self_queue_size": num_self_sequences_in_queue // env_group_size,
"max_group_size": max_group_size,
}
except AttributeError:
ret_dict = {"current_step": 0, "queue_size": 0}
ret_dict = {
"current_step": 0,
"queue_size": 0,
"unallocated_fraction": 1.0,
"num_self_sequences_in_queue": 0,
}
ret_dict["env_weight"] = env_weight
return ret_dict
@ -342,6 +525,7 @@ async def reset_data():
app.state.started = False
app.state.requesters = []
app.state.envs = []
app.state.buffer = {}
except KeyError:
pass
return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)

View file

@ -1,47 +1,104 @@
from typing import Dict, List, Optional, Tuple
def find_groups_summing_to_target(buffer: List[Dict], target_size: int) -> List[int]:
"""
Find indices of groups in buffer that sum exactly to target_size.
Prioritizes FIFO order.
:param buffer: Buffer of groups from same env
:param target_size: Target sum of group sizes
:return: List of indices that sum to target_size, or empty list if impossible
"""
if not buffer:
return []
# First try simple FIFO
current_sum = 0
indices = []
for i, group in enumerate(buffer):
size = len(group["tokens"])
if current_sum + size <= target_size:
indices.append(i)
current_sum += size
if current_sum == target_size:
return indices
# If FIFO doesn't work exactly, try dynamic programming
# to find any valid combination (still preferring earlier indices)
n = len(buffer)
sizes = [len(g["tokens"]) for g in buffer]
# dp[i][j] = can we make sum j using first i groups
dp = [[False] * (target_size + 1) for _ in range(n + 1)]
dp[0][0] = True
for i in range(1, n + 1):
for j in range(target_size + 1):
# Don't take group i-1
dp[i][j] = dp[i - 1][j]
# Take group i-1 if possible
if j >= sizes[i - 1]:
dp[i][j] = dp[i][j] or dp[i - 1][j - sizes[i - 1]]
if not dp[n][target_size]:
return []
# Backtrack to find indices, preferring earlier ones
indices = []
j = target_size
for i in range(n, 0, -1):
if j >= sizes[i - 1] and dp[i - 1][j - sizes[i - 1]]:
indices.append(i - 1)
j -= sizes[i - 1]
return sorted(indices) # Return in FIFO order
def grab_exact_from_heterogeneous_queue(
queue: List[Dict[str, List]], batch_size: int
) -> Tuple[Optional[List], List]:
"""
Grabs a batch of size batchsize from a queue of different sized items
Grabs a batch of exactly batch_size sequences from a queue of items with different group sizes.
Each item in the queue has a 'tokens' field containing a list of sequences.
e.g. queue = [{"tokens": [[1, 2, 3],[4, 5, 6, 7, 8]]}, {"tokens": [[9, 10]]}]
where the first item has 2 sequences and the second has 1 sequence.
without going over the batchsize. This function will return a batch of size batchsize, and the new queue.
This function returns a batch containing exactly batch_size sequences total, and the remaining queue.
Because all groups are a common denominator of the batchsize, and all groups are a power of 2,
we can simplify a bit by assuming we can grab groups of groups to be equal to the maximum group size.
Note that we cannot drop items from groups, so we must grab the entire group if we grab it.
Note that we cannot split items, so we must take the entire item with all its sequences.
There may be a more efficient clearing mechanism by grouping these smaller groups heterogeneously, but
forcing them all into powers of two groups is a simple way to ensure we can grab a batch of the correct size.
:param queue:
:param batch_size:
:param queue: List of items, each with a 'tokens' field containing sequences
:param batch_size: Target number of sequences for the batch
:return: batch, new_queue
"""
# Pass 1: precompute group sizes, total tokens and early exit if not enough tokens.
# Pass 1: precompute group sizes, total sequences and early exit if not enough sequences.
total_groups = len(queue)
if total_groups == 0:
return None, queue
group_sizes = []
lengths = []
total_tokens = 0
total_sequences = 0
max_group_size = 0
for item in queue:
length = len(item["tokens"])
length = len(item["tokens"]) # Number of sequences in this group
lengths.append(length)
group_sizes.append(length)
total_tokens += length
total_sequences += length
if length > max_group_size:
max_group_size = length
if total_tokens < batch_size:
if total_sequences < batch_size:
return None, queue
group_sizes_set = set(group_sizes)
@ -55,30 +112,168 @@ def grab_exact_from_heterogeneous_queue(
potential_batch_indices.extend(group_batching_storage[group_size])
group_batching_storage[group_size].clear() # much faster than = []
# Calculate total batch tokens only once (avoid repeated sums)
potential_batch_token_total = sum(lengths[i] for i in potential_batch_indices)
if potential_batch_token_total < batch_size:
# Calculate total sequences in potential batch only once (avoid repeated sums)
potential_batch_sequences_total = sum(lengths[i] for i in potential_batch_indices)
if potential_batch_sequences_total < batch_size:
return None, queue
# Batch selection
batch = []
batch_indices = []
running_tokens = 0
running_seqs = 0
for idx in potential_batch_indices:
group = queue[idx]
batch.append(group)
batch_indices.append(idx)
running_tokens += lengths[idx]
if running_tokens == batch_size:
running_seqs += lengths[idx]
if running_seqs == batch_size:
break
elif running_tokens > batch_size:
elif running_seqs > batch_size:
# Should never happen due to problem constraints, but sanity check
return None, queue
if running_tokens != batch_size:
if running_seqs != batch_size:
return None, queue
# Construct new_queue with a single pass, using a set for O(1) lookup
batch_indices_set = set(batch_indices)
new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set]
return batch, new_queue
def grab_batch_with_minimum_allocations(
queue: List[Dict[str, any]], batch_size: int, env_configs: List[Dict[str, any]]
) -> Tuple[Optional[List], List]:
"""
Grabs a batch from the queue while respecting minimum allocation requirements for environments.
This function works with groups where each group contains multiple sequences.
:param queue: List of groups with env_id field and sequences (stored in 'tokens' field)
:param batch_size: Target batch size in sequences
:param env_configs: List of environment configs with min_batch_allocation field
:return: batch, new_queue
"""
if not queue:
return None, queue
# Build env_id to min allocation mapping
env_min_allocations = {}
for env in env_configs:
if env.get("connected", False) and env.get("min_batch_allocation") is not None:
env_min_allocations[env["registered_id"]] = env["min_batch_allocation"]
# If no minimum allocations, fall back to original function
if not env_min_allocations:
return grab_exact_from_heterogeneous_queue(queue, batch_size)
# First, find the maximum group size across all items
max_group_size = 0
for item in queue:
group_size = len(item.get("tokens", []))
if group_size > max_group_size:
max_group_size = group_size
# Group queue items by env_id and calculate which can form complete packs
items_by_env = {}
packable_items_by_env = {}
for i, item in enumerate(queue):
env_id = item.get("env_id")
group_size = len(item.get("tokens", []))
if env_id is not None:
if env_id not in items_by_env:
items_by_env[env_id] = {}
packable_items_by_env[env_id] = []
if group_size not in items_by_env[env_id]:
items_by_env[env_id][group_size] = []
items_by_env[env_id][group_size].append((i, item, group_size))
# Check if we can form a complete pack
items_of_size = items_by_env[env_id][group_size]
if len(items_of_size) * group_size == max_group_size:
# We have a complete pack!
packable_items_by_env[env_id].extend(items_of_size)
items_by_env[env_id][group_size] = []
# Calculate minimum sequences needed per env
min_sequences_per_env = {}
total_min_sequences = 0
for env_id, min_proportion in env_min_allocations.items():
min_sequences = int(batch_size * min_proportion)
if min_sequences > 0:
# Check if this env has any items in the queue at all
if env_id not in items_by_env:
# This env has a minimum but no items - can't satisfy minimum
return None, queue
# Check if this env has any packable items
if env_id not in packable_items_by_env or not packable_items_by_env[env_id]:
# This env has items but no packable items - can't satisfy minimum
return None, queue
min_sequences_per_env[env_id] = min_sequences
total_min_sequences += min_sequences
# If minimums exceed batch size, scale them down proportionally
if total_min_sequences > batch_size:
scale_factor = batch_size / total_min_sequences
for env_id in min_sequences_per_env:
# Ensure at least one pack from each env with minimum
if packable_items_by_env.get(env_id):
min_group_size = min(g[2] for g in packable_items_by_env[env_id])
min_sequences_per_env[env_id] = max(
min_group_size,
int(min_sequences_per_env[env_id] * scale_factor),
)
# Build batch ensuring minimums are met
batch = []
batch_indices = []
sequences_taken_per_env = {env_id: 0 for env_id in packable_items_by_env}
total_sequences = 0
# First pass: satisfy minimum requirements using packable items
for env_id, min_sequences in min_sequences_per_env.items():
if env_id in packable_items_by_env:
# Take packable items in order (FIFO)
for idx, item, group_size in packable_items_by_env[env_id]:
if sequences_taken_per_env[env_id] >= min_sequences:
break
if total_sequences + group_size <= batch_size:
batch.append(item)
batch_indices.append(idx)
sequences_taken_per_env[env_id] += group_size
total_sequences += group_size
# Second pass: fill remaining slots with packable items from any env
if total_sequences < batch_size:
# Collect all remaining packable items in queue order
all_packable = []
for i, item in enumerate(queue):
if i not in batch_indices:
# Check if this item is in any env's packable list
env_id = item.get("env_id")
if env_id in packable_items_by_env:
for idx, packable_item, size in packable_items_by_env[env_id]:
if idx == i:
all_packable.append((i, item, size))
break
# Take packable items in queue order
for idx, item, group_size in all_packable:
if total_sequences + group_size <= batch_size:
batch.append(item)
batch_indices.append(idx)
total_sequences += group_size
if total_sequences == batch_size:
break
# If we couldn't form a full batch, return None
if total_sequences != batch_size:
return None, queue
# Construct new queue
batch_indices_set = set(batch_indices)
new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set]
return batch, new_queue