Add curriculum for arc_agi (#372)

* Add curriculum for arc_agi

* Resolve conflicts

* Remove code smell

* Remove unwanted code
This commit is contained in:
Adefioye 2025-04-01 14:17:52 -05:00 committed by GitHub
parent 50846c3534
commit e3af2dd2bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 136 additions and 9 deletions

View file

@ -1,6 +1,6 @@
import pytest
from reasoning_gym.arc.arc_agi import ArcAgiConfig, ArcAgiDataset
from reasoning_gym.arc.arc_agi import ArcAgiConfig, ArcAgiCurriculum, ArcAgiDataset
def test_arc_agi_config_validation():
@ -65,12 +65,28 @@ def test_arc_agi_items():
def test_arc_agi_augmentations():
"""Test that augmentations can be selectively enabled/disabled"""
# Test with all augmentations disabled
config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=[], use_color_permutation=False)
config = ArcAgiConfig(
seed=42,
size=10,
rotations=[],
mirrors=[],
rotations_weights=[1.0],
mirrors_weights=[1.0],
use_color_permutation=False,
)
base_dataset = ArcAgiDataset(config)
base_items = list(base_dataset)
# Test with specific rotation only
rot_config = ArcAgiConfig(seed=42, size=10, rotations=["90"], mirrors=[], use_color_permutation=False)
rot_config = ArcAgiConfig(
seed=42,
size=10,
rotations=["90"],
mirrors=[],
rotations_weights=[0.5, 0.5],
mirrors_weights=[1.0],
use_color_permutation=False,
)
rot_dataset = ArcAgiDataset(rot_config)
rot_items = list(rot_dataset)
@ -80,7 +96,15 @@ def test_arc_agi_augmentations():
), "90-degree rotation augmentation had no effect"
# Test with specific mirror only
mirror_config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=["horizontal"], use_color_permutation=False)
mirror_config = ArcAgiConfig(
seed=42,
size=10,
rotations=[],
mirrors=["horizontal"],
rotations_weights=[1.0],
mirrors_weights=[0.5, 0.5],
use_color_permutation=False,
)
mirror_dataset = ArcAgiDataset(mirror_config)
mirror_items = list(mirror_dataset)
@ -90,7 +114,15 @@ def test_arc_agi_augmentations():
), "Horizontal mirror augmentation had no effect"
# Test with color permutation only
color_config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=[], use_color_permutation=True)
color_config = ArcAgiConfig(
seed=42,
size=10,
rotations=[],
mirrors=[],
rotations_weights=[1.0],
mirrors_weights=[1.0],
use_color_permutation=True,
)
color_dataset = ArcAgiDataset(color_config)
color_items = list(color_dataset)
@ -166,3 +198,48 @@ def test_arc_agi_shuffled_order():
for a, b in zip(shuffled, unshuffled):
assert a["question"] != b["question"]
assert a["answer"] == b["answer"]
def test_arc_agi_curriculum():
"""Test the curriculum for complex arithmetic."""
curriculum = ArcAgiCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2]
assert base_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]
# Test and validate increase in levels
curriculum.increment_attr_level("rotations_weights")
curriculum.increment_attr_level("mirrors_weights")
increased_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
assert increased_cfg.rotations_weights == [0.15, 0.3, 0.25, 0.3]
assert increased_cfg.mirrors_weights == [0.2, 0.2, 0.2, 0.2, 0.2]
# Test and validate decrease in levels
curriculum.decrement_attr_level("rotations_weights")
curriculum.decrement_attr_level("mirrors_weights")
decreased_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
assert decreased_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2]
assert decreased_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]
# Test upper bound boundary condition
for _ in range(10):
curriculum.increment_attr_level("rotations_weights")
curriculum.increment_attr_level("mirrors_weights")
upper_bound_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
assert upper_bound_cfg.rotations_weights == [0.0, 0.4, 0.2, 0.4]
assert upper_bound_cfg.mirrors_weights == [0.05, 0.05, 0.1, 0.4, 0.4]
# Test lower bound boundary condition
for _ in range(10):
curriculum.decrement_attr_level("rotations_weights")
curriculum.decrement_attr_level("mirrors_weights")
lower_bound_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
assert lower_bound_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2]
assert lower_bound_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]