minor arc_1d tweaks

This commit is contained in:
Andreas Koepf 2025-02-23 16:37:40 +01:00
parent ec3050a4f6
commit 469934d9b7
2 changed files with 57 additions and 34 deletions

View file

@ -802,13 +802,13 @@ def task_duplicate_block_from_seeds(rng: Random, size: int) -> Optional[dict[str
def task_fill_from_pixel(rng: Random, size: int) -> Optional[dict[str, list[int]]]:
"""Generate a task where a pixel fills in one direction until hitting another pixel."""
if size < 6:
if size < 8:
return None
block_size = rng.randint(3, size - 3)
block_size = rng.randint(3, size - 5)
# Position block with space for seed
block_pos = rng.randint(1, size - block_size - 1)
block_pos = rng.randint(2, size - block_size - 2)
# Create input
question = gen_field(size)
@ -819,9 +819,9 @@ def task_fill_from_pixel(rng: Random, size: int) -> Optional[dict[str, list[int]
question[block_pos + i] = block_color
# Place seed pixel and determine fill direction
seed_color = rng.randint(1, 9)
while seed_color == block_color:
seed_color = rng.randint(1, 9)
seed_color = rng.randint(1, 8)
if seed_color >= block_color:
seed_color += 1
is_left = rng.random() < 0.5
@ -847,48 +847,51 @@ def task_fill_from_pixel(rng: Random, size: int) -> Optional[dict[str, list[int]
def task_mark_size_two_blocks(rng: Random, size: int) -> Optional[dict[str, list[int]]]:
"""Generate a task where size-2 blocks are marked with surrounding pixels."""
blocks = []
pos = 0
if size < 8:
return None
# Generate blocks with minimum gap of 2
# Start with one size-2 block
blocks = [2]
pos = 4 # Space for first block (2) + gap (2)
# Generate more blocks
while pos < size:
if rng.random() < 0.4:
block_size = rng.randint(1, 3)
# Check if we have space for block and potential markers
needed_space = block_size + (2 if block_size == 2 else 0)
if pos + needed_space < size:
blocks.append((pos, block_size))
pos += block_size + 2 # Minimum gap of 2
if pos + block_size <= size:
blocks.append(block_size)
pos += block_size + 2 # block + gap
else:
blocks.append(0)
pos += 1
pos += 1
# Shuffle block sizes
rng.shuffle(blocks)
if len(blocks) < 2:
return None
# Assign positions with proper gaps
block_positions = []
pos = 0
# Verify gaps between blocks (including markers)
valid = True
for i in range(len(blocks) - 1):
pos1, size1 = blocks[i]
pos2, _ = blocks[i + 1]
needed_gap = 3 if size1 == 2 else 2
if pos2 - (pos1 + size1) < needed_gap:
valid = False
break
if not valid:
return None
for block_size in blocks:
if block_size == 0:
pos += 1
else:
block_positions.append((pos, block_size))
pos += block_size + 2 # Move past block + gap
# Create input with blocks
question = gen_field(size)
for pos, block_size in blocks:
# Place block
for pos, block_size in block_positions:
block_color = rng.randint(1, 8)
if block_color >= 3: # avoid marker color 3
block_color += 1
for i in range(block_size):
question[pos + i] = 1
question[pos + i] = block_color
# Create answer with markers
answer = question.copy()
for pos, block_size in blocks:
for pos, block_size in block_positions:
if block_size == 2:
# Add markers for size 2 blocks
if pos > 0:
answer[pos - 1] = 3
if pos + block_size < size:
@ -935,7 +938,10 @@ def task_fill_until_collision(rng: Random, size: int) -> Optional[dict[str, list
# Color random pixels
for pos in positions:
question[pos] = rng.randint(1, 9)
c = rng.randint(1, 8)
if c >= 5: # don't use side marker color 5
c += 1
question[pos] = c
positions.sort()

View file

@ -1,3 +1,5 @@
from random import Random
import pytest
from reasoning_gym.arc import Arc1DConfig, Arc1DDataset
@ -125,3 +127,18 @@ def test_arc_1d_size_ranges(min_size: int, max_size: int):
assert min_size <= len(entry["metadata"]["test_example"]["input"]) <= max_size
assert min_size <= len(entry["metadata"]["test_example"]["output"]) <= max_size
assert dataset.score_answer(entry["answer"], entry) == 1.0
def test_arc_1d_generate_all_tasks():
config = Arc1DConfig(size=100, seed=17, min_size=8, max_size=10)
dataset = Arc1DDataset(config)
tasks = dataset.ARC_1D_TASKS
rng = Random(999)
for task_name, (generator_fn, args) in tasks.items():
for j in range(3):
for i in range(20):
x = generator_fn(rng=rng, size=10, **args)
if x is not None:
break
assert i < 20
print(task_name, j, i, x)