"""Tests for Spiral Matrix questions generation""" import pytest from reasoning_gym.algorithmic.spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset def test_spiral_matrix_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): config = SpiralMatrixConfig(max_n=-1) # Negative not allowed config.validate() with pytest.raises(AssertionError): config = SpiralMatrixConfig(max_n=0) # Zero not allowed config.validate() with pytest.raises(AssertionError): config = SpiralMatrixConfig(max_n=1) # One not allowed config.validate() def test_spiral_matrix_dataset_deterministic(): """Test that dataset generates same items with same seed""" config = SpiralMatrixConfig(seed=42, size=10) dataset1 = SpiralMatrixDataset(config) dataset2 = SpiralMatrixDataset(config) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] def test_spiral_matrix_dataset_items(): """Test basic properties of generated items""" config = SpiralMatrixConfig(max_n=5, size=10, seed=42) dataset = SpiralMatrixDataset(config) for i in range(len(dataset)): item = dataset[i] # Check item structure assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item # Check metadata assert "matrix" in item["metadata"] assert "solution" in item["metadata"] matrix = item["metadata"]["matrix"] solution = item["metadata"]["solution"] # Verify list dimensions assert len(matrix) <= config.max_n assert all(len(row) <= config.max_n for row in matrix) assert sum(len(row) for row in matrix) == len(solution) def test_spiral_matrix_dataset_iteration(): """Test that iteration respects dataset size""" config = SpiralMatrixConfig(size=5, seed=42) dataset = SpiralMatrixDataset(config) items = list(dataset) assert len(items) == config.size # Test multiple iterations yield same items assert items == list(dataset) def test_spiral_matrix_answer(): """Test the _get_spiral method""" config = SpiralMatrixConfig(seed=42) dataset = SpiralMatrixDataset(config) # 2D grid matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] assert dataset._get_spiral(matrix) == [1, 2, 3, 6, 9, 8, 7, 4, 5] # Answer is identical (up to trimming) entry = {"answer": "1 2 3 6 9 8 7 4 5"} answer = "\n\n1 2 3 6 9 8 7 4 5\n" assert dataset.score_answer(answer, entry) == 1.0 # Score answer in list format (partially correct) entry = {"answer": "1 2 3 6 9 8 7 4 5"} answer = "[1, 2, 3, 6, 9, 8, 7, 4, 5]" assert dataset.score_answer(answer, entry) == 0.1 # Answer is incorrect entry = {"answer": "1 2 3 6 9 8 7 4 5"} answer = "1 2 3" assert dataset.score_answer(answer, entry) == 0.0 # Answer is none entry = {"answer": "1 2 3 6 9 8 7 4 5"} answer = None assert dataset.score_answer(answer, entry) == 0.0 def test_spiral_matrix_curriculum(): curriculum = SpiralMatrixCurriculum() base_value = {"size": 150, "seed": 1} base_cfg: SpiralMatrixConfig = curriculum.generate_configuration(base_value) assert base_cfg.seed == 1 assert base_cfg.size == 150 assert base_cfg.min_n == 10 and base_cfg.max_n == 10 # test incrementing attribute levels curriculum.increment_attr_level("n") increased_cfg = curriculum.generate_configuration(base_value) assert increased_cfg.min_n == 10 and increased_cfg.max_n == 25 # test decrementing attribute levels curriculum.decrement_attr_level("n") decreased_cfg = curriculum.generate_configuration(base_value) assert decreased_cfg.min_n == 10 and decreased_cfg.max_n == 10