generlize to k rotations

This commit is contained in:
Zafir Stojanovski 2025-02-08 15:14:04 +01:00
parent 9dd5e85439
commit 7f0ddc4f84
3 changed files with 76 additions and 26 deletions

View file

@ -1,4 +1,4 @@
"""Rotate a square matrix by 90 degrees clockwise.
"""Rotate a square matrix clockwise.
A popular Leetcode problem:
https://leetcode.com/problems/rotate-image/description/
@ -11,11 +11,11 @@ from typing import Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it by 90 degrees clockwise.
QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it clockwise.
Example:
Input:
Input: Rotate the matrix below by 90 degrees clockwise:
1 2 3
4 5 6
7 8 9
@ -25,7 +25,7 @@ Output:
8 5 2
9 6 3
Rotate the matrix below by 90 degrees clockwise:
Rotate the matrix below by {degrees} degrees clockwise:
{matrix}
"""
@ -35,6 +35,7 @@ class RotateMatrixConfig:
"""Configuration for Rotate Matrix dataset generation"""
max_n: int = 10 # Maximum size of the matrix
max_rotations: int = 4 # Maximum number of rotations (90 degrees each)
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
@ -42,6 +43,7 @@ class RotateMatrixConfig:
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.max_n, "max_n must be at least 1"
assert 0 <= self.max_rotations, "max_rotations must be at least 0"
class RotateMatrixDataset(ProceduralDataset):
@ -58,19 +60,21 @@ class RotateMatrixDataset(ProceduralDataset):
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
return matrix
def _get_rotated(self, matrix: list[list[int]]) -> list[list[int]]:
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
"""Rotate the matrix by 90 degrees clockwise"""
num_rotations %= 4
n = len(matrix)
output = deepcopy(matrix)
for l in range(n // 2):
for i in range(l, n - 1 - l):
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
matrix[n - 1 - i][l],
matrix[l][i],
matrix[i][n - 1 - l],
matrix[n - 1 - l][n - 1 - i],
)
for _ in range(num_rotations):
for l in range(n // 2):
for i in range(l, n - 1 - l):
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
output[n - 1 - i][l],
output[l][i],
output[i][n - 1 - l],
output[n - 1 - l][n - 1 - i],
)
return output
@ -83,14 +87,16 @@ class RotateMatrixDataset(ProceduralDataset):
rng = Random(self.seed + idx)
matrix = self._get_matrix(rng)
num_rotations = rng.randint(0, self.config.max_rotations)
matrix_str = self._matrix_to_str(matrix)
answer = self._get_rotated(matrix)
answer = self._get_rotated(matrix, num_rotations)
answer_str = self._matrix_to_str(answer)
return {
"question": QUESTION_TEMPLATE.format(matrix=matrix_str),
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90),
"answer": answer_str,
"metadata": {"matrix": matrix, "solution": answer},
"metadata": {"matrix": matrix, "num_rotations": num_rotations, "solution": answer},
}