feat: Add configurable rotation and mirror augmentation variants

This commit is contained in:
Andreas Koepf (aider) 2025-02-09 00:14:43 +01:00 committed by Andreas Koepf
parent b73040b066
commit ec8036c099
3 changed files with 486 additions and 19 deletions

View file

@ -8,10 +8,23 @@ def test_arc_agi_config_validation():
with pytest.raises(AssertionError):
ArcAgiConfig(size=0).validate()
# Valid config should not raise
with pytest.raises(AssertionError):
ArcAgiConfig(rotations=["invalid"]).validate()
with pytest.raises(AssertionError):
ArcAgiConfig(mirrors=["invalid"]).validate()
# Valid configs should not raise
config = ArcAgiConfig(size=10, seed=42)
config.validate()
config = ArcAgiConfig(rotations=["90", "180"], mirrors=["horizontal", "diagonal"])
config.validate()
# Empty lists should be valid (no augmentations)
config = ArcAgiConfig(rotations=[], mirrors=[])
config.validate()
def test_arc_agi_deterministic():
"""Test dataset reproducibility with fixed seed"""
@ -52,26 +65,36 @@ def test_arc_agi_items():
def test_arc_agi_augmentations():
"""Test that augmentations can be selectively enabled/disabled"""
# Test with all augmentations disabled
config = ArcAgiConfig(seed=42, size=10, use_rotations=False, use_mirrors=False, use_color_permutation=False)
config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=[], use_color_permutation=False)
base_dataset = ArcAgiDataset(config)
base_items = list(base_dataset)
# Test with rotations only
rot_config = ArcAgiConfig(seed=42, size=10, use_rotations=True, use_mirrors=False, use_color_permutation=False)
# Test with specific rotation only
rot_config = ArcAgiConfig(seed=42, size=10, rotations=["90"], mirrors=[], use_color_permutation=False)
rot_dataset = ArcAgiDataset(rot_config)
rot_items = list(rot_dataset)
# Items should differ when rotations are enabled
# Items should differ with rotation enabled
assert any(
base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"] for i in range(len(base_items))
), "Rotation augmentation had no effect"
), "90-degree rotation augmentation had no effect"
# Test with specific mirror only
mirror_config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=["horizontal"], use_color_permutation=False)
mirror_dataset = ArcAgiDataset(mirror_config)
mirror_items = list(mirror_dataset)
# Items should differ with mirror enabled
assert any(
base_items[i]["metadata"]["input"] != mirror_items[i]["metadata"]["input"] for i in range(len(base_items))
), "Horizontal mirror augmentation had no effect"
# Test with color permutation only
color_config = ArcAgiConfig(seed=42, size=10, use_rotations=False, use_mirrors=False, use_color_permutation=True)
color_config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=[], use_color_permutation=True)
color_dataset = ArcAgiDataset(color_config)
color_items = list(color_dataset)
# Items should differ when color permutation is enabled
# Items should differ with color permutation enabled
assert any(
base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"] for i in range(len(base_items))
), "Color permutation had no effect"