diff --git a/reasoning_gym/algorithmic/rotate_matrix.py b/reasoning_gym/algorithmic/rotate_matrix.py index 4fdf651e..adeaa47c 100644 --- a/reasoning_gym/algorithmic/rotate_matrix.py +++ b/reasoning_gym/algorithmic/rotate_matrix.py @@ -60,22 +60,16 @@ class RotateMatrixDataset(ProceduralDataset): matrix = [numbers[i * n : (i + 1) * n] for i in range(n)] return matrix + def _rot90(self, matrix: list[list[int]]) -> list[list[int]]: + """quarter clockwise rotation""" + return [list(row) for row in zip(*matrix[::-1])] + def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]: """Rotate the matrix K times by 90 degrees clockwise""" num_rotations %= 4 - n = len(matrix) output = deepcopy(matrix) - for _ in range(num_rotations): - for l in range(n // 2): - for i in range(l, n - 1 - l): - (output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = ( - output[n - 1 - i][l], - output[l][i], - output[i][n - 1 - l], - output[n - 1 - l][n - 1 - i], - ) - + output = self._rot90(output) return output def _matrix_to_str(self, matrix: list[list[int]]) -> str: