mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-22 16:49:06 +00:00
feat: Add configurable rotation and mirror augmentation variants
This commit is contained in:
parent
b73040b066
commit
ec8036c099
3 changed files with 486 additions and 19 deletions
|
|
@ -8,10 +8,23 @@ def test_arc_agi_config_validation():
|
|||
with pytest.raises(AssertionError):
|
||||
ArcAgiConfig(size=0).validate()
|
||||
|
||||
# Valid config should not raise
|
||||
with pytest.raises(AssertionError):
|
||||
ArcAgiConfig(rotations=["invalid"]).validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
ArcAgiConfig(mirrors=["invalid"]).validate()
|
||||
|
||||
# Valid configs should not raise
|
||||
config = ArcAgiConfig(size=10, seed=42)
|
||||
config.validate()
|
||||
|
||||
config = ArcAgiConfig(rotations=["90", "180"], mirrors=["horizontal", "diagonal"])
|
||||
config.validate()
|
||||
|
||||
# Empty lists should be valid (no augmentations)
|
||||
config = ArcAgiConfig(rotations=[], mirrors=[])
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_arc_agi_deterministic():
|
||||
"""Test dataset reproducibility with fixed seed"""
|
||||
|
|
@ -52,26 +65,36 @@ def test_arc_agi_items():
|
|||
def test_arc_agi_augmentations():
|
||||
"""Test that augmentations can be selectively enabled/disabled"""
|
||||
# Test with all augmentations disabled
|
||||
config = ArcAgiConfig(seed=42, size=10, use_rotations=False, use_mirrors=False, use_color_permutation=False)
|
||||
config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=[], use_color_permutation=False)
|
||||
base_dataset = ArcAgiDataset(config)
|
||||
base_items = list(base_dataset)
|
||||
|
||||
# Test with rotations only
|
||||
rot_config = ArcAgiConfig(seed=42, size=10, use_rotations=True, use_mirrors=False, use_color_permutation=False)
|
||||
# Test with specific rotation only
|
||||
rot_config = ArcAgiConfig(seed=42, size=10, rotations=["90"], mirrors=[], use_color_permutation=False)
|
||||
rot_dataset = ArcAgiDataset(rot_config)
|
||||
rot_items = list(rot_dataset)
|
||||
|
||||
# Items should differ when rotations are enabled
|
||||
# Items should differ with rotation enabled
|
||||
assert any(
|
||||
base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"] for i in range(len(base_items))
|
||||
), "Rotation augmentation had no effect"
|
||||
), "90-degree rotation augmentation had no effect"
|
||||
|
||||
# Test with specific mirror only
|
||||
mirror_config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=["horizontal"], use_color_permutation=False)
|
||||
mirror_dataset = ArcAgiDataset(mirror_config)
|
||||
mirror_items = list(mirror_dataset)
|
||||
|
||||
# Items should differ with mirror enabled
|
||||
assert any(
|
||||
base_items[i]["metadata"]["input"] != mirror_items[i]["metadata"]["input"] for i in range(len(base_items))
|
||||
), "Horizontal mirror augmentation had no effect"
|
||||
|
||||
# Test with color permutation only
|
||||
color_config = ArcAgiConfig(seed=42, size=10, use_rotations=False, use_mirrors=False, use_color_permutation=True)
|
||||
color_config = ArcAgiConfig(seed=42, size=10, rotations=[], mirrors=[], use_color_permutation=True)
|
||||
color_dataset = ArcAgiDataset(color_config)
|
||||
color_items = list(color_dataset)
|
||||
|
||||
# Items should differ when color permutation is enabled
|
||||
# Items should differ with color permutation enabled
|
||||
assert any(
|
||||
base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"] for i in range(len(base_items))
|
||||
), "Color permutation had no effect"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue