mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
* Add curriculum for arc_agi * Resolve conflicts * Remove code smell * Remove unwanted code
245 lines
8.1 KiB
Python
245 lines
8.1 KiB
Python
import pytest
|
|
|
|
from reasoning_gym.arc.arc_agi import ArcAgiConfig, ArcAgiCurriculum, ArcAgiDataset
|
|
|
|
|
|
def test_arc_agi_config_validation():
|
|
"""Test validation of ArcAgi configuration parameters"""
|
|
with pytest.raises(AssertionError):
|
|
ArcAgiConfig(size=0).validate()
|
|
|
|
with pytest.raises(AssertionError):
|
|
ArcAgiConfig(rotations=["invalid"]).validate()
|
|
|
|
with pytest.raises(AssertionError):
|
|
ArcAgiConfig(mirrors=["invalid"]).validate()
|
|
|
|
# Valid configs should not raise
|
|
config = ArcAgiConfig(size=10, seed=42)
|
|
config.validate()
|
|
|
|
config = ArcAgiConfig(rotations=["90", "180"], mirrors=["horizontal", "diagonal"])
|
|
config.validate()
|
|
|
|
# Empty lists should be valid (no augmentations)
|
|
config = ArcAgiConfig(rotations=[], mirrors=[])
|
|
config.validate()
|
|
|
|
|
|
def test_arc_agi_deterministic():
|
|
"""Test dataset reproducibility with fixed seed"""
|
|
config = ArcAgiConfig(seed=42, size=10)
|
|
ds1 = ArcAgiDataset(config)
|
|
ds2 = ArcAgiDataset(config)
|
|
|
|
for i in range(len(ds1)):
|
|
assert ds1[i] == ds2[i], "ArcAgi datasets with same seed should match exactly"
|
|
|
|
|
|
def test_arc_agi_items():
|
|
"""Test basic structure and metadata of generated items"""
|
|
config = ArcAgiConfig(seed=42, size=10)
|
|
dataset = ArcAgiDataset(config)
|
|
|
|
for item in dataset:
|
|
assert isinstance(item, dict)
|
|
assert "question" in item
|
|
assert "answer" in item
|
|
assert "metadata" in item
|
|
|
|
meta = item["metadata"]
|
|
assert "input" in meta
|
|
assert "output" in meta
|
|
assert "task_id" in meta
|
|
|
|
# Verify input/output are tuples of tuples (board format)
|
|
assert isinstance(meta["input"], tuple)
|
|
assert isinstance(meta["output"], tuple)
|
|
assert all(isinstance(row, tuple) for row in meta["input"])
|
|
assert all(isinstance(row, tuple) for row in meta["output"])
|
|
|
|
# Verify task_id is a string
|
|
assert isinstance(meta["task_id"], str)
|
|
|
|
|
|
def test_arc_agi_augmentations():
|
|
"""Test that augmentations can be selectively enabled/disabled"""
|
|
# Test with all augmentations disabled
|
|
config = ArcAgiConfig(
|
|
seed=42,
|
|
size=10,
|
|
rotations=[],
|
|
mirrors=[],
|
|
rotations_weights=[1.0],
|
|
mirrors_weights=[1.0],
|
|
use_color_permutation=False,
|
|
)
|
|
base_dataset = ArcAgiDataset(config)
|
|
base_items = list(base_dataset)
|
|
|
|
# Test with specific rotation only
|
|
rot_config = ArcAgiConfig(
|
|
seed=42,
|
|
size=10,
|
|
rotations=["90"],
|
|
mirrors=[],
|
|
rotations_weights=[0.5, 0.5],
|
|
mirrors_weights=[1.0],
|
|
use_color_permutation=False,
|
|
)
|
|
rot_dataset = ArcAgiDataset(rot_config)
|
|
rot_items = list(rot_dataset)
|
|
|
|
# Items should differ with rotation enabled
|
|
assert any(
|
|
base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"] for i in range(len(base_items))
|
|
), "90-degree rotation augmentation had no effect"
|
|
|
|
# Test with specific mirror only
|
|
mirror_config = ArcAgiConfig(
|
|
seed=42,
|
|
size=10,
|
|
rotations=[],
|
|
mirrors=["horizontal"],
|
|
rotations_weights=[1.0],
|
|
mirrors_weights=[0.5, 0.5],
|
|
use_color_permutation=False,
|
|
)
|
|
mirror_dataset = ArcAgiDataset(mirror_config)
|
|
mirror_items = list(mirror_dataset)
|
|
|
|
# Items should differ with mirror enabled
|
|
assert any(
|
|
base_items[i]["metadata"]["input"] != mirror_items[i]["metadata"]["input"] for i in range(len(base_items))
|
|
), "Horizontal mirror augmentation had no effect"
|
|
|
|
# Test with color permutation only
|
|
color_config = ArcAgiConfig(
|
|
seed=42,
|
|
size=10,
|
|
rotations=[],
|
|
mirrors=[],
|
|
rotations_weights=[1.0],
|
|
mirrors_weights=[1.0],
|
|
use_color_permutation=True,
|
|
)
|
|
color_dataset = ArcAgiDataset(color_config)
|
|
color_items = list(color_dataset)
|
|
|
|
# Items should differ with color permutation enabled
|
|
assert any(
|
|
base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"] for i in range(len(base_items))
|
|
), "Color permutation had no effect"
|
|
|
|
|
|
def test_arc_agi_scoring():
|
|
"""Test solution verification and scoring"""
|
|
config = ArcAgiConfig(size=10, seed=123)
|
|
dataset = ArcAgiDataset(config)
|
|
|
|
for item in dataset:
|
|
# Test correct solution
|
|
assert dataset.score_answer(item["answer"], entry=item) == 1.0
|
|
|
|
# Test invalid format
|
|
assert dataset.score_answer("invalid grid format", entry=item) == 0.0
|
|
|
|
# Test None answer
|
|
assert dataset.score_answer(None, entry=item) == 0.0
|
|
|
|
# Test wrong but valid grid format
|
|
wrong_answer = "1 0 0 0\n0 0 0 1"
|
|
assert dataset.score_answer(wrong_answer, entry=item) == 0.05
|
|
|
|
|
|
def test_arc_agi_dataset_modes():
|
|
"""Test dataset behavior with different train/eval configurations"""
|
|
# Test train-only mode
|
|
train_config = ArcAgiConfig(use_train=True, use_eval=False, size=10, seed=42)
|
|
train_ds = ArcAgiDataset(train_config)
|
|
assert len(train_ds._task_ids) > 0
|
|
|
|
# Test eval-only mode
|
|
eval_config = ArcAgiConfig(use_train=False, use_eval=True, size=10, seed=42)
|
|
eval_ds = ArcAgiDataset(eval_config)
|
|
assert len(eval_ds._task_ids) > 0
|
|
|
|
# Test both modes
|
|
both_config = ArcAgiConfig(use_train=True, use_eval=True, size=10, seed=42)
|
|
both_ds = ArcAgiDataset(both_config)
|
|
assert len(both_ds._task_ids) > len(train_ds._task_ids)
|
|
assert len(both_ds._task_ids) > len(eval_ds._task_ids)
|
|
|
|
|
|
def test_arc_agi_shuffled_order():
|
|
config_unshuffled = ArcAgiConfig(
|
|
shuffle_example_order=False,
|
|
use_train=True,
|
|
use_eval=False,
|
|
rotations=[],
|
|
mirrors=[],
|
|
use_color_permutation=False,
|
|
size=3,
|
|
seed=42,
|
|
)
|
|
config_shuffled = ArcAgiConfig(
|
|
shuffle_example_order=True,
|
|
use_train=True,
|
|
use_eval=False,
|
|
rotations=[],
|
|
mirrors=[],
|
|
use_color_permutation=False,
|
|
size=3,
|
|
seed=42,
|
|
)
|
|
unshuffled = ArcAgiDataset(config_unshuffled)
|
|
shuffled = ArcAgiDataset(config_shuffled)
|
|
|
|
for a, b in zip(shuffled, unshuffled):
|
|
assert a["question"] != b["question"]
|
|
assert a["answer"] == b["answer"]
|
|
|
|
|
|
def test_arc_agi_curriculum():
|
|
"""Test the curriculum for complex arithmetic."""
|
|
curriculum = ArcAgiCurriculum()
|
|
base_value = {"size": 150, "seed": 1}
|
|
|
|
base_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
|
|
|
|
assert base_cfg.seed == 1
|
|
assert base_cfg.size == 150
|
|
assert base_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2]
|
|
assert base_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]
|
|
|
|
# Test and validate increase in levels
|
|
curriculum.increment_attr_level("rotations_weights")
|
|
curriculum.increment_attr_level("mirrors_weights")
|
|
|
|
increased_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
|
|
assert increased_cfg.rotations_weights == [0.15, 0.3, 0.25, 0.3]
|
|
assert increased_cfg.mirrors_weights == [0.2, 0.2, 0.2, 0.2, 0.2]
|
|
|
|
# Test and validate decrease in levels
|
|
curriculum.decrement_attr_level("rotations_weights")
|
|
curriculum.decrement_attr_level("mirrors_weights")
|
|
|
|
decreased_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
|
|
assert decreased_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2]
|
|
assert decreased_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]
|
|
|
|
# Test upper bound boundary condition
|
|
for _ in range(10):
|
|
curriculum.increment_attr_level("rotations_weights")
|
|
curriculum.increment_attr_level("mirrors_weights")
|
|
upper_bound_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
|
|
assert upper_bound_cfg.rotations_weights == [0.0, 0.4, 0.2, 0.4]
|
|
assert upper_bound_cfg.mirrors_weights == [0.05, 0.05, 0.1, 0.4, 0.4]
|
|
|
|
# Test lower bound boundary condition
|
|
for _ in range(10):
|
|
curriculum.decrement_attr_level("rotations_weights")
|
|
curriculum.decrement_attr_level("mirrors_weights")
|
|
lower_bound_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value)
|
|
assert lower_bound_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2]
|
|
assert lower_bound_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]
|