mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
feat: add minimum batch allocation support for environments
- Add min_batch_allocation parameter to ensure environments contribute minimum proportion to each batch - Implement grab_batch_with_minimum_allocations function with proper scaling when allocations exceed 100% - Add mixed-size group buffering to handle variable-sized data submissions - Update server to use minimum allocation logic when any env has min_batch_allocation set - Add comprehensive tests for minimum allocation scenarios - Update documentation in API README and CONFIG.md - Update example environments to demonstrate the feature This feature allows critical environments to guarantee they contribute at least a specified proportion (0.0-1.0) to each training batch, ensuring important data sources are always represented during training. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
4769eeb4a6
commit
08e14cc745
11 changed files with 1670 additions and 91 deletions
154
atroposlib/tests/test_utils/test_heterogeneous_packing.py
Normal file
154
atroposlib/tests/test_utils/test_heterogeneous_packing.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
"""Tests for heterogeneous group packing utility."""
|
||||
|
||||
import pytest
|
||||
|
||||
from atroposlib.api.utils import find_groups_summing_to_target
|
||||
|
||||
|
||||
class TestHeterogeneousPacking:
|
||||
"""Test cases for finding groups that sum to target size."""
|
||||
|
||||
def test_simple_fifo_exact_match(self):
|
||||
"""Test when FIFO order gives exact match."""
|
||||
buffer = [
|
||||
{"tokens": [[1, 2]], "scores": [0.5]}, # size 1
|
||||
{"tokens": [[3, 4], [5, 6]], "scores": [0.6, 0.7]}, # size 2
|
||||
{"tokens": [[7, 8]], "scores": [0.8]}, # size 1
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 4)
|
||||
assert indices == [0, 1, 2]
|
||||
|
||||
def test_fifo_partial_match(self):
|
||||
"""Test when FIFO can match with subset."""
|
||||
buffer = [
|
||||
{"tokens": [[1, 2], [3, 4]], "scores": [0.5, 0.6]}, # size 2
|
||||
{"tokens": [[5, 6], [7, 8]], "scores": [0.7, 0.8]}, # size 2
|
||||
{
|
||||
"tokens": [[9, 10], [11, 12], [13, 14], [15, 16]],
|
||||
"scores": [0.9, 1.0, 1.1, 1.2],
|
||||
}, # size 4
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 4)
|
||||
assert indices == [0, 1] # First two groups sum to 4
|
||||
|
||||
def test_need_dynamic_programming(self):
|
||||
"""Test when FIFO doesn't work but other combinations do."""
|
||||
buffer = [
|
||||
{"tokens": [[1, 2], [3, 4], [5, 6]], "scores": [0.5, 0.6, 0.7]}, # size 3
|
||||
{"tokens": [[7, 8]], "scores": [0.8]}, # size 1
|
||||
{
|
||||
"tokens": [[9, 10], [11, 12], [13, 14], [15, 16]],
|
||||
"scores": [0.9, 1.0, 1.1, 1.2],
|
||||
}, # size 4
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 5)
|
||||
assert indices == [1, 2] # Groups at index 1 (size 1) and 2 (size 4)
|
||||
|
||||
def test_impossible_target(self):
|
||||
"""Test when no combination can reach target."""
|
||||
buffer = [
|
||||
{"tokens": [[1, 2], [3, 4]], "scores": [0.5, 0.6]}, # size 2
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8], [9, 10], [11, 12]],
|
||||
"scores": [0.7, 0.8, 0.9, 1.0],
|
||||
}, # size 4
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 3)
|
||||
assert indices == [] # Can't make 3 from groups of size 2 and 4
|
||||
|
||||
def test_empty_buffer(self):
|
||||
"""Test with empty buffer."""
|
||||
indices = find_groups_summing_to_target([], 4)
|
||||
assert indices == []
|
||||
|
||||
def test_single_group_exact(self):
|
||||
"""Test when single group matches exactly."""
|
||||
buffer = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4], [5, 6], [7, 8]],
|
||||
"scores": [0.5, 0.6, 0.7, 0.8],
|
||||
}, # size 4
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 4)
|
||||
assert indices == [0]
|
||||
|
||||
def test_bradley_terry_pairs(self):
|
||||
"""Test RLAIF use case with Bradley-Terry pairs."""
|
||||
buffer = [
|
||||
{"tokens": [[1, 2], [3, 4]], "scores": [0.7, 0.3]}, # size 2 (BT pair)
|
||||
{"tokens": [[5, 6], [7, 8]], "scores": [0.6, 0.4]}, # size 2 (BT pair)
|
||||
{"tokens": [[9, 10], [11, 12]], "scores": [0.8, 0.2]}, # size 2 (BT pair)
|
||||
{"tokens": [[13, 14], [15, 16]], "scores": [0.5, 0.5]}, # size 2 (BT pair)
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 8)
|
||||
assert indices == [0, 1, 2, 3] # All 4 pairs
|
||||
|
||||
def test_mixed_sizes_complex(self):
|
||||
"""Test with various power-of-2 sizes."""
|
||||
buffer = [
|
||||
{"tokens": [[1]], "scores": [0.5]}, # size 1
|
||||
{"tokens": [[2], [3]], "scores": [0.6, 0.7]}, # size 2
|
||||
{"tokens": [[4]], "scores": [0.8]}, # size 1
|
||||
{"tokens": [[5], [6], [7], [8]], "scores": [0.9, 1.0, 1.1, 1.2]}, # size 4
|
||||
{"tokens": [[9], [10]], "scores": [1.3, 1.4]}, # size 2
|
||||
]
|
||||
|
||||
# Target 8: should find combination that sums to 8
|
||||
indices = find_groups_summing_to_target(buffer, 8)
|
||||
assert len(indices) > 0
|
||||
assert sum(len(buffer[i]["tokens"]) for i in indices) == 8
|
||||
|
||||
def test_large_groups(self):
|
||||
"""Test with larger group sizes."""
|
||||
buffer = [
|
||||
{"tokens": [[i] for i in range(16)], "scores": [0.5] * 16}, # size 16
|
||||
{"tokens": [[i] for i in range(8)], "scores": [0.6] * 8}, # size 8
|
||||
{"tokens": [[i] for i in range(8)], "scores": [0.7] * 8}, # size 8
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 32)
|
||||
assert indices == [0, 1, 2] # All groups needed
|
||||
|
||||
def test_prefer_earlier_indices(self):
|
||||
"""Test that algorithm prefers earlier indices when multiple solutions exist."""
|
||||
buffer = [
|
||||
{"tokens": [[1], [2]], "scores": [0.5, 0.6]}, # size 2
|
||||
{"tokens": [[3], [4]], "scores": [0.7, 0.8]}, # size 2
|
||||
{"tokens": [[5], [6], [7], [8]], "scores": [0.9, 1.0, 1.1, 1.2]}, # size 4
|
||||
{"tokens": [[9], [10]], "scores": [1.3, 1.4]}, # size 2
|
||||
{"tokens": [[11], [12]], "scores": [1.5, 1.6]}, # size 2
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 4)
|
||||
assert indices == [0, 1] # Should prefer first two groups over later ones
|
||||
|
||||
def test_exact_fit_with_remainder(self):
|
||||
"""Test when we can form exact target but have leftover groups."""
|
||||
buffer = [
|
||||
{"tokens": [[1], [2]], "scores": [0.5, 0.6]}, # size 2
|
||||
{"tokens": [[3], [4], [5], [6]], "scores": [0.7, 0.8, 0.9, 1.0]}, # size 4
|
||||
{"tokens": [[7], [8]], "scores": [1.1, 1.2]}, # size 2
|
||||
{"tokens": [[9]], "scores": [1.3]}, # size 1
|
||||
]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 6)
|
||||
assert sorted(indices) == [0, 1] # First two groups sum to 6
|
||||
|
||||
def test_stress_many_small_groups(self):
|
||||
"""Test with many small groups."""
|
||||
# Create 16 groups of size 1
|
||||
buffer = [{"tokens": [[i]], "scores": [i * 0.1]} for i in range(16)]
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, 8)
|
||||
assert len(indices) == 8
|
||||
assert indices == list(range(8)) # Should take first 8
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
806
atroposlib/tests/test_utils/test_min_batch_allocation.py
Normal file
806
atroposlib/tests/test_utils/test_min_batch_allocation.py
Normal file
|
|
@ -0,0 +1,806 @@
|
|||
"""Tests for minimum batch allocation functionality."""
|
||||
|
||||
import random
|
||||
|
||||
from atroposlib.api.utils import grab_batch_with_minimum_allocations
|
||||
|
||||
|
||||
class TestMinBatchAllocation:
|
||||
"""Test cases for minimum batch allocation feature."""
|
||||
|
||||
def test_basic_minimum_allocation(self):
|
||||
"""Test that basic minimum allocations are respected."""
|
||||
# Each item represents a group with multiple token sequences
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # 2 groups
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
}, # 2 groups
|
||||
{
|
||||
"tokens": [[9, 10], [11, 12]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # 2 groups
|
||||
{
|
||||
"tokens": [[13, 14], [15, 16]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
}, # 2 groups
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.25,
|
||||
}, # 25% = 2 groups min
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # 50% = 4 groups min
|
||||
]
|
||||
|
||||
batch_size = 8 # 8 token groups total
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(
|
||||
queue, batch_size, env_configs
|
||||
)
|
||||
|
||||
assert batch is not None
|
||||
|
||||
# Count groups (not items) per environment
|
||||
env_groups = {}
|
||||
total_groups = 0
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
groups = len(item["tokens"])
|
||||
env_groups[env_id] = env_groups.get(env_id, 0) + groups
|
||||
total_groups += groups
|
||||
|
||||
assert total_groups == batch_size
|
||||
|
||||
# Env 1 should have at least 50% (4 groups)
|
||||
assert env_groups.get(1, 0) >= 4
|
||||
# Env 0 should have at least 25% (2 groups)
|
||||
assert env_groups.get(0, 0) >= 2
|
||||
|
||||
def test_no_minimum_allocation_fallback(self):
|
||||
"""Test fallback to original function when no minimums specified."""
|
||||
queue = [
|
||||
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
|
||||
{"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]},
|
||||
{"tokens": [[5, 6]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
|
||||
{"tokens": [[7, 8]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True},
|
||||
{"registered_id": 1, "connected": True},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
# Should still form a batch using original logic
|
||||
assert batch is not None
|
||||
assert len(new_queue) < len(queue)
|
||||
|
||||
def test_conflicting_minimums_scale_down(self):
|
||||
"""Test that conflicting minimums > 100% are scaled down."""
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # 2 groups
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
}, # 2 groups
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.7}, # 70%
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.6,
|
||||
}, # 60% = 130% total
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
# Should still form a batch with scaled allocations
|
||||
assert batch is not None
|
||||
assert len(batch) == 2 # Both items needed to form batch of 4 groups
|
||||
|
||||
def test_missing_env_in_queue(self):
|
||||
"""Test handling when an env has minimum but no items in queue."""
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # 2 groups
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # 2 groups
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.3},
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # No items!
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
# Should return None because env 1 has minimum allocation but no items
|
||||
assert batch is None
|
||||
|
||||
def test_disconnected_env_ignored(self):
|
||||
"""Test that disconnected environments are ignored."""
|
||||
queue = [
|
||||
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
|
||||
{"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7, 0.8]},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.25},
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": False,
|
||||
"min_batch_allocation": 0.75,
|
||||
}, # Disconnected!
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs)
|
||||
|
||||
# Should only consider connected env
|
||||
assert batch is not None
|
||||
# May include env 1 items but won't enforce its minimum
|
||||
|
||||
def test_mixed_group_sizes(self):
|
||||
"""Test handling of different group sizes."""
|
||||
queue = [
|
||||
{"tokens": [[1]], "env_id": 0, "masks": [[1]], "scores": [0.5]}, # size 1
|
||||
{
|
||||
"tokens": [[2, 3, 4, 5]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1, 1, 1]],
|
||||
"scores": [0.6, 0.7, 0.8, 0.9],
|
||||
}, # size 4
|
||||
{
|
||||
"tokens": [[6, 7]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # size 2
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": 0.25},
|
||||
]
|
||||
|
||||
# Try to form batch of size 7 (which would need all items)
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 7, env_configs)
|
||||
|
||||
if batch is not None:
|
||||
total_tokens = sum(len(item["tokens"]) for item in batch)
|
||||
assert total_tokens == 7
|
||||
|
||||
def test_empty_queue(self):
|
||||
"""Test handling of empty queue."""
|
||||
queue = []
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
assert batch is None
|
||||
assert new_queue == []
|
||||
|
||||
def test_insufficient_items_for_batch(self):
|
||||
"""Test when there aren't enough items to form a full batch."""
|
||||
queue = [
|
||||
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5, 0.6]},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
|
||||
]
|
||||
|
||||
# Request batch size 4 but only have 2 tokens
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
assert batch is None
|
||||
assert len(new_queue) == 1 # Original queue unchanged
|
||||
|
||||
def test_heterogeneous_envs(self):
|
||||
"""Test envs with individual group sizes."""
|
||||
# Env 0: all groups have size 2
|
||||
# Env 1: all groups have size 4
|
||||
# Env 2: all groups have size 8
|
||||
queue = []
|
||||
|
||||
# Add items for env 0 (group size 2)
|
||||
for i in range(1):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 2, i * 2 + 1] for _ in range(2)],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1] for _ in range(2)],
|
||||
"scores": [0.5, 0.6],
|
||||
}
|
||||
)
|
||||
# for i in range(1):
|
||||
# queue.append(
|
||||
# {
|
||||
# "tokens": [[i * 2, i * 2 + 1] for _ in range(2)],
|
||||
# "env_id": 1,
|
||||
# "masks": [[1, 1] for _ in range(2)],
|
||||
# "scores": [0.5, 0.6],
|
||||
# }
|
||||
# )
|
||||
# Add 3 items of group size 2 to show why greedy packing doesn't work
|
||||
for i in range(3):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 2, i * 2 + 1] for _ in range(2)],
|
||||
"env_id": 6,
|
||||
"masks": [[1, 1] for _ in range(2)],
|
||||
"scores": [0.5, 0.6],
|
||||
}
|
||||
)
|
||||
|
||||
# Add items for env 1 (group size 4)
|
||||
for i in range(5):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [
|
||||
[i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3] for _ in range(4)
|
||||
],
|
||||
"env_id": 2,
|
||||
"masks": [[1, 1, 1, 1] for _ in range(4)],
|
||||
"scores": [0.7, 0.8, 0.9, 1.0],
|
||||
}
|
||||
)
|
||||
|
||||
# Add items for env 2 (group size 8)
|
||||
for i in range(3):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 8 + j] for j in range(8)],
|
||||
"env_id": 3,
|
||||
"masks": [[1] for _ in range(8)],
|
||||
"scores": [0.5] * 8,
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.1,
|
||||
}, # min 2 sequences
|
||||
# {
|
||||
# "registered_id": 1,
|
||||
# "connected": True,
|
||||
# "min_batch_allocation": 0.1,
|
||||
# }, # min 2 sequences
|
||||
# {
|
||||
# "registered_id": 2,
|
||||
# "connected": True,
|
||||
# "min_batch_allocation": 0.25,
|
||||
# }, # min 4 sequences
|
||||
{
|
||||
"registered_id": 3,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # min 8 sequences
|
||||
]
|
||||
|
||||
batch_size = 16
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(
|
||||
queue, batch_size, env_configs
|
||||
)
|
||||
|
||||
# Since env 0 has min allocation of 10% but can't form any complete packs
|
||||
# (has 1 item of size 2, needs 4 to make pack of 8), the function should
|
||||
# return None as it cannot satisfy the minimum allocation requirement
|
||||
assert batch is None
|
||||
|
||||
# Queue should be unchanged
|
||||
assert len(new_queue) == len(queue)
|
||||
|
||||
def test_packing_constraint_enforcement(self):
|
||||
"""Test that packing to max group size is properly enforced."""
|
||||
# Create queue with items that can't form complete packs
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.5],
|
||||
}, # size 1
|
||||
{
|
||||
"tokens": [[3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.5],
|
||||
}, # size 1
|
||||
{
|
||||
"tokens": [[5, 6]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.5],
|
||||
}, # size 1
|
||||
# Need 4 items of size 1 to make a pack of 4, only have 3
|
||||
{
|
||||
"tokens": [[7, 8], [9, 10], [11, 12], [13, 14]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1]] * 4,
|
||||
"scores": [0.7] * 4,
|
||||
}, # size 4
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.25},
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": None},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
# Should return None because env 0 can't form complete packs
|
||||
assert batch is None
|
||||
|
||||
def test_fifo_order_preservation(self):
|
||||
"""Test that FIFO order is preserved when forming batches."""
|
||||
queue = []
|
||||
# Add items with sequential scores to track order
|
||||
for i in range(8):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i, i + 1]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [float(i)], # Use score to track original order
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": None},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
if batch is not None:
|
||||
# Check that we got the first 4 items (scores 0-3)
|
||||
batch_scores = [item["scores"][0] for item in batch]
|
||||
assert sorted(batch_scores) == [0.0, 1.0, 2.0, 3.0]
|
||||
|
||||
def test_exact_minimum_boundary(self):
|
||||
"""Test behavior at exact minimum allocation boundaries."""
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
},
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
},
|
||||
{
|
||||
"tokens": [[9, 10], [11, 12]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
},
|
||||
{
|
||||
"tokens": [[13, 14], [15, 16]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # Exactly 50%
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # Exactly 50%
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 8, env_configs)
|
||||
|
||||
assert batch is not None
|
||||
env_counts = {}
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
count = len(item["tokens"])
|
||||
env_counts[env_id] = env_counts.get(env_id, 0) + count
|
||||
|
||||
# Both envs should get exactly 4 sequences (50%)
|
||||
assert env_counts[0] == 4
|
||||
assert env_counts[1] == 4
|
||||
|
||||
def test_zero_minimum_allocation(self):
|
||||
"""Test that zero minimum allocation is handled correctly."""
|
||||
queue = [
|
||||
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]},
|
||||
{"tokens": [[3, 4]], "env_id": 1, "masks": [[1, 1]], "scores": [0.7]},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.0}, # 0%
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": 0.5}, # 50%
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs)
|
||||
|
||||
# Should work fine - env 0 has no minimum requirement
|
||||
assert batch is not None
|
||||
|
||||
def test_multiple_complete_packs(self):
|
||||
"""Test forming multiple complete packs from same environment."""
|
||||
queue = []
|
||||
# Add 16 items of size 1 from env 0 (can form 4 complete packs of 4)
|
||||
for i in range(16):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i]],
|
||||
"env_id": 0,
|
||||
"masks": [[1]],
|
||||
"scores": [0.5],
|
||||
}
|
||||
)
|
||||
|
||||
# Add 2 items of size 4 from env 1
|
||||
for i in range(2):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[100 + i * 4, 101 + i * 4, 102 + i * 4, 103 + i * 4]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1, 1, 1]],
|
||||
"scores": [0.7],
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.75,
|
||||
}, # 12 sequences
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": None},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 16, env_configs)
|
||||
|
||||
assert batch is not None
|
||||
env_counts = {}
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
count = len(item["tokens"])
|
||||
env_counts[env_id] = env_counts.get(env_id, 0) + count
|
||||
|
||||
# Env 0 should get at least 12 sequences
|
||||
assert env_counts.get(0, 0) >= 12
|
||||
assert sum(env_counts.values()) == 16
|
||||
|
||||
def test_no_packable_items(self):
|
||||
"""Test when no items can form complete packs."""
|
||||
queue = [
|
||||
{
|
||||
"tokens": [[1, 2], [3, 4]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # size 2
|
||||
{
|
||||
"tokens": [[5, 6], [7, 8]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # size 2
|
||||
{
|
||||
"tokens": [[9, 10], [11, 12]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}, # size 2
|
||||
# Only 3 items of size 2, need 4 to make complete pack of 8
|
||||
{
|
||||
"tokens": [
|
||||
[13, 14, 15, 16],
|
||||
[17, 18, 19, 20],
|
||||
[21, 22, 23, 24],
|
||||
[25, 26, 27, 28],
|
||||
[29, 30, 31, 32],
|
||||
[33, 34, 35, 36],
|
||||
[37, 38, 39, 40],
|
||||
[41, 42, 43, 44],
|
||||
],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1, 1, 1]] * 8,
|
||||
"scores": [0.7] * 8,
|
||||
}, # size 8
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.25,
|
||||
}, # Can't form complete packs
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": None},
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 8, env_configs)
|
||||
|
||||
# Env 0 can't form complete packs (has 3 items, needs 4)
|
||||
assert batch is None
|
||||
|
||||
def test_env_without_items(self):
|
||||
"""Test env config without any items in queue."""
|
||||
queue = [
|
||||
{"tokens": [[1, 2]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]},
|
||||
{"tokens": [[3, 4]], "env_id": 0, "masks": [[1, 1]], "scores": [0.5]},
|
||||
]
|
||||
|
||||
env_configs = [
|
||||
{"registered_id": 0, "connected": True, "min_batch_allocation": 0.5},
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": None,
|
||||
}, # No items
|
||||
{
|
||||
"registered_id": 2,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.3,
|
||||
}, # No items
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 2, env_configs)
|
||||
|
||||
# Should work - env 2 has no items so its minimum is ignored
|
||||
assert batch is not None
|
||||
|
||||
def test_scaling_with_single_env(self):
|
||||
"""Test scaling behavior with only one env having minimum."""
|
||||
queue = []
|
||||
for i in range(8):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i, i + 1]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.5],
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 1.5,
|
||||
}, # 150% - impossible
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 4, env_configs)
|
||||
|
||||
# Should scale down to 100% and work
|
||||
assert batch is not None
|
||||
assert len(batch) == 4
|
||||
|
||||
def test_mixed_null_and_set_minimums(self):
|
||||
"""Test mix of environments with and without minimum allocations."""
|
||||
queue = []
|
||||
# Env 0: 4 items of size 2
|
||||
for i in range(4):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 2, i * 2 + 1], [i * 2 + 10, i * 2 + 11]],
|
||||
"env_id": 0,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.5, 0.6],
|
||||
}
|
||||
)
|
||||
# Env 1: 2 items of size 2
|
||||
for i in range(2):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 2 + 20, i * 2 + 21], [i * 2 + 30, i * 2 + 31]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.7, 0.8],
|
||||
}
|
||||
)
|
||||
# Env 2: 2 items of size 2 (no minimum)
|
||||
for i in range(2):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[i * 2 + 40, i * 2 + 41], [i * 2 + 50, i * 2 + 51]],
|
||||
"env_id": 2,
|
||||
"masks": [[1, 1], [1, 1]],
|
||||
"scores": [0.9, 1.0],
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.4,
|
||||
}, # 40% = 6.4 ≈ 6
|
||||
{
|
||||
"registered_id": 1,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.2,
|
||||
}, # 20% = 3.2 ≈ 3
|
||||
{
|
||||
"registered_id": 2,
|
||||
"connected": True,
|
||||
"min_batch_allocation": None,
|
||||
}, # No minimum
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 16, env_configs)
|
||||
|
||||
assert batch is not None
|
||||
env_counts = {}
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
count = len(item["tokens"])
|
||||
env_counts[env_id] = env_counts.get(env_id, 0) + count
|
||||
|
||||
# Check minimums are satisfied
|
||||
assert env_counts.get(0, 0) >= 6 # At least 40% of 16
|
||||
assert env_counts.get(1, 0) >= 3 # At least 20% of 16
|
||||
assert sum(env_counts.values()) == 16
|
||||
|
||||
def test_random_consistent_group_sizes(self):
|
||||
"""Random test where each env has a consistent power-of-2 group size."""
|
||||
for _ in range(100):
|
||||
batch_size = 64 * random.randint(1, 4)
|
||||
num_envs = random.randint(2, 4)
|
||||
|
||||
# Assign each env a consistent group size
|
||||
env_group_sizes = {}
|
||||
for env_id in range(num_envs):
|
||||
env_group_sizes[env_id] = 2 ** random.randint(0, 3) # 1, 2, 4, or 8
|
||||
|
||||
# Create queue
|
||||
queue = []
|
||||
for env_id in range(num_envs):
|
||||
group_size = env_group_sizes[env_id]
|
||||
num_items = random.randint(5, 20)
|
||||
for i in range(num_items):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [
|
||||
[env_id * 1000 + i * 10 + j] for j in range(group_size)
|
||||
],
|
||||
"env_id": env_id,
|
||||
"masks": [[1] for _ in range(group_size)],
|
||||
"scores": [0.5 + env_id * 0.1] * group_size,
|
||||
}
|
||||
)
|
||||
|
||||
# Random minimum allocations that sum to less than 1.0
|
||||
env_configs = []
|
||||
remaining = 0.9
|
||||
for env_id in range(num_envs):
|
||||
if env_id == num_envs - 1:
|
||||
min_alloc = remaining
|
||||
else:
|
||||
min_alloc = random.uniform(0.1, min(0.4, remaining))
|
||||
remaining -= min_alloc
|
||||
|
||||
env_configs.append(
|
||||
{
|
||||
"registered_id": env_id,
|
||||
"connected": True,
|
||||
"min_batch_allocation": min_alloc,
|
||||
}
|
||||
)
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(
|
||||
queue, batch_size, env_configs
|
||||
)
|
||||
|
||||
if batch is not None:
|
||||
# Verify batch size
|
||||
total_sequences = sum(len(item["tokens"]) for item in batch)
|
||||
assert total_sequences == batch_size
|
||||
|
||||
# Verify all items from same env have same group size
|
||||
env_group_sizes_seen = {}
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
group_size = len(item["tokens"])
|
||||
if env_id in env_group_sizes_seen:
|
||||
assert group_size == env_group_sizes_seen[env_id]
|
||||
else:
|
||||
env_group_sizes_seen[env_id] = group_size
|
||||
|
||||
def test_queue_dominated_by_one_env(self):
|
||||
"""Test minimum allocation when one env dominates the queue."""
|
||||
queue = []
|
||||
|
||||
# Only env 1 items in queue
|
||||
for i in range(100):
|
||||
queue.append(
|
||||
{
|
||||
"tokens": [[1000 + i, 1001 + i]],
|
||||
"env_id": 1,
|
||||
"masks": [[1, 1]],
|
||||
"scores": [0.7],
|
||||
}
|
||||
)
|
||||
|
||||
env_configs = [
|
||||
{
|
||||
"registered_id": 0,
|
||||
"connected": True,
|
||||
"min_batch_allocation": 0.5,
|
||||
}, # 50% but no items!
|
||||
{"registered_id": 1, "connected": True, "min_batch_allocation": 0.3}, # 30%
|
||||
]
|
||||
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 10, env_configs)
|
||||
|
||||
# Should return None because env 0 has minimum allocation but no items
|
||||
assert batch is None
|
||||
|
||||
# Test with env 0 having no minimum - should work
|
||||
env_configs[0]["min_batch_allocation"] = None
|
||||
batch, new_queue = grab_batch_with_minimum_allocations(queue, 10, env_configs)
|
||||
|
||||
assert batch is not None
|
||||
env_counts = {}
|
||||
for item in batch:
|
||||
env_id = item["env_id"]
|
||||
count = len(item["tokens"])
|
||||
env_counts[env_id] = env_counts.get(env_id, 0) + count
|
||||
|
||||
# Should all be from env 1
|
||||
assert env_counts.get(1, 0) == 10
|
||||
assert sum(env_counts.values()) == 10
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = TestMinBatchAllocation()
|
||||
test.test_queue_dominated_by_one_env()
|
||||
Loading…
Add table
Add a link
Reference in a new issue