generlize to k rotations

This commit is contained in:
Zafir Stojanovski 2025-02-08 15:14:04 +01:00
parent 807dad12de
commit 42312fe786
3 changed files with 76 additions and 26 deletions

View file

@ -101,7 +101,7 @@ See the [Dataset Gallery](https://github.com/open-thought/reasoning-gym/blob/mai
- `WordLadderDataset`: Generate word ladder puzzles where one word is transformed into another by changing one letter at a time
- `GroupAnagramsDataset`: Group anagrams together in a list of words
- `IsomorphicStrings`: Check if two strings are isomorphic (have the same character mapping)
- `RotateMatrix`: Rotate a matrix by 90 degrees clockwise
- `RotateMatrix`: Rotate a matrix by X degrees clockwise
### <small>Code Tasks</small>

View file

@ -1,4 +1,4 @@
"""Rotate a square matrix by 90 degrees clockwise.
"""Rotate a square matrix clockwise.
A popular Leetcode problem:
https://leetcode.com/problems/rotate-image/description/
@ -11,11 +11,11 @@ from typing import Optional
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it by 90 degrees clockwise.
QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it clockwise.
Example:
Input:
Input: Rotate the matrix below by 90 degrees clockwise:
1 2 3
4 5 6
7 8 9
@ -25,7 +25,7 @@ Output:
8 5 2
9 6 3
Rotate the matrix below by 90 degrees clockwise:
Rotate the matrix below by {degrees} degrees clockwise:
{matrix}
"""
@ -35,6 +35,7 @@ class RotateMatrixConfig:
"""Configuration for Rotate Matrix dataset generation"""
max_n: int = 10 # Maximum size of the matrix
max_rotations: int = 4 # Maximum number of rotations (90 degrees each)
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
@ -42,6 +43,7 @@ class RotateMatrixConfig:
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.max_n, "max_n must be at least 1"
assert 0 <= self.max_rotations, "max_rotations must be at least 0"
class RotateMatrixDataset(ProceduralDataset):
@ -58,19 +60,21 @@ class RotateMatrixDataset(ProceduralDataset):
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
return matrix
def _get_rotated(self, matrix: list[list[int]]) -> list[list[int]]:
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
"""Rotate the matrix by 90 degrees clockwise"""
num_rotations %= 4
n = len(matrix)
output = deepcopy(matrix)
for l in range(n // 2):
for i in range(l, n - 1 - l):
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
matrix[n - 1 - i][l],
matrix[l][i],
matrix[i][n - 1 - l],
matrix[n - 1 - l][n - 1 - i],
)
for _ in range(num_rotations):
for l in range(n // 2):
for i in range(l, n - 1 - l):
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
output[n - 1 - i][l],
output[l][i],
output[i][n - 1 - l],
output[n - 1 - l][n - 1 - i],
)
return output
@ -83,14 +87,16 @@ class RotateMatrixDataset(ProceduralDataset):
rng = Random(self.seed + idx)
matrix = self._get_matrix(rng)
num_rotations = rng.randint(0, self.config.max_rotations)
matrix_str = self._matrix_to_str(matrix)
answer = self._get_rotated(matrix)
answer = self._get_rotated(matrix, num_rotations)
answer_str = self._matrix_to_str(answer)
return {
"question": QUESTION_TEMPLATE.format(matrix=matrix_str),
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90),
"answer": answer_str,
"metadata": {"matrix": matrix, "solution": answer},
"metadata": {"matrix": matrix, "num_rotations": num_rotations, "solution": answer},
}

View file

@ -15,6 +15,10 @@ def test_rotate_matrix_config_validation():
config = RotateMatrixConfig(max_n=0) # Zero not allowed
config.validate()
with pytest.raises(AssertionError):
config = RotateMatrixConfig(max_rotations=-1) # Negative not allowed
config.validate()
def test_rotate_matrix_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
@ -45,6 +49,7 @@ def test_rotate_matrix_dataset_items():
matrix = item["metadata"]["matrix"]
solution = item["metadata"]["solution"]
num_rotations = item["metadata"]["num_rotations"]
# Verify matrix dimensions
assert len(matrix) <= config.max_n
@ -52,6 +57,7 @@ def test_rotate_matrix_dataset_items():
assert len(solution) <= config.max_n
assert all(len(row) <= config.max_n for row in solution)
assert set(e for row in matrix for e in row) == set(e for row in solution for e in row)
assert num_rotations <= config.max_rotations
def test_rotate_matrix_dataset_iteration():
@ -71,23 +77,25 @@ def test_rotate_matrix_answer():
config = RotateMatrixConfig(seed=42)
dataset = RotateMatrixDataset(config)
# n = 1
# n = 1, num_rotations = 1
matrix = [[8]]
expected = [[8]]
assert dataset._get_rotated(matrix) == expected
assert dataset._get_rotated(matrix, num_rotations=1) == expected
# n = 2
# n = 3, num_rotations = 0 (no rotation)
matrix = [
[0, 1],
[2, 3],
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
expected = [
[2, 0],
[3, 1],
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
assert dataset._get_rotated(matrix) == expected
assert dataset._get_rotated(matrix, num_rotations=0) == expected
# n = 3
# n = 3, num_rotations = 1 (90 degrees clockwise)
matrix = [
[0, 1, 2],
[3, 4, 5],
@ -98,3 +106,39 @@ def test_rotate_matrix_answer():
[7, 4, 1],
[8, 5, 2],
]
assert dataset._get_rotated(matrix, num_rotations=1) == expected
# n = 3, num_rotations = 2 (180 degrees clockwise)
matrix = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
expected = [
[8, 7, 6],
[5, 4, 3],
[2, 1, 0],
]
assert dataset._get_rotated(matrix, num_rotations=2) == expected
# n = 3, num_rotations = 3 (270 degrees clockwise)
matrix = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
assert dataset._get_rotated(matrix, num_rotations=3) == expected
# n = 4, num_rotations = 4 (360 degrees clockwise == 0 degrees)
matrix = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
expected = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]
assert dataset._get_rotated(matrix, num_rotations=4) == expected