diff --git a/reasoning_gym/arc/__init__.py b/reasoning_gym/arc/__init__.py index fcb2b262..47ba75b1 100644 --- a/reasoning_gym/arc/__init__.py +++ b/reasoning_gym/arc/__init__.py @@ -1,5 +1,5 @@ from .arc_1d import Arc1DConfig, Arc1DCurriculum, Arc1DDataset -from .arc_agi import ArcAgiConfig, ArcAgiDataset +from .arc_agi import ArcAgiConfig, ArcAgiCurriculum, ArcAgiDataset from .rearc import ReArcConfig, ReArcCurriculum, ReArcDataset __all__ = [ @@ -8,6 +8,7 @@ __all__ = [ "Arc1DCurriculum", "ArcAgiConfig", "ArcAgiDataset", + "ArcAgiCurriculum", "ReArcDataset", "ReArcConfig", "ReArcCurriculum", diff --git a/reasoning_gym/arc/arc_agi.py b/reasoning_gym/arc/arc_agi.py index b46a7091..0ee35f5a 100644 --- a/reasoning_gym/arc/arc_agi.py +++ b/reasoning_gym/arc/arc_agi.py @@ -14,6 +14,8 @@ from reasoning_gym.arc.board_format import ( from reasoning_gym.dataset import ProceduralDataset from reasoning_gym.factory import register_dataset +from ..coaching import BaseCurriculum, ScalarAttributeDefinition + DATASET_NAME = "arc_agi" @@ -31,6 +33,13 @@ class ArcAgiConfig: use_color_permutation: bool = True shuffle_example_order: bool = True # whether to shuffle the order of example board pairs for each riddle + rotations_weights: list[float] = field( + default_factory=lambda: [0.25, 0.25, 0.25, 0.25] + ) # ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270] + mirrors_weights: list[float] = field( + default_factory=lambda: [0.2, 0.2, 0.2, 0.2, 0.2] + ) # MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror] + seed: Optional[int] = None size: int = 500 @@ -117,13 +126,19 @@ class ArcAgiDataset(ProceduralDataset): # Map rotation strings to functions rotation_map = {"90": rot90, "180": rot180, "270": rot270} if self.config.rotations: - chosen_rot = rng.choice([identity] + [rotation_map[r] for r in self.config.rotations]) + chosen_rot = rng.choices( + [identity] + [rotation_map[r] for r in self.config.rotations], + weights=self.config.rotations_weights, + k=1, + )[0] fns.append(chosen_rot) # Map mirror strings to functions mirror_map = {"horizontal": hmirror, "vertical": vmirror, "diagonal": dmirror, "counterdiagonal": cmirror} if self.config.mirrors: - chosen_mirror = rng.choice([identity] + [mirror_map[m] for m in self.config.mirrors]) + chosen_mirror = rng.choices( + [identity] + [mirror_map[m] for m in self.config.mirrors], weights=self.config.mirrors_weights, k=1 + )[0] fns.append(chosen_mirror) if self.config.use_color_permutation: @@ -189,6 +204,10 @@ class ArcAgiDataset(ProceduralDataset): "input": totuple(augmented_test_input), "output": totuple(augmented_test_output), "task_id": task_id, + "difficulty": { + "rotations_weights": self.config.rotations_weights, + "mirrors_weights": self.config.mirrors_weights, + }, }, } @@ -207,4 +226,34 @@ class ArcAgiDataset(ProceduralDataset): return reward -register_dataset(DATASET_NAME, ArcAgiDataset, ArcAgiConfig) +class ArcAgiCurriculum(BaseCurriculum): + """Curriculum for ARC-AGI-1 tasks""" + + def __init__(self): + super().__init__(ArcAgiCurriculum.__name__, ArcAgiConfig) + + # Define attributes + self._define_attributes( + ScalarAttributeDefinition( + name="rotations_weights", + field_name="rotations_weights", + # ROTATION_AUGMENTATIONS = [identity, rot90, rot180, rot270] + levels=[[0.3, 0.2, 0.3, 0.2], [0.15, 0.3, 0.25, 0.3], [0.1, 0.35, 0.2, 0.35], [0.0, 0.4, 0.2, 0.4]], + description="Rotation augmentation weights", + ), + ScalarAttributeDefinition( + name="mirrors_weights", + field_name="mirrors_weights", + # MIRROR_AUGMENTATIONS = [identity, hmirror, vmirror, dmirror, cmirror] + levels=[ + [0.3, 0.3, 0.2, 0.1, 0.1], + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.1, 0.1, 0.2, 0.3, 0.3], + [0.05, 0.05, 0.1, 0.4, 0.4], + ], + description="Mirror augmentation weights", + ), + ) + + +register_dataset("arc_agi", ArcAgiDataset, ArcAgiConfig, ArcAgiCurriculum) diff --git a/tests/test_arc_agi.py b/tests/test_arc_agi.py index ca7a92e3..1f1c4831 100644 --- a/tests/test_arc_agi.py +++ b/tests/test_arc_agi.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.arc.arc_agi import ArcAgiConfig, ArcAgiDataset +from reasoning_gym.arc.arc_agi import ArcAgiConfig, ArcAgiCurriculum, ArcAgiDataset def test_arc_agi_config_validation(): @@ -65,12 +65,28 @@ def test_arc_agi_items(): def test_arc_agi_augmentations(): """Test that augmentations can be selectively enabled/disabled""" # Test with all augmentations disabled - config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=[], use_color_permutation=False) + config = ArcAgiConfig( + seed=42, + size=10, + rotations=[], + mirrors=[], + rotations_weights=[1.0], + mirrors_weights=[1.0], + use_color_permutation=False, + ) base_dataset = ArcAgiDataset(config) base_items = list(base_dataset) # Test with specific rotation only - rot_config = ArcAgiConfig(seed=42, size=10, rotations=["90"], mirrors=[], use_color_permutation=False) + rot_config = ArcAgiConfig( + seed=42, + size=10, + rotations=["90"], + mirrors=[], + rotations_weights=[0.5, 0.5], + mirrors_weights=[1.0], + use_color_permutation=False, + ) rot_dataset = ArcAgiDataset(rot_config) rot_items = list(rot_dataset) @@ -80,7 +96,15 @@ def test_arc_agi_augmentations(): ), "90-degree rotation augmentation had no effect" # Test with specific mirror only - mirror_config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=["horizontal"], use_color_permutation=False) + mirror_config = ArcAgiConfig( + seed=42, + size=10, + rotations=[], + mirrors=["horizontal"], + rotations_weights=[1.0], + mirrors_weights=[0.5, 0.5], + use_color_permutation=False, + ) mirror_dataset = ArcAgiDataset(mirror_config) mirror_items = list(mirror_dataset) @@ -90,7 +114,15 @@ def test_arc_agi_augmentations(): ), "Horizontal mirror augmentation had no effect" # Test with color permutation only - color_config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=[], use_color_permutation=True) + color_config = ArcAgiConfig( + seed=42, + size=10, + rotations=[], + mirrors=[], + rotations_weights=[1.0], + mirrors_weights=[1.0], + use_color_permutation=True, + ) color_dataset = ArcAgiDataset(color_config) color_items = list(color_dataset) @@ -166,3 +198,48 @@ def test_arc_agi_shuffled_order(): for a, b in zip(shuffled, unshuffled): assert a["question"] != b["question"] assert a["answer"] == b["answer"] + + +def test_arc_agi_curriculum(): + """Test the curriculum for complex arithmetic.""" + curriculum = ArcAgiCurriculum() + base_value = {"size": 150, "seed": 1} + + base_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value) + + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2] + assert base_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1] + + # Test and validate increase in levels + curriculum.increment_attr_level("rotations_weights") + curriculum.increment_attr_level("mirrors_weights") + + increased_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value) + assert increased_cfg.rotations_weights == [0.15, 0.3, 0.25, 0.3] + assert increased_cfg.mirrors_weights == [0.2, 0.2, 0.2, 0.2, 0.2] + + # Test and validate decrease in levels + curriculum.decrement_attr_level("rotations_weights") + curriculum.decrement_attr_level("mirrors_weights") + + decreased_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value) + assert decreased_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2] + assert decreased_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1] + + # Test upper bound boundary condition + for _ in range(10): + curriculum.increment_attr_level("rotations_weights") + curriculum.increment_attr_level("mirrors_weights") + upper_bound_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value) + assert upper_bound_cfg.rotations_weights == [0.0, 0.4, 0.2, 0.4] + assert upper_bound_cfg.mirrors_weights == [0.05, 0.05, 0.1, 0.4, 0.4] + + # Test lower bound boundary condition + for _ in range(10): + curriculum.decrement_attr_level("rotations_weights") + curriculum.decrement_attr_level("mirrors_weights") + lower_bound_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value) + assert lower_bound_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2] + assert lower_bound_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]