diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 9acc5007..dac1b62c 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -12,6 +12,7 @@ from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset from .letter_counting import LetterCountingConfig, LetterCountingDataset from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset +from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset from .palindrome_generation import PalindromeConfig, PalindromeDataset @@ -60,4 +61,6 @@ __all__ = [ "IsomorphicStringsDataset", "RotateMatrixConfig", "RotateMatrixDataset", + "ManipulateMatrixConfig", + "ManipulateMatrixDataset", ] diff --git a/reasoning_gym/algorithmic/manipulate_matrix.py b/reasoning_gym/algorithmic/manipulate_matrix.py new file mode 100644 index 00000000..5b1ec95e --- /dev/null +++ b/reasoning_gym/algorithmic/manipulate_matrix.py @@ -0,0 +1,268 @@ +"""Manipulate matrices by performing augmentations such as rotations, flips, mapping, etc.""" + +from copy import deepcopy +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """For the following matrix: +{matrix} + +Perform the following series of operations in order: +- Identity transformation, i.e. no change +{operations} +""" + + +def num_rows(matrix: list[list[int]]) -> int: + return len(matrix) + + +def num_cols(matrix: list[list[int]]) -> int: + return len(matrix[0]) if matrix else 0 + + +@dataclass +class ManipulateMatrixConfig: + """Configuration for Manipulate Matrix dataset generation""" + + max_rows: int = 10 # Maximum number of rows + max_cols: int = 10 # Maximum number of columns + p_rotate: float = 0.2 # Probability of rotating the matrix + p_hmirror: float = 0.2 # Probability of horizontally mirroring the matrix + p_vmirror: float = 0.2 # Probability of vertically mirroring the matrix + p_dmirror: float = 0.2 # Probability of mirroring along the diagonal + p_cmirror: float = 0.2 # Probability of mirroring along the counterdiagonal + p_map: float = 0.2 # Probability of mapping a certain value to another + p_crop: float = 0.2 # Probability of cropping the matrix + p_remove_every_nth_row: float = 0.2 # Probability of removing every nth row + p_remove_every_nth_col: float = 0.2 # Probability of removing every nth column + p_zero_divisible: float = 0.2 # Probability of setting elements divisible by some number to zero + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 1 <= self.max_rows, "max_rows must be at least 1" + assert 1 <= self.max_cols, "max_cols must be at least 1" + assert 0 <= self.p_rotate <= 1, "p_rotate must be between 0 and 1" + assert 0 <= self.p_hmirror <= 1, "p_hmirror must be between 0 and 1" + assert 0 <= self.p_vmirror <= 1, "p_vmirror must be between 0 and 1" + assert 0 <= self.p_dmirror <= 1, "p_dmirror must be between 0 and 1" + assert 0 <= self.p_cmirror <= 1, "p_cmirror must be between 0 and 1" + assert 0 <= self.p_map <= 1, "p_map must be between 0 and 1" + assert 0 <= self.p_crop <= 1, "p_crop must be between 0 and 1" + assert 0 <= self.p_remove_every_nth_row <= 1, "p_remove_every_nth_row must be between 0 and 1" + assert 0 <= self.p_remove_every_nth_col <= 1, "p_remove_nth_col must be between 0 and 1" + assert 0 <= self.p_zero_divisible <= 1, "p_zero_divisible must be between 0 and 1" + + +class ManipulateMatrixDataset(ProceduralDataset): + """Generates Manipulate Matrix exercises with configurable difficulty""" + + def __init__(self, config: ManipulateMatrixConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self._rotations = { + "90": self._rot90, + "180": self._rot180, + "270": self._rot270, + "360": self._identity, + } + self._all_transforms = [ + "rotate", + "hmirror", + "vmirror", + "dmirror", + "cmirror", + "map", + "zero_divisible", + "crop", + "remove_every_nth_row", + "remove_every_nth_col", + ] + + def _get_matrix(self, rng: Random) -> list[list[int]]: + """Generate a random matrix""" + rows = rng.randint(1, self.config.max_rows) + cols = rng.randint(1, self.config.max_cols) + numbers = [rng.randint(0, 9) for _ in range(rows * cols)] + matrix = [numbers[i * cols : (i + 1) * cols] for i in range(rows)] + return matrix + + def _matrix_to_str(self, matrix: list[list[int]]) -> str: + """Get a string representation of the matrix""" + return "\n".join(" ".join(str(x) for x in row) for row in matrix) + + def _identity(self, matrix: list[list[int]]) -> list[list[int]]: + """Identity transformation""" + return matrix + + def _rot90(self, matrix: list[list[int]]) -> list[list[int]]: + """quarter clockwise rotation""" + return [list(row) for row in zip(*matrix[::-1])] + + def _rot180(self, matrix: list[list[int]]) -> list[list[int]]: + """half rotation""" + return [list(row[::-1]) for row in matrix[::-1]] + + def _rot270(self, matrix: list[list[int]]) -> list[list[int]]: + """quarter anticlockwise rotation""" + return [list(row[::-1]) for row in zip(*matrix[::-1])][::-1] + + def _hmirror(self, matrix: list[list[int]]) -> list[list[int]]: + """mirroring along horizontal""" + return matrix[::-1] + + def _vmirror(self, matrix: list[list[int]]) -> list[list[int]]: + """mirroring along vertical""" + return [row[::-1] for row in matrix] + + def _dmirror(self, matrix: list[list[int]]) -> list[list[int]]: + """mirroring along diagonal""" + return list(list(row) for row in zip(*matrix)) + + def _cmirror(self, matrix: list[list[int]]) -> list[list[int]]: + """mirroring along counterdiagonal""" + return list(list(row) for row in zip(*[r[::-1] for r in matrix[::-1]])) + + def _map(self, matrix: list[list[int]], a: int, b: int) -> list[list[int]]: + """mapping a to b""" + return [[b if x == a else x for x in row] for row in matrix] + + def _zero_divisible(self, matrix: list[list[int]], k: int) -> list[list[int]]: + """set elements divisible by k to zero""" + return [[0 if x % k == 0 else x for x in row] for row in matrix] + + def _crop( + self, matrix: list[list[int]], row_start: int, row_end: int, col_start: int, col_end: int + ) -> list[list[int]]: + """crop the matrix (1-indexed)""" + return [row[col_start - 1 : col_end] for row in matrix[row_start - 1 : row_end]] + + def _remove_every_nth_row(self, matrix: list[list[int]], n: int) -> list[list[int]]: + """remove every nth row (1-indexed)""" + return [row for i, row in enumerate(matrix, start=1) if i % n != 0] + + def _remove_every_nth_col(self, matrix: list[list[int]], n: int) -> list[list[int]]: + """remove every nth column (1-indexed)""" + return [[col for i, col in enumerate(row, start=1) if i % n != 0] for row in matrix] + + def __getitem__(self, idx: int) -> dict: + """Generate a single Manipulate Matrix question""" + rng = Random(self.seed + idx) + + matrix = self._get_matrix(rng) + matrix_str = self._matrix_to_str(matrix) + + # Shuffle the order of operations (make sure to copy the list to guarantee same order) + all_transforms = deepcopy(self._all_transforms) + rng.shuffle(all_transforms) + operations = [] + + answer = deepcopy(matrix) + + for transform in all_transforms: + # Rotate + if transform == "rotate" and rng.random() < self.config.p_rotate: + rotation = rng.choice(list(self._rotations.keys())) + answer = self._rotations[rotation](answer) + operations.append( + { + "transform": transform, + "degrees": rotation, + "instruction": f"- Rotate the matrix {rotation} degrees", + } + ) + # Horizontal mirror + if transform == "hmirror" and rng.random() < self.config.p_hmirror: + answer = self._hmirror(answer) + operations.append({"transform": transform, "instruction": "- Horizontally mirror the matrix"}) + # Vertical mirror + if transform == "vmirror" and rng.random() < self.config.p_vmirror: + answer = self._vmirror(answer) + operations.append({"transform": transform, "instruction": "- Vertically mirror the matrix"}) + # Diagonal mirror + if transform == "dmirror" and rng.random() < self.config.p_dmirror: + answer = self._dmirror(answer) + operations.append({"transform": transform, "instruction": "- Mirror the matrix along the diagonal"}) + # Counterdiagonal mirror + if transform == "cmirror" and rng.random() < self.config.p_cmirror: + answer = self._cmirror(answer) + operations.append( + {"transform": transform, "instruction": "- Mirror the matrix along the counterdiagonal"} + ) + # Map a value to another + if transform == "map" and rng.random() < self.config.p_map: + a, b = rng.sample(range(10), 2) + answer = self._map(answer, a, b) + operations.append( + {"transform": transform, "from": a, "to": b, "instruction": f"- Map each occurrence of {a} to {b}"} + ) + # Set elements divisible by k to zero + if transform == "zero_divisible" and rng.random() < self.config.p_zero_divisible: + k = rng.randint(1, 9) + answer = self._zero_divisible(answer, k) + operations.append( + {"transform": transform, "k": k, "instruction": f"- Set all elements divisible by {k} to zero"} + ) + # Crop the matrix + if transform == "crop" and rng.random() < self.config.p_crop: + row_start = rng.randint(1, num_rows(answer)) + row_end = rng.randint(row_start, num_rows(answer)) + col_start = rng.randint(1, num_cols(answer)) + col_end = rng.randint(col_start, num_cols(answer)) + answer = self._crop(answer, row_start, row_end, col_start, col_end) + operations.append( + { + "transform": transform, + "row_start": row_start, + "row_end": row_end, + "col_start": col_start, + "col_end": col_end, + "instruction": f"- Crop the matrix to rows {row_start}-{row_end} and columns {col_start}-{col_end} (1-indexed)", + } + ) + # Remove every nth row + if ( + transform == "remove_every_nth_row" + and rng.random() < self.config.p_remove_every_nth_row + and num_rows(answer) > 1 + ): + n = rng.randint(2, num_rows(answer)) + answer = self._remove_every_nth_row(answer, n) + formatting = "st" if n == 1 else "nd" if n == 2 else "th" + operations.append( + {"transform": transform, "n": n, "instruction": f"- Remove every {n}-{formatting} row (1-indexed)"} + ) + # Remove every nth column + if ( + transform == "remove_every_nth_col" + and rng.random() < self.config.p_remove_every_nth_col + and num_cols(answer) > 1 + ): + n = rng.randint(2, num_cols(answer)) + answer = self._remove_every_nth_col(answer, n) + formatting = "st" if n == 1 else "nd" if n == 2 else "th" + operations.append( + { + "transform": transform, + "n": n, + "instruction": f"- Remove every {n}-{formatting} column (1-indexed)", + } + ) + + answer_str = self._matrix_to_str(answer) + + return { + "question": QUESTION_TEMPLATE.format( + matrix=matrix_str, operations="\n".join(op["instruction"] for op in operations) + ), + "answer": answer_str, + "metadata": {"matrix": matrix, "solution": answer, "operations": operations}, + } + + +register_dataset("manipulate_matrix", ManipulateMatrixDataset, ManipulateMatrixConfig) diff --git a/tests/test_manipulate_matrix.py b/tests/test_manipulate_matrix.py new file mode 100644 index 00000000..eb44aa6f --- /dev/null +++ b/tests/test_manipulate_matrix.py @@ -0,0 +1,210 @@ +"""Tests for Manipulate Matrix questions generation""" + +import pytest + +from reasoning_gym.algorithmic.manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset + + +def test_manipulate_matrix_config_validation(): + """Test that invalid configs raise appropriate errors""" + + invalid_dims = [-1, 0] # Dimensions should be positive integers + dim_fields = ["max_rows", "max_cols"] + + for field in dim_fields: + for dim in invalid_dims: + with pytest.raises(AssertionError): + config = ManipulateMatrixConfig(**{field: dim}) + config.validate() + + invalid_probabilities = [-0.01, 1.01] # Probabilities should be between 0 and 1 inclusive + probability_fields = [ + "p_hmirror", + "p_vmirror", + "p_dmirror", + "p_cmirror", + "p_map", + "p_crop", + "p_remove_every_nth_row", + "p_remove_every_nth_col", + "p_zero_divisible", + ] + + for field in probability_fields: + for prob in invalid_probabilities: + with pytest.raises(AssertionError): + config = ManipulateMatrixConfig(**{field: prob}) + config.validate() + + +def test_manipulate_matrix_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = ManipulateMatrixConfig(seed=42, size=10) + dataset1 = ManipulateMatrixDataset(config) + dataset2 = ManipulateMatrixDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_manipulate_matrix_dataset_items(): + """Test basic properties of generated items""" + config = ManipulateMatrixConfig(max_rows=7, max_cols=7, size=10, seed=42) + dataset = ManipulateMatrixDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "matrix" in item["metadata"] + assert "solution" in item["metadata"] + assert "operations" in item["metadata"] + + matrix = item["metadata"]["matrix"] + solution = item["metadata"]["solution"] + operations = item["metadata"]["operations"] + + # Verify matrix dimensions + assert len(matrix) <= config.max_rows + assert all(len(row) <= config.max_cols for row in matrix) + assert len(solution) <= config.max_rows + assert all(len(row) <= config.max_cols for row in solution) + for op in operations: + assert "transform" in op + assert "instruction" in op + + +def test_manipulate_matrix_dataset_iteration(): + """Test that iteration respects dataset size""" + config = ManipulateMatrixConfig(size=5, seed=42) + dataset = ManipulateMatrixDataset(config) + + items = list(dataset) + assert len(items) == config.size + + assert items == list(dataset) + + +def test_manipulate_matrix_transforms(): + """Test the _get_rotated method""" + config = ManipulateMatrixConfig(seed=42) + dataset = ManipulateMatrixDataset(config) + matrix = [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 10], + [11, 12, 13, 14, 15], + [16, 17, 18, 19, 20], + [21, 22, 23, 24, 25], + ] + + # identity + assert dataset._identity(matrix) == matrix + + # rot 90 degrees + assert dataset._rot90(matrix) == [ + [21, 16, 11, 6, 1], + [22, 17, 12, 7, 2], + [23, 18, 13, 8, 3], + [24, 19, 14, 9, 4], + [25, 20, 15, 10, 5], + ] + + # rot 180 degrees + assert dataset._rot180(matrix) == [ + [25, 24, 23, 22, 21], + [20, 19, 18, 17, 16], + [15, 14, 13, 12, 11], + [10, 9, 8, 7, 6], + [5, 4, 3, 2, 1], + ] + + # rot 270 degrees + assert dataset._rot270(matrix) == [ + [5, 10, 15, 20, 25], + [4, 9, 14, 19, 24], + [3, 8, 13, 18, 23], + [2, 7, 12, 17, 22], + [1, 6, 11, 16, 21], + ] + + # hmirror + assert dataset._hmirror(matrix) == [ + [21, 22, 23, 24, 25], + [16, 17, 18, 19, 20], + [11, 12, 13, 14, 15], + [6, 7, 8, 9, 10], + [1, 2, 3, 4, 5], + ] + + # vmirror + assert dataset._vmirror(matrix) == [ + [5, 4, 3, 2, 1], + [10, 9, 8, 7, 6], + [15, 14, 13, 12, 11], + [20, 19, 18, 17, 16], + [25, 24, 23, 22, 21], + ] + + # dmirror + assert dataset._dmirror(matrix) == [ + [1, 6, 11, 16, 21], + [2, 7, 12, 17, 22], + [3, 8, 13, 18, 23], + [4, 9, 14, 19, 24], + [5, 10, 15, 20, 25], + ] + + # cmirror + assert dataset._cmirror(matrix) == [ + [25, 20, 15, 10, 5], + [24, 19, 14, 9, 4], + [23, 18, 13, 8, 3], + [22, 17, 12, 7, 2], + [21, 16, 11, 6, 1], + ] + + # map + assert dataset._map(matrix, a=13, b=0) == [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 10], + [11, 12, 0, 14, 15], # 13 -> 0 + [16, 17, 18, 19, 20], + [21, 22, 23, 24, 25], + ] + + # zero divisible + assert dataset._zero_divisible(matrix, k=3) == [ + [1, 2, 0, 4, 5], + [0, 7, 8, 0, 10], + [11, 0, 13, 14, 0], + [16, 17, 0, 19, 20], + [0, 22, 23, 0, 25], + ] + + # crop + assert dataset._crop(matrix, row_start=2, row_end=4, col_start=1, col_end=3) == [ + [6, 7, 8], + [11, 12, 13], + [16, 17, 18], + ] + + # remove every nth row + assert dataset._remove_every_nth_row(matrix, n=2) == [ + [1, 2, 3, 4, 5], + [11, 12, 13, 14, 15], + [21, 22, 23, 24, 25], + ] + + # remove every nth col + assert dataset._remove_every_nth_col(matrix, n=2) == [ + [1, 3, 5], + [6, 8, 10], + [11, 13, 15], + [16, 18, 20], + [21, 23, 25], + ]