mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
generlize to k rotations
This commit is contained in:
parent
807dad12de
commit
42312fe786
3 changed files with 76 additions and 26 deletions
|
|
@ -101,7 +101,7 @@ See the [Dataset Gallery](https://github.com/open-thought/reasoning-gym/blob/mai
|
|||
- `WordLadderDataset`: Generate word ladder puzzles where one word is transformed into another by changing one letter at a time
|
||||
- `GroupAnagramsDataset`: Group anagrams together in a list of words
|
||||
- `IsomorphicStrings`: Check if two strings are isomorphic (have the same character mapping)
|
||||
- `RotateMatrix`: Rotate a matrix by 90 degrees clockwise
|
||||
- `RotateMatrix`: Rotate a matrix by X degrees clockwise
|
||||
|
||||
### <small>Code Tasks</small>
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Rotate a square matrix by 90 degrees clockwise.
|
||||
"""Rotate a square matrix clockwise.
|
||||
|
||||
A popular Leetcode problem:
|
||||
https://leetcode.com/problems/rotate-image/description/
|
||||
|
|
@ -11,11 +11,11 @@ from typing import Optional
|
|||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it by 90 degrees clockwise.
|
||||
QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it clockwise.
|
||||
|
||||
Example:
|
||||
|
||||
Input:
|
||||
Input: Rotate the matrix below by 90 degrees clockwise:
|
||||
1 2 3
|
||||
4 5 6
|
||||
7 8 9
|
||||
|
|
@ -25,7 +25,7 @@ Output:
|
|||
8 5 2
|
||||
9 6 3
|
||||
|
||||
Rotate the matrix below by 90 degrees clockwise:
|
||||
Rotate the matrix below by {degrees} degrees clockwise:
|
||||
{matrix}
|
||||
"""
|
||||
|
||||
|
|
@ -35,6 +35,7 @@ class RotateMatrixConfig:
|
|||
"""Configuration for Rotate Matrix dataset generation"""
|
||||
|
||||
max_n: int = 10 # Maximum size of the matrix
|
||||
max_rotations: int = 4 # Maximum number of rotations (90 degrees each)
|
||||
|
||||
size: int = 500 # Virtual dataset size
|
||||
seed: Optional[int] = None
|
||||
|
|
@ -42,6 +43,7 @@ class RotateMatrixConfig:
|
|||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 1 <= self.max_n, "max_n must be at least 1"
|
||||
assert 0 <= self.max_rotations, "max_rotations must be at least 0"
|
||||
|
||||
|
||||
class RotateMatrixDataset(ProceduralDataset):
|
||||
|
|
@ -58,19 +60,21 @@ class RotateMatrixDataset(ProceduralDataset):
|
|||
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
|
||||
return matrix
|
||||
|
||||
def _get_rotated(self, matrix: list[list[int]]) -> list[list[int]]:
|
||||
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
|
||||
"""Rotate the matrix by 90 degrees clockwise"""
|
||||
num_rotations %= 4
|
||||
n = len(matrix)
|
||||
output = deepcopy(matrix)
|
||||
|
||||
for l in range(n // 2):
|
||||
for i in range(l, n - 1 - l):
|
||||
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
|
||||
matrix[n - 1 - i][l],
|
||||
matrix[l][i],
|
||||
matrix[i][n - 1 - l],
|
||||
matrix[n - 1 - l][n - 1 - i],
|
||||
)
|
||||
for _ in range(num_rotations):
|
||||
for l in range(n // 2):
|
||||
for i in range(l, n - 1 - l):
|
||||
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
|
||||
output[n - 1 - i][l],
|
||||
output[l][i],
|
||||
output[i][n - 1 - l],
|
||||
output[n - 1 - l][n - 1 - i],
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
|
@ -83,14 +87,16 @@ class RotateMatrixDataset(ProceduralDataset):
|
|||
rng = Random(self.seed + idx)
|
||||
|
||||
matrix = self._get_matrix(rng)
|
||||
num_rotations = rng.randint(0, self.config.max_rotations)
|
||||
matrix_str = self._matrix_to_str(matrix)
|
||||
answer = self._get_rotated(matrix)
|
||||
|
||||
answer = self._get_rotated(matrix, num_rotations)
|
||||
answer_str = self._matrix_to_str(answer)
|
||||
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(matrix=matrix_str),
|
||||
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90),
|
||||
"answer": answer_str,
|
||||
"metadata": {"matrix": matrix, "solution": answer},
|
||||
"metadata": {"matrix": matrix, "num_rotations": num_rotations, "solution": answer},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,10 @@ def test_rotate_matrix_config_validation():
|
|||
config = RotateMatrixConfig(max_n=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RotateMatrixConfig(max_rotations=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_rotate_matrix_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
|
|
@ -45,6 +49,7 @@ def test_rotate_matrix_dataset_items():
|
|||
|
||||
matrix = item["metadata"]["matrix"]
|
||||
solution = item["metadata"]["solution"]
|
||||
num_rotations = item["metadata"]["num_rotations"]
|
||||
|
||||
# Verify matrix dimensions
|
||||
assert len(matrix) <= config.max_n
|
||||
|
|
@ -52,6 +57,7 @@ def test_rotate_matrix_dataset_items():
|
|||
assert len(solution) <= config.max_n
|
||||
assert all(len(row) <= config.max_n for row in solution)
|
||||
assert set(e for row in matrix for e in row) == set(e for row in solution for e in row)
|
||||
assert num_rotations <= config.max_rotations
|
||||
|
||||
|
||||
def test_rotate_matrix_dataset_iteration():
|
||||
|
|
@ -71,23 +77,25 @@ def test_rotate_matrix_answer():
|
|||
config = RotateMatrixConfig(seed=42)
|
||||
dataset = RotateMatrixDataset(config)
|
||||
|
||||
# n = 1
|
||||
# n = 1, num_rotations = 1
|
||||
matrix = [[8]]
|
||||
expected = [[8]]
|
||||
assert dataset._get_rotated(matrix) == expected
|
||||
assert dataset._get_rotated(matrix, num_rotations=1) == expected
|
||||
|
||||
# n = 2
|
||||
# n = 3, num_rotations = 0 (no rotation)
|
||||
matrix = [
|
||||
[0, 1],
|
||||
[2, 3],
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8],
|
||||
]
|
||||
expected = [
|
||||
[2, 0],
|
||||
[3, 1],
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8],
|
||||
]
|
||||
assert dataset._get_rotated(matrix) == expected
|
||||
assert dataset._get_rotated(matrix, num_rotations=0) == expected
|
||||
|
||||
# n = 3
|
||||
# n = 3, num_rotations = 1 (90 degrees clockwise)
|
||||
matrix = [
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
|
|
@ -98,3 +106,39 @@ def test_rotate_matrix_answer():
|
|||
[7, 4, 1],
|
||||
[8, 5, 2],
|
||||
]
|
||||
assert dataset._get_rotated(matrix, num_rotations=1) == expected
|
||||
|
||||
# n = 3, num_rotations = 2 (180 degrees clockwise)
|
||||
matrix = [
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8],
|
||||
]
|
||||
expected = [
|
||||
[8, 7, 6],
|
||||
[5, 4, 3],
|
||||
[2, 1, 0],
|
||||
]
|
||||
assert dataset._get_rotated(matrix, num_rotations=2) == expected
|
||||
|
||||
# n = 3, num_rotations = 3 (270 degrees clockwise)
|
||||
matrix = [
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8],
|
||||
]
|
||||
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
|
||||
assert dataset._get_rotated(matrix, num_rotations=3) == expected
|
||||
|
||||
# n = 4, num_rotations = 4 (360 degrees clockwise == 0 degrees)
|
||||
matrix = [
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8],
|
||||
]
|
||||
expected = [
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8],
|
||||
]
|
||||
assert dataset._get_rotated(matrix, num_rotations=4) == expected
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue