diff --git a/reasoning_gym/arc/arc_agi.py b/reasoning_gym/arc/arc_agi.py index ae2065b4..1bbff894 100644 --- a/reasoning_gym/arc/arc_agi.py +++ b/reasoning_gym/arc/arc_agi.py @@ -99,13 +99,13 @@ class ArcAgiDataset(ProceduralDataset): def _create_augmentation_fn(self, rng: Random) -> Callable[[Board], Board]: """Create a composite augmentation function from enabled options""" fns = [] - + if self.config.use_rotations: fns.append(rng.choice(ROTATION_AUGMENTATIONS)) - + if self.config.use_mirrors: fns.append(rng.choice(MIRROR_AUGMENTATIONS)) - + if self.config.use_color_permutation: color_table = list(range(10)) rng.shuffle(color_table) @@ -116,7 +116,7 @@ class ArcAgiDataset(ProceduralDataset): for fn in fns: result = fn(result) return result - + return composite_fn def __getitem__(self, idx: int) -> dict: @@ -137,16 +137,14 @@ class ArcAgiDataset(ProceduralDataset): # Apply augmentation to all train examples augmented_train = [] for p in train: - augmented_train.append({ - "input": augment(p["input"]), - "output": augment(p["output"]) - }) + augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])}) examples = [ - format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) + format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) for i, p in enumerate(augmented_train) ] examples = "".join(examples) + # Apply augmentation to test example augmented_test_input = augment(test["input"]) augmented_test_output = augment(test["output"]) diff --git a/tests/test_arc_agi.py b/tests/test_arc_agi.py index 2650372d..bb42e9d8 100644 --- a/tests/test_arc_agi.py +++ b/tests/test_arc_agi.py @@ -52,48 +52,28 @@ def test_arc_agi_items(): def test_arc_agi_augmentations(): """Test that augmentations can be selectively enabled/disabled""" # Test with all augmentations disabled - config = ArcAgiConfig( - seed=42, - size=10, - use_rotations=False, - use_mirrors=False, - use_color_permutation=False - ) + config = ArcAgiConfig(seed=42, size=10, use_rotations=False, use_mirrors=False, use_color_permutation=False) base_dataset = ArcAgiDataset(config) base_items = list(base_dataset) # Test with rotations only - rot_config = ArcAgiConfig( - seed=42, - size=10, - use_rotations=True, - use_mirrors=False, - use_color_permutation=False - ) + rot_config = ArcAgiConfig(seed=42, size=10, use_rotations=True, use_mirrors=False, use_color_permutation=False) rot_dataset = ArcAgiDataset(rot_config) rot_items = list(rot_dataset) # Items should differ when rotations are enabled assert any( - base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"] - for i in range(len(base_items)) + base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"] for i in range(len(base_items)) ), "Rotation augmentation had no effect" # Test with color permutation only - color_config = ArcAgiConfig( - seed=42, - size=10, - use_rotations=False, - use_mirrors=False, - use_color_permutation=True - ) + color_config = ArcAgiConfig(seed=42, size=10, use_rotations=False, use_mirrors=False, use_color_permutation=True) color_dataset = ArcAgiDataset(color_config) color_items = list(color_dataset) # Items should differ when color permutation is enabled assert any( - base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"] - for i in range(len(base_items)) + base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"] for i in range(len(base_items)) ), "Color permutation had no effect" @@ -113,7 +93,7 @@ def test_arc_agi_scoring(): assert dataset.score_answer(None, entry=item) == 0.0 # Test wrong but valid grid format - wrong_answer = "0 0\n0 0" + wrong_answer = "1 0 0 0\n0 0 0 1" assert dataset.score_answer(wrong_answer, entry=item) == 0.05