import pytest from reasoning_gym.arc.arc_agi import ArcAgiConfig, ArcAgiCurriculum, ArcAgiDataset def test_arc_agi_config_validation(): """Test validation of ArcAgi configuration parameters""" with pytest.raises(AssertionError): ArcAgiConfig(size=0).validate() with pytest.raises(AssertionError): ArcAgiConfig(rotations=["invalid"]).validate() with pytest.raises(AssertionError): ArcAgiConfig(mirrors=["invalid"]).validate() # Valid configs should not raise config = ArcAgiConfig(size=10, seed=42) config.validate() config = ArcAgiConfig(rotations=["90", "180"], mirrors=["horizontal", "diagonal"]) config.validate() # Empty lists should be valid (no augmentations) config = ArcAgiConfig(rotations=[], mirrors=[]) config.validate() def test_arc_agi_deterministic(): """Test dataset reproducibility with fixed seed""" config = ArcAgiConfig(seed=42, size=10) ds1 = ArcAgiDataset(config) ds2 = ArcAgiDataset(config) for i in range(len(ds1)): assert ds1[i] == ds2[i], "ArcAgi datasets with same seed should match exactly" def test_arc_agi_items(): """Test basic structure and metadata of generated items""" config = ArcAgiConfig(seed=42, size=10) dataset = ArcAgiDataset(config) for item in dataset: assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item meta = item["metadata"] assert "input" in meta assert "output" in meta assert "task_id" in meta # Verify input/output are tuples of tuples (board format) assert isinstance(meta["input"], tuple) assert isinstance(meta["output"], tuple) assert all(isinstance(row, tuple) for row in meta["input"]) assert all(isinstance(row, tuple) for row in meta["output"]) # Verify task_id is a string assert isinstance(meta["task_id"], str) def test_arc_agi_augmentations(): """Test that augmentations can be selectively enabled/disabled""" # Test with all augmentations disabled config = ArcAgiConfig( seed=42, size=10, rotations=[], mirrors=[], rotations_weights=[1.0], mirrors_weights=[1.0], use_color_permutation=False, ) base_dataset = ArcAgiDataset(config) base_items = list(base_dataset) # Test with specific rotation only rot_config = ArcAgiConfig( seed=42, size=10, rotations=["90"], mirrors=[], rotations_weights=[0.5, 0.5], mirrors_weights=[1.0], use_color_permutation=False, ) rot_dataset = ArcAgiDataset(rot_config) rot_items = list(rot_dataset) # Items should differ with rotation enabled assert any( base_items[i]["metadata"]["input"] != rot_items[i]["metadata"]["input"] for i in range(len(base_items)) ), "90-degree rotation augmentation had no effect" # Test with specific mirror only mirror_config = ArcAgiConfig( seed=42, size=10, rotations=[], mirrors=["horizontal"], rotations_weights=[1.0], mirrors_weights=[0.5, 0.5], use_color_permutation=False, ) mirror_dataset = ArcAgiDataset(mirror_config) mirror_items = list(mirror_dataset) # Items should differ with mirror enabled assert any( base_items[i]["metadata"]["input"] != mirror_items[i]["metadata"]["input"] for i in range(len(base_items)) ), "Horizontal mirror augmentation had no effect" # Test with color permutation only color_config = ArcAgiConfig( seed=42, size=10, rotations=[], mirrors=[], rotations_weights=[1.0], mirrors_weights=[1.0], use_color_permutation=True, ) color_dataset = ArcAgiDataset(color_config) color_items = list(color_dataset) # Items should differ with color permutation enabled assert any( base_items[i]["metadata"]["input"] != color_items[i]["metadata"]["input"] for i in range(len(base_items)) ), "Color permutation had no effect" def test_arc_agi_scoring(): """Test solution verification and scoring""" config = ArcAgiConfig(size=10, seed=123) dataset = ArcAgiDataset(config) for item in dataset: # Test correct solution assert dataset.score_answer(item["answer"], entry=item) == 1.0 # Test invalid format assert dataset.score_answer("invalid grid format", entry=item) == 0.0 # Test None answer assert dataset.score_answer(None, entry=item) == 0.0 # Test wrong but valid grid format wrong_answer = "1 0 0 0\n0 0 0 1" assert dataset.score_answer(wrong_answer, entry=item) == 0.05 def test_arc_agi_dataset_modes(): """Test dataset behavior with different train/eval configurations""" # Test train-only mode train_config = ArcAgiConfig(use_train=True, use_eval=False, size=10, seed=42) train_ds = ArcAgiDataset(train_config) assert len(train_ds._task_ids) > 0 # Test eval-only mode eval_config = ArcAgiConfig(use_train=False, use_eval=True, size=10, seed=42) eval_ds = ArcAgiDataset(eval_config) assert len(eval_ds._task_ids) > 0 # Test both modes both_config = ArcAgiConfig(use_train=True, use_eval=True, size=10, seed=42) both_ds = ArcAgiDataset(both_config) assert len(both_ds._task_ids) > len(train_ds._task_ids) assert len(both_ds._task_ids) > len(eval_ds._task_ids) def test_arc_agi_shuffled_order(): config_unshuffled = ArcAgiConfig( shuffle_example_order=False, use_train=True, use_eval=False, rotations=[], mirrors=[], use_color_permutation=False, size=3, seed=42, ) config_shuffled = ArcAgiConfig( shuffle_example_order=True, use_train=True, use_eval=False, rotations=[], mirrors=[], use_color_permutation=False, size=3, seed=42, ) unshuffled = ArcAgiDataset(config_unshuffled) shuffled = ArcAgiDataset(config_shuffled) for a, b in zip(shuffled, unshuffled): assert a["question"] != b["question"] assert a["answer"] == b["answer"] def test_arc_agi_curriculum(): """Test the curriculum for complex arithmetic.""" curriculum = ArcAgiCurriculum() base_value = {"size": 150, "seed": 1} base_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value) assert base_cfg.seed == 1 assert base_cfg.size == 150 assert base_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2] assert base_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1] # Test and validate increase in levels curriculum.increment_attr_level("rotations_weights") curriculum.increment_attr_level("mirrors_weights") increased_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value) assert increased_cfg.rotations_weights == [0.15, 0.3, 0.25, 0.3] assert increased_cfg.mirrors_weights == [0.2, 0.2, 0.2, 0.2, 0.2] # Test and validate decrease in levels curriculum.decrement_attr_level("rotations_weights") curriculum.decrement_attr_level("mirrors_weights") decreased_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value) assert decreased_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2] assert decreased_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1] # Test upper bound boundary condition for _ in range(10): curriculum.increment_attr_level("rotations_weights") curriculum.increment_attr_level("mirrors_weights") upper_bound_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value) assert upper_bound_cfg.rotations_weights == [0.0, 0.4, 0.2, 0.4] assert upper_bound_cfg.mirrors_weights == [0.05, 0.05, 0.1, 0.4, 0.4] # Test lower bound boundary condition for _ in range(10): curriculum.decrement_attr_level("rotations_weights") curriculum.decrement_attr_level("mirrors_weights") lower_bound_cfg: ArcAgiCurriculum = curriculum.generate_configuration(base_value) assert lower_bound_cfg.rotations_weights == [0.3, 0.2, 0.3, 0.2] assert lower_bound_cfg.mirrors_weights == [0.3, 0.3, 0.2, 0.1, 0.1]