diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 9acc5007..decfd4d2 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -7,6 +7,7 @@ Algorithmic tasks for training reasoning capabilities: """ from .base_conversion import BaseConversionConfig, BaseConversionDataset +from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset @@ -60,4 +61,6 @@ __all__ = [ "IsomorphicStringsDataset", "RotateMatrixConfig", "RotateMatrixDataset", + "BinaryMatrixConfig", + "BinaryMatrixDataset", ] diff --git a/reasoning_gym/algorithmic/binary_matrix.py b/reasoning_gym/algorithmic/binary_matrix.py new file mode 100644 index 00000000..c34d0566 --- /dev/null +++ b/reasoning_gym/algorithmic/binary_matrix.py @@ -0,0 +1,125 @@ +"""Find the distance to the nearest 0 for each cell in a binary matrix. + +A popular Leetcode problem: +https://leetcode.com/problems/01-matrix/description/ +""" + +from collections import deque +from copy import deepcopy +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Given a square matrix, your job is to find the distance of the nearest 0 for each cell. + +Example: + +Input: Rotate the matrix below by 90 degrees clockwise: +0 0 0 +0 1 0 +1 1 1 + +Output: +0 0 0 +0 1 0 +1 2 1 + +Find the distance to the nearest 0 for each cell in the matrix below: +{matrix} +""" + + +@dataclass +class BinaryMatrixConfig: + """Configuration for Binary Matrix dataset generation""" + + max_n: int = 10 # Maximum size of the matrix + p_zero: float = 0.7 # Probability of a cell being 0 + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 1 <= self.max_n, "max_n must be at least 1" + assert 0 < self.p_zero <= 1, "p_zero must be between 0 and 1" + + +class BinaryMatrixDataset(ProceduralDataset): + """Generates Binary Matrix exercises with configurable difficulty""" + + def __init__(self, config: BinaryMatrixConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def _get_binary_matrix(self, rng: Random) -> list[list[int]]: + """Generate a random binary matrix""" + n = rng.randint(1, self.config.max_n) + # Ensure at least one 0 in the matrix, so that a solution exists + numbers = [0] + [0 if rng.random() < self.config.p_zero else 1 for _ in range(n**2 - 1)] + matrix = [numbers[i * n : (i + 1) * n] for i in range(n)] + return matrix + + def _get_distances(self, matrix: list[list[int]]) -> list[list[int]]: + """Get the distance to the nearest 0 for each cell in the matrix""" + n = len(matrix) + directions = [[1, 0], [-1, 0], [0, 1], [0, -1]] + visited = set() + queue = deque() + + output = [[float("inf")] * n for _ in range(n)] + + for r in range(n): + for c in range(n): + if matrix[r][c] == 0: + output[r][c] = 0 + visited.add((r, c)) + queue.append((r, c)) + + clock = 1 + while True: + temp = deque() + while queue: + r, c = queue.popleft() + for dr, dc in directions: + new_r, new_c = r + dr, c + dc + if ( + 0 <= new_r < n + and 0 <= new_c < n + and (new_r, new_c) not in visited + and matrix[new_r][new_c] == 1 + ): + output[new_r][new_c] = clock + visited.add((new_r, new_c)) + temp.append((new_r, new_c)) + if temp: + queue = temp + else: + break + clock += 1 + + return output + + def _matrix_to_str(self, matrix: list[list[int]]) -> str: + """Get a string representation of the matrix""" + return "\n".join(" ".join(str(x) for x in row) for row in matrix) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Binary Matrix question""" + rng = Random(self.seed + idx) + + matrix = self._get_binary_matrix(rng) + matrix_str = self._matrix_to_str(matrix) + + answer = self._get_distances(matrix) + answer_str = self._matrix_to_str(answer) + + return { + "question": QUESTION_TEMPLATE.format(matrix=matrix_str), + "answer": answer_str, + "metadata": {"matrix": matrix, "solution": answer}, + } + + +register_dataset("binary_matrix", BinaryMatrixDataset, BinaryMatrixConfig) diff --git a/tests/test_binary_matrix.py b/tests/test_binary_matrix.py new file mode 100644 index 00000000..60c700d7 --- /dev/null +++ b/tests/test_binary_matrix.py @@ -0,0 +1,100 @@ +"""Tests for Binary Matrix questions generation""" + +import pytest + +from reasoning_gym.algorithmic.binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset + + +def test_binary_matrix_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = BinaryMatrixConfig(max_n=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = BinaryMatrixConfig(max_n=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = BinaryMatrixConfig(p_zero=0) # <= 0 not allowed + config.validate() + + with pytest.raises(AssertionError): + config = BinaryMatrixConfig(p_zero=1.01) # > 1 not allowed + config.validate() + + +def test_binary_matrix_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = BinaryMatrixConfig(seed=42, size=10) + dataset1 = BinaryMatrixDataset(config) + dataset2 = BinaryMatrixDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_binary_matrix_dataset_items(): + """Test basic properties of generated items""" + config = BinaryMatrixConfig(max_n=5, size=10, seed=42) + dataset = BinaryMatrixDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "matrix" in item["metadata"] + assert "solution" in item["metadata"] + + matrix = item["metadata"]["matrix"] + solution = item["metadata"]["solution"] + + # Verify list dimensions + assert len(matrix) <= config.max_n + assert all(len(row) <= config.max_n for row in matrix) + assert all(len(row) <= config.max_n for row in solution) + + # Verify matrix values + for r in range(len(matrix)): + for c in range(len(matrix[r])): + assert matrix[r][c] in {0, 1} + assert solution[r][c] >= matrix[r][c] + + +def test_binary_matrix_dataset_iteration(): + """Test that iteration respects dataset size""" + config = BinaryMatrixConfig(size=5, seed=42) + dataset = BinaryMatrixDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_binary_matrix_answer(): + """Test the _get_distances method""" + config = BinaryMatrixConfig(seed=42) + dataset = BinaryMatrixDataset(config) + + # 1x1 matrix + matrix = [[0]] + assert dataset._get_distances(matrix) == [[0]] + + # 2x2 matrix + matrix = [[0, 1], [1, 1]] + assert dataset._get_distances(matrix) == [[0, 1], [1, 2]] + + # 3x3 matrix + matrix = [[0, 0, 0], [0, 1, 0], [1, 1, 1]] + assert dataset._get_distances(matrix) == [[0, 0, 0], [0, 1, 0], [1, 2, 1]] + + # Empty matrix + matrix = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + assert dataset._get_distances(matrix) == [[0, 0, 0], [0, 0, 0], [0, 0, 0]]