matrix manipulation

This commit is contained in:
Zafir Stojanovski 2025-02-10 13:51:39 +01:00
parent ed37eae559
commit 3d66cc6a7f
3 changed files with 481 additions and 0 deletions

View file

@ -12,6 +12,7 @@ from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset
from .letter_counting import LetterCountingConfig, LetterCountingDataset
from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset
from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
from .number_sorting import NumberSortingConfig, NumberSortingDataset
from .palindrome_generation import PalindromeConfig, PalindromeDataset
@ -60,4 +61,6 @@ __all__ = [
"IsomorphicStringsDataset",
"RotateMatrixConfig",
"RotateMatrixDataset",
"ManipulateMatrixConfig",
"ManipulateMatrixDataset",
]

View file

@ -0,0 +1,268 @@
"""Manipulate matrices by performing augmentations such as rotations, flips, mapping, etc."""
from copy import deepcopy
from dataclasses import dataclass
from random import Random
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """For the following matrix:
{matrix}
Perform the following series of operations in order:
- Identity transformation, i.e. no change
{operations}
"""
def num_rows(matrix: list[list[int]]) -> int:
return len(matrix)
def num_cols(matrix: list[list[int]]) -> int:
return len(matrix[0]) if matrix else 0
@dataclass
class ManipulateMatrixConfig:
"""Configuration for Manipulate Matrix dataset generation"""
max_rows: int = 10 # Maximum number of rows
max_cols: int = 10 # Maximum number of columns
p_rotate: float = 0.2 # Probability of rotating the matrix
p_hmirror: float = 0.2 # Probability of horizontally mirroring the matrix
p_vmirror: float = 0.2 # Probability of vertically mirroring the matrix
p_dmirror: float = 0.2 # Probability of mirroring along the diagonal
p_cmirror: float = 0.2 # Probability of mirroring along the counterdiagonal
p_map: float = 0.2 # Probability of mapping a certain value to another
p_crop: float = 0.2 # Probability of cropping the matrix
p_remove_every_nth_row: float = 0.2 # Probability of removing every nth row
p_remove_every_nth_col: float = 0.2 # Probability of removing every nth column
p_zero_divisible: float = 0.2 # Probability of setting elements divisible by some number to zero
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.max_rows, "max_rows must be at least 1"
assert 1 <= self.max_cols, "max_cols must be at least 1"
assert 0 <= self.p_rotate <= 1, "p_rotate must be between 0 and 1"
assert 0 <= self.p_hmirror <= 1, "p_hmirror must be between 0 and 1"
assert 0 <= self.p_vmirror <= 1, "p_vmirror must be between 0 and 1"
assert 0 <= self.p_dmirror <= 1, "p_dmirror must be between 0 and 1"
assert 0 <= self.p_cmirror <= 1, "p_cmirror must be between 0 and 1"
assert 0 <= self.p_map <= 1, "p_map must be between 0 and 1"
assert 0 <= self.p_crop <= 1, "p_crop must be between 0 and 1"
assert 0 <= self.p_remove_every_nth_row <= 1, "p_remove_every_nth_row must be between 0 and 1"
assert 0 <= self.p_remove_every_nth_col <= 1, "p_remove_nth_col must be between 0 and 1"
assert 0 <= self.p_zero_divisible <= 1, "p_zero_divisible must be between 0 and 1"
class ManipulateMatrixDataset(ProceduralDataset):
"""Generates Manipulate Matrix exercises with configurable difficulty"""
def __init__(self, config: ManipulateMatrixConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self._rotations = {
"90": self._rot90,
"180": self._rot180,
"270": self._rot270,
"360": self._identity,
}
self._all_transforms = [
"rotate",
"hmirror",
"vmirror",
"dmirror",
"cmirror",
"map",
"zero_divisible",
"crop",
"remove_every_nth_row",
"remove_every_nth_col",
]
def _get_matrix(self, rng: Random) -> list[list[int]]:
"""Generate a random matrix"""
rows = rng.randint(1, self.config.max_rows)
cols = rng.randint(1, self.config.max_cols)
numbers = [rng.randint(0, 9) for _ in range(rows * cols)]
matrix = [numbers[i * cols : (i + 1) * cols] for i in range(rows)]
return matrix
def _matrix_to_str(self, matrix: list[list[int]]) -> str:
"""Get a string representation of the matrix"""
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
def _identity(self, matrix: list[list[int]]) -> list[list[int]]:
"""Identity transformation"""
return matrix
def _rot90(self, matrix: list[list[int]]) -> list[list[int]]:
"""quarter clockwise rotation"""
return [list(row) for row in zip(*matrix[::-1])]
def _rot180(self, matrix: list[list[int]]) -> list[list[int]]:
"""half rotation"""
return [list(row[::-1]) for row in matrix[::-1]]
def _rot270(self, matrix: list[list[int]]) -> list[list[int]]:
"""quarter anticlockwise rotation"""
return [list(row[::-1]) for row in zip(*matrix[::-1])][::-1]
def _hmirror(self, matrix: list[list[int]]) -> list[list[int]]:
"""mirroring along horizontal"""
return matrix[::-1]
def _vmirror(self, matrix: list[list[int]]) -> list[list[int]]:
"""mirroring along vertical"""
return [row[::-1] for row in matrix]
def _dmirror(self, matrix: list[list[int]]) -> list[list[int]]:
"""mirroring along diagonal"""
return list(list(row) for row in zip(*matrix))
def _cmirror(self, matrix: list[list[int]]) -> list[list[int]]:
"""mirroring along counterdiagonal"""
return list(list(row) for row in zip(*[r[::-1] for r in matrix[::-1]]))
def _map(self, matrix: list[list[int]], a: int, b: int) -> list[list[int]]:
"""mapping a to b"""
return [[b if x == a else x for x in row] for row in matrix]
def _zero_divisible(self, matrix: list[list[int]], k: int) -> list[list[int]]:
"""set elements divisible by k to zero"""
return [[0 if x % k == 0 else x for x in row] for row in matrix]
def _crop(
self, matrix: list[list[int]], row_start: int, row_end: int, col_start: int, col_end: int
) -> list[list[int]]:
"""crop the matrix (1-indexed)"""
return [row[col_start - 1 : col_end] for row in matrix[row_start - 1 : row_end]]
def _remove_every_nth_row(self, matrix: list[list[int]], n: int) -> list[list[int]]:
"""remove every nth row (1-indexed)"""
return [row for i, row in enumerate(matrix, start=1) if i % n != 0]
def _remove_every_nth_col(self, matrix: list[list[int]], n: int) -> list[list[int]]:
"""remove every nth column (1-indexed)"""
return [[col for i, col in enumerate(row, start=1) if i % n != 0] for row in matrix]
def __getitem__(self, idx: int) -> dict:
"""Generate a single Manipulate Matrix question"""
rng = Random(self.seed + idx)
matrix = self._get_matrix(rng)
matrix_str = self._matrix_to_str(matrix)
# Shuffle the order of operations (make sure to copy the list to guarantee same order)
all_transforms = deepcopy(self._all_transforms)
rng.shuffle(all_transforms)
operations = []
answer = deepcopy(matrix)
for transform in all_transforms:
# Rotate
if transform == "rotate" and rng.random() < self.config.p_rotate:
rotation = rng.choice(list(self._rotations.keys()))
answer = self._rotations[rotation](answer)
operations.append(
{
"transform": transform,
"degrees": rotation,
"instruction": f"- Rotate the matrix {rotation} degrees",
}
)
# Horizontal mirror
if transform == "hmirror" and rng.random() < self.config.p_hmirror:
answer = self._hmirror(answer)
operations.append({"transform": transform, "instruction": "- Horizontally mirror the matrix"})
# Vertical mirror
if transform == "vmirror" and rng.random() < self.config.p_vmirror:
answer = self._vmirror(answer)
operations.append({"transform": transform, "instruction": "- Vertically mirror the matrix"})
# Diagonal mirror
if transform == "dmirror" and rng.random() < self.config.p_dmirror:
answer = self._dmirror(answer)
operations.append({"transform": transform, "instruction": "- Mirror the matrix along the diagonal"})
# Counterdiagonal mirror
if transform == "cmirror" and rng.random() < self.config.p_cmirror:
answer = self._cmirror(answer)
operations.append(
{"transform": transform, "instruction": "- Mirror the matrix along the counterdiagonal"}
)
# Map a value to another
if transform == "map" and rng.random() < self.config.p_map:
a, b = rng.sample(range(10), 2)
answer = self._map(answer, a, b)
operations.append(
{"transform": transform, "from": a, "to": b, "instruction": f"- Map each occurrence of {a} to {b}"}
)
# Set elements divisible by k to zero
if transform == "zero_divisible" and rng.random() < self.config.p_zero_divisible:
k = rng.randint(1, 9)
answer = self._zero_divisible(answer, k)
operations.append(
{"transform": transform, "k": k, "instruction": f"- Set all elements divisible by {k} to zero"}
)
# Crop the matrix
if transform == "crop" and rng.random() < self.config.p_crop:
row_start = rng.randint(1, num_rows(answer))
row_end = rng.randint(row_start, num_rows(answer))
col_start = rng.randint(1, num_cols(answer))
col_end = rng.randint(col_start, num_cols(answer))
answer = self._crop(answer, row_start, row_end, col_start, col_end)
operations.append(
{
"transform": transform,
"row_start": row_start,
"row_end": row_end,
"col_start": col_start,
"col_end": col_end,
"instruction": f"- Crop the matrix to rows {row_start}-{row_end} and columns {col_start}-{col_end} (1-indexed)",
}
)
# Remove every nth row
if (
transform == "remove_every_nth_row"
and rng.random() < self.config.p_remove_every_nth_row
and num_rows(answer) > 1
):
n = rng.randint(2, num_rows(answer))
answer = self._remove_every_nth_row(answer, n)
formatting = "st" if n == 1 else "nd" if n == 2 else "th"
operations.append(
{"transform": transform, "n": n, "instruction": f"- Remove every {n}-{formatting} row (1-indexed)"}
)
# Remove every nth column
if (
transform == "remove_every_nth_col"
and rng.random() < self.config.p_remove_every_nth_col
and num_cols(answer) > 1
):
n = rng.randint(2, num_cols(answer))
answer = self._remove_every_nth_col(answer, n)
formatting = "st" if n == 1 else "nd" if n == 2 else "th"
operations.append(
{
"transform": transform,
"n": n,
"instruction": f"- Remove every {n}-{formatting} column (1-indexed)",
}
)
answer_str = self._matrix_to_str(answer)
return {
"question": QUESTION_TEMPLATE.format(
matrix=matrix_str, operations="\n".join(op["instruction"] for op in operations)
),
"answer": answer_str,
"metadata": {"matrix": matrix, "solution": answer, "operations": operations},
}
register_dataset("manipulate_matrix", ManipulateMatrixDataset, ManipulateMatrixConfig)

View file

@ -0,0 +1,210 @@
"""Tests for Manipulate Matrix questions generation"""
import pytest
from reasoning_gym.algorithmic.manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset
def test_manipulate_matrix_config_validation():
"""Test that invalid configs raise appropriate errors"""
invalid_dims = [-1, 0] # Dimensions should be positive integers
dim_fields = ["max_rows", "max_cols"]
for field in dim_fields:
for dim in invalid_dims:
with pytest.raises(AssertionError):
config = ManipulateMatrixConfig(**{field: dim})
config.validate()
invalid_probabilities = [-0.01, 1.01] # Probabilities should be between 0 and 1 inclusive
probability_fields = [
"p_hmirror",
"p_vmirror",
"p_dmirror",
"p_cmirror",
"p_map",
"p_crop",
"p_remove_every_nth_row",
"p_remove_every_nth_col",
"p_zero_divisible",
]
for field in probability_fields:
for prob in invalid_probabilities:
with pytest.raises(AssertionError):
config = ManipulateMatrixConfig(**{field: prob})
config.validate()
def test_manipulate_matrix_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = ManipulateMatrixConfig(seed=42, size=10)
dataset1 = ManipulateMatrixDataset(config)
dataset2 = ManipulateMatrixDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_manipulate_matrix_dataset_items():
"""Test basic properties of generated items"""
config = ManipulateMatrixConfig(max_rows=7, max_cols=7, size=10, seed=42)
dataset = ManipulateMatrixDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "matrix" in item["metadata"]
assert "solution" in item["metadata"]
assert "operations" in item["metadata"]
matrix = item["metadata"]["matrix"]
solution = item["metadata"]["solution"]
operations = item["metadata"]["operations"]
# Verify matrix dimensions
assert len(matrix) <= config.max_rows
assert all(len(row) <= config.max_cols for row in matrix)
assert len(solution) <= config.max_rows
assert all(len(row) <= config.max_cols for row in solution)
for op in operations:
assert "transform" in op
assert "instruction" in op
def test_manipulate_matrix_dataset_iteration():
"""Test that iteration respects dataset size"""
config = ManipulateMatrixConfig(size=5, seed=42)
dataset = ManipulateMatrixDataset(config)
items = list(dataset)
assert len(items) == config.size
assert items == list(dataset)
def test_manipulate_matrix_transforms():
"""Test the _get_rotated method"""
config = ManipulateMatrixConfig(seed=42)
dataset = ManipulateMatrixDataset(config)
matrix = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20],
[21, 22, 23, 24, 25],
]
# identity
assert dataset._identity(matrix) == matrix
# rot 90 degrees
assert dataset._rot90(matrix) == [
[21, 16, 11, 6, 1],
[22, 17, 12, 7, 2],
[23, 18, 13, 8, 3],
[24, 19, 14, 9, 4],
[25, 20, 15, 10, 5],
]
# rot 180 degrees
assert dataset._rot180(matrix) == [
[25, 24, 23, 22, 21],
[20, 19, 18, 17, 16],
[15, 14, 13, 12, 11],
[10, 9, 8, 7, 6],
[5, 4, 3, 2, 1],
]
# rot 270 degrees
assert dataset._rot270(matrix) == [
[5, 10, 15, 20, 25],
[4, 9, 14, 19, 24],
[3, 8, 13, 18, 23],
[2, 7, 12, 17, 22],
[1, 6, 11, 16, 21],
]
# hmirror
assert dataset._hmirror(matrix) == [
[21, 22, 23, 24, 25],
[16, 17, 18, 19, 20],
[11, 12, 13, 14, 15],
[6, 7, 8, 9, 10],
[1, 2, 3, 4, 5],
]
# vmirror
assert dataset._vmirror(matrix) == [
[5, 4, 3, 2, 1],
[10, 9, 8, 7, 6],
[15, 14, 13, 12, 11],
[20, 19, 18, 17, 16],
[25, 24, 23, 22, 21],
]
# dmirror
assert dataset._dmirror(matrix) == [
[1, 6, 11, 16, 21],
[2, 7, 12, 17, 22],
[3, 8, 13, 18, 23],
[4, 9, 14, 19, 24],
[5, 10, 15, 20, 25],
]
# cmirror
assert dataset._cmirror(matrix) == [
[25, 20, 15, 10, 5],
[24, 19, 14, 9, 4],
[23, 18, 13, 8, 3],
[22, 17, 12, 7, 2],
[21, 16, 11, 6, 1],
]
# map
assert dataset._map(matrix, a=13, b=0) == [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 0, 14, 15], # 13 -> 0
[16, 17, 18, 19, 20],
[21, 22, 23, 24, 25],
]
# zero divisible
assert dataset._zero_divisible(matrix, k=3) == [
[1, 2, 0, 4, 5],
[0, 7, 8, 0, 10],
[11, 0, 13, 14, 0],
[16, 17, 0, 19, 20],
[0, 22, 23, 0, 25],
]
# crop
assert dataset._crop(matrix, row_start=2, row_end=4, col_start=1, col_end=3) == [
[6, 7, 8],
[11, 12, 13],
[16, 17, 18],
]
# remove every nth row
assert dataset._remove_every_nth_row(matrix, n=2) == [
[1, 2, 3, 4, 5],
[11, 12, 13, 14, 15],
[21, 22, 23, 24, 25],
]
# remove every nth col
assert dataset._remove_every_nth_col(matrix, n=2) == [
[1, 3, 5],
[6, 8, 10],
[11, 13, 15],
[16, 18, 20],
[21, 23, 25],
]