diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 4c33d08f..1582718d 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -16,6 +16,7 @@ from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset from .palindrome_generation import PalindromeConfig, PalindromeDataset from .ransom_note import RansomNoteConfig, RansomNoteDataset +from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .word_ladder import WordLadderConfig, WordLadderDataset @@ -54,4 +55,6 @@ __all__ = [ "RansomNoteDataset", "IsomorphicStringsConfig", "IsomorphicStringsDataset", + "RotateMatrixConfig", + "RotateMatrixDataset", ] diff --git a/reasoning_gym/algorithmic/rotate_matrix.py b/reasoning_gym/algorithmic/rotate_matrix.py new file mode 100644 index 00000000..ac50a281 --- /dev/null +++ b/reasoning_gym/algorithmic/rotate_matrix.py @@ -0,0 +1,103 @@ +"""Rotate a square matrix clockwise. + +A popular Leetcode problem: +https://leetcode.com/problems/rotate-image/description/ +""" + +from copy import deepcopy +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it clockwise. + +Example: + +Input: Rotate the matrix below by 90 degrees clockwise: +1 2 3 +4 5 6 +7 8 9 + +Output: +7 4 1 +8 5 2 +9 6 3 + +Rotate the matrix below by {degrees} degrees clockwise: +{matrix} +""" + + +@dataclass +class RotateMatrixConfig: + """Configuration for Rotate Matrix dataset generation""" + + max_n: int = 10 # Maximum size of the matrix + max_rotations: int = 4 # Maximum number of rotations (90 degrees each) + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 1 <= self.max_n, "max_n must be at least 1" + assert 0 <= self.max_rotations, "max_rotations must be at least 0" + + +class RotateMatrixDataset(ProceduralDataset): + """Generates Rotate Matrix exercises with configurable difficulty""" + + def __init__(self, config: RotateMatrixConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def _get_matrix(self, rng: Random) -> list[list[int]]: + """Generate a random matrix""" + n = rng.randint(1, self.config.max_n) + numbers = list(range(n**2)) + rng.shuffle(numbers) + matrix = [numbers[i * n : (i + 1) * n] for i in range(n)] + return matrix + + def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]: + """Rotate the matrix K times by 90 degrees clockwise""" + num_rotations %= 4 + n = len(matrix) + output = deepcopy(matrix) + + for _ in range(num_rotations): + for l in range(n // 2): + for i in range(l, n - 1 - l): + (output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = ( + output[n - 1 - i][l], + output[l][i], + output[i][n - 1 - l], + output[n - 1 - l][n - 1 - i], + ) + + return output + + def _matrix_to_str(self, matrix: list[list[int]]) -> str: + """Get a string representation of the matrix""" + return "\n".join(" ".join(str(x) for x in row) for row in matrix) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Spiral Matrix question""" + rng = Random(self.seed + idx) + + matrix = self._get_matrix(rng) + num_rotations = rng.randint(0, self.config.max_rotations) + matrix_str = self._matrix_to_str(matrix) + + answer = self._get_rotated(matrix, num_rotations) + answer_str = self._matrix_to_str(answer) + + return { + "question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90), + "answer": answer_str, + "metadata": {"matrix": matrix, "num_rotations": num_rotations, "solution": answer}, + } + + +register_dataset("rotate_matrix", RotateMatrixDataset, RotateMatrixConfig) diff --git a/tests/test_rotate_matrix.py b/tests/test_rotate_matrix.py new file mode 100644 index 00000000..c2e43df9 --- /dev/null +++ b/tests/test_rotate_matrix.py @@ -0,0 +1,144 @@ +"""Tests for Rotate Matrix questions generation""" + +import pytest + +from reasoning_gym.algorithmic.rotate_matrix import RotateMatrixConfig, RotateMatrixDataset + + +def test_rotate_matrix_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = RotateMatrixConfig(max_n=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = RotateMatrixConfig(max_n=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = RotateMatrixConfig(max_rotations=-1) # Negative not allowed + config.validate() + + +def test_rotate_matrix_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = RotateMatrixConfig(seed=42, size=10) + dataset1 = RotateMatrixDataset(config) + dataset2 = RotateMatrixDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_rotate_matrix_dataset_items(): + """Test basic properties of generated items""" + config = RotateMatrixConfig(max_n=7, size=10, seed=42) + dataset = RotateMatrixDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "matrix" in item["metadata"] + assert "solution" in item["metadata"] + + matrix = item["metadata"]["matrix"] + solution = item["metadata"]["solution"] + num_rotations = item["metadata"]["num_rotations"] + + # Verify matrix dimensions + assert len(matrix) <= config.max_n + assert all(len(row) <= config.max_n for row in matrix) + assert len(solution) <= config.max_n + assert all(len(row) <= config.max_n for row in solution) + assert set(e for row in matrix for e in row) == set(e for row in solution for e in row) + assert num_rotations <= config.max_rotations + + +def test_rotate_matrix_dataset_iteration(): + """Test that iteration respects dataset size""" + config = RotateMatrixConfig(size=5, seed=42) + dataset = RotateMatrixDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_rotate_matrix_answer(): + """Test the _get_rotated method""" + config = RotateMatrixConfig(seed=42) + dataset = RotateMatrixDataset(config) + + # n = 1, num_rotations = 1 + matrix = [[8]] + expected = [[8]] + assert dataset._get_rotated(matrix, num_rotations=1) == expected + + # n = 3, num_rotations = 0 (no rotation) + matrix = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ] + expected = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ] + assert dataset._get_rotated(matrix, num_rotations=0) == expected + + # n = 3, num_rotations = 1 (90 degrees clockwise) + matrix = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ] + expected = [ + [6, 3, 0], + [7, 4, 1], + [8, 5, 2], + ] + assert dataset._get_rotated(matrix, num_rotations=1) == expected + + # n = 3, num_rotations = 2 (180 degrees clockwise) + matrix = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ] + expected = [ + [8, 7, 6], + [5, 4, 3], + [2, 1, 0], + ] + assert dataset._get_rotated(matrix, num_rotations=2) == expected + + # n = 3, num_rotations = 3 (270 degrees clockwise) + matrix = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ] + expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]] + assert dataset._get_rotated(matrix, num_rotations=3) == expected + + # n = 4, num_rotations = 4 (360 degrees clockwise == 0 degrees) + matrix = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ] + expected = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ] + assert dataset._get_rotated(matrix, num_rotations=4) == expected