feat: Add configurable augmentations to ArcAgiDataset with consistent application

This commit is contained in:
Andreas Koepf (aider) 2025-02-08 23:59:45 +01:00
parent 492570ff5c
commit f8e76b8048
2 changed files with 95 additions and 5 deletions

View file

@ -22,6 +22,9 @@ class ArcAgiConfig:
board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions())
seed: Optional[int] = None
size: int = 500
use_rotations: bool = True
use_mirrors: bool = True
use_color_permutation: bool = True
def validate(self):
assert self.size > 0, "Size of dataset must be positive."
@ -93,6 +96,29 @@ class ArcAgiDataset(ProceduralDataset):
self._tasks[x.id] = x.to_dict()
self._task_ids = list(self._tasks.keys())
def _create_augmentation_fn(self, rng: Random) -> Callable[[Board], Board]:
"""Create a composite augmentation function from enabled options"""
fns = []
if self.config.use_rotations:
fns.append(rng.choice(ROTATION_AUGMENTATIONS))
if self.config.use_mirrors:
fns.append(rng.choice(MIRROR_AUGMENTATIONS))
if self.config.use_color_permutation:
color_table = list(range(10))
rng.shuffle(color_table)
fns.append(lambda x: cmap(x, color_table))
def composite_fn(board: Board) -> Board:
result = board
for fn in fns:
result = fn(result)
return result
return composite_fn
def __getitem__(self, idx: int) -> dict:
"""
Generate a single ARC-AGI-1 task
@ -102,15 +128,31 @@ class ArcAgiDataset(ProceduralDataset):
task_id = rng.choice(self._task_ids)
task = self._tasks[task_id]
# Create augmentation function to be used for all examples
augment = self._create_augmentation_fn(rng)
train = task["train"]
test = task["test"][0]
# Apply augmentation to all train examples
augmented_train = []
for p in train:
augmented_train.append({
"input": augment(p["input"]),
"output": augment(p["output"])
})
examples = [
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) for i, p in enumerate(train)
format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts)
for i, p in enumerate(augmented_train)
]
examples = "".join(examples)
test_input = format_board(test["input"], self.board_format_opts)
test_output = format_board(test["output"], self.board_format_opts)
# Apply augmentation to test example
augmented_test_input = augment(test["input"])
augmented_test_output = augment(test["output"])
test_input = format_board(augmented_test_input, self.board_format_opts)
test_output = format_board(augmented_test_output, self.board_format_opts)
input_prompt = self._prompt_templates.format(examples=examples, input_grid=test_input)
@ -121,8 +163,8 @@ class ArcAgiDataset(ProceduralDataset):
"question": input_prompt,
"answer": test_output,
"metadata": {
"input": totuple(test["input"]),
"output": totuple(test["output"]),
"input": totuple(augmented_test_input),
"output": totuple(augmented_test_output),
"task_id": task_id,
},
}

View file

@ -49,6 +49,54 @@ def test_arc_agi_items():
assert isinstance(meta["task_id"], str)
def test_arc_agi_augmentations():
"""Test that augmentations can be selectively enabled/disabled"""
# Test with all augmentations disabled
config = ArcAgiConfig(
seed=42,
size=10,
use_rotations=False,
use_mirrors=False,
use_color_permutation=False
)
base_dataset = ArcAgiDataset(config)
base_items = list(base_dataset)
# Test with rotations only
rot_config = ArcAgiConfig(
seed=42,
size=10,
use_rotations=True,
use_mirrors=False,
use_color_permutation=False
)
rot_dataset = ArcAgiDataset(rot_config)
rot_items = list(rot_dataset)
# Items should differ when rotations are enabled
assert any(
base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"]
for i in range(len(base_items))
), "Rotation augmentation had no effect"
# Test with color permutation only
color_config = ArcAgiConfig(
seed=42,
size=10,
use_rotations=False,
use_mirrors=False,
use_color_permutation=True
)
color_dataset = ArcAgiDataset(color_config)
color_items = list(color_dataset)
# Items should differ when color permutation is enabled
assert any(
base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"]
for i in range(len(base_items))
), "Color permutation had no effect"
def test_arc_agi_scoring():
"""Test solution verification and scoring"""
config = ArcAgiConfig(size=10, seed=123)