diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index fec825d0..7d6d9f02 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -32,7 +32,7 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMat from .rotten_oranges import RottenOrangesConfig, RottenOrangesDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset -from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset +from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset from .string_insertion import StringInsertionConfig, StringInsertionDataset from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .string_splitting import StringSplittingConfig, StringSplittingDataset @@ -82,6 +82,7 @@ __all__ = [ "PalindromePartitioningDataset", "SpiralMatrixConfig", "SpiralMatrixDataset", + "SpiralMatrixCurriculum", "RansomNoteConfig", "RansomNoteDataset", "IsomorphicStringsConfig", diff --git a/reasoning_gym/algorithmic/spiral_matrix.py b/reasoning_gym/algorithmic/spiral_matrix.py index f895c628..52bf2785 100644 --- a/reasoning_gym/algorithmic/spiral_matrix.py +++ b/reasoning_gym/algorithmic/spiral_matrix.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Given a matrix, your job is to generate a list of elements in spiral order, starting from the top-left element. @@ -30,6 +31,7 @@ For the matrix below, what is the list of elements in spiral order? class SpiralMatrixConfig: """Configuration for Spiral Matrix dataset generation""" + min_n: int = 2 # Minimum number of rows/cols in the matrix max_n: int = 10 # Maximum number of rows/cols in the matrix size: int = 500 # Virtual dataset size @@ -37,7 +39,7 @@ class SpiralMatrixConfig: def validate(self): """Validate configuration parameters""" - assert 2 <= self.max_n, "max_n must be at least 2" + assert 2 <= self.min_n <= self.max_n, "min_n must be between 2 and max_n" class SpiralMatrixDataset(ProceduralDataset): @@ -46,9 +48,8 @@ class SpiralMatrixDataset(ProceduralDataset): def __init__(self, config: SpiralMatrixConfig): super().__init__(config=config, seed=config.seed, size=config.size) - def _get_matrix(self, rng: Random) -> list[list[int]]: + def _get_matrix(self, rng: Random, n: int) -> list[list[int]]: """Generate a random matrix""" - n = rng.randint(2, self.config.max_n) numbers = [rng.randint(0, 9) for _ in range(n**2)] rng.shuffle(numbers) matrix = [numbers[i * n : (i + 1) * n] for i in range(n)] @@ -100,7 +101,8 @@ class SpiralMatrixDataset(ProceduralDataset): """Generate a single Spiral Matrix question""" rng = Random(self.seed + idx) - matrix = self._get_matrix(rng) + n = rng.randint(2, self.config.max_n) + matrix = self._get_matrix(rng, n) matrix_str = self._matrix_to_str(matrix) answer = self._get_spiral(matrix) answer_str = self._list_to_str(answer) @@ -108,7 +110,11 @@ class SpiralMatrixDataset(ProceduralDataset): return { "question": QUESTION_TEMPLATE.format(matrix=matrix_str), "answer": answer_str, - "metadata": {"matrix": matrix, "solution": answer}, + "metadata": { + "matrix": matrix, + "solution": answer, + "difficulty": {"n": n}, + }, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -133,4 +139,23 @@ class SpiralMatrixDataset(ProceduralDataset): return 0.0 -register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig) +class SpiralMatrixCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(SpiralMatrixCurriculum.__name__, SpiralMatrixConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="n", + levels=[10, 25, 50, 100], + default_level=0, + description="Number of rows/cols in the matrix", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_n", + upper_field_name="max_n", + ) + ) + + +register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig, SpiralMatrixCurriculum) diff --git a/tests/test_spiral_matrix.py b/tests/test_spiral_matrix.py index 9e5c510e..a6aec30e 100644 --- a/tests/test_spiral_matrix.py +++ b/tests/test_spiral_matrix.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.algorithmic.spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset +from reasoning_gym.algorithmic.spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset def test_spiral_matrix_config_validation(): @@ -96,3 +96,24 @@ def test_spiral_matrix_answer(): entry = {"answer": "1 2 3 6 9 8 7 4 5"} answer = None assert dataset.score_answer(answer, entry) == 0.0 + + +def test_spiral_matrix_curriculum(): + curriculum = SpiralMatrixCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: SpiralMatrixConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_n == 10 and base_cfg.max_n == 10 + + # test incrementing attribute levels + curriculum.increment_attr_level("n") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_n == 10 and increased_cfg.max_n == 25 + + # test decrementing attribute levels + curriculum.decrement_attr_level("n") + decreased_cfg = curriculum.generate_configuration(base_value) + assert decreased_cfg.min_n == 10 and decreased_cfg.max_n == 10