mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-24 17:05:03 +00:00
minor arc_1d tweaks
This commit is contained in:
parent
ec3050a4f6
commit
469934d9b7
2 changed files with 57 additions and 34 deletions
|
|
@ -1,3 +1,5 @@
|
|||
from random import Random
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.arc import Arc1DConfig, Arc1DDataset
|
||||
|
|
@ -125,3 +127,18 @@ def test_arc_1d_size_ranges(min_size: int, max_size: int):
|
|||
assert min_size <= len(entry["metadata"]["test_example"]["input"]) <= max_size
|
||||
assert min_size <= len(entry["metadata"]["test_example"]["output"]) <= max_size
|
||||
assert dataset.score_answer(entry["answer"], entry) == 1.0
|
||||
|
||||
|
||||
def test_arc_1d_generate_all_tasks():
|
||||
config = Arc1DConfig(size=100, seed=17, min_size=8, max_size=10)
|
||||
dataset = Arc1DDataset(config)
|
||||
tasks = dataset.ARC_1D_TASKS
|
||||
rng = Random(999)
|
||||
for task_name, (generator_fn, args) in tasks.items():
|
||||
for j in range(3):
|
||||
for i in range(20):
|
||||
x = generator_fn(rng=rng, size=10, **args)
|
||||
if x is not None:
|
||||
break
|
||||
assert i < 20
|
||||
print(task_name, j, i, x)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue