mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
Merge pull request #7 from misrasaurabh1/codeflash/optimize-grab_exact_from_heterogeneous_queue-ma3pegzo
⚡️ Speed up function `grab_exact_from_heterogeneous_queue` by 1,680%
This commit is contained in:
commit
c1ba77ec26
1 changed files with 53 additions and 30 deletions
|
|
@ -22,40 +22,63 @@ def grab_exact_from_heterogeneous_queue(
|
|||
:param batch_size:
|
||||
:return: batch, new_queue
|
||||
"""
|
||||
# check if we can even potentially grab a batch
|
||||
if sum(len(item["tokens"]) for item in queue) < batch_size:
|
||||
|
||||
# Pass 1: precompute group sizes, total tokens and early exit if not enough tokens.
|
||||
total_groups = len(queue)
|
||||
if total_groups == 0:
|
||||
return None, queue
|
||||
# Get max batch size
|
||||
max_group_size = max(len(group["tokens"]) for group in queue)
|
||||
group_sizes = set(len(group["tokens"]) for group in queue)
|
||||
group_batching_storage = {i: [] for i in group_sizes}
|
||||
# pack the groups into [max_group_size // group_size] packs
|
||||
potential_batch = []
|
||||
for i, item in enumerate(queue):
|
||||
key = len(item["tokens"])
|
||||
group_batching_storage[key].append({"group": item, "indx": i})
|
||||
if len(group_batching_storage[key]) * key == max_group_size:
|
||||
potential_batch.extend(group_batching_storage[key])
|
||||
group_batching_storage[key] = []
|
||||
if (
|
||||
sum(len(grouped_items["group"]["tokens"]) for grouped_items in potential_batch)
|
||||
< batch_size
|
||||
):
|
||||
|
||||
group_sizes = []
|
||||
lengths = []
|
||||
total_tokens = 0
|
||||
max_group_size = 0
|
||||
|
||||
for item in queue:
|
||||
l = len(item["tokens"])
|
||||
lengths.append(l)
|
||||
group_sizes.append(l)
|
||||
total_tokens += l
|
||||
if l > max_group_size:
|
||||
max_group_size = l
|
||||
|
||||
if total_tokens < batch_size:
|
||||
return None, queue
|
||||
# we have a batch
|
||||
|
||||
group_sizes_set = set(group_sizes)
|
||||
group_batching_storage = {size: [] for size in group_sizes_set}
|
||||
|
||||
# Index into the queue and batch related indices into "packs"
|
||||
potential_batch_indices = []
|
||||
for i, group_size in enumerate(group_sizes):
|
||||
group_batching_storage[group_size].append(i)
|
||||
if len(group_batching_storage[group_size]) * group_size == max_group_size:
|
||||
potential_batch_indices.extend(group_batching_storage[group_size])
|
||||
group_batching_storage[group_size].clear() # much faster than = []
|
||||
|
||||
# Calculate total batch tokens only once (avoid repeated sums)
|
||||
potential_batch_token_total = sum(lengths[i] for i in potential_batch_indices)
|
||||
if potential_batch_token_total < batch_size:
|
||||
return None, queue
|
||||
|
||||
# Batch selection
|
||||
batch = []
|
||||
indxes_to_remove_from_queue = []
|
||||
for item in potential_batch:
|
||||
group = item["group"]
|
||||
indx = item["indx"]
|
||||
batch_indices = []
|
||||
running_tokens = 0
|
||||
for idx in potential_batch_indices:
|
||||
group = queue[idx]
|
||||
batch.append(group)
|
||||
indxes_to_remove_from_queue.append(indx)
|
||||
if sum(len(item["tokens"]) for item in batch) == batch_size:
|
||||
batch_indices.append(idx)
|
||||
running_tokens += lengths[idx]
|
||||
if running_tokens == batch_size:
|
||||
break
|
||||
if sum(len(item["tokens"]) for item in batch) != batch_size:
|
||||
elif running_tokens > batch_size:
|
||||
# Should never happen due to problem constraints, but sanity check
|
||||
return None, queue
|
||||
|
||||
if running_tokens != batch_size:
|
||||
return None, queue
|
||||
# remove the items from the queue
|
||||
new_queue = [
|
||||
item for i, item in enumerate(queue) if i not in indxes_to_remove_from_queue
|
||||
]
|
||||
|
||||
# Construct new_queue with a single pass, using a set for O(1) lookup
|
||||
batch_indices_set = set(batch_indices)
|
||||
new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set]
|
||||
return batch, new_queue
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue