diff --git a/reasoning_gym/algorithmic/pool_matrix.py b/reasoning_gym/algorithmic/pool_matrix.py index 811085d5..dda7ed2d 100644 --- a/reasoning_gym/algorithmic/pool_matrix.py +++ b/reasoning_gym/algorithmic/pool_matrix.py @@ -9,7 +9,30 @@ import numpy as np from ..factory import ProceduralDataset, register_dataset -QUESTION_TEMPLATE = """Perform {pool_type} pooling on the following matrix: +QUESTION_TEMPLATE = """Your job is to perform max/average pooling on the given matrix. +The stride is equal to the kernel size, meaning there is no overlap between the pooling regions. + +Example 1: +- Input: Perform max pooling on the following matrix with a kernel size of 2: +1 2 3 4 +5 6 7 8 +9 10 11 12 +13 14 15 16 +- Output: +6 8 +14 16 + +Example 2: +- Input: Perform average pooling on the following matrix with a kernel size of 2: +1 2 3 4 +5 6 7 8 +9 10 11 12 +13 14 15 16 +- Output: +3.5 5.5 +11.5 13.5 + +Perform {pool_type} pooling on the following matrix with a kernel size of {pool_size}: {matrix} """ @@ -18,6 +41,8 @@ QUESTION_TEMPLATE = """Perform {pool_type} pooling on the following matrix: class PoolMatrixConfig: """Configuration for Pool Matrix dataset generation""" + min_rows: int = 2 # Minimum rows of the matrix + min_cols: int = 2 # Minimum columns of the matrix max_rows: int = 10 # Maximum rows of the matrix max_cols: int = 10 # Maximum columns of the matrix max_pool_size: int = 3 # Maximum pooling size @@ -27,8 +52,10 @@ class PoolMatrixConfig: def validate(self): """Validate configuration parameters""" - assert 1 <= self.max_rows, "max_rows must be at least 1" - assert 1 <= self.max_cols, "max_cols must be at least 1" + assert 2 <= self.min_rows, "min_rows must be at least 2" + assert 2 <= self.min_cols, "min_cols must be at least 2" + assert self.min_rows <= self.max_rows, "max_rows must be at least min_rows" + assert self.min_cols <= self.max_cols, "max_cols must be at least min_cols" assert 1 <= self.max_pool_size, "max_pool_size must be at least 1" @@ -40,9 +67,9 @@ class PoolMatrixDataset(ProceduralDataset): def _get_matrix(self, rng: Random) -> np.ndarray: """Generate a random matrix""" - rows = rng.randint(1, self.config.max_rows) - cols = rng.randint(1, self.config.max_cols) - return np.array([[rng.randint(0, 10) for _ in range(cols)] for _ in range(rows)]) + rows = rng.randint(self.config.min_rows, self.config.max_rows) + cols = rng.randint(self.config.min_rows, self.config.max_cols) + return np.random.randint(0, 10, (rows, cols)) def _matrix_to_str(self, matrix: np.ndarray) -> str: """Get a string representation of the matrix""" @@ -89,6 +116,7 @@ class PoolMatrixDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single Pool Matrix question""" rng = Random(self.seed + idx) + np.random.seed(self.seed + idx) matrix = self._get_matrix(rng) matrix_str = self._matrix_to_str(matrix) @@ -100,7 +128,7 @@ class PoolMatrixDataset(ProceduralDataset): answer_str = self._matrix_to_str(answer) return { - "question": QUESTION_TEMPLATE.format(matrix=matrix_str, pool_type=pool_type), + "question": QUESTION_TEMPLATE.format(matrix=matrix_str, pool_type=pool_type, pool_size=pool_size), "answer": answer_str, "metadata": { "matrix": matrix.tolist(), diff --git a/tests/test_pool_matrix.py b/tests/test_pool_matrix.py index 422bd642..aa3fe6b6 100644 --- a/tests/test_pool_matrix.py +++ b/tests/test_pool_matrix.py @@ -9,7 +9,7 @@ from reasoning_gym.algorithmic.pool_matrix import PoolMatrixConfig, PoolMatrixDa def test_pool_matrix_config_validation(): """Test that invalid configs raise appropriate errors""" - for field in ["max_rows", "max_cols", "max_pool_size"]: + for field in ["min_rows", "min_cols", "max_rows", "max_cols"]: with pytest.raises(AssertionError): config = PoolMatrixConfig(**{field: -1}) # Negative not allowed config.validate() @@ -18,6 +18,18 @@ def test_pool_matrix_config_validation(): config = PoolMatrixConfig(**{field: 0}) # Zero not allowed config.validate() + with pytest.raises(AssertionError): + config = PoolMatrixConfig(**{field: 1}) # One not allowed + config.validate() + + with pytest.raises(AssertionError): + config = PoolMatrixConfig(max_pool_size=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = PoolMatrixConfig(max_pool_size=0) # Zero not allowed + config.validate() + def test_pool_matrix_dataset_deterministic(): """Test that dataset generates same items with same seed""" @@ -80,9 +92,6 @@ def test_pool_matrix_answer(): dataset = PoolMatrixDataset(config) # 1. Max pooling - matrix = np.array([[1]]) - assert np.allclose(dataset._max_pool(matrix, 2), np.array([[1]])) - matrix = np.array([[1, 2], [3, 4]]) assert np.allclose(dataset._max_pool(matrix, 2), np.array([[4]])) @@ -106,10 +115,6 @@ def test_pool_matrix_answer(): assert np.allclose(dataset._max_pool(matrix, 2), np.array([[6, 8], [14, 16]])) # 2. Average pooling - - matrix = np.array([[1]]) - assert np.allclose(dataset._average_pool(matrix, 2), np.array([[1]])) - matrix = np.array([[1, 2], [3, 4]]) assert np.allclose(dataset._average_pool(matrix, 2), np.array([[2.5]]))