mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-29 17:35:07 +00:00
Merge pull request #204 from NousResearch/multienv-enforce-mins
Multienv with enforced minimum samples in a batch
This commit is contained in:
commit
58446dbcb1
11 changed files with 1670 additions and 91 deletions
|
|
@ -18,6 +18,7 @@ This service specifically handles the **experience data pathway**:
|
|||
* Registration endpoints for Trainers and Rollout Handlers.
|
||||
* Serves batches of aggregated experience data to Trainers.
|
||||
* Supports heterogeneous environments with weighting (via `/register-env` weight and internal batching).
|
||||
* Minimum batch allocation support to guarantee certain environments contribute a minimum proportion to each batch.
|
||||
* Provides status endpoints for monitoring queue size and training step count.
|
||||
* Basic integration with Weights & Biases (W&B) project/group info.
|
||||
* Endpoints for Rollout Handlers to disconnect gracefully.
|
||||
|
|
@ -97,6 +98,8 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
|
|||
max_token_length: int # Max length this env produces
|
||||
desired_name: str # Base name for identification/logging
|
||||
weight: float # Weight for sampling/batching (e.g., 1.0)
|
||||
group_size: int # Expected number of sequences per data submission
|
||||
min_batch_allocation: Optional[float] = None # Minimum proportion of batch (0.0-1.0)
|
||||
```
|
||||
* **Response:** Provides assigned ID, unique W&B name, checkpoint info.
|
||||
```json
|
||||
|
|
@ -135,14 +138,17 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
|
|||
overrides: Optional[List[dict]] = None # Per-item logging overrides
|
||||
group_overrides: Optional[dict] = None # Group logging overrides
|
||||
images: Optional[Any] = None # Image data (if applicable)
|
||||
env_id: Optional[int] = None # ID of the environment that generated this data
|
||||
```
|
||||
* **Response:** `{"status": "received"}`
|
||||
* **Response:**
|
||||
* Normal submission: `{"status": "received"}`
|
||||
* Mixed-size group buffered: `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
|
||||
* `POST /scored_data_list`
|
||||
* **Description:** Endpoint for Rollout Handlers to push a list of `ScoredData` chunks.
|
||||
* **Request Body:** `List[ScoredData]`
|
||||
* **Response:** `{"status": "received", "groups_processed": <count>}`
|
||||
* `GET /batch`
|
||||
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic (`grab_exact_from_heterogeneous_queue`) to form a batch of the configured size from the available data in the queue, potentially respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
|
||||
* **Description:** Called by the Trainer to request a batch of data for training. The server uses internal logic to form a batch of the configured size from the available data in the queue. If any environments have minimum batch allocations specified, it uses `grab_batch_with_minimum_allocations` to ensure each environment gets at least its minimum proportion of the batch. Otherwise, it uses `grab_exact_from_heterogeneous_queue` to form batches respecting environment weights. The server increments its internal step counter when a batch is successfully formed and returned.
|
||||
* **Response:**
|
||||
* Success: `{"batch": [<data_item_1>, ..., <data_item_N>]}` where each `data_item` matches the structure pushed via `/scored_data`.
|
||||
* Not enough data: `{"batch": null}`
|
||||
|
|
@ -178,6 +184,57 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
|
|||
* mermaid diagram of how a rollout handler interacts with the api is located [here](env_interaction.md).
|
||||
6. **Shutdown:** Handlers may call `POST /disconnect-env`.
|
||||
|
||||
## Minimum Batch Allocation Feature
|
||||
|
||||
The API supports ensuring minimum batch allocations for specific environments. This feature is useful when you want to guarantee that certain environments contribute at least a minimum proportion of sequences to each training batch.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Environment Registration**: When registering an environment via `/register-env`, you can specify:
|
||||
- `min_batch_allocation` (Optional[float]): A value between 0.0 and 1.0 representing the minimum proportion of the batch this environment should contribute
|
||||
- `group_size` (int): The expected number of sequences per data submission from this environment
|
||||
|
||||
2. **Batch Formation**: When the trainer requests a batch via `/batch`:
|
||||
- If any environment has a `min_batch_allocation` specified, the system uses special logic to ensure minimums are met
|
||||
- The system attempts to allocate at least `min_batch_allocation * batch_size` sequences from each environment with a minimum
|
||||
- If the sum of all minimum allocations exceeds 1.0, they are proportionally scaled down
|
||||
- If an environment with a minimum allocation has no data available, the batch formation fails (returns null)
|
||||
|
||||
3. **Mixed-Size Group Handling**: When an environment submits data with a different number of sequences than its declared `group_size`:
|
||||
- The data is buffered separately for that environment
|
||||
- The system attempts to combine buffered groups to match the expected `group_size`
|
||||
- Once combined, the data is added to the main queue
|
||||
- Response includes `{"status": "buffered", "buffer_size": <sequences_in_buffer>}`
|
||||
|
||||
### Example Configuration
|
||||
|
||||
```python
|
||||
# Environment 1: Requires at least 30% of each batch
|
||||
{
|
||||
"max_token_length": 512,
|
||||
"desired_name": "critical_env",
|
||||
"weight": 1.0,
|
||||
"group_size": 4,
|
||||
"min_batch_allocation": 0.3 # 30% minimum
|
||||
}
|
||||
|
||||
# Environment 2: No minimum requirement
|
||||
{
|
||||
"max_token_length": 512,
|
||||
"desired_name": "standard_env",
|
||||
"weight": 1.0,
|
||||
"group_size": 2,
|
||||
"min_batch_allocation": None # No minimum
|
||||
}
|
||||
```
|
||||
|
||||
### Important Notes
|
||||
|
||||
- Minimum allocations are enforced per batch, not globally
|
||||
- If minimum allocations cannot be satisfied (e.g., not enough data from a required environment), batch formation fails
|
||||
- Environments without `min_batch_allocation` fill the remaining batch space after minimums are satisfied
|
||||
- The feature respects heterogeneous packing constraints when forming batches
|
||||
|
||||
## Limitations & TODOs
|
||||
|
||||
* **In-Memory State:** The primary limitation is that all queues, configurations, and states are stored in the FastAPI application's memory (`app.state`).
|
||||
|
|
|
|||
|
|
@ -7,7 +7,11 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from fastapi.responses import PlainTextResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from atroposlib.api.utils import grab_exact_from_heterogeneous_queue
|
||||
from atroposlib.api.utils import (
|
||||
find_groups_summing_to_target,
|
||||
grab_batch_with_minimum_allocations,
|
||||
grab_exact_from_heterogeneous_queue,
|
||||
)
|
||||
|
||||
# Message import removed - using Dict[str, Any] for more flexible validation
|
||||
|
||||
|
|
@ -42,6 +46,10 @@ class RegisterEnv(BaseModel):
|
|||
max_token_length: int
|
||||
desired_name: str
|
||||
weight: float
|
||||
group_size: int
|
||||
min_batch_allocation: Optional[float] = (
|
||||
None # Minimum proportion of a batch this env should be allocated (0.0-1.0)
|
||||
)
|
||||
|
||||
|
||||
class EnvIdentifier(BaseModel):
|
||||
|
|
@ -60,6 +68,7 @@ class ScoredData(BaseModel):
|
|||
overrides: Optional[List[dict]] = None
|
||||
group_overrides: Optional[dict] = None
|
||||
images: Optional[Any] = None
|
||||
env_id: Optional[int] = None # ID of the environment that generated this data
|
||||
|
||||
@field_validator("messages", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -115,6 +124,7 @@ async def register(registration: Registration):
|
|||
app.state.curr_batch = []
|
||||
app.state.started = False
|
||||
app.state.envs = []
|
||||
app.state.buffer = {} # Buffer for mixed-size groups per environment
|
||||
try:
|
||||
app.state.requesters.append(uuid.uuid4().int)
|
||||
except AttributeError:
|
||||
|
|
@ -157,6 +167,8 @@ async def register_env_url(register_env: RegisterEnv):
|
|||
"registered_id": registered_id,
|
||||
"last_update": time.time(),
|
||||
"connected": True,
|
||||
"min_batch_allocation": register_env.min_batch_allocation,
|
||||
"group_size": register_env.group_size,
|
||||
}
|
||||
)
|
||||
return {
|
||||
|
|
@ -207,14 +219,31 @@ async def get_batch():
|
|||
return {"batch": app.state.curr_batch.pop()}
|
||||
else:
|
||||
new_batches = []
|
||||
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
|
||||
app.state.queue, app.state.batchsize
|
||||
# Check if any envs have minimum allocations
|
||||
has_min_allocations = any(
|
||||
env.get("min_batch_allocation") is not None
|
||||
for env in getattr(app.state, "envs", [])
|
||||
)
|
||||
while batch is not None:
|
||||
new_batches.append(batch)
|
||||
|
||||
if has_min_allocations:
|
||||
batch, app.state.queue = grab_batch_with_minimum_allocations(
|
||||
app.state.queue, app.state.batchsize, app.state.envs
|
||||
)
|
||||
else:
|
||||
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
|
||||
app.state.queue, app.state.batchsize
|
||||
)
|
||||
|
||||
while batch is not None:
|
||||
new_batches.append(batch)
|
||||
if has_min_allocations:
|
||||
batch, app.state.queue = grab_batch_with_minimum_allocations(
|
||||
app.state.queue, app.state.batchsize, app.state.envs
|
||||
)
|
||||
else:
|
||||
batch, app.state.queue = grab_exact_from_heterogeneous_queue(
|
||||
app.state.queue, app.state.batchsize
|
||||
)
|
||||
steps_to_take = len(new_batches)
|
||||
if steps_to_take == 0:
|
||||
return {"batch": None}
|
||||
|
|
@ -224,7 +253,7 @@ async def get_batch():
|
|||
app.state.curr_batch.append(batch)
|
||||
curr_batch = app.state.curr_batch.pop()
|
||||
# check length before sending
|
||||
print(f"Sending batch of length {sum(len(x['tokens']) for x in curr_batch)}")
|
||||
print(f"Sending batch of {sum(len(x['tokens']) for x in curr_batch)} sequences")
|
||||
return {"batch": curr_batch}
|
||||
|
||||
|
||||
|
|
@ -246,20 +275,58 @@ async def get_latest_example():
|
|||
|
||||
@app.post("/scored_data")
|
||||
async def scored_data(scored_data: ScoredData):
|
||||
app.state.queue.append(
|
||||
{
|
||||
"tokens": scored_data.tokens,
|
||||
"masks": scored_data.masks,
|
||||
"scores": scored_data.scores,
|
||||
"advantages": scored_data.advantages,
|
||||
"ref_logprobs": scored_data.ref_logprobs,
|
||||
"messages": scored_data.messages,
|
||||
"overrides": scored_data.overrides,
|
||||
"group_overrides": scored_data.group_overrides,
|
||||
"images": scored_data.images,
|
||||
}
|
||||
)
|
||||
app.state.latest = app.state.queue[-1]
|
||||
data_dict = {
|
||||
"tokens": scored_data.tokens,
|
||||
"masks": scored_data.masks,
|
||||
"scores": scored_data.scores,
|
||||
"advantages": scored_data.advantages,
|
||||
"ref_logprobs": scored_data.ref_logprobs,
|
||||
"messages": scored_data.messages,
|
||||
"overrides": scored_data.overrides,
|
||||
"group_overrides": scored_data.group_overrides,
|
||||
"images": scored_data.images,
|
||||
"env_id": scored_data.env_id,
|
||||
}
|
||||
|
||||
# Check if this is a mixed-size group
|
||||
env_id = scored_data.env_id
|
||||
if env_id is not None and env_id < len(app.state.envs):
|
||||
expected_group_size = app.state.envs[env_id].get("group_size", 1)
|
||||
actual_group_size = len(scored_data.tokens)
|
||||
|
||||
if actual_group_size != expected_group_size:
|
||||
# Mixed size group - add to buffer
|
||||
if env_id not in app.state.buffer:
|
||||
app.state.buffer[env_id] = []
|
||||
|
||||
app.state.buffer[env_id].append(data_dict)
|
||||
|
||||
# Try to find groups that sum to expected_group_size
|
||||
indices = find_groups_summing_to_target(
|
||||
app.state.buffer[env_id], expected_group_size
|
||||
)
|
||||
|
||||
if indices:
|
||||
# Add these groups to queue in order
|
||||
groups_to_add = []
|
||||
for idx in sorted(indices, reverse=True):
|
||||
groups_to_add.append(app.state.buffer[env_id].pop(idx))
|
||||
|
||||
# Add in FIFO order
|
||||
for group in reversed(groups_to_add):
|
||||
app.state.queue.append(group)
|
||||
app.state.latest = group
|
||||
|
||||
return {
|
||||
"status": "buffered",
|
||||
"buffer_size": sum(
|
||||
len(g["tokens"]) for g in app.state.buffer.get(env_id, [])
|
||||
),
|
||||
}
|
||||
|
||||
# Normal path - correct size or no env info
|
||||
app.state.queue.append(data_dict)
|
||||
app.state.latest = data_dict
|
||||
return {"status": "received"}
|
||||
|
||||
|
||||
|
|
@ -267,24 +334,57 @@ async def scored_data(scored_data: ScoredData):
|
|||
async def scored_data_list(scored_data_list: List[ScoredData]):
|
||||
"""Handle a list of ScoredData objects for step-based learning"""
|
||||
|
||||
for idx, scored_data in enumerate(scored_data_list):
|
||||
# Process each scored data item
|
||||
for scored_data in scored_data_list:
|
||||
data_dict = {
|
||||
"tokens": scored_data.tokens,
|
||||
"masks": scored_data.masks,
|
||||
"scores": scored_data.scores,
|
||||
"advantages": scored_data.advantages,
|
||||
"ref_logprobs": scored_data.ref_logprobs,
|
||||
"images": scored_data.images,
|
||||
"messages": scored_data.messages,
|
||||
"overrides": scored_data.overrides,
|
||||
"group_overrides": scored_data.group_overrides,
|
||||
"env_id": scored_data.env_id,
|
||||
}
|
||||
|
||||
app.state.queue.append(
|
||||
{
|
||||
"tokens": scored_data.tokens,
|
||||
"masks": scored_data.masks,
|
||||
"scores": scored_data.scores,
|
||||
"advantages": scored_data.advantages,
|
||||
"ref_logprobs": scored_data.ref_logprobs,
|
||||
"images": scored_data.images,
|
||||
"messages": scored_data.messages,
|
||||
"overrides": scored_data.overrides,
|
||||
"group_overrides": scored_data.group_overrides,
|
||||
}
|
||||
)
|
||||
# Check if this is a mixed-size group
|
||||
env_id = scored_data.env_id
|
||||
if env_id is not None and env_id < len(app.state.envs):
|
||||
expected_group_size = app.state.envs[env_id].get("group_size", 1)
|
||||
actual_group_size = len(scored_data.tokens)
|
||||
|
||||
if scored_data_list:
|
||||
app.state.latest = app.state.queue[-1]
|
||||
if actual_group_size != expected_group_size:
|
||||
# Mixed size group - add to buffer
|
||||
if env_id not in app.state.buffer:
|
||||
app.state.buffer[env_id] = []
|
||||
|
||||
app.state.buffer[env_id].append(data_dict)
|
||||
|
||||
# Try to find groups that sum to expected_group_size
|
||||
indices = find_groups_summing_to_target(
|
||||
app.state.buffer[env_id], expected_group_size
|
||||
)
|
||||
|
||||
if indices:
|
||||
# Add these groups to queue in order
|
||||
groups_to_add = []
|
||||
for idx in sorted(indices, reverse=True):
|
||||
groups_to_add.append(app.state.buffer[env_id].pop(idx))
|
||||
|
||||
# Add in FIFO order
|
||||
for group in reversed(groups_to_add):
|
||||
app.state.queue.append(group)
|
||||
app.state.latest = group
|
||||
else:
|
||||
# Normal size - add directly to queue
|
||||
app.state.queue.append(data_dict)
|
||||
app.state.latest = data_dict
|
||||
else:
|
||||
# No env info or normal path - add directly to queue
|
||||
app.state.queue.append(data_dict)
|
||||
app.state.latest = data_dict
|
||||
|
||||
return {"status": "received", "groups_processed": len(scored_data_list)}
|
||||
|
||||
|
|
@ -309,6 +409,7 @@ async def get_status_env(env: EnvIdentifier):
|
|||
if x["connected"]
|
||||
]
|
||||
)
|
||||
env_group_size = app.state.envs[env.env_id]["group_size"]
|
||||
env_weight = (
|
||||
app.state.envs[env.env_id]["max_context_len"]
|
||||
* app.state.envs[env.env_id]["weight"]
|
||||
|
|
@ -318,13 +419,95 @@ async def get_status_env(env: EnvIdentifier):
|
|||
0.01, env_weight
|
||||
) # Minimum weight of 0.01 :) TODO: try to figure out a better way to do this
|
||||
|
||||
# Calculate total minimum allocations
|
||||
total_min_allocation = 0.0
|
||||
for env_config in app.state.envs:
|
||||
if (
|
||||
env_config.get("connected", False)
|
||||
and env_config.get("min_batch_allocation") is not None
|
||||
):
|
||||
total_min_allocation += env_config["min_batch_allocation"]
|
||||
|
||||
# Calculate unallocated fraction
|
||||
unallocated_fraction = 1.0 - min(total_min_allocation, 1.0)
|
||||
|
||||
# Find the maximum group size across all items in queue
|
||||
queue = getattr(app.state, "queue", [])
|
||||
max_group_size = 1
|
||||
num_self_sequences_in_queue = 0
|
||||
for item in queue:
|
||||
group_size = len(item.get("tokens", []))
|
||||
if group_size > max_group_size:
|
||||
max_group_size = group_size
|
||||
if item.get("env_id") == env.env_id:
|
||||
# update the group size for the requesting env, handle cases where the group size may be dynamic with max
|
||||
env_group_size = max(env_group_size, group_size)
|
||||
num_self_sequences_in_queue += group_size
|
||||
|
||||
# update the group size for the requesting env
|
||||
app.state.envs[env.env_id]["group_size"] = env_group_size
|
||||
|
||||
# Calculate minimum sequences allocated to each environment
|
||||
batch_size = getattr(app.state, "batchsize", 0)
|
||||
min_sequences_by_env = {}
|
||||
for env_config in app.state.envs:
|
||||
if (
|
||||
env_config.get("connected", False)
|
||||
and env_config.get("min_batch_allocation") is not None
|
||||
):
|
||||
env_id = env_config["registered_id"]
|
||||
min_sequences = int(batch_size * env_config["min_batch_allocation"])
|
||||
min_sequences_by_env[env_id] = min_sequences
|
||||
|
||||
# Count sequences and calculate packed groups for each environment
|
||||
import math
|
||||
|
||||
sequences_by_env = {}
|
||||
packed_groups_by_env = {}
|
||||
curr_env_total_sequences = 0
|
||||
|
||||
for item in queue:
|
||||
env_id = item.get("env_id")
|
||||
seq_count = len(item.get("tokens", []))
|
||||
|
||||
# Special handling for the requesting environment
|
||||
if env_id == env.env_id:
|
||||
curr_env_total_sequences += seq_count
|
||||
else:
|
||||
if env_id not in sequences_by_env:
|
||||
sequences_by_env[env_id] = 0
|
||||
sequences_by_env[env_id] += seq_count
|
||||
|
||||
# Calculate packed groups for each environment (excluding the requesting env)
|
||||
if max_group_size > 1:
|
||||
for env_id, seq_count in sequences_by_env.items():
|
||||
packed_groups_by_env[env_id] = math.ceil(seq_count / max_group_size)
|
||||
|
||||
# Calculate adjusted queue size
|
||||
# (curr_env_total_sequences + sum of available sequences from other envs after their minimums)
|
||||
available_from_others = 0
|
||||
for env_id in packed_groups_by_env:
|
||||
packed_sequences = packed_groups_by_env[env_id] * max_group_size
|
||||
min_sequences = min_sequences_by_env.get(env_id, 0)
|
||||
available_from_others += max(0, packed_sequences - min_sequences)
|
||||
|
||||
env_queue_size = curr_env_total_sequences + available_from_others
|
||||
|
||||
try:
|
||||
ret_dict = {
|
||||
"current_step": app.state.status_dict["step"],
|
||||
"queue_size": len(app.state.queue),
|
||||
"queue_size": env_queue_size // env_group_size,
|
||||
"unallocated_fraction": unallocated_fraction,
|
||||
"self_queue_size": num_self_sequences_in_queue // env_group_size,
|
||||
"max_group_size": max_group_size,
|
||||
}
|
||||
except AttributeError:
|
||||
ret_dict = {"current_step": 0, "queue_size": 0}
|
||||
ret_dict = {
|
||||
"current_step": 0,
|
||||
"queue_size": 0,
|
||||
"unallocated_fraction": 1.0,
|
||||
"num_self_sequences_in_queue": 0,
|
||||
}
|
||||
ret_dict["env_weight"] = env_weight
|
||||
return ret_dict
|
||||
|
||||
|
|
@ -342,6 +525,7 @@ async def reset_data():
|
|||
app.state.started = False
|
||||
app.state.requesters = []
|
||||
app.state.envs = []
|
||||
app.state.buffer = {}
|
||||
except KeyError:
|
||||
pass
|
||||
return PlainTextResponse("Reset successful", status_code=status.HTTP_200_OK)
|
||||
|
|
|
|||
|
|
@ -1,47 +1,104 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def find_groups_summing_to_target(buffer: List[Dict], target_size: int) -> List[int]:
|
||||
"""
|
||||
Find indices of groups in buffer that sum exactly to target_size.
|
||||
Prioritizes FIFO order.
|
||||
|
||||
:param buffer: Buffer of groups from same env
|
||||
:param target_size: Target sum of group sizes
|
||||
:return: List of indices that sum to target_size, or empty list if impossible
|
||||
"""
|
||||
if not buffer:
|
||||
return []
|
||||
|
||||
# First try simple FIFO
|
||||
current_sum = 0
|
||||
indices = []
|
||||
|
||||
for i, group in enumerate(buffer):
|
||||
size = len(group["tokens"])
|
||||
if current_sum + size <= target_size:
|
||||
indices.append(i)
|
||||
current_sum += size
|
||||
if current_sum == target_size:
|
||||
return indices
|
||||
|
||||
# If FIFO doesn't work exactly, try dynamic programming
|
||||
# to find any valid combination (still preferring earlier indices)
|
||||
n = len(buffer)
|
||||
sizes = [len(g["tokens"]) for g in buffer]
|
||||
|
||||
# dp[i][j] = can we make sum j using first i groups
|
||||
dp = [[False] * (target_size + 1) for _ in range(n + 1)]
|
||||
dp[0][0] = True
|
||||
|
||||
for i in range(1, n + 1):
|
||||
for j in range(target_size + 1):
|
||||
# Don't take group i-1
|
||||
dp[i][j] = dp[i - 1][j]
|
||||
# Take group i-1 if possible
|
||||
if j >= sizes[i - 1]:
|
||||
dp[i][j] = dp[i][j] or dp[i - 1][j - sizes[i - 1]]
|
||||
|
||||
if not dp[n][target_size]:
|
||||
return []
|
||||
|
||||
# Backtrack to find indices, preferring earlier ones
|
||||
indices = []
|
||||
j = target_size
|
||||
for i in range(n, 0, -1):
|
||||
if j >= sizes[i - 1] and dp[i - 1][j - sizes[i - 1]]:
|
||||
indices.append(i - 1)
|
||||
j -= sizes[i - 1]
|
||||
|
||||
return sorted(indices) # Return in FIFO order
|
||||
|
||||
|
||||
def grab_exact_from_heterogeneous_queue(
|
||||
queue: List[Dict[str, List]], batch_size: int
|
||||
) -> Tuple[Optional[List], List]:
|
||||
"""
|
||||
Grabs a batch of size batchsize from a queue of different sized items
|
||||
Grabs a batch of exactly batch_size sequences from a queue of items with different group sizes.
|
||||
|
||||
Each item in the queue has a 'tokens' field containing a list of sequences.
|
||||
e.g. queue = [{"tokens": [[1, 2, 3],[4, 5, 6, 7, 8]]}, {"tokens": [[9, 10]]}]
|
||||
where the first item has 2 sequences and the second has 1 sequence.
|
||||
|
||||
without going over the batchsize. This function will return a batch of size batchsize, and the new queue.
|
||||
This function returns a batch containing exactly batch_size sequences total, and the remaining queue.
|
||||
|
||||
Because all groups are a common denominator of the batchsize, and all groups are a power of 2,
|
||||
we can simplify a bit by assuming we can grab groups of groups to be equal to the maximum group size.
|
||||
Note that we cannot drop items from groups, so we must grab the entire group if we grab it.
|
||||
Note that we cannot split items, so we must take the entire item with all its sequences.
|
||||
|
||||
There may be a more efficient clearing mechanism by grouping these smaller groups heterogeneously, but
|
||||
forcing them all into powers of two groups is a simple way to ensure we can grab a batch of the correct size.
|
||||
|
||||
:param queue:
|
||||
:param batch_size:
|
||||
:param queue: List of items, each with a 'tokens' field containing sequences
|
||||
:param batch_size: Target number of sequences for the batch
|
||||
:return: batch, new_queue
|
||||
"""
|
||||
|
||||
# Pass 1: precompute group sizes, total tokens and early exit if not enough tokens.
|
||||
# Pass 1: precompute group sizes, total sequences and early exit if not enough sequences.
|
||||
total_groups = len(queue)
|
||||
if total_groups == 0:
|
||||
return None, queue
|
||||
|
||||
group_sizes = []
|
||||
lengths = []
|
||||
total_tokens = 0
|
||||
total_sequences = 0
|
||||
max_group_size = 0
|
||||
|
||||
for item in queue:
|
||||
length = len(item["tokens"])
|
||||
length = len(item["tokens"]) # Number of sequences in this group
|
||||
lengths.append(length)
|
||||
group_sizes.append(length)
|
||||
total_tokens += length
|
||||
total_sequences += length
|
||||
if length > max_group_size:
|
||||
max_group_size = length
|
||||
|
||||
if total_tokens < batch_size:
|
||||
if total_sequences < batch_size:
|
||||
return None, queue
|
||||
|
||||
group_sizes_set = set(group_sizes)
|
||||
|
|
@ -55,30 +112,168 @@ def grab_exact_from_heterogeneous_queue(
|
|||
potential_batch_indices.extend(group_batching_storage[group_size])
|
||||
group_batching_storage[group_size].clear() # much faster than = []
|
||||
|
||||
# Calculate total batch tokens only once (avoid repeated sums)
|
||||
potential_batch_token_total = sum(lengths[i] for i in potential_batch_indices)
|
||||
if potential_batch_token_total < batch_size:
|
||||
# Calculate total sequences in potential batch only once (avoid repeated sums)
|
||||
potential_batch_sequences_total = sum(lengths[i] for i in potential_batch_indices)
|
||||
if potential_batch_sequences_total < batch_size:
|
||||
return None, queue
|
||||
|
||||
# Batch selection
|
||||
batch = []
|
||||
batch_indices = []
|
||||
running_tokens = 0
|
||||
running_seqs = 0
|
||||
for idx in potential_batch_indices:
|
||||
group = queue[idx]
|
||||
batch.append(group)
|
||||
batch_indices.append(idx)
|
||||
running_tokens += lengths[idx]
|
||||
if running_tokens == batch_size:
|
||||
running_seqs += lengths[idx]
|
||||
if running_seqs == batch_size:
|
||||
break
|
||||
elif running_tokens > batch_size:
|
||||
elif running_seqs > batch_size:
|
||||
# Should never happen due to problem constraints, but sanity check
|
||||
return None, queue
|
||||
|
||||
if running_tokens != batch_size:
|
||||
if running_seqs != batch_size:
|
||||
return None, queue
|
||||
|
||||
# Construct new_queue with a single pass, using a set for O(1) lookup
|
||||
batch_indices_set = set(batch_indices)
|
||||
new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set]
|
||||
return batch, new_queue
|
||||
|
||||
|
||||
def grab_batch_with_minimum_allocations(
|
||||
queue: List[Dict[str, any]], batch_size: int, env_configs: List[Dict[str, any]]
|
||||
) -> Tuple[Optional[List], List]:
|
||||
"""
|
||||
Grabs a batch from the queue while respecting minimum allocation requirements for environments.
|
||||
This function works with groups where each group contains multiple sequences.
|
||||
|
||||
:param queue: List of groups with env_id field and sequences (stored in 'tokens' field)
|
||||
:param batch_size: Target batch size in sequences
|
||||
:param env_configs: List of environment configs with min_batch_allocation field
|
||||
:return: batch, new_queue
|
||||
"""
|
||||
if not queue:
|
||||
return None, queue
|
||||
|
||||
# Build env_id to min allocation mapping
|
||||
env_min_allocations = {}
|
||||
for env in env_configs:
|
||||
if env.get("connected", False) and env.get("min_batch_allocation") is not None:
|
||||
env_min_allocations[env["registered_id"]] = env["min_batch_allocation"]
|
||||
|
||||
# If no minimum allocations, fall back to original function
|
||||
if not env_min_allocations:
|
||||
return grab_exact_from_heterogeneous_queue(queue, batch_size)
|
||||
|
||||
# First, find the maximum group size across all items
|
||||
max_group_size = 0
|
||||
for item in queue:
|
||||
group_size = len(item.get("tokens", []))
|
||||
if group_size > max_group_size:
|
||||
max_group_size = group_size
|
||||
|
||||
# Group queue items by env_id and calculate which can form complete packs
|
||||
items_by_env = {}
|
||||
packable_items_by_env = {}
|
||||
|
||||
for i, item in enumerate(queue):
|
||||
env_id = item.get("env_id")
|
||||
group_size = len(item.get("tokens", []))
|
||||
|
||||
if env_id is not None:
|
||||
if env_id not in items_by_env:
|
||||
items_by_env[env_id] = {}
|
||||
packable_items_by_env[env_id] = []
|
||||
|
||||
if group_size not in items_by_env[env_id]:
|
||||
items_by_env[env_id][group_size] = []
|
||||
|
||||
items_by_env[env_id][group_size].append((i, item, group_size))
|
||||
|
||||
# Check if we can form a complete pack
|
||||
items_of_size = items_by_env[env_id][group_size]
|
||||
if len(items_of_size) * group_size == max_group_size:
|
||||
# We have a complete pack!
|
||||
packable_items_by_env[env_id].extend(items_of_size)
|
||||
items_by_env[env_id][group_size] = []
|
||||
|
||||
# Calculate minimum sequences needed per env
|
||||
min_sequences_per_env = {}
|
||||
total_min_sequences = 0
|
||||
for env_id, min_proportion in env_min_allocations.items():
|
||||
min_sequences = int(batch_size * min_proportion)
|
||||
if min_sequences > 0:
|
||||
# Check if this env has any items in the queue at all
|
||||
if env_id not in items_by_env:
|
||||
# This env has a minimum but no items - can't satisfy minimum
|
||||
return None, queue
|
||||
# Check if this env has any packable items
|
||||
if env_id not in packable_items_by_env or not packable_items_by_env[env_id]:
|
||||
# This env has items but no packable items - can't satisfy minimum
|
||||
return None, queue
|
||||
min_sequences_per_env[env_id] = min_sequences
|
||||
total_min_sequences += min_sequences
|
||||
|
||||
# If minimums exceed batch size, scale them down proportionally
|
||||
if total_min_sequences > batch_size:
|
||||
scale_factor = batch_size / total_min_sequences
|
||||
for env_id in min_sequences_per_env:
|
||||
# Ensure at least one pack from each env with minimum
|
||||
if packable_items_by_env.get(env_id):
|
||||
min_group_size = min(g[2] for g in packable_items_by_env[env_id])
|
||||
min_sequences_per_env[env_id] = max(
|
||||
min_group_size,
|
||||
int(min_sequences_per_env[env_id] * scale_factor),
|
||||
)
|
||||
|
||||
# Build batch ensuring minimums are met
|
||||
batch = []
|
||||
batch_indices = []
|
||||
sequences_taken_per_env = {env_id: 0 for env_id in packable_items_by_env}
|
||||
total_sequences = 0
|
||||
|
||||
# First pass: satisfy minimum requirements using packable items
|
||||
for env_id, min_sequences in min_sequences_per_env.items():
|
||||
if env_id in packable_items_by_env:
|
||||
# Take packable items in order (FIFO)
|
||||
for idx, item, group_size in packable_items_by_env[env_id]:
|
||||
if sequences_taken_per_env[env_id] >= min_sequences:
|
||||
break
|
||||
if total_sequences + group_size <= batch_size:
|
||||
batch.append(item)
|
||||
batch_indices.append(idx)
|
||||
sequences_taken_per_env[env_id] += group_size
|
||||
total_sequences += group_size
|
||||
|
||||
# Second pass: fill remaining slots with packable items from any env
|
||||
if total_sequences < batch_size:
|
||||
# Collect all remaining packable items in queue order
|
||||
all_packable = []
|
||||
for i, item in enumerate(queue):
|
||||
if i not in batch_indices:
|
||||
# Check if this item is in any env's packable list
|
||||
env_id = item.get("env_id")
|
||||
if env_id in packable_items_by_env:
|
||||
for idx, packable_item, size in packable_items_by_env[env_id]:
|
||||
if idx == i:
|
||||
all_packable.append((i, item, size))
|
||||
break
|
||||
|
||||
# Take packable items in queue order
|
||||
for idx, item, group_size in all_packable:
|
||||
if total_sequences + group_size <= batch_size:
|
||||
batch.append(item)
|
||||
batch_indices.append(idx)
|
||||
total_sequences += group_size
|
||||
if total_sequences == batch_size:
|
||||
break
|
||||
|
||||
# If we couldn't form a full batch, return None
|
||||
if total_sequences != batch_size:
|
||||
return None, queue
|
||||
|
||||
# Construct new queue
|
||||
batch_indices_set = set(batch_indices)
|
||||
new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set]
|
||||
return batch, new_queue
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
|
|
@ -162,6 +163,14 @@ class BaseEnvConfig(BaseModel):
|
|||
default=False,
|
||||
description="Whether to include messages in the output transmitted to the trainer",
|
||||
)
|
||||
min_batch_allocation: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Minimum proportion of a batch this environment should be allocated (0.0-1.0)",
|
||||
)
|
||||
worker_timeout: float = Field(
|
||||
default=600,
|
||||
description="Timeout for a a task, in seconds, if -1, no timeout",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
|
@ -237,6 +246,26 @@ class BaseEnv(ABC):
|
|||
else:
|
||||
self.jsonl_writer = None
|
||||
|
||||
@property
|
||||
def derived_batch_size(self):
|
||||
"""Calculate the effective batch size for this environment based on minimum allocations."""
|
||||
# If batch_size is not set or no status yet, return the config batch_size
|
||||
if not hasattr(self, "status_dict") or self.config.batch_size == -1:
|
||||
return self.config.batch_size
|
||||
|
||||
# Get unallocated fraction from status
|
||||
unallocated_fraction = self.status_dict.get("unallocated_fraction", 1.0)
|
||||
|
||||
# If this env has a minimum allocation, add it to the unallocated portion
|
||||
if self.config.min_batch_allocation is not None:
|
||||
effective_fraction = unallocated_fraction + self.config.min_batch_allocation
|
||||
else:
|
||||
# This env competes for the unallocated portion based on its weight
|
||||
effective_fraction = unallocated_fraction
|
||||
|
||||
# Calculate derived batch size
|
||||
return int(self.config.batch_size * effective_fraction)
|
||||
|
||||
@classmethod
|
||||
def config_init(
|
||||
cls,
|
||||
|
|
@ -434,6 +463,8 @@ class BaseEnv(ABC):
|
|||
"max_token_length": self.config.max_token_length,
|
||||
"desired_name": self.config.wandb_name,
|
||||
"weight": self.config.inference_weight,
|
||||
"min_batch_allocation": self.config.min_batch_allocation,
|
||||
"group_size": self.config.group_size,
|
||||
},
|
||||
) as resp:
|
||||
data = await parse_http_response(resp, logger)
|
||||
|
|
@ -614,6 +645,13 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
Send scored data to the API with retry logic for timeouts and server errors.
|
||||
"""
|
||||
# Add env_id to the data
|
||||
if isinstance(scored_data, list):
|
||||
for item in scored_data:
|
||||
item["env_id"] = getattr(self, "env_id", None)
|
||||
else:
|
||||
scored_data["env_id"] = getattr(self, "env_id", None)
|
||||
|
||||
url = (
|
||||
f"{self.config.rollout_server_url}/scored_data_list"
|
||||
if isinstance(scored_data, list)
|
||||
|
|
@ -736,7 +774,7 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
Handle the rollout of an item
|
||||
"""
|
||||
item = self.running_items.get(item_uuid)
|
||||
item = self.running_items.get(item_uuid)["item"]
|
||||
if item is None:
|
||||
print(f"item {item_uuid} not found... returning")
|
||||
return None
|
||||
|
|
@ -813,7 +851,9 @@ class BaseEnv(ABC):
|
|||
self.eval_runner = eval_task
|
||||
if self.config.eval_handling == EvalHandlingEnum.STOP_TRAIN:
|
||||
# Stop training if eval is running
|
||||
self.backlog.extend(self.running_items.values())
|
||||
self.backlog.extend(
|
||||
[x["item"] for x in self.running_items.values()]
|
||||
)
|
||||
for worker in self.workers:
|
||||
worker.cancel()
|
||||
self.workers = set()
|
||||
|
|
@ -852,16 +892,72 @@ class BaseEnv(ABC):
|
|||
max_num_workers,
|
||||
(
|
||||
self.config.max_batches_offpolicy
|
||||
* self.config.batch_size
|
||||
* self.derived_batch_size
|
||||
// self.config.group_size
|
||||
)
|
||||
- (self.status_dict["queue_size"]),
|
||||
)
|
||||
# Now if we have a minimum batch allocation, we need to add workers to fill the self queue, in case of
|
||||
# overruns by other environments
|
||||
if self.config.min_batch_allocation is not None:
|
||||
min_workers_to_fill_self_queue = max(
|
||||
0,
|
||||
math.ceil(
|
||||
(
|
||||
(
|
||||
(
|
||||
math.ceil(
|
||||
self.config.min_batch_allocation
|
||||
* self.config.batch_size
|
||||
* self.config.max_batches_offpolicy
|
||||
/ self.status_dict["max_group_size"]
|
||||
)
|
||||
+ (
|
||||
self.status_dict["max_group_size"]
|
||||
// self.config.group_size
|
||||
)
|
||||
)
|
||||
* self.status_dict["max_group_size"]
|
||||
)
|
||||
- (
|
||||
(
|
||||
self.status_dict["max_group_size"]
|
||||
* self.status_dict["self_queue_size"]
|
||||
// (
|
||||
self.status_dict["max_group_size"]
|
||||
/ self.config.group_size
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
/ self.config.group_size
|
||||
),
|
||||
)
|
||||
max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue)
|
||||
print(
|
||||
f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, "
|
||||
f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}",
|
||||
flush=True,
|
||||
)
|
||||
if (self.curr_step == 0) and (len(self.workers) == 0):
|
||||
# We are starting up, so we should just skip the append to the list
|
||||
pass
|
||||
else:
|
||||
self.workers_added_list.append(max_num_workers - len(self.workers))
|
||||
if len(self.workers) > max_num_workers:
|
||||
print(
|
||||
f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, "
|
||||
"sending workers to backlog",
|
||||
flush=True,
|
||||
)
|
||||
num_to_reduce = len(self.workers) - max_num_workers
|
||||
running_items_to_remove = list(self.running_items.keys())[:num_to_reduce]
|
||||
for item_uuid in running_items_to_remove:
|
||||
self.backlog.append(self.running_items[item_uuid]["item"])
|
||||
self.running_items[item_uuid]["worker"].cancel()
|
||||
self.workers.discard(self.running_items[item_uuid]["worker"])
|
||||
self.running_items.pop(item_uuid)
|
||||
|
||||
while len(self.workers) < max_num_workers:
|
||||
# Generate a UUID for tracking this item
|
||||
item_uuid = str(uuid.uuid4())
|
||||
|
|
@ -871,8 +967,12 @@ class BaseEnv(ABC):
|
|||
item = await self.get_next_item()
|
||||
if item is None:
|
||||
break
|
||||
self.running_items[item_uuid] = item
|
||||
worker = asyncio.create_task(self.handle_env(item_uuid))
|
||||
self.running_items[item_uuid] = {
|
||||
"item": item,
|
||||
"worker": worker,
|
||||
"start_time": time.time(),
|
||||
}
|
||||
self.workers.add(worker)
|
||||
worker.add_done_callback(
|
||||
lambda fut, i=item: (
|
||||
|
|
@ -926,9 +1026,32 @@ class BaseEnv(ABC):
|
|||
>= self.config.max_batches_offpolicy * self.config.batch_size
|
||||
)
|
||||
and (self.config.max_batches_offpolicy > 0)
|
||||
) or (self.config.batch_size == -1):
|
||||
and (
|
||||
(self.config.min_batch_allocation is None)
|
||||
or (
|
||||
(
|
||||
(
|
||||
(
|
||||
math.ceil(
|
||||
self.config.min_batch_allocation
|
||||
* self.config.batch_size
|
||||
* self.config.max_batches_offpolicy
|
||||
/ self.status_dict["max_group_size"]
|
||||
)
|
||||
* (
|
||||
self.status_dict["max_group_size"]
|
||||
// self.config.group_size
|
||||
)
|
||||
)
|
||||
)
|
||||
- (self.status_dict["self_queue_size"])
|
||||
)
|
||||
<= 0
|
||||
)
|
||||
)
|
||||
) or (self.derived_batch_size == -1):
|
||||
# We have too many, lets cleanup the tasks and wait a bit
|
||||
self.backlog.extend(self.running_items.values())
|
||||
self.backlog.extend([x["item"] for x in self.running_items.values()])
|
||||
for worker in self.workers:
|
||||
worker.cancel()
|
||||
self.running_items = dict()
|
||||
|
|
@ -937,6 +1060,18 @@ class BaseEnv(ABC):
|
|||
pass
|
||||
else:
|
||||
await self.add_train_workers()
|
||||
# cleanup workers that have timed out
|
||||
if self.config.worker_timeout > 0:
|
||||
for item_uuid, item in list(self.running_items.items()):
|
||||
if time.time() - item["start_time"] > self.config.worker_timeout:
|
||||
logger.warning(
|
||||
f"Worker {item_uuid} has timed out after {time.time() - item['start_time']} seconds"
|
||||
)
|
||||
item["worker"].cancel()
|
||||
self.workers.discard(item["worker"])
|
||||
self.running_items.pop(item_uuid)
|
||||
# Do we want to retry? probably not...
|
||||
# self.backlog.append(item["item"])
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def process_manager(self):
|
||||
|
|
|
|||
154
atroposlib/tests/test_utils/test_heterogeneous_packing.py
Normal file
154
atroposlib/tests/test_utils/test_heterogeneous_packing.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
"""Tests for heterogeneous group packing utility."""
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.api.utils import find_groups_summing_to_target
|
||||
|
||||
|
||||
class TestHeterogeneousPacking:
|
||||
"""Test cases for finding groups that sum to target size."""
|
||||
|
||||
def test_simple_fifo_exact_match(self):
|
||||
"""Test when FIFO order gives exact match."""
|
||||
buffer = [
|
||||
{"tokens": [[1, 2]], "scores": [0.5]}, # size 1
|
||||
{"tokens": [[3, 4], [5, 6]], "scores": [0.6, 0.7]}, # size 2
|
||||
{"tokens": [[7, 8]], "scores": [0.8]}, # size 1
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 4)
|
||||
assert indices == [0, 1, 2]
|
||||
|
||||
def test_fifo_partial_match(self):
|
||||
"""Test when FIFO can match with subset."""
|
||||
buffer = [
|
||||
{"tokens": [[1, 2], [3, 4]], "scores": [0.5, 0.6]}, # size 2
|
||||
{"tokens": [[5, 6], [7, 8]], "scores": [0.7, 0.8]}, # size 2
|
||||
{
|
||||
"tokens": [[9, 10], [11, 12], [13, 14], [15, 16]],
|
||||
"scores": [0.9, 1.0, 1.1, 1.2],
|
||||
}, # size 4
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 4)
|
||||
assert indices == [0, 1] # First two groups sum to 4
|
||||
|
||||
def test_need_dynamic_programming(self):
|
||||
"""Test when FIFO doesn't work but other combinations do."""
|
||||
buffer = [
|
||||
{"tokens": [[1, 2], [3, 4], [5, 6]], "scores": [0.5, 0.6, 0.7]}, # size 3
|
||||
{"tokens": [[7, 8]], "scores": [0.8]}, # size 1
|
||||
{
|
||||
"tokens": [[9, 10], [11, 12], [13, 14], [15, 16]],
|
||||
"scores": [0.9, 1.0, 1.1, 1.2],
|
||||
}, # size 4
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 5)
|
||||
assert indices == [1, 2] # Groups at index 1 (size 1) and 2 (size 4)
|
||||
|
||||
def test_impossible_target(self):
|
||||
"""Test when no combination can reach target."""
|
||||
buffer = [
|
||||
{"tokens": [[1, 2], [3, 4]], "scores": [0.5, 0.6]}, # size 2
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8], [9, 10], [11, 12]],
|
||||
"scores": [0.7, 0.8, 0.9, 1.0],
|
||||
}, # size 4
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 3)
|
||||
assert indices == [] # Can't make 3 from groups of size 2 and 4
|
||||
|
||||
def test_empty_buffer(self):
|
||||
"""Test with empty buffer."""
|
||||
indices = find_groups_summing_to_target([], 4)
|
||||
assert indices == []
|
||||
|
||||
def test_single_group_exact(self):
|
||||
"""Test when single group matches exactly."""
|
||||
buffer = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4], [5, 6], [7, 8]],
|
||||
"scores": [0.5, 0.6, 0.7, 0.8],
|
||||
}, # size 4
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 4)
|
||||
assert indices == [0]
|
||||
|
||||
def test_bradley_terry_pairs(self):
|
||||
"""Test RLAIF use case with Bradley-Terry pairs."""
|
||||
buffer = [
|
||||
{"tokens": [[1, 2], [3, 4]], "scores": [0.7, 0.3]}, # size 2 (BT pair)
|
||||
{"tokens": [[5, 6], [7, 8]], "scores": [0.6, 0.4]}, # size 2 (BT pair)
|
||||
{"tokens": [[9, 10], [11, 12]], "scores": [0.8, 0.2]}, # size 2 (BT pair)
|
||||
{"tokens": [[13, 14], [15, 16]], "scores": [0.5, 0.5]}, # size 2 (BT pair)
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 8)
|
||||
assert indices == [0, 1, 2, 3] # All 4 pairs
|
||||
|
||||
def test_mixed_sizes_complex(self):
|
||||
"""Test with various power-of-2 sizes."""
|
||||
buffer = [
|
||||
{"tokens": [[1]], "scores": [0.5]}, # size 1
|
||||
{"tokens": [[2], [3]], "scores": [0.6, 0.7]}, # size 2
|
||||
{"tokens": [[4]], "scores": [0.8]}, # size 1
|
||||
{"tokens": [[5], [6], [7], [8]], "scores": [0.9, 1.0, 1.1, 1.2]}, # size 4
|
||||
{"tokens": [[9], [10]], "scores": [1.3, 1.4]}, # size 2
|
||||
]
|
||||
|
||||
# Target 8: should find combination that sums to 8
|
||||
indices = find_groups_summing_to_target(buffer, 8)
|
||||
assert len(indices) > 0
|
||||
assert sum(len(buffer[i]["tokens"]) for i in indices) == 8
|
||||
|
||||
def test_large_groups(self):
|
||||
"""Test with larger group sizes."""
|
||||
buffer = [
|
||||
{"tokens": [[i] for i in range(16)], "scores": [0.5] * 16}, # size 16
|
||||
{"tokens": [[i] for i in range(8)], "scores": [0.6] * 8}, # size 8
|
||||
{"tokens": [[i] for i in range(8)], "scores": [0.7] * 8}, # size 8
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 32)
|
||||
assert indices == [0, 1, 2] # All groups needed
|
||||
|
||||
def test_prefer_earlier_indices(self):
|
||||
"""Test that algorithm prefers earlier indices when multiple solutions exist."""
|
||||
buffer = [
|
||||
{"tokens": [[1], [2]], "scores": [0.5, 0.6]}, # size 2
|
||||
{"tokens": [[3], [4]], "scores": [0.7, 0.8]}, # size 2
|
||||
{"tokens": [[5], [6], [7], [8]], "scores": [0.9, 1.0, 1.1, 1.2]}, # size 4
|
||||
{"tokens": [[9], [10]], "scores": [1.3, 1.4]}, # size 2
|
||||
{"tokens": [[11], [12]], "scores": [1.5, 1.6]}, # size 2
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 4)
|
||||
assert indices == [0, 1] # Should prefer first two groups over later ones
|
||||
|
||||
def test_exact_fit_with_remainder(self):
|
||||
"""Test when we can form exact target but have leftover groups."""
|
||||
buffer = [
|
||||
{"tokens": [[1], [2]], "scores": [0.5, 0.6]}, # size 2
|
||||
{"tokens": [[3], [4], [5], [6]], "scores": [0.7, 0.8, 0.9, 1.0]}, # size 4
|
||||
{"tokens": [[7], [8]], "scores": [1.1, 1.2]}, # size 2
|
||||
{"tokens": [[9]], "scores": [1.3]}, # size 1
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 6)
|
||||
assert sorted(indices) == [0, 1] # First two groups sum to 6
|
||||
|
||||
def test_stress_many_small_groups(self):
|
||||
"""Test with many small groups."""
|
||||
# Create 16 groups of size 1
|
||||
buffer = [{"tokens": [[i]], "scores": [i * 0.1]} for i in range(16)]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 8)
|
||||
assert len(indices) == 8
|
||||
assert indices == list(range(8)) # Should take first 8
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
806
atroposlib/tests/test_utils/test_min_batch_allocation.py
Normal file
806
atroposlib/tests/test_utils/test_min_batch_allocation.py
Normal file
|
|
@ -0,0 +1,806 @@
|
|||
"""Tests for minimum batch allocation functionality."""
|
||||
|
||||
import random
|
||||
|
||||
from atroposlib.api.utils import grab_batch_with_minimum_allocations
|
||||
|
||||
|
||||
class TestMinBatchAllocation:
|
||||
"""Test cases for minimum batch allocation feature."""
|
||||
|
||||
def test_basic_minimum_allocation(self):
|
||||
"""Test that basic minimum allocations are respected."""
|
||||
# Each item represents a group with multiple token sequences
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # 2 groups
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
}, # 2 groups
|
||||
{
|
||||
"tokens": [[9, 10], [11, 12]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # 2 groups
|
||||
{
|
||||
"tokens": [[13, 14], [15, 16]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
}, # 2 groups
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.25,
|
||||
}, # 25% = 2 groups min
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # 50% = 4 groups min
|
||||
]
|
||||
|
||||
batch_size = 8 # 8 token groups total
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(
|
||||
queue, batch_size, env_configs
|
||||
)
|
||||
|
||||
assert batch is not None
|
||||
|
||||
# Count groups (not items) per environment
|
||||
env_groups = {}
|
||||
total_groups = 0
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
groups = len(item["tokens"])
|
||||
env_groups[env_id] = env_groups.get(env_id, 0) + groups
|
||||
total_groups += groups
|
||||
|
||||
assert total_groups == batch_size
|
||||
|
||||
# Env 1 should have at least 50% (4 groups)
|
||||
assert env_groups.get(1, 0) >= 4
|
||||
# Env 0 should have at least 25% (2 groups)
|
||||
assert env_groups.get(0, 0) >= 2
|
||||
|
||||
def test_no_minimum_allocation_fallback(self):
|
||||
"""Test fallback to original function when no minimums specified."""
|
||||
queue = [
|
||||
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
|
||||
{"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]},
|
||||
{"tokens": [[5, 6]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
|
||||
{"tokens": [[7, 8]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True},
|
||||
{"registered_id": 1, "connected": True},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
# Should still form a batch using original logic
|
||||
assert batch is not None
|
||||
assert len(new_queue) < len(queue)
|
||||
|
||||
def test_conflicting_minimums_scale_down(self):
|
||||
"""Test that conflicting minimums > 100% are scaled down."""
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # 2 groups
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
}, # 2 groups
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.7}, # 70%
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.6,
|
||||
}, # 60% = 130% total
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
# Should still form a batch with scaled allocations
|
||||
assert batch is not None
|
||||
assert len(batch) == 2 # Both items needed to form batch of 4 groups
|
||||
|
||||
def test_missing_env_in_queue(self):
|
||||
"""Test handling when an env has minimum but no items in queue."""
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # 2 groups
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # 2 groups
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.3},
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # No items!
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
# Should return None because env 1 has minimum allocation but no items
|
||||
assert batch is None
|
||||
|
||||
def test_disconnected_env_ignored(self):
|
||||
"""Test that disconnected environments are ignored."""
|
||||
queue = [
|
||||
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
|
||||
{"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.25},
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": False,
|
||||
"min_batch_allocation": 0.75,
|
||||
}, # Disconnected!
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs)
|
||||
|
||||
# Should only consider connected env
|
||||
assert batch is not None
|
||||
# May include env 1 items but won't enforce its minimum
|
||||
|
||||
def test_mixed_group_sizes(self):
|
||||
"""Test handling of different group sizes."""
|
||||
queue = [
|
||||
{"tokens": [[1]], "env_id": 0, "masks": [[1]], "scores": [0.5]}, # size 1
|
||||
{
|
||||
"tokens": [[2, 3, 4, 5]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1, 1, 1]],
|
||||
"scores": [0.6, 0.7, 0.8, 0.9],
|
||||
}, # size 4
|
||||
{
|
||||
"tokens": [[6, 7]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # size 2
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": 0.25},
|
||||
]
|
||||
|
||||
# Try to form batch of size 7 (which would need all items)
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 7, env_configs)
|
||||
|
||||
if batch is not None:
|
||||
total_tokens = sum(len(item["tokens"]) for item in batch)
|
||||
assert total_tokens == 7
|
||||
|
||||
def test_empty_queue(self):
|
||||
"""Test handling of empty queue."""
|
||||
queue = []
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
assert batch is None
|
||||
assert new_queue == []
|
||||
|
||||
def test_insufficient_items_for_batch(self):
|
||||
"""Test when there aren't enough items to form a full batch."""
|
||||
queue = [
|
||||
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
|
||||
]
|
||||
|
||||
# Request batch size 4 but only have 2 tokens
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
assert batch is None
|
||||
assert len(new_queue) == 1 # Original queue unchanged
|
||||
|
||||
def test_heterogeneous_envs(self):
|
||||
"""Test envs with individual group sizes."""
|
||||
# Env 0: all groups have size 2
|
||||
# Env 1: all groups have size 4
|
||||
# Env 2: all groups have size 8
|
||||
queue = []
|
||||
|
||||
# Add items for env 0 (group size 2)
|
||||
for i in range(1):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 2, i * 2 + 1] for _ in range(2)],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1] for _ in range(2)],
|
||||
"scores": [0.5, 0.6],
|
||||
}
|
||||
)
|
||||
# for i in range(1):
|
||||
# queue.append(
|
||||
# {
|
||||
# "tokens": [[i * 2, i * 2 + 1] for _ in range(2)],
|
||||
# "env_id": 1,
|
||||
# "masks": [[1, 1] for _ in range(2)],
|
||||
# "scores": [0.5, 0.6],
|
||||
# }
|
||||
# )
|
||||
# Add 3 items of group size 2 to show why greedy packing doesn't work
|
||||
for i in range(3):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 2, i * 2 + 1] for _ in range(2)],
|
||||
"env_id": 6,
|
||||
"masks": [[1, 1] for _ in range(2)],
|
||||
"scores": [0.5, 0.6],
|
||||
}
|
||||
)
|
||||
|
||||
# Add items for env 1 (group size 4)
|
||||
for i in range(5):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [
|
||||
[i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3] for _ in range(4)
|
||||
],
|
||||
"env_id": 2,
|
||||
"masks": [[1, 1, 1, 1] for _ in range(4)],
|
||||
"scores": [0.7, 0.8, 0.9, 1.0],
|
||||
}
|
||||
)
|
||||
|
||||
# Add items for env 2 (group size 8)
|
||||
for i in range(3):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 8 + j] for j in range(8)],
|
||||
"env_id": 3,
|
||||
"masks": [[1] for _ in range(8)],
|
||||
"scores": [0.5] * 8,
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.1,
|
||||
}, # min 2 sequences
|
||||
# {
|
||||
# "registered_id": 1,
|
||||
# "connected": True,
|
||||
# "min_batch_allocation": 0.1,
|
||||
# }, # min 2 sequences
|
||||
# {
|
||||
# "registered_id": 2,
|
||||
# "connected": True,
|
||||
# "min_batch_allocation": 0.25,
|
||||
# }, # min 4 sequences
|
||||
{
|
||||
"registered_id": 3,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # min 8 sequences
|
||||
]
|
||||
|
||||
batch_size = 16
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(
|
||||
queue, batch_size, env_configs
|
||||
)
|
||||
|
||||
# Since env 0 has min allocation of 10% but can't form any complete packs
|
||||
# (has 1 item of size 2, needs 4 to make pack of 8), the function should
|
||||
# return None as it cannot satisfy the minimum allocation requirement
|
||||
assert batch is None
|
||||
|
||||
# Queue should be unchanged
|
||||
assert len(new_queue) == len(queue)
|
||||
|
||||
def test_packing_constraint_enforcement(self):
|
||||
"""Test that packing to max group size is properly enforced."""
|
||||
# Create queue with items that can't form complete packs
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.5],
|
||||
}, # size 1
|
||||
{
|
||||
"tokens": [[3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.5],
|
||||
}, # size 1
|
||||
{
|
||||
"tokens": [[5, 6]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.5],
|
||||
}, # size 1
|
||||
# Need 4 items of size 1 to make a pack of 4, only have 3
|
||||
{
|
||||
"tokens": [[7, 8], [9, 10], [11, 12], [13, 14]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1]] * 4,
|
||||
"scores": [0.7] * 4,
|
||||
}, # size 4
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.25},
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": None},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
# Should return None because env 0 can't form complete packs
|
||||
assert batch is None
|
||||
|
||||
def test_fifo_order_preservation(self):
|
||||
"""Test that FIFO order is preserved when forming batches."""
|
||||
queue = []
|
||||
# Add items with sequential scores to track order
|
||||
for i in range(8):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i, i + 1]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [float(i)], # Use score to track original order
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": None},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
if batch is not None:
|
||||
# Check that we got the first 4 items (scores 0-3)
|
||||
batch_scores = [item["scores"][0] for item in batch]
|
||||
assert sorted(batch_scores) == [0.0, 1.0, 2.0, 3.0]
|
||||
|
||||
def test_exact_minimum_boundary(self):
|
||||
"""Test behavior at exact minimum allocation boundaries."""
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
},
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
},
|
||||
{
|
||||
"tokens": [[9, 10], [11, 12]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
},
|
||||
{
|
||||
"tokens": [[13, 14], [15, 16]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # Exactly 50%
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # Exactly 50%
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 8, env_configs)
|
||||
|
||||
assert batch is not None
|
||||
env_counts = {}
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
count = len(item["tokens"])
|
||||
env_counts[env_id] = env_counts.get(env_id, 0) + count
|
||||
|
||||
# Both envs should get exactly 4 sequences (50%)
|
||||
assert env_counts[0] == 4
|
||||
assert env_counts[1] == 4
|
||||
|
||||
def test_zero_minimum_allocation(self):
|
||||
"""Test that zero minimum allocation is handled correctly."""
|
||||
queue = [
|
||||
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]},
|
||||
{"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7]},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.0}, # 0%
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": 0.5}, # 50%
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs)
|
||||
|
||||
# Should work fine - env 0 has no minimum requirement
|
||||
assert batch is not None
|
||||
|
||||
def test_multiple_complete_packs(self):
|
||||
"""Test forming multiple complete packs from same environment."""
|
||||
queue = []
|
||||
# Add 16 items of size 1 from env 0 (can form 4 complete packs of 4)
|
||||
for i in range(16):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i]],
|
||||
"env_id": 0,
|
||||
"masks": [[1]],
|
||||
"scores": [0.5],
|
||||
}
|
||||
)
|
||||
|
||||
# Add 2 items of size 4 from env 1
|
||||
for i in range(2):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[100 + i * 4, 101 + i * 4, 102 + i * 4, 103 + i * 4]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1, 1, 1]],
|
||||
"scores": [0.7],
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.75,
|
||||
}, # 12 sequences
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": None},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 16, env_configs)
|
||||
|
||||
assert batch is not None
|
||||
env_counts = {}
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
count = len(item["tokens"])
|
||||
env_counts[env_id] = env_counts.get(env_id, 0) + count
|
||||
|
||||
# Env 0 should get at least 12 sequences
|
||||
assert env_counts.get(0, 0) >= 12
|
||||
assert sum(env_counts.values()) == 16
|
||||
|
||||
def test_no_packable_items(self):
|
||||
"""Test when no items can form complete packs."""
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # size 2
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # size 2
|
||||
{
|
||||
"tokens": [[9, 10], [11, 12]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # size 2
|
||||
# Only 3 items of size 2, need 4 to make complete pack of 8
|
||||
{
|
||||
"tokens": [
|
||||
[13, 14, 15, 16],
|
||||
[17, 18, 19, 20],
|
||||
[21, 22, 23, 24],
|
||||
[25, 26, 27, 28],
|
||||
[29, 30, 31, 32],
|
||||
[33, 34, 35, 36],
|
||||
[37, 38, 39, 40],
|
||||
[41, 42, 43, 44],
|
||||
],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1, 1, 1]] * 8,
|
||||
"scores": [0.7] * 8,
|
||||
}, # size 8
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.25,
|
||||
}, # Can't form complete packs
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": None},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 8, env_configs)
|
||||
|
||||
# Env 0 can't form complete packs (has 3 items, needs 4)
|
||||
assert batch is None
|
||||
|
||||
def test_env_without_items(self):
|
||||
"""Test env config without any items in queue."""
|
||||
queue = [
|
||||
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]},
|
||||
{"tokens": [[3, 4]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": None,
|
||||
}, # No items
|
||||
{
|
||||
"registered_id": 2,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.3,
|
||||
}, # No items
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs)
|
||||
|
||||
# Should work - env 2 has no items so its minimum is ignored
|
||||
assert batch is not None
|
||||
|
||||
def test_scaling_with_single_env(self):
|
||||
"""Test scaling behavior with only one env having minimum."""
|
||||
queue = []
|
||||
for i in range(8):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i, i + 1]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.5],
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 1.5,
|
||||
}, # 150% - impossible
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
# Should scale down to 100% and work
|
||||
assert batch is not None
|
||||
assert len(batch) == 4
|
||||
|
||||
def test_mixed_null_and_set_minimums(self):
|
||||
"""Test mix of environments with and without minimum allocations."""
|
||||
queue = []
|
||||
# Env 0: 4 items of size 2
|
||||
for i in range(4):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 2, i * 2 + 1], [i * 2 + 10, i * 2 + 11]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}
|
||||
)
|
||||
# Env 1: 2 items of size 2
|
||||
for i in range(2):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 2 + 20, i * 2 + 21], [i * 2 + 30, i * 2 + 31]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
}
|
||||
)
|
||||
# Env 2: 2 items of size 2 (no minimum)
|
||||
for i in range(2):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 2 + 40, i * 2 + 41], [i * 2 + 50, i * 2 + 51]],
|
||||
"env_id": 2,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.9, 1.0],
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.4,
|
||||
}, # 40% = 6.4 ≈ 6
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.2,
|
||||
}, # 20% = 3.2 ≈ 3
|
||||
{
|
||||
"registered_id": 2,
|
||||
"connected": True,
|
||||
"min_batch_allocation": None,
|
||||
}, # No minimum
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 16, env_configs)
|
||||
|
||||
assert batch is not None
|
||||
env_counts = {}
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
count = len(item["tokens"])
|
||||
env_counts[env_id] = env_counts.get(env_id, 0) + count
|
||||
|
||||
# Check minimums are satisfied
|
||||
assert env_counts.get(0, 0) >= 6 # At least 40% of 16
|
||||
assert env_counts.get(1, 0) >= 3 # At least 20% of 16
|
||||
assert sum(env_counts.values()) == 16
|
||||
|
||||
def test_random_consistent_group_sizes(self):
|
||||
"""Random test where each env has a consistent power-of-2 group size."""
|
||||
for _ in range(100):
|
||||
batch_size = 64 * random.randint(1, 4)
|
||||
num_envs = random.randint(2, 4)
|
||||
|
||||
# Assign each env a consistent group size
|
||||
env_group_sizes = {}
|
||||
for env_id in range(num_envs):
|
||||
env_group_sizes[env_id] = 2 ** random.randint(0, 3) # 1, 2, 4, or 8
|
||||
|
||||
# Create queue
|
||||
queue = []
|
||||
for env_id in range(num_envs):
|
||||
group_size = env_group_sizes[env_id]
|
||||
num_items = random.randint(5, 20)
|
||||
for i in range(num_items):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [
|
||||
[env_id * 1000 + i * 10 + j] for j in range(group_size)
|
||||
],
|
||||
"env_id": env_id,
|
||||
"masks": [[1] for _ in range(group_size)],
|
||||
"scores": [0.5 + env_id * 0.1] * group_size,
|
||||
}
|
||||
)
|
||||
|
||||
# Random minimum allocations that sum to less than 1.0
|
||||
env_configs = []
|
||||
remaining = 0.9
|
||||
for env_id in range(num_envs):
|
||||
if env_id == num_envs - 1:
|
||||
min_alloc = remaining
|
||||
else:
|
||||
min_alloc = random.uniform(0.1, min(0.4, remaining))
|
||||
remaining -= min_alloc
|
||||
|
||||
env_configs.append(
|
||||
{
|
||||
"registered_id": env_id,
|
||||
"connected": True,
|
||||
"min_batch_allocation": min_alloc,
|
||||
}
|
||||
)
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(
|
||||
queue, batch_size, env_configs
|
||||
)
|
||||
|
||||
if batch is not None:
|
||||
# Verify batch size
|
||||
total_sequences = sum(len(item["tokens"]) for item in batch)
|
||||
assert total_sequences == batch_size
|
||||
|
||||
# Verify all items from same env have same group size
|
||||
env_group_sizes_seen = {}
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
group_size = len(item["tokens"])
|
||||
if env_id in env_group_sizes_seen:
|
||||
assert group_size == env_group_sizes_seen[env_id]
|
||||
else:
|
||||
env_group_sizes_seen[env_id] = group_size
|
||||
|
||||
def test_queue_dominated_by_one_env(self):
|
||||
"""Test minimum allocation when one env dominates the queue."""
|
||||
queue = []
|
||||
|
||||
# Only env 1 items in queue
|
||||
for i in range(100):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[1000 + i, 1001 + i]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.7],
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # 50% but no items!
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": 0.3}, # 30%
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 10, env_configs)
|
||||
|
||||
# Should return None because env 0 has minimum allocation but no items
|
||||
assert batch is None
|
||||
|
||||
# Test with env 0 having no minimum - should work
|
||||
env_configs[0]["min_batch_allocation"] = None
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 10, env_configs)
|
||||
|
||||
assert batch is not None
|
||||
env_counts = {}
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
count = len(item["tokens"])
|
||||
env_counts[env_id] = env_counts.get(env_id, 0) + count
|
||||
|
||||
# Should all be from env 1
|
||||
assert env_counts.get(1, 0) == 10
|
||||
assert sum(env_counts.values()) == 10
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = TestMinBatchAllocation()
|
||||
test.test_queue_dominated_by_one_env()
|
||||
Loading…
Add table
Add a link
Reference in a new issue