diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 00d60120..fec825d0 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -28,7 +28,7 @@ from .palindrome_generation import PalindromeConfig, PalindromeDataset from .palindrome_partitioning import PalindromePartitioningConfig, PalindromePartitioningDataset from .pool_matrix import PoolMatrixConfig, PoolMatrixDataset from .ransom_note import RansomNoteConfig, RansomNoteDataset -from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset +from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset from .rotten_oranges import RottenOrangesConfig, RottenOrangesDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset @@ -89,6 +89,7 @@ __all__ = [ "IsomorphicStringsCurriculum", "RotateMatrixConfig", "RotateMatrixDataset", + "RotateMatrixCurriculum", "ManipulateMatrixConfig", "ManipulateMatrixDataset", "ManipulateMatrixCurriculum", diff --git a/reasoning_gym/algorithmic/rotate_matrix.py b/reasoning_gym/algorithmic/rotate_matrix.py index 2154243f..7df1fba2 100644 --- a/reasoning_gym/algorithmic/rotate_matrix.py +++ b/reasoning_gym/algorithmic/rotate_matrix.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Given a square matrix, your job is to rotate it clockwise. @@ -24,16 +25,18 @@ Rotate the matrix below by {degrees} degrees clockwise: class RotateMatrixConfig: """Configuration for Rotate Matrix dataset generation""" + min_n: int = 2 # Minimum size of the matrix max_n: int = 10 # Maximum size of the matrix - max_rotations: int = 4 # Maximum number of rotations (90 degrees each) + min_rotations: int = 0 # Minimum number of rotations + max_rotations: int = 10 # Maximum number of rotations (90 degrees each) size: int = 500 # Virtual dataset size seed: Optional[int] = None def validate(self): """Validate configuration parameters""" - assert 1 <= self.max_n, "max_n must be at least 1" - assert 0 <= self.max_rotations, "max_rotations must be at least 0" + assert 2 <= self.min_n <= self.max_n, "min_n and max_n must be between 2 and 10" + assert 0 <= self.min_rotations <= self.max_rotations, "min_rotations must be between 0 and max_rotations" class RotateMatrixDataset(ProceduralDataset): @@ -42,11 +45,9 @@ class RotateMatrixDataset(ProceduralDataset): def __init__(self, config: RotateMatrixConfig): super().__init__(config=config, seed=config.seed, size=config.size) - def _get_matrix(self, rng: Random) -> list[list[int]]: + def _get_matrix(self, rng: Random, n: int) -> list[list[int]]: """Generate a random matrix""" - n = rng.randint(1, self.config.max_n) - numbers = list(range(n**2)) - rng.shuffle(numbers) + numbers = list(rng.randint(0, 9) for _ in range(n**2)) matrix = [numbers[i * n : (i + 1) * n] for i in range(n)] return matrix @@ -70,8 +71,9 @@ class RotateMatrixDataset(ProceduralDataset): """Generate a single Rotate Matrix question""" rng = Random(self.seed + idx) - matrix = self._get_matrix(rng) - num_rotations = rng.randint(0, self.config.max_rotations) + n = rng.randint(self.config.min_n, self.config.max_n) + matrix = self._get_matrix(rng, n) + num_rotations = rng.randint(self.config.min_rotations, self.config.max_rotations) matrix_str = self._matrix_to_str(matrix) answer = self._get_rotated(matrix, num_rotations) @@ -80,8 +82,45 @@ class RotateMatrixDataset(ProceduralDataset): return { "question": QUESTION_TEMPLATE.format(matrix=matrix_str, degrees=num_rotations * 90), "answer": answer_str, - "metadata": {"matrix": matrix, "num_rotations": num_rotations, "solution": answer}, + "metadata": { + "matrix": matrix, + "num_rotations": num_rotations, + "solution": answer, + "difficulty": { + "n": n, + "num_rotations": num_rotations, + }, + }, } -register_dataset("rotate_matrix", RotateMatrixDataset, RotateMatrixConfig) +class RotateMatrixCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(RotateMatrixCurriculum.__name__, RotateMatrixConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="n", + levels=[10, 25, 50, 100], + default_level=0, + description="Size of the square matrix", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_n", + upper_field_name="max_n", + ), + RangeAttributeDefinition( + name="num_rotations", + levels=[4, 8, 12, 16], + default_level=0, + description="Number of 90-degree rotations", + attr_type=AttributeType.APPEND, + min_value=0, + lower_field_name="min_rotations", + upper_field_name="max_rotations", + ), + ) + + +register_dataset("rotate_matrix", RotateMatrixDataset, RotateMatrixConfig, RotateMatrixCurriculum) diff --git a/tests/test_rotate_matrix.py b/tests/test_rotate_matrix.py index c2e43df9..173b4ef4 100644 --- a/tests/test_rotate_matrix.py +++ b/tests/test_rotate_matrix.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.algorithmic.rotate_matrix import RotateMatrixConfig, RotateMatrixDataset +from reasoning_gym.algorithmic.rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset def test_rotate_matrix_config_validation(): @@ -142,3 +142,28 @@ def test_rotate_matrix_answer(): [6, 7, 8], ] assert dataset._get_rotated(matrix, num_rotations=4) == expected + + +def test_rotate_matrix_curriculum(): + curriculum = RotateMatrixCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: RotateMatrixConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_n == 10 and base_cfg.max_n == 10 + assert base_cfg.min_rotations == 4 and base_cfg.max_rotations == 4 + + # test incrementing attribute levels + curriculum.increment_attr_level("n") + curriculum.increment_attr_level("num_rotations") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_n == 10 and increased_cfg.max_n == 25 + assert increased_cfg.min_rotations == 4 and increased_cfg.max_rotations == 8 + + # test decrementing attribute level for n again + curriculum.decrement_attr_level("n") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_n == 10 and partially_decreased_cfg.max_n == 10 + assert partially_decreased_cfg.min_rotations == 4 and partially_decreased_cfg.max_rotations == 8