fix index out of range of arc_1d dataset (#190)

This commit is contained in:
Andreas Koepf 2025-02-23 12:51:41 +01:00
parent a1a305c8d7
commit e444bbf7a1
3 changed files with 30 additions and 24 deletions

View file

@ -18,7 +18,7 @@ class Arc1DConfig:
def validate(self) -> None: def validate(self) -> None:
"""Validate configuration parameters""" """Validate configuration parameters"""
assert self.min_size > 0, "min_size must be positive" assert self.min_size >= 8, "min_size must be >= 8"
assert self.max_size >= self.min_size, "max_size must be >= min_size" assert self.max_size >= self.min_size, "max_size must be >= min_size"
assert self.num_train > 0, "num_train must be positive" assert self.num_train > 0, "num_train must be positive"
assert self.size > 0, "size must be positive" assert self.size > 0, "size must be positive"

View file

@ -38,7 +38,7 @@ def task_move_n_pix(rng: Random, size: int, move_pix: int, solid: bool) -> Optio
def task_move_n_pix_wrapped(rng: Random, size: int, move_pix: int, solid: bool) -> Optional[dict[str, list[int]]]: def task_move_n_pix_wrapped(rng: Random, size: int, move_pix: int, solid: bool) -> Optional[dict[str, list[int]]]:
"""Generate a task where a block is moved to the right by move_pix pixels with wrapping.""" """Generate a task where a block is moved to the right by move_pix pixels with wrapping."""
block_size = rng.randint(1, size) block_size = rng.randint(1, size)
block_pos = rng.randint(0, size) block_pos = rng.randint(0, size - 1)
if solid: if solid:
color = rng.randint(1, 9) color = rng.randint(1, 9)
@ -95,8 +95,8 @@ def task_block_touch_dot(rng: Random, size: int) -> Optional[dict[str, list[int]
dot_color = 1 dot_color = 1
block_color = rng.randint(2, 9) block_color = rng.randint(2, 9)
block_size = rng.randint(1, size) block_size = rng.randint(1, size - 1)
dot_pos = rng.randint(0, size) dot_pos = rng.randint(0, size - 1)
can_place_left = dot_pos >= block_size can_place_left = dot_pos >= block_size
can_place_right = dot_pos + block_size < size can_place_right = dot_pos + block_size < size
@ -134,8 +134,8 @@ def task_block_touch_dot_n_pix(rng: Random, size: int, move_pix: int) -> Optiona
dot_color = 2 dot_color = 2
block_color = rng.randint(3, 9) block_color = rng.randint(3, 9)
block_size = rng.randint(1, size) block_size = rng.randint(1, size - 1)
dot_pos = rng.randint(0, size) dot_pos = rng.randint(0, size - 1)
can_place_left = dot_pos >= block_size can_place_left = dot_pos >= block_size
can_place_right = dot_pos + block_size < size can_place_right = dot_pos + block_size < size
@ -177,8 +177,8 @@ def task_block_scale_to_dot(rng: Random, size: int) -> Optional[dict[str, list[i
dot_color = 2 dot_color = 2
block_color = rng.randint(3, 9) block_color = rng.randint(3, 9)
block_size = rng.randint(1, size) block_size = rng.randint(1, size - 1)
dot_pos = rng.randint(0, size) dot_pos = rng.randint(0, size - 1)
can_place_left = dot_pos >= block_size can_place_left = dot_pos >= block_size
can_place_right = dot_pos + block_size < size can_place_right = dot_pos + block_size < size
@ -271,16 +271,13 @@ def task_reflect_block_with_border_pixel_random(rng: Random, size: int) -> Optio
side = "left" if rng.random() < 0.5 else "right" side = "left" if rng.random() < 0.5 else "right"
pos = rng.randint(0, size - block_size) pos = rng.randint(0, size - block_size)
block = [rng.randint(1, 9) for _ in range(block_size)]
border_color = rng.randint(1, 9) border_color = rng.randint(1, 9)
other_colors = tuple(c for c in range(1, 9) if c != border_color)
block = [rng.choice(other_colors) for _ in range(block_size)]
if side == "left": if side == "left":
if block[0] == border_color:
return None
block[0] = border_color block[0] = border_color
else: else:
if block[block_size - 1] == border_color:
return None
block[block_size - 1] = border_color block[block_size - 1] = border_color
question = write_block(pos, block, gen_field(size)) question = write_block(pos, block, gen_field(size))
@ -294,8 +291,8 @@ def task_reflect_block_around_dot(rng: Random, size: int) -> Optional[dict[str,
"""Generate a task where a block is reflected around a dot.""" """Generate a task where a block is reflected around a dot."""
dot_color = 2 dot_color = 2
dot_pos = rng.randint(0, size) dot_pos = rng.randint(0, size - 1)
block_size = rng.randint(1, size) block_size = rng.randint(1, size - 1)
block_pos = rng.randint(0, size - block_size) block_pos = rng.randint(0, size - block_size)
block_end = block_pos + block_size - 1 block_end = block_pos + block_size - 1
@ -471,7 +468,7 @@ def task_copy_block_to_dots_colors(rng: Random, size: int) -> Optional[dict[str,
dot_colors = [] dot_colors = []
pos = block_size + block_size // 2 + 1 pos = block_size + block_size // 2 + 1
while pos < size - block_size: while pos <= size - block_size:
if rng.random() < 0.5: if rng.random() < 0.5:
dot_color = rng.randint(1, 9) dot_color = rng.randint(1, 9)
dot_positions.append(pos) dot_positions.append(pos)
@ -759,13 +756,14 @@ def task_duplicate_block_from_seeds(rng: Random, size: int) -> Optional[dict[str
return None return None
# Position block with space for seeds # Position block with space for seeds
block_pos = rng.randint(2, size - block_size - 1) block_pos = rng.randint(2, size - block_size - 2)
# Decide seed placement # Decide seed placement
left_seed = rng.random() < 0.5 left_seed = False
right_seed = rng.random() < 0.5 right_seed = False
if not (left_seed or right_seed): while not left_seed and not right_seed:
return None left_seed = rng.random() < 0.5
right_seed = rng.random() < 0.5
# Create input # Create input
question = gen_field(size) question = gen_field(size)
@ -1039,8 +1037,8 @@ def task_color_left_half_blocks(rng: Random, size: int) -> Optional[dict[str, li
# Generate blocks with gap 1 # Generate blocks with gap 1
while pos < size: while pos < size:
if rng.random() < 0.4: if rng.random() < 0.4:
block_size = rng.randint(2, 8) block_size = rng.randint(2, size // 2)
if pos + block_size >= size: if pos + block_size > size:
break break
blocks.append((pos, block_size)) blocks.append((pos, block_size))

View file

@ -69,7 +69,7 @@ def test_arc_1d_items():
def test_arc_1d_iteration(): def test_arc_1d_iteration():
"""Test that iteration respects dataset size""" """Test that iteration respects dataset size"""
config = Arc1DConfig(size=5, seed=42) # Small size for testing config = Arc1DConfig(size=100, seed=42) # Small size for testing
dataset = Arc1DDataset(config) dataset = Arc1DDataset(config)
# Test manual iteration # Test manual iteration
@ -105,3 +105,11 @@ def test_arc_1d_scoring():
# Test None answer # Test None answer
assert dataset.score_answer(None, entry) == 0.0 assert dataset.score_answer(None, entry) == 0.0
@pytest.mark.parametrize("board_size", [8, 9, 10, 12, 15, 20])
def test_arc_1d_sizes(board_size: int):
config = Arc1DConfig(size=1000, seed=42 + board_size, min_size=board_size, max_size=board_size)
dataset = Arc1DDataset(config)
for entry in dataset:
assert dataset.score_answer(entry["answer"], entry) == 1.0