diff --git a/atroposlib/api/utils.py b/atroposlib/api/utils.py index c2fef67c..5f95feb1 100644 --- a/atroposlib/api/utils.py +++ b/atroposlib/api/utils.py @@ -22,40 +22,63 @@ def grab_exact_from_heterogeneous_queue( :param batch_size: :return: batch, new_queue """ - # check if we can even potentially grab a batch - if sum(len(item["tokens"]) for item in queue) < batch_size: + + # Pass 1: precompute group sizes, total tokens and early exit if not enough tokens. + total_groups = len(queue) + if total_groups == 0: return None, queue - # Get max batch size - max_group_size = max(len(group["tokens"]) for group in queue) - group_sizes = set(len(group["tokens"]) for group in queue) - group_batching_storage = {i: [] for i in group_sizes} - # pack the groups into [max_group_size // group_size] packs - potential_batch = [] - for i, item in enumerate(queue): - key = len(item["tokens"]) - group_batching_storage[key].append({"group": item, "indx": i}) - if len(group_batching_storage[key]) * key == max_group_size: - potential_batch.extend(group_batching_storage[key]) - group_batching_storage[key] = [] - if ( - sum(len(grouped_items["group"]["tokens"]) for grouped_items in potential_batch) - < batch_size - ): + + group_sizes = [] + lengths = [] + total_tokens = 0 + max_group_size = 0 + + for item in queue: + l = len(item["tokens"]) + lengths.append(l) + group_sizes.append(l) + total_tokens += l + if l > max_group_size: + max_group_size = l + + if total_tokens < batch_size: return None, queue - # we have a batch + + group_sizes_set = set(group_sizes) + group_batching_storage = {size: [] for size in group_sizes_set} + + # Index into the queue and batch related indices into "packs" + potential_batch_indices = [] + for i, group_size in enumerate(group_sizes): + group_batching_storage[group_size].append(i) + if len(group_batching_storage[group_size]) * group_size == max_group_size: + potential_batch_indices.extend(group_batching_storage[group_size]) + group_batching_storage[group_size].clear() # much faster than = [] + + # Calculate total batch tokens only once (avoid repeated sums) + potential_batch_token_total = sum(lengths[i] for i in potential_batch_indices) + if potential_batch_token_total < batch_size: + return None, queue + + # Batch selection batch = [] - indxes_to_remove_from_queue = [] - for item in potential_batch: - group = item["group"] - indx = item["indx"] + batch_indices = [] + running_tokens = 0 + for idx in potential_batch_indices: + group = queue[idx] batch.append(group) - indxes_to_remove_from_queue.append(indx) - if sum(len(item["tokens"]) for item in batch) == batch_size: + batch_indices.append(idx) + running_tokens += lengths[idx] + if running_tokens == batch_size: break - if sum(len(item["tokens"]) for item in batch) != batch_size: + elif running_tokens > batch_size: + # Should never happen due to problem constraints, but sanity check + return None, queue + + if running_tokens != batch_size: return None, queue - # remove the items from the queue - new_queue = [ - item for i, item in enumerate(queue) if i not in indxes_to_remove_from_queue - ] + + # Construct new_queue with a single pass, using a set for O(1) lookup + batch_indices_set = set(batch_indices) + new_queue = [item for i, item in enumerate(queue) if i not in batch_indices_set] return batch, new_queue