diff --git a/reasoning_gym/algorithmic/manipulate_matrix.py b/reasoning_gym/algorithmic/manipulate_matrix.py index 1b5412f4..ab0bf592 100644 --- a/reasoning_gym/algorithmic/manipulate_matrix.py +++ b/reasoning_gym/algorithmic/manipulate_matrix.py @@ -3,7 +3,9 @@ from copy import deepcopy from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Any, Optional + +import numpy as np from ..factory import ProceduralDataset, register_dataset @@ -28,21 +30,22 @@ def num_cols(matrix: list[list[int]]) -> int: class ManipulateMatrixConfig: """Configuration for Manipulate Matrix dataset generation""" - min_rows: int = 1 # Minimum number of rows - min_cols: int = 1 # Minimum number of columns + min_rows: int = 2 # Minimum number of rows + min_cols: int = 2 # Minimum number of columns max_rows: int = 10 # Maximum number of rows max_cols: int = 10 # Maximum number of columns - max_transforms: int = 5 # Maximum number of transformations to apply - p_rotate: float = 0.2 # Probability of rotating the matrix - p_hmirror: float = 0.2 # Probability of horizontally mirroring the matrix - p_vmirror: float = 0.2 # Probability of vertically mirroring the matrix - p_dmirror: float = 0.2 # Probability of mirroring along the diagonal - p_cmirror: float = 0.2 # Probability of mirroring along the counterdiagonal - p_map: float = 0.2 # Probability of mapping a certain value to another - p_crop: float = 0.2 # Probability of cropping the matrix - p_remove_every_nth_row: float = 0.2 # Probability of removing every nth row - p_remove_every_nth_col: float = 0.2 # Probability of removing every nth column - p_zero_divisible: float = 0.2 # Probability of setting elements divisible by some number to zero + min_transforms: int = 1 # Minimum number of transformations to apply + max_transforms: int = 10 # Maximum number of transformations to apply + w_rotate: float = 1 # Weight of rotating the matrix + w_hmirror: float = 1 # Weight of horizontally mirroring the matrix + w_vmirror: float = 1 # Weight of vertically mirroring the matrix + w_dmirror: float = 1 # Weight of mirroring along the diagonal + w_cmirror: float = 1 # Weight of mirroring along the counterdiagonal + w_map: float = 1 # Weight of mapping a certain value to another + w_crop: float = 1 # Weight of cropping the matrix + w_remove_every_nth_row: float = 1 # Weight of removing every nth row + w_remove_every_nth_col: float = 1 # Weight of removing every nth column + w_zero_divisible: float = 1 # Weight of setting elements divisible by some number to zero size: int = 500 # Virtual dataset size seed: Optional[int] = None @@ -53,17 +56,27 @@ class ManipulateMatrixConfig: assert 1 <= self.min_cols, "min_cols must be at least 1" assert self.min_rows <= self.max_rows, "max_rows must be at least min_rows" assert self.min_cols <= self.max_cols, "max_cols must be at least min_cols" - assert 0 <= self.max_transforms, "max_transforms must be non-negative" - assert 0 <= self.p_rotate <= 1, "p_rotate must be between 0 and 1" - assert 0 <= self.p_hmirror <= 1, "p_hmirror must be between 0 and 1" - assert 0 <= self.p_vmirror <= 1, "p_vmirror must be between 0 and 1" - assert 0 <= self.p_dmirror <= 1, "p_dmirror must be between 0 and 1" - assert 0 <= self.p_cmirror <= 1, "p_cmirror must be between 0 and 1" - assert 0 <= self.p_map <= 1, "p_map must be between 0 and 1" - assert 0 <= self.p_crop <= 1, "p_crop must be between 0 and 1" - assert 0 <= self.p_remove_every_nth_row <= 1, "p_remove_every_nth_row must be between 0 and 1" - assert 0 <= self.p_remove_every_nth_col <= 1, "p_remove_nth_col must be between 0 and 1" - assert 0 <= self.p_zero_divisible <= 1, "p_zero_divisible must be between 0 and 1" + assert 1 <= self.min_transforms, "min_transforms must be at least 1" + assert self.min_transforms <= self.max_transforms, "max_transforms must be at least min_transforms" + assert ( + np.sum( + np.exp( + [ + self.w_rotate, + self.w_hmirror, + self.w_vmirror, + self.w_dmirror, + self.w_cmirror, + self.w_map, + self.w_crop, + self.w_remove_every_nth_row, + self.w_remove_every_nth_col, + self.w_zero_divisible, + ] + ) + ) + > 0 + ), "At least one weight must be non-zero" class ManipulateMatrixDataset(ProceduralDataset): @@ -89,6 +102,21 @@ class ManipulateMatrixDataset(ProceduralDataset): "remove_every_nth_row", "remove_every_nth_col", ] + weights = np.array( + [ + config.w_rotate, + config.w_hmirror, + config.w_vmirror, + config.w_dmirror, + config.w_cmirror, + config.w_map, + config.w_crop, + config.w_remove_every_nth_row, + config.w_remove_every_nth_col, + config.w_zero_divisible, + ] + ) + self._weights = np.exp(weights) / np.sum(np.exp(weights)) def _get_matrix(self, rng: Random) -> list[list[int]]: """Generate a random matrix""" @@ -102,6 +130,25 @@ class ManipulateMatrixDataset(ProceduralDataset): """Get a string representation of the matrix""" return "\n".join(" ".join(str(x) for x in row) for row in matrix) + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: + oracle_answer = entry["answer"].strip() + if answer is not None and len(answer) > 0: + answer = answer.strip() + if answer == oracle_answer: + return 1.0 + + # perhaps the model's answer has unnecessary spaces (e.g. after last row element) + answer = self._matrix_to_str([row.strip().split() for row in answer.strip().split("\n")]).strip() + if answer == oracle_answer: + return 1.0 + + if oracle_answer in answer: + return len(oracle_answer) / len(answer) + else: + return 0.01 + + return 0.0 + def _identity(self, matrix: list[list[int]]) -> list[list[int]]: """Identity transformation""" return matrix @@ -163,15 +210,16 @@ class ManipulateMatrixDataset(ProceduralDataset): matrix = self._get_matrix(rng) matrix_str = self._matrix_to_str(matrix) - num_transforms = rng.randint(0, self.config.max_transforms) - transforms = rng.sample(self._all_transforms, num_transforms) + num_transforms = rng.randint(self.config.min_transforms, self.config.max_transforms) operations = [] - answer = deepcopy(matrix) - for transform in transforms: + while len(operations) < num_transforms: + # Choose a transform randomly, weighted by the probability of each transform + transform = rng.choices(self._all_transforms, weights=self._weights, k=1)[0] + # Rotate - if transform == "rotate" and rng.random() < self.config.p_rotate: + if transform == "rotate": rotation = rng.choice(list(self._rotations.keys())) answer = self._rotations[rotation](answer) operations.append( @@ -182,39 +230,39 @@ class ManipulateMatrixDataset(ProceduralDataset): } ) # Horizontal mirror - if transform == "hmirror" and rng.random() < self.config.p_hmirror: + if transform == "hmirror": answer = self._hmirror(answer) operations.append({"transform": transform, "instruction": "- Horizontally mirror the matrix"}) # Vertical mirror - if transform == "vmirror" and rng.random() < self.config.p_vmirror: + if transform == "vmirror": answer = self._vmirror(answer) operations.append({"transform": transform, "instruction": "- Vertically mirror the matrix"}) # Diagonal mirror - if transform == "dmirror" and rng.random() < self.config.p_dmirror: + if transform == "dmirror": answer = self._dmirror(answer) operations.append({"transform": transform, "instruction": "- Mirror the matrix along the diagonal"}) # Counterdiagonal mirror - if transform == "cmirror" and rng.random() < self.config.p_cmirror: + if transform == "cmirror": answer = self._cmirror(answer) operations.append( {"transform": transform, "instruction": "- Mirror the matrix along the counterdiagonal"} ) # Map a value to another - if transform == "map" and rng.random() < self.config.p_map: + if transform == "map": a, b = rng.sample(range(10), 2) answer = self._map(answer, a, b) operations.append( {"transform": transform, "from": a, "to": b, "instruction": f"- Map each occurrence of {a} to {b}"} ) # Set elements divisible by k to zero - if transform == "zero_divisible" and rng.random() < self.config.p_zero_divisible: + if transform == "zero_divisible": k = rng.randint(1, 9) answer = self._zero_divisible(answer, k) operations.append( {"transform": transform, "k": k, "instruction": f"- Set all elements divisible by {k} to zero"} ) # Crop the matrix - if transform == "crop" and rng.random() < self.config.p_crop: + if transform == "crop": row_start = rng.randint(1, num_rows(answer)) row_end = rng.randint(row_start, num_rows(answer)) col_start = rng.randint(1, num_cols(answer)) @@ -231,11 +279,7 @@ class ManipulateMatrixDataset(ProceduralDataset): } ) # Remove every nth row - if ( - transform == "remove_every_nth_row" - and rng.random() < self.config.p_remove_every_nth_row - and num_rows(answer) > 1 - ): + if transform == "remove_every_nth_row" and num_rows(answer) > 1: n = rng.randint(2, num_rows(answer)) answer = self._remove_every_nth_row(answer, n) formatting = "st" if n == 1 else "nd" if n == 2 else "th" @@ -243,11 +287,7 @@ class ManipulateMatrixDataset(ProceduralDataset): {"transform": transform, "n": n, "instruction": f"- Remove every {n}-{formatting} row (1-indexed)"} ) # Remove every nth column - if ( - transform == "remove_every_nth_col" - and rng.random() < self.config.p_remove_every_nth_col - and num_cols(answer) > 1 - ): + if transform == "remove_every_nth_col" and num_cols(answer) > 1: n = rng.randint(2, num_cols(answer)) answer = self._remove_every_nth_col(answer, n) formatting = "st" if n == 1 else "nd" if n == 2 else "th" diff --git a/tests/test_manipulate_matrix.py b/tests/test_manipulate_matrix.py index 9dfc3655..2801340f 100644 --- a/tests/test_manipulate_matrix.py +++ b/tests/test_manipulate_matrix.py @@ -8,9 +8,11 @@ from reasoning_gym.algorithmic.manipulate_matrix import ManipulateMatrixConfig, def test_manipulate_matrix_config_validation(): """Test that invalid configs raise appropriate errors""" - with pytest.raises(AssertionError): - config = ManipulateMatrixConfig(max_transforms=-1) # max_transforms should be non-negative - config.validate() + for field in ["min_transforms", "max_transforms"]: + for num_transforms in [-1, 0]: # Number of transforms should be positive + with pytest.raises(AssertionError): + config = ManipulateMatrixConfig(**{field: num_transforms}) + config.validate() invalid_dims = [-1, 0] # Dimensions should be positive integers dim_fields = ["min_rows", "min_cols", "max_rows", "max_cols"] @@ -21,25 +23,6 @@ def test_manipulate_matrix_config_validation(): config = ManipulateMatrixConfig(**{field: dim}) config.validate() - invalid_probabilities = [-0.01, 1.01] # Probabilities should be between 0 and 1 inclusive - probability_fields = [ - "p_hmirror", - "p_vmirror", - "p_dmirror", - "p_cmirror", - "p_map", - "p_crop", - "p_remove_every_nth_row", - "p_remove_every_nth_col", - "p_zero_divisible", - ] - - for field in probability_fields: - for prob in invalid_probabilities: - with pytest.raises(AssertionError): - config = ManipulateMatrixConfig(**{field: prob}) - config.validate() - def test_manipulate_matrix_dataset_deterministic(): """Test that dataset generates same items with same seed""" @@ -212,3 +195,27 @@ def test_manipulate_matrix_transforms(): [16, 18, 20], [21, 23, 25], ] + + +def test_manipulate_matrix_score_answer(): + """Test the score_answer method""" + config = ManipulateMatrixConfig(seed=42) + dataset = ManipulateMatrixDataset(config) + + entry = {"answer": dataset._matrix_to_str([[1, 2, 3], [4, 5, 6], [7, 8, 9]])} + + # perfect match + answer = "1 2 3\n4 5 6\n7 8 9" + assert dataset.score_answer(answer, entry) == 1.0 + + # model answer contains unnecessary empty spaces + answer = "1 2 3\n4 5 6 \n7 8 9 " + assert dataset.score_answer(answer, entry) == 1.0 + + # incorrect answer + answer = "1 2 3\n4 5 6\n7 8 8" + assert dataset.score_answer(answer, entry) == 0.01 + + # answer is none + answer = None + assert dataset.score_answer(answer, entry) == 0.0