diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index b2f8709a..4c799bcd 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -16,10 +16,10 @@ from .number_sorting import NumberSortingConfig, NumberSortingDataset from .palindrome_generation import PalindromeConfig, PalindromeDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset +from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset -from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset __all__ = [ "SpellBackwardConfig", diff --git a/reasoning_gym/algorithmic/spiral_matrix.py b/reasoning_gym/algorithmic/spiral_matrix.py index a7d40d86..ecff246d 100644 --- a/reasoning_gym/algorithmic/spiral_matrix.py +++ b/reasoning_gym/algorithmic/spiral_matrix.py @@ -14,7 +14,7 @@ QUESTION_TEMPLATE = """Given a matrix, your job is to generate a list of element Example: -Input: +Input: 1 2 3 4 5 6 7 8 9 10 11 12 @@ -30,8 +30,8 @@ For the matrix below, what is the list of elements in spiral order? class SpiralMatrixConfig: """Configuration for Spiral Matrix dataset generation""" - max_rows: int = 10 # Maximum number of rows in the matrix - max_cols: int = 10 # Maximum number of columns in the matrix + max_rows: int = 10 # Maximum number of rows in the matrix + max_cols: int = 10 # Maximum number of columns in the matrix size: int = 500 # Virtual dataset size seed: Optional[int] = None @@ -67,29 +67,33 @@ class SpiralMatrixDataset(ProceduralDataset): for i in range(l, r): out.append(matrix[t][i]) t += 1 - if t == b: break + if t == b: + break for i in range(t, b): - out.append(matrix[i][r-1]) + out.append(matrix[i][r - 1]) r -= 1 - if l == r: break + if l == r: + break - for i in range(r-1, l-1, -1): - out.append(matrix[b-1][i]) + for i in range(r - 1, l - 1, -1): + out.append(matrix[b - 1][i]) b -= 1 - if t == b: break + if t == b: + break - for i in range(b-1, t-1, -1): + for i in range(b - 1, t - 1, -1): out.append(matrix[i][l]) l += 1 - if l == r: break + if l == r: + break return out - + def _matrix_to_str(self, matrix: list[list[int]]) -> str: """Get a string representation of the matrix""" return "\n".join(" ".join(str(x) for x in row) for row in matrix) - + def _list_to_str(self, array: list[int]) -> str: """Get a string representation of the array""" return " ".join(str(x) for x in array) @@ -97,7 +101,7 @@ class SpiralMatrixDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single Spiral Matrix question""" rng = Random(self.seed + idx) - + matrix = self._get_matrix(rng) matrix_str = self._matrix_to_str(matrix) answer = self._get_spiral(matrix)