mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
matrix manipulation
This commit is contained in:
parent
ed37eae559
commit
3d66cc6a7f
3 changed files with 481 additions and 0 deletions
|
|
@ -12,6 +12,7 @@ from .group_anagrams import GroupAnagramsConfig, GroupAnagramsDataset
|
|||
from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsDataset
|
||||
from .letter_counting import LetterCountingConfig, LetterCountingDataset
|
||||
from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset
|
||||
from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset
|
||||
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||
from .number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||
from .palindrome_generation import PalindromeConfig, PalindromeDataset
|
||||
|
|
@ -60,4 +61,6 @@ __all__ = [
|
|||
"IsomorphicStringsDataset",
|
||||
"RotateMatrixConfig",
|
||||
"RotateMatrixDataset",
|
||||
"ManipulateMatrixConfig",
|
||||
"ManipulateMatrixDataset",
|
||||
]
|
||||
|
|
|
|||
268
reasoning_gym/algorithmic/manipulate_matrix.py
Normal file
268
reasoning_gym/algorithmic/manipulate_matrix.py
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
"""Manipulate matrices by performing augmentations such as rotations, flips, mapping, etc."""
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """For the following matrix:
|
||||
{matrix}
|
||||
|
||||
Perform the following series of operations in order:
|
||||
- Identity transformation, i.e. no change
|
||||
{operations}
|
||||
"""
|
||||
|
||||
|
||||
def num_rows(matrix: list[list[int]]) -> int:
|
||||
return len(matrix)
|
||||
|
||||
|
||||
def num_cols(matrix: list[list[int]]) -> int:
|
||||
return len(matrix[0]) if matrix else 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ManipulateMatrixConfig:
|
||||
"""Configuration for Manipulate Matrix dataset generation"""
|
||||
|
||||
max_rows: int = 10 # Maximum number of rows
|
||||
max_cols: int = 10 # Maximum number of columns
|
||||
p_rotate: float = 0.2 # Probability of rotating the matrix
|
||||
p_hmirror: float = 0.2 # Probability of horizontally mirroring the matrix
|
||||
p_vmirror: float = 0.2 # Probability of vertically mirroring the matrix
|
||||
p_dmirror: float = 0.2 # Probability of mirroring along the diagonal
|
||||
p_cmirror: float = 0.2 # Probability of mirroring along the counterdiagonal
|
||||
p_map: float = 0.2 # Probability of mapping a certain value to another
|
||||
p_crop: float = 0.2 # Probability of cropping the matrix
|
||||
p_remove_every_nth_row: float = 0.2 # Probability of removing every nth row
|
||||
p_remove_every_nth_col: float = 0.2 # Probability of removing every nth column
|
||||
p_zero_divisible: float = 0.2 # Probability of setting elements divisible by some number to zero
|
||||
|
||||
size: int = 500 # Virtual dataset size
|
||||
seed: Optional[int] = None
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 1 <= self.max_rows, "max_rows must be at least 1"
|
||||
assert 1 <= self.max_cols, "max_cols must be at least 1"
|
||||
assert 0 <= self.p_rotate <= 1, "p_rotate must be between 0 and 1"
|
||||
assert 0 <= self.p_hmirror <= 1, "p_hmirror must be between 0 and 1"
|
||||
assert 0 <= self.p_vmirror <= 1, "p_vmirror must be between 0 and 1"
|
||||
assert 0 <= self.p_dmirror <= 1, "p_dmirror must be between 0 and 1"
|
||||
assert 0 <= self.p_cmirror <= 1, "p_cmirror must be between 0 and 1"
|
||||
assert 0 <= self.p_map <= 1, "p_map must be between 0 and 1"
|
||||
assert 0 <= self.p_crop <= 1, "p_crop must be between 0 and 1"
|
||||
assert 0 <= self.p_remove_every_nth_row <= 1, "p_remove_every_nth_row must be between 0 and 1"
|
||||
assert 0 <= self.p_remove_every_nth_col <= 1, "p_remove_nth_col must be between 0 and 1"
|
||||
assert 0 <= self.p_zero_divisible <= 1, "p_zero_divisible must be between 0 and 1"
|
||||
|
||||
|
||||
class ManipulateMatrixDataset(ProceduralDataset):
|
||||
"""Generates Manipulate Matrix exercises with configurable difficulty"""
|
||||
|
||||
def __init__(self, config: ManipulateMatrixConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
self._rotations = {
|
||||
"90": self._rot90,
|
||||
"180": self._rot180,
|
||||
"270": self._rot270,
|
||||
"360": self._identity,
|
||||
}
|
||||
self._all_transforms = [
|
||||
"rotate",
|
||||
"hmirror",
|
||||
"vmirror",
|
||||
"dmirror",
|
||||
"cmirror",
|
||||
"map",
|
||||
"zero_divisible",
|
||||
"crop",
|
||||
"remove_every_nth_row",
|
||||
"remove_every_nth_col",
|
||||
]
|
||||
|
||||
def _get_matrix(self, rng: Random) -> list[list[int]]:
|
||||
"""Generate a random matrix"""
|
||||
rows = rng.randint(1, self.config.max_rows)
|
||||
cols = rng.randint(1, self.config.max_cols)
|
||||
numbers = [rng.randint(0, 9) for _ in range(rows * cols)]
|
||||
matrix = [numbers[i * cols : (i + 1) * cols] for i in range(rows)]
|
||||
return matrix
|
||||
|
||||
def _matrix_to_str(self, matrix: list[list[int]]) -> str:
|
||||
"""Get a string representation of the matrix"""
|
||||
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
|
||||
|
||||
def _identity(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
"""Identity transformation"""
|
||||
return matrix
|
||||
|
||||
def _rot90(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
"""quarter clockwise rotation"""
|
||||
return [list(row) for row in zip(*matrix[::-1])]
|
||||
|
||||
def _rot180(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
"""half rotation"""
|
||||
return [list(row[::-1]) for row in matrix[::-1]]
|
||||
|
||||
def _rot270(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
"""quarter anticlockwise rotation"""
|
||||
return [list(row[::-1]) for row in zip(*matrix[::-1])][::-1]
|
||||
|
||||
def _hmirror(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
"""mirroring along horizontal"""
|
||||
return matrix[::-1]
|
||||
|
||||
def _vmirror(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
"""mirroring along vertical"""
|
||||
return [row[::-1] for row in matrix]
|
||||
|
||||
def _dmirror(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
"""mirroring along diagonal"""
|
||||
return list(list(row) for row in zip(*matrix))
|
||||
|
||||
def _cmirror(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
"""mirroring along counterdiagonal"""
|
||||
return list(list(row) for row in zip(*[r[::-1] for r in matrix[::-1]]))
|
||||
|
||||
def _map(self, matrix: list[list[int]], a: int, b: int) -> list[list[int]]:
|
||||
"""mapping a to b"""
|
||||
return [[b if x == a else x for x in row] for row in matrix]
|
||||
|
||||
def _zero_divisible(self, matrix: list[list[int]], k: int) -> list[list[int]]:
|
||||
"""set elements divisible by k to zero"""
|
||||
return [[0 if x % k == 0 else x for x in row] for row in matrix]
|
||||
|
||||
def _crop(
|
||||
self, matrix: list[list[int]], row_start: int, row_end: int, col_start: int, col_end: int
|
||||
) -> list[list[int]]:
|
||||
"""crop the matrix (1-indexed)"""
|
||||
return [row[col_start - 1 : col_end] for row in matrix[row_start - 1 : row_end]]
|
||||
|
||||
def _remove_every_nth_row(self, matrix: list[list[int]], n: int) -> list[list[int]]:
|
||||
"""remove every nth row (1-indexed)"""
|
||||
return [row for i, row in enumerate(matrix, start=1) if i % n != 0]
|
||||
|
||||
def _remove_every_nth_col(self, matrix: list[list[int]], n: int) -> list[list[int]]:
|
||||
"""remove every nth column (1-indexed)"""
|
||||
return [[col for i, col in enumerate(row, start=1) if i % n != 0] for row in matrix]
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Manipulate Matrix question"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
matrix = self._get_matrix(rng)
|
||||
matrix_str = self._matrix_to_str(matrix)
|
||||
|
||||
# Shuffle the order of operations (make sure to copy the list to guarantee same order)
|
||||
all_transforms = deepcopy(self._all_transforms)
|
||||
rng.shuffle(all_transforms)
|
||||
operations = []
|
||||
|
||||
answer = deepcopy(matrix)
|
||||
|
||||
for transform in all_transforms:
|
||||
# Rotate
|
||||
if transform == "rotate" and rng.random() < self.config.p_rotate:
|
||||
rotation = rng.choice(list(self._rotations.keys()))
|
||||
answer = self._rotations[rotation](answer)
|
||||
operations.append(
|
||||
{
|
||||
"transform": transform,
|
||||
"degrees": rotation,
|
||||
"instruction": f"- Rotate the matrix {rotation} degrees",
|
||||
}
|
||||
)
|
||||
# Horizontal mirror
|
||||
if transform == "hmirror" and rng.random() < self.config.p_hmirror:
|
||||
answer = self._hmirror(answer)
|
||||
operations.append({"transform": transform, "instruction": "- Horizontally mirror the matrix"})
|
||||
# Vertical mirror
|
||||
if transform == "vmirror" and rng.random() < self.config.p_vmirror:
|
||||
answer = self._vmirror(answer)
|
||||
operations.append({"transform": transform, "instruction": "- Vertically mirror the matrix"})
|
||||
# Diagonal mirror
|
||||
if transform == "dmirror" and rng.random() < self.config.p_dmirror:
|
||||
answer = self._dmirror(answer)
|
||||
operations.append({"transform": transform, "instruction": "- Mirror the matrix along the diagonal"})
|
||||
# Counterdiagonal mirror
|
||||
if transform == "cmirror" and rng.random() < self.config.p_cmirror:
|
||||
answer = self._cmirror(answer)
|
||||
operations.append(
|
||||
{"transform": transform, "instruction": "- Mirror the matrix along the counterdiagonal"}
|
||||
)
|
||||
# Map a value to another
|
||||
if transform == "map" and rng.random() < self.config.p_map:
|
||||
a, b = rng.sample(range(10), 2)
|
||||
answer = self._map(answer, a, b)
|
||||
operations.append(
|
||||
{"transform": transform, "from": a, "to": b, "instruction": f"- Map each occurrence of {a} to {b}"}
|
||||
)
|
||||
# Set elements divisible by k to zero
|
||||
if transform == "zero_divisible" and rng.random() < self.config.p_zero_divisible:
|
||||
k = rng.randint(1, 9)
|
||||
answer = self._zero_divisible(answer, k)
|
||||
operations.append(
|
||||
{"transform": transform, "k": k, "instruction": f"- Set all elements divisible by {k} to zero"}
|
||||
)
|
||||
# Crop the matrix
|
||||
if transform == "crop" and rng.random() < self.config.p_crop:
|
||||
row_start = rng.randint(1, num_rows(answer))
|
||||
row_end = rng.randint(row_start, num_rows(answer))
|
||||
col_start = rng.randint(1, num_cols(answer))
|
||||
col_end = rng.randint(col_start, num_cols(answer))
|
||||
answer = self._crop(answer, row_start, row_end, col_start, col_end)
|
||||
operations.append(
|
||||
{
|
||||
"transform": transform,
|
||||
"row_start": row_start,
|
||||
"row_end": row_end,
|
||||
"col_start": col_start,
|
||||
"col_end": col_end,
|
||||
"instruction": f"- Crop the matrix to rows {row_start}-{row_end} and columns {col_start}-{col_end} (1-indexed)",
|
||||
}
|
||||
)
|
||||
# Remove every nth row
|
||||
if (
|
||||
transform == "remove_every_nth_row"
|
||||
and rng.random() < self.config.p_remove_every_nth_row
|
||||
and num_rows(answer) > 1
|
||||
):
|
||||
n = rng.randint(2, num_rows(answer))
|
||||
answer = self._remove_every_nth_row(answer, n)
|
||||
formatting = "st" if n == 1 else "nd" if n == 2 else "th"
|
||||
operations.append(
|
||||
{"transform": transform, "n": n, "instruction": f"- Remove every {n}-{formatting} row (1-indexed)"}
|
||||
)
|
||||
# Remove every nth column
|
||||
if (
|
||||
transform == "remove_every_nth_col"
|
||||
and rng.random() < self.config.p_remove_every_nth_col
|
||||
and num_cols(answer) > 1
|
||||
):
|
||||
n = rng.randint(2, num_cols(answer))
|
||||
answer = self._remove_every_nth_col(answer, n)
|
||||
formatting = "st" if n == 1 else "nd" if n == 2 else "th"
|
||||
operations.append(
|
||||
{
|
||||
"transform": transform,
|
||||
"n": n,
|
||||
"instruction": f"- Remove every {n}-{formatting} column (1-indexed)",
|
||||
}
|
||||
)
|
||||
|
||||
answer_str = self._matrix_to_str(answer)
|
||||
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(
|
||||
matrix=matrix_str, operations="\n".join(op["instruction"] for op in operations)
|
||||
),
|
||||
"answer": answer_str,
|
||||
"metadata": {"matrix": matrix, "solution": answer, "operations": operations},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("manipulate_matrix", ManipulateMatrixDataset, ManipulateMatrixConfig)
|
||||
210
tests/test_manipulate_matrix.py
Normal file
210
tests/test_manipulate_matrix.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""Tests for Manipulate Matrix questions generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset
|
||||
|
||||
|
||||
def test_manipulate_matrix_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
|
||||
invalid_dims = [-1, 0] # Dimensions should be positive integers
|
||||
dim_fields = ["max_rows", "max_cols"]
|
||||
|
||||
for field in dim_fields:
|
||||
for dim in invalid_dims:
|
||||
with pytest.raises(AssertionError):
|
||||
config = ManipulateMatrixConfig(**{field: dim})
|
||||
config.validate()
|
||||
|
||||
invalid_probabilities = [-0.01, 1.01] # Probabilities should be between 0 and 1 inclusive
|
||||
probability_fields = [
|
||||
"p_hmirror",
|
||||
"p_vmirror",
|
||||
"p_dmirror",
|
||||
"p_cmirror",
|
||||
"p_map",
|
||||
"p_crop",
|
||||
"p_remove_every_nth_row",
|
||||
"p_remove_every_nth_col",
|
||||
"p_zero_divisible",
|
||||
]
|
||||
|
||||
for field in probability_fields:
|
||||
for prob in invalid_probabilities:
|
||||
with pytest.raises(AssertionError):
|
||||
config = ManipulateMatrixConfig(**{field: prob})
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_manipulate_matrix_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = ManipulateMatrixConfig(seed=42, size=10)
|
||||
dataset1 = ManipulateMatrixDataset(config)
|
||||
dataset2 = ManipulateMatrixDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_manipulate_matrix_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = ManipulateMatrixConfig(max_rows=7, max_cols=7, size=10, seed=42)
|
||||
dataset = ManipulateMatrixDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "matrix" in item["metadata"]
|
||||
assert "solution" in item["metadata"]
|
||||
assert "operations" in item["metadata"]
|
||||
|
||||
matrix = item["metadata"]["matrix"]
|
||||
solution = item["metadata"]["solution"]
|
||||
operations = item["metadata"]["operations"]
|
||||
|
||||
# Verify matrix dimensions
|
||||
assert len(matrix) <= config.max_rows
|
||||
assert all(len(row) <= config.max_cols for row in matrix)
|
||||
assert len(solution) <= config.max_rows
|
||||
assert all(len(row) <= config.max_cols for row in solution)
|
||||
for op in operations:
|
||||
assert "transform" in op
|
||||
assert "instruction" in op
|
||||
|
||||
|
||||
def test_manipulate_matrix_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = ManipulateMatrixConfig(size=5, seed=42)
|
||||
dataset = ManipulateMatrixDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_manipulate_matrix_transforms():
|
||||
"""Test the _get_rotated method"""
|
||||
config = ManipulateMatrixConfig(seed=42)
|
||||
dataset = ManipulateMatrixDataset(config)
|
||||
matrix = [
|
||||
[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10],
|
||||
[11, 12, 13, 14, 15],
|
||||
[16, 17, 18, 19, 20],
|
||||
[21, 22, 23, 24, 25],
|
||||
]
|
||||
|
||||
# identity
|
||||
assert dataset._identity(matrix) == matrix
|
||||
|
||||
# rot 90 degrees
|
||||
assert dataset._rot90(matrix) == [
|
||||
[21, 16, 11, 6, 1],
|
||||
[22, 17, 12, 7, 2],
|
||||
[23, 18, 13, 8, 3],
|
||||
[24, 19, 14, 9, 4],
|
||||
[25, 20, 15, 10, 5],
|
||||
]
|
||||
|
||||
# rot 180 degrees
|
||||
assert dataset._rot180(matrix) == [
|
||||
[25, 24, 23, 22, 21],
|
||||
[20, 19, 18, 17, 16],
|
||||
[15, 14, 13, 12, 11],
|
||||
[10, 9, 8, 7, 6],
|
||||
[5, 4, 3, 2, 1],
|
||||
]
|
||||
|
||||
# rot 270 degrees
|
||||
assert dataset._rot270(matrix) == [
|
||||
[5, 10, 15, 20, 25],
|
||||
[4, 9, 14, 19, 24],
|
||||
[3, 8, 13, 18, 23],
|
||||
[2, 7, 12, 17, 22],
|
||||
[1, 6, 11, 16, 21],
|
||||
]
|
||||
|
||||
# hmirror
|
||||
assert dataset._hmirror(matrix) == [
|
||||
[21, 22, 23, 24, 25],
|
||||
[16, 17, 18, 19, 20],
|
||||
[11, 12, 13, 14, 15],
|
||||
[6, 7, 8, 9, 10],
|
||||
[1, 2, 3, 4, 5],
|
||||
]
|
||||
|
||||
# vmirror
|
||||
assert dataset._vmirror(matrix) == [
|
||||
[5, 4, 3, 2, 1],
|
||||
[10, 9, 8, 7, 6],
|
||||
[15, 14, 13, 12, 11],
|
||||
[20, 19, 18, 17, 16],
|
||||
[25, 24, 23, 22, 21],
|
||||
]
|
||||
|
||||
# dmirror
|
||||
assert dataset._dmirror(matrix) == [
|
||||
[1, 6, 11, 16, 21],
|
||||
[2, 7, 12, 17, 22],
|
||||
[3, 8, 13, 18, 23],
|
||||
[4, 9, 14, 19, 24],
|
||||
[5, 10, 15, 20, 25],
|
||||
]
|
||||
|
||||
# cmirror
|
||||
assert dataset._cmirror(matrix) == [
|
||||
[25, 20, 15, 10, 5],
|
||||
[24, 19, 14, 9, 4],
|
||||
[23, 18, 13, 8, 3],
|
||||
[22, 17, 12, 7, 2],
|
||||
[21, 16, 11, 6, 1],
|
||||
]
|
||||
|
||||
# map
|
||||
assert dataset._map(matrix, a=13, b=0) == [
|
||||
[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10],
|
||||
[11, 12, 0, 14, 15], # 13 -> 0
|
||||
[16, 17, 18, 19, 20],
|
||||
[21, 22, 23, 24, 25],
|
||||
]
|
||||
|
||||
# zero divisible
|
||||
assert dataset._zero_divisible(matrix, k=3) == [
|
||||
[1, 2, 0, 4, 5],
|
||||
[0, 7, 8, 0, 10],
|
||||
[11, 0, 13, 14, 0],
|
||||
[16, 17, 0, 19, 20],
|
||||
[0, 22, 23, 0, 25],
|
||||
]
|
||||
|
||||
# crop
|
||||
assert dataset._crop(matrix, row_start=2, row_end=4, col_start=1, col_end=3) == [
|
||||
[6, 7, 8],
|
||||
[11, 12, 13],
|
||||
[16, 17, 18],
|
||||
]
|
||||
|
||||
# remove every nth row
|
||||
assert dataset._remove_every_nth_row(matrix, n=2) == [
|
||||
[1, 2, 3, 4, 5],
|
||||
[11, 12, 13, 14, 15],
|
||||
[21, 22, 23, 24, 25],
|
||||
]
|
||||
|
||||
# remove every nth col
|
||||
assert dataset._remove_every_nth_col(matrix, n=2) == [
|
||||
[1, 3, 5],
|
||||
[6, 8, 10],
|
||||
[11, 13, 15],
|
||||
[16, 18, 20],
|
||||
[21, 23, 25],
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue