mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
28 lines
956 B
Python
28 lines
956 B
Python
import random
|
|
|
|
from atroposlib.api.utils import grab_exact_from_heterogeneous_queue
|
|
|
|
|
|
def test_grab_exact_from_heterogeneous_queue():
|
|
"randomly samples from the space of potential inputs to grab_exact_from_heterogeneous_queue"
|
|
for random_bs in range(10000):
|
|
bs = 64 * random.randint(1, 20)
|
|
queue = []
|
|
for i in range(random.randint(1, 100)):
|
|
# queue.append(
|
|
# {
|
|
# "tokens": [[2 * i] for _ in range(2)],
|
|
# }
|
|
# )
|
|
queue.append(
|
|
{
|
|
"tokens": [[2 * i + 1] for _ in range(8)],
|
|
}
|
|
)
|
|
batch, queue = grab_exact_from_heterogeneous_queue(queue, bs)
|
|
if random_bs == 0:
|
|
print(batch)
|
|
if batch is not None:
|
|
assert (
|
|
sum(len(item["tokens"]) for item in batch) == bs
|
|
), f"expected batch size {bs}, got {len(batch)}"
|