diff --git a/reasoning_gym/arc/arc_agi.py b/reasoning_gym/arc/arc_agi.py index b96698bb..98c3f000 100644 --- a/reasoning_gym/arc/arc_agi.py +++ b/reasoning_gym/arc/arc_agi.py @@ -27,6 +27,7 @@ class ArcAgiConfig: default_factory=lambda: ["horizontal", "vertical", "diagonal", "counterdiagonal"] ) # empty list for no mirrors use_color_permutation: bool = True + shuffle_example_order: bool = True # whether to shuffle the order of example board pairs for each riddle seed: Optional[int] = None size: int = 500 @@ -87,8 +88,8 @@ def cmap(board: Board, colors: list[int]) -> Board: return [[colors[c] for c in row] for row in board] -ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270] -MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror] +# ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270] +# MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror] class ArcAgiDataset(ProceduralDataset): @@ -156,6 +157,9 @@ class ArcAgiDataset(ProceduralDataset): for p in train: augmented_train.append({"input": augment(p["input"]), "output": augment(p["output"])}) + if self.config.shuffle_example_order: + rng.shuffle(augmented_train) + examples = [ format_board_pair(i + 1, p, formatting_options=self.config.board_format_opts) for i, p in enumerate(augmented_train) diff --git a/tests/test_arc_agi.py b/tests/test_arc_agi.py index da43e6ab..cfeecf21 100644 --- a/tests/test_arc_agi.py +++ b/tests/test_arc_agi.py @@ -137,3 +137,32 @@ def test_arc_agi_dataset_modes(): both_ds = ArcAgiDataset(both_config) assert len(both_ds._task_ids) > len(train_ds._task_ids) assert len(both_ds._task_ids) > len(eval_ds._task_ids) + + +def test_arc_agi_shuffled_order(): + config_unshuffled = ArcAgiConfig( + shuffle_example_order=False, + use_train=True, + use_eval=False, + rotations=[], + mirrors=[], + use_color_permutation=False, + size=3, + seed=42, + ) + config_shuffled = ArcAgiConfig( + shuffle_example_order=True, + use_train=True, + use_eval=False, + rotations=[], + mirrors=[], + use_color_permutation=False, + size=3, + seed=42, + ) + unshuffled = ArcAgiDataset(config_unshuffled) + shuffled = ArcAgiDataset(config_shuffled) + + for a, b in zip(shuffled, unshuffled): + assert a["question"] != b["question"] + assert a["answer"] == b["answer"]