diff --git a/reasoning_gym/arc/arc_agi.py b/reasoning_gym/arc/arc_agi.py index 1e3d41da..df8bcb06 100644 --- a/reasoning_gym/arc/arc_agi.py +++ b/reasoning_gym/arc/arc_agi.py @@ -22,6 +22,9 @@ class ArcAgiConfig: board_format_opts: BoardFormattingOptions = field(default_factory=lambda: BoardFormattingOptions()) seed: Optional[int] = None size: int = 500 + use_rotations: bool = True + use_mirrors: bool = True + use_color_permutation: bool = True def validate(self): assert self.size > 0, "Size of dataset must be positive." @@ -93,6 +96,29 @@ class ArcAgiDataset(ProceduralDataset): self._tasks[x.id] = x.to_dict() self._task_ids = list(self._tasks.keys()) + def _create_augmentation_fn(self, rng: Random) -> Callable[[Board], Board]: + """Create a composite augmentation function from enabled options""" + fns = [] + + if self.config.use_rotations: + fns.append(rng.choice(ROTATION_AUGMENTATIONS)) + + if self.config.use_mirrors: + fns.append(rng.choice(MIRROR_AUGMENTATIONS)) + + if self.config.use_color_permutation: + color_table = list(range(10)) + rng.shuffle(color_table) + fns.append(lambda x: cmap(x, color_table)) + + def composite_fn(board: Board) -> Board: + result = board + for fn in fns: + result = fn(result) + return result + + return composite_fn + def __getitem__(self, idx: int) -> dict: """ Generate a single ARC-AGI-1 task @@ -102,15 +128,31 @@ class ArcAgiDataset(ProceduralDataset): task_id = rng.choice(self._task_ids) task = self._tasks[task_id] + # Create augmentation function to be used for all examples + augment = self._create_augmentation_fn(rng) + train = task["train"] test = task["test"][0] + # Apply augmentation to all train examples + augmented_train = [] + for p in train: + augmented_train.append({ + "input": augment(p["input"]), + "output": augment(p["output"]) + }) + examples = [ - format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) for i, p in enumerate(train) + format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) + for i, p in enumerate(augmented_train) ] examples = "".join(examples) - test_input = format_board(test["input"], self.board_format_opts) - test_output = format_board(test["output"], self.board_format_opts) + # Apply augmentation to test example + augmented_test_input = augment(test["input"]) + augmented_test_output = augment(test["output"]) + + test_input = format_board(augmented_test_input, self.board_format_opts) + test_output = format_board(augmented_test_output, self.board_format_opts) input_prompt = self._prompt_templates.format(examples=examples, input_grid=test_input) @@ -121,8 +163,8 @@ class ArcAgiDataset(ProceduralDataset): "question": input_prompt, "answer": test_output, "metadata": { - "input": totuple(test["input"]), - "output": totuple(test["output"]), + "input": totuple(augmented_test_input), + "output": totuple(augmented_test_output), "task_id": task_id, }, } diff --git a/tests/test_arc_agi.py b/tests/test_arc_agi.py index cb36f123..2650372d 100644 --- a/tests/test_arc_agi.py +++ b/tests/test_arc_agi.py @@ -49,6 +49,54 @@ def test_arc_agi_items(): assert isinstance(meta["task_id"], str) +def test_arc_agi_augmentations(): + """Test that augmentations can be selectively enabled/disabled""" + # Test with all augmentations disabled + config = ArcAgiConfig( + seed=42, + size=10, + use_rotations=False, + use_mirrors=False, + use_color_permutation=False + ) + base_dataset = ArcAgiDataset(config) + base_items = list(base_dataset) + + # Test with rotations only + rot_config = ArcAgiConfig( + seed=42, + size=10, + use_rotations=True, + use_mirrors=False, + use_color_permutation=False + ) + rot_dataset = ArcAgiDataset(rot_config) + rot_items = list(rot_dataset) + + # Items should differ when rotations are enabled + assert any( + base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"] + for i in range(len(base_items)) + ), "Rotation augmentation had no effect" + + # Test with color permutation only + color_config = ArcAgiConfig( + seed=42, + size=10, + use_rotations=False, + use_mirrors=False, + use_color_permutation=True + ) + color_dataset = ArcAgiDataset(color_config) + color_items = list(color_dataset) + + # Items should differ when color permutation is enabled + assert any( + base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"] + for i in range(len(base_items)) + ), "Color permutation had no effect" + + def test_arc_agi_scoring(): """Test solution verification and scoring""" config = ArcAgiConfig(size=10, seed=123)