Merge pull request #87 from zafstojano/env/rotate-matrix

Rotate Matrix k times
This commit is contained in:
Andreas Köpf 2025-02-08 17:46:58 +01:00 committed by GitHub
commit 307a031146
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 250 additions and 0 deletions

View file

@ -16,6 +16,7 @@ from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
from .number_sorting import NumberSortingConfig, NumberSortingDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset
from .palindrome_generation import PalindromeConfig, PalindromeDataset from .palindrome_generation import PalindromeConfig, PalindromeDataset
from .ransom_note import RansomNoteConfig, RansomNoteDataset from .ransom_note import RansomNoteConfig, RansomNoteDataset
from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .word_ladder import WordLadderConfig, WordLadderDataset from .word_ladder import WordLadderConfig, WordLadderDataset
@ -54,4 +55,6 @@ __all__ = [
"RansomNoteDataset", "RansomNoteDataset",
"IsomorphicStringsConfig", "IsomorphicStringsConfig",
"IsomorphicStringsDataset", "IsomorphicStringsDataset",
"RotateMatrixConfig",
"RotateMatrixDataset",
] ]

View file

@ -0,0 +1,103 @@
"""Rotate a square matrix clockwise.
A popular Leetcode problem:
https://leetcode.com/problems/rotate-image/description/
"""
from copy import deepcopy
from dataclasses import dataclass
from random import Random
from typing import Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it clockwise.
Example:
Input: Rotate the matrix below by 90 degrees clockwise:
1 2 3
4 5 6
7 8 9
Output:
7 4 1
8 5 2
9 6 3
Rotate the matrix below by {degrees} degrees clockwise:
{matrix}
"""
@dataclass
class RotateMatrixConfig:
"""Configuration for Rotate Matrix dataset generation"""
max_n: int = 10 # Maximum size of the matrix
max_rotations: int = 4 # Maximum number of rotations (90 degrees each)
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.max_n, "max_n must be at least 1"
assert 0 <= self.max_rotations, "max_rotations must be at least 0"
class RotateMatrixDataset(ProceduralDataset):
"""Generates Rotate Matrix exercises with configurable difficulty"""
def __init__(self, config: RotateMatrixConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _get_matrix(self, rng: Random) -> list[list[int]]:
"""Generate a random matrix"""
n = rng.randint(1, self.config.max_n)
numbers = list(range(n**2))
rng.shuffle(numbers)
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
return matrix
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
"""Rotate the matrix K times by 90 degrees clockwise"""
num_rotations %= 4
n = len(matrix)
output = deepcopy(matrix)
for _ in range(num_rotations):
for l in range(n // 2):
for i in range(l, n - 1 - l):
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
output[n - 1 - i][l],
output[l][i],
output[i][n - 1 - l],
output[n - 1 - l][n - 1 - i],
)
return output
def _matrix_to_str(self, matrix: list[list[int]]) -> str:
"""Get a string representation of the matrix"""
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
def __getitem__(self, idx: int) -> dict:
"""Generate a single Spiral Matrix question"""
rng = Random(self.seed + idx)
matrix = self._get_matrix(rng)
num_rotations = rng.randint(0, self.config.max_rotations)
matrix_str = self._matrix_to_str(matrix)
answer = self._get_rotated(matrix, num_rotations)
answer_str = self._matrix_to_str(answer)
return {
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90),
"answer": answer_str,
"metadata": {"matrix": matrix, "num_rotations": num_rotations, "solution": answer},
}
register_dataset("rotate_matrix", RotateMatrixDataset, RotateMatrixConfig)

144
tests/test_rotate_matrix.py Normal file
View file

@ -0,0 +1,144 @@
"""Tests for Rotate Matrix questions generation"""
import pytest
from reasoning_gym.algorithmic.rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
def test_rotate_matrix_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = RotateMatrixConfig(max_n=-1) # Negative not allowed
config.validate()
with pytest.raises(AssertionError):
config = RotateMatrixConfig(max_n=0) # Zero not allowed
config.validate()
with pytest.raises(AssertionError):
config = RotateMatrixConfig(max_rotations=-1) # Negative not allowed
config.validate()
def test_rotate_matrix_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = RotateMatrixConfig(seed=42, size=10)
dataset1 = RotateMatrixDataset(config)
dataset2 = RotateMatrixDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_rotate_matrix_dataset_items():
"""Test basic properties of generated items"""
config = RotateMatrixConfig(max_n=7, size=10, seed=42)
dataset = RotateMatrixDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "matrix" in item["metadata"]
assert "solution" in item["metadata"]
matrix = item["metadata"]["matrix"]
solution = item["metadata"]["solution"]
num_rotations = item["metadata"]["num_rotations"]
# Verify matrix dimensions
assert len(matrix) <= config.max_n
assert all(len(row) <= config.max_n for row in matrix)
assert len(solution) <= config.max_n
assert all(len(row) <= config.max_n for row in solution)
assert set(e for row in matrix for e in row) == set(e for row in solution for e in row)
assert num_rotations <= config.max_rotations
def test_rotate_matrix_dataset_iteration():
"""Test that iteration respects dataset size"""
config = RotateMatrixConfig(size=5, seed=42)
dataset = RotateMatrixDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
def test_rotate_matrix_answer():
"""Test the _get_rotated method"""
config = RotateMatrixConfig(seed=42)
dataset = RotateMatrixDataset(config)
# n = 1, num_rotations = 1
matrix = [[8]]
expected = [[8]]
assert dataset._get_rotated(matrix, num_rotations=1) == expected
# n = 3, num_rotations = 0 (no rotation)
matrix = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
expected = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
assert dataset._get_rotated(matrix, num_rotations=0) == expected
# n = 3, num_rotations = 1 (90 degrees clockwise)
matrix = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
expected = [
[6, 3, 0],
[7, 4, 1],
[8, 5, 2],
]
assert dataset._get_rotated(matrix, num_rotations=1) == expected
# n = 3, num_rotations = 2 (180 degrees clockwise)
matrix = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
expected = [
[8, 7, 6],
[5, 4, 3],
[2, 1, 0],
]
assert dataset._get_rotated(matrix, num_rotations=2) == expected
# n = 3, num_rotations = 3 (270 degrees clockwise)
matrix = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
assert dataset._get_rotated(matrix, num_rotations=3) == expected
# n = 4, num_rotations = 4 (360 degrees clockwise == 0 degrees)
matrix = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
expected = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
assert dataset._get_rotated(matrix, num_rotations=4) == expected