mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-30 17:40:45 +00:00
Merge pull request #87 from zafstojano/env/rotate-matrix
Rotate Matrix k times
This commit is contained in:
commit
307a031146
3 changed files with 250 additions and 0 deletions
|
|
@ -16,6 +16,7 @@ from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
|
||||||
from .number_sorting import NumberSortingConfig, NumberSortingDataset
|
from .number_sorting import NumberSortingConfig, NumberSortingDataset
|
||||||
from .palindrome_generation import PalindromeConfig, PalindromeDataset
|
from .palindrome_generation import PalindromeConfig, PalindromeDataset
|
||||||
from .ransom_note import RansomNoteConfig, RansomNoteDataset
|
from .ransom_note import RansomNoteConfig, RansomNoteDataset
|
||||||
|
from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
|
||||||
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
||||||
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
||||||
from .word_ladder import WordLadderConfig, WordLadderDataset
|
from .word_ladder import WordLadderConfig, WordLadderDataset
|
||||||
|
|
@ -54,4 +55,6 @@ __all__ = [
|
||||||
"RansomNoteDataset",
|
"RansomNoteDataset",
|
||||||
"IsomorphicStringsConfig",
|
"IsomorphicStringsConfig",
|
||||||
"IsomorphicStringsDataset",
|
"IsomorphicStringsDataset",
|
||||||
|
"RotateMatrixConfig",
|
||||||
|
"RotateMatrixDataset",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
103
reasoning_gym/algorithmic/rotate_matrix.py
Normal file
103
reasoning_gym/algorithmic/rotate_matrix.py
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
"""Rotate a square matrix clockwise.
|
||||||
|
|
||||||
|
A popular Leetcode problem:
|
||||||
|
https://leetcode.com/problems/rotate-image/description/
|
||||||
|
"""
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from random import Random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it clockwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
Input: Rotate the matrix below by 90 degrees clockwise:
|
||||||
|
1 2 3
|
||||||
|
4 5 6
|
||||||
|
7 8 9
|
||||||
|
|
||||||
|
Output:
|
||||||
|
7 4 1
|
||||||
|
8 5 2
|
||||||
|
9 6 3
|
||||||
|
|
||||||
|
Rotate the matrix below by {degrees} degrees clockwise:
|
||||||
|
{matrix}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RotateMatrixConfig:
|
||||||
|
"""Configuration for Rotate Matrix dataset generation"""
|
||||||
|
|
||||||
|
max_n: int = 10 # Maximum size of the matrix
|
||||||
|
max_rotations: int = 4 # Maximum number of rotations (90 degrees each)
|
||||||
|
|
||||||
|
size: int = 500 # Virtual dataset size
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert 1 <= self.max_n, "max_n must be at least 1"
|
||||||
|
assert 0 <= self.max_rotations, "max_rotations must be at least 0"
|
||||||
|
|
||||||
|
|
||||||
|
class RotateMatrixDataset(ProceduralDataset):
|
||||||
|
"""Generates Rotate Matrix exercises with configurable difficulty"""
|
||||||
|
|
||||||
|
def __init__(self, config: RotateMatrixConfig):
|
||||||
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def _get_matrix(self, rng: Random) -> list[list[int]]:
|
||||||
|
"""Generate a random matrix"""
|
||||||
|
n = rng.randint(1, self.config.max_n)
|
||||||
|
numbers = list(range(n**2))
|
||||||
|
rng.shuffle(numbers)
|
||||||
|
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
def _get_rotated(self, matrix: list[list[int]], num_rotations: int) -> list[list[int]]:
|
||||||
|
"""Rotate the matrix K times by 90 degrees clockwise"""
|
||||||
|
num_rotations %= 4
|
||||||
|
n = len(matrix)
|
||||||
|
output = deepcopy(matrix)
|
||||||
|
|
||||||
|
for _ in range(num_rotations):
|
||||||
|
for l in range(n // 2):
|
||||||
|
for i in range(l, n - 1 - l):
|
||||||
|
(output[l][i], output[i][n - 1 - l], output[n - 1 - l][n - 1 - i], output[n - 1 - i][l]) = (
|
||||||
|
output[n - 1 - i][l],
|
||||||
|
output[l][i],
|
||||||
|
output[i][n - 1 - l],
|
||||||
|
output[n - 1 - l][n - 1 - i],
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _matrix_to_str(self, matrix: list[list[int]]) -> str:
|
||||||
|
"""Get a string representation of the matrix"""
|
||||||
|
return "\n".join(" ".join(str(x) for x in row) for row in matrix)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
"""Generate a single Spiral Matrix question"""
|
||||||
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
|
matrix = self._get_matrix(rng)
|
||||||
|
num_rotations = rng.randint(0, self.config.max_rotations)
|
||||||
|
matrix_str = self._matrix_to_str(matrix)
|
||||||
|
|
||||||
|
answer = self._get_rotated(matrix, num_rotations)
|
||||||
|
answer_str = self._matrix_to_str(answer)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90),
|
||||||
|
"answer": answer_str,
|
||||||
|
"metadata": {"matrix": matrix, "num_rotations": num_rotations, "solution": answer},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
register_dataset("rotate_matrix", RotateMatrixDataset, RotateMatrixConfig)
|
||||||
144
tests/test_rotate_matrix.py
Normal file
144
tests/test_rotate_matrix.py
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
"""Tests for Rotate Matrix questions generation"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reasoning_gym.algorithmic.rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_rotate_matrix_config_validation():
|
||||||
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = RotateMatrixConfig(max_n=-1) # Negative not allowed
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = RotateMatrixConfig(max_n=0) # Zero not allowed
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
config = RotateMatrixConfig(max_rotations=-1) # Negative not allowed
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_rotate_matrix_dataset_deterministic():
|
||||||
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
config = RotateMatrixConfig(seed=42, size=10)
|
||||||
|
dataset1 = RotateMatrixDataset(config)
|
||||||
|
dataset2 = RotateMatrixDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset1)):
|
||||||
|
assert dataset1[i] == dataset2[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_rotate_matrix_dataset_items():
|
||||||
|
"""Test basic properties of generated items"""
|
||||||
|
config = RotateMatrixConfig(max_n=7, size=10, seed=42)
|
||||||
|
dataset = RotateMatrixDataset(config)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
# Check item structure
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "question" in item
|
||||||
|
assert "answer" in item
|
||||||
|
assert "metadata" in item
|
||||||
|
|
||||||
|
# Check metadata
|
||||||
|
assert "matrix" in item["metadata"]
|
||||||
|
assert "solution" in item["metadata"]
|
||||||
|
|
||||||
|
matrix = item["metadata"]["matrix"]
|
||||||
|
solution = item["metadata"]["solution"]
|
||||||
|
num_rotations = item["metadata"]["num_rotations"]
|
||||||
|
|
||||||
|
# Verify matrix dimensions
|
||||||
|
assert len(matrix) <= config.max_n
|
||||||
|
assert all(len(row) <= config.max_n for row in matrix)
|
||||||
|
assert len(solution) <= config.max_n
|
||||||
|
assert all(len(row) <= config.max_n for row in solution)
|
||||||
|
assert set(e for row in matrix for e in row) == set(e for row in solution for e in row)
|
||||||
|
assert num_rotations <= config.max_rotations
|
||||||
|
|
||||||
|
|
||||||
|
def test_rotate_matrix_dataset_iteration():
|
||||||
|
"""Test that iteration respects dataset size"""
|
||||||
|
config = RotateMatrixConfig(size=5, seed=42)
|
||||||
|
dataset = RotateMatrixDataset(config)
|
||||||
|
|
||||||
|
items = list(dataset)
|
||||||
|
assert len(items) == config.size
|
||||||
|
|
||||||
|
# Test multiple iterations yield same items
|
||||||
|
assert items == list(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rotate_matrix_answer():
|
||||||
|
"""Test the _get_rotated method"""
|
||||||
|
config = RotateMatrixConfig(seed=42)
|
||||||
|
dataset = RotateMatrixDataset(config)
|
||||||
|
|
||||||
|
# n = 1, num_rotations = 1
|
||||||
|
matrix = [[8]]
|
||||||
|
expected = [[8]]
|
||||||
|
assert dataset._get_rotated(matrix, num_rotations=1) == expected
|
||||||
|
|
||||||
|
# n = 3, num_rotations = 0 (no rotation)
|
||||||
|
matrix = [
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[6, 7, 8],
|
||||||
|
]
|
||||||
|
expected = [
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[6, 7, 8],
|
||||||
|
]
|
||||||
|
assert dataset._get_rotated(matrix, num_rotations=0) == expected
|
||||||
|
|
||||||
|
# n = 3, num_rotations = 1 (90 degrees clockwise)
|
||||||
|
matrix = [
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[6, 7, 8],
|
||||||
|
]
|
||||||
|
expected = [
|
||||||
|
[6, 3, 0],
|
||||||
|
[7, 4, 1],
|
||||||
|
[8, 5, 2],
|
||||||
|
]
|
||||||
|
assert dataset._get_rotated(matrix, num_rotations=1) == expected
|
||||||
|
|
||||||
|
# n = 3, num_rotations = 2 (180 degrees clockwise)
|
||||||
|
matrix = [
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[6, 7, 8],
|
||||||
|
]
|
||||||
|
expected = [
|
||||||
|
[8, 7, 6],
|
||||||
|
[5, 4, 3],
|
||||||
|
[2, 1, 0],
|
||||||
|
]
|
||||||
|
assert dataset._get_rotated(matrix, num_rotations=2) == expected
|
||||||
|
|
||||||
|
# n = 3, num_rotations = 3 (270 degrees clockwise)
|
||||||
|
matrix = [
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[6, 7, 8],
|
||||||
|
]
|
||||||
|
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
|
||||||
|
assert dataset._get_rotated(matrix, num_rotations=3) == expected
|
||||||
|
|
||||||
|
# n = 4, num_rotations = 4 (360 degrees clockwise == 0 degrees)
|
||||||
|
matrix = [
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[6, 7, 8],
|
||||||
|
]
|
||||||
|
expected = [
|
||||||
|
[0, 1, 2],
|
||||||
|
[3, 4, 5],
|
||||||
|
[6, 7, 8],
|
||||||
|
]
|
||||||
|
assert dataset._get_rotated(matrix, num_rotations=4) == expected
|
||||||
Loading…
Add table
Add a link
Reference in a new issue