From 8d4e9030c090b0d1510a1227a9686cea6570c632 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Sat, 8 Mar 2025 01:57:37 +0100 Subject: [PATCH] manipulate matrix curriculum (#293) --- reasoning_gym/algorithmic/__init__.py | 3 +- .../algorithmic/manipulate_matrix.py | 65 ++++++++++++++++--- tests/test_manipulate_matrix.py | 35 +++++++++- 3 files changed, 93 insertions(+), 10 deletions(-) diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 931955d0..00d60120 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -21,7 +21,7 @@ from .isomorphic_strings import IsomorphicStringsConfig, IsomorphicStringsCurric from .jugs import JugsConfig, JugsDataset from .letter_counting import LetterCountingConfig, LetterCountingDataset from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset -from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset +from .manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixCurriculum, ManipulateMatrixDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset from .palindrome_generation import PalindromeConfig, PalindromeDataset @@ -91,6 +91,7 @@ __all__ = [ "RotateMatrixDataset", "ManipulateMatrixConfig", "ManipulateMatrixDataset", + "ManipulateMatrixCurriculum", "BinaryMatrixConfig", "BinaryMatrixDataset", "BinaryMatrixCurriculum", diff --git a/reasoning_gym/algorithmic/manipulate_matrix.py b/reasoning_gym/algorithmic/manipulate_matrix.py index ec964094..f82ef736 100644 --- a/reasoning_gym/algorithmic/manipulate_matrix.py +++ b/reasoning_gym/algorithmic/manipulate_matrix.py @@ -7,6 +7,7 @@ from typing import Any, Optional import numpy as np +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """For the following matrix: @@ -52,8 +53,8 @@ class ManipulateMatrixConfig: def validate(self): """Validate configuration parameters""" - assert 1 <= self.min_rows, "min_rows must be at least 1" - assert 1 <= self.min_cols, "min_cols must be at least 1" + assert 2 <= self.min_rows, "min_rows must be at least 2" + assert 2 <= self.min_cols, "min_cols must be at least 2" assert self.min_rows <= self.max_rows, "max_rows must be at least min_rows" assert self.min_cols <= self.max_cols, "max_cols must be at least min_cols" assert 1 <= self.min_transforms, "min_transforms must be at least 1" @@ -118,10 +119,8 @@ class ManipulateMatrixDataset(ProceduralDataset): ) self._weights = np.exp(weights) / np.sum(np.exp(weights)) - def _get_matrix(self, rng: Random) -> list[list[int]]: + def _get_matrix(self, rng: Random, rows: int, cols: int) -> list[list[int]]: """Generate a random matrix""" - rows = rng.randint(self.config.min_rows, self.config.max_rows) - cols = rng.randint(self.config.min_cols, self.config.max_cols) numbers = [rng.randint(0, 9) for _ in range(rows * cols)] matrix = [numbers[i * cols : (i + 1) * cols] for i in range(rows)] return matrix @@ -205,7 +204,9 @@ class ManipulateMatrixDataset(ProceduralDataset): """Generate a single Manipulate Matrix question""" rng = Random(self.seed + idx) - matrix = self._get_matrix(rng) + rows = rng.randint(self.config.min_rows, self.config.max_rows) + cols = rng.randint(self.config.min_cols, self.config.max_cols) + matrix = self._get_matrix(rng, rows, cols) matrix_str = self._matrix_to_str(matrix) num_transforms = rng.randint(self.config.min_transforms, self.config.max_transforms) @@ -304,8 +305,56 @@ class ManipulateMatrixDataset(ProceduralDataset): matrix=matrix_str, operations="\n".join(op["instruction"] for op in operations) ), "answer": answer_str, - "metadata": {"matrix": matrix, "solution": answer, "operations": operations}, + "metadata": { + "matrix": matrix, + "solution": answer, + "operations": operations, + "difficulty": { + "rows": rows, + "cols": cols, + "num_transforms": num_transforms, + }, + }, } -register_dataset("manipulate_matrix", ManipulateMatrixDataset, ManipulateMatrixConfig) +class ManipulateMatrixCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(ManipulateMatrixCurriculum.__name__, ManipulateMatrixConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="rows", + levels=[10, 25, 50, 100], + default_level=0, + description="Number of rows in the matrix", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_rows", + upper_field_name="max_rows", + ), + RangeAttributeDefinition( + name="cols", + levels=[10, 25, 50, 100], + default_level=0, + description="Number of columns in the matrix", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_cols", + upper_field_name="max_cols", + ), + RangeAttributeDefinition( + name="num_transforms", + levels=[5, 10, 20, 30], + default_level=0, + description="Number of transformations to apply", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_transforms", + upper_field_name="max_transforms", + ), + ) + + +register_dataset("manipulate_matrix", ManipulateMatrixDataset, ManipulateMatrixConfig, ManipulateMatrixCurriculum) diff --git a/tests/test_manipulate_matrix.py b/tests/test_manipulate_matrix.py index 1a31af56..c58ba163 100644 --- a/tests/test_manipulate_matrix.py +++ b/tests/test_manipulate_matrix.py @@ -2,7 +2,11 @@ import pytest -from reasoning_gym.algorithmic.manipulate_matrix import ManipulateMatrixConfig, ManipulateMatrixDataset +from reasoning_gym.algorithmic.manipulate_matrix import ( + ManipulateMatrixConfig, + ManipulateMatrixCurriculum, + ManipulateMatrixDataset, +) def test_manipulate_matrix_config_validation(): @@ -219,3 +223,32 @@ def test_manipulate_matrix_score_answer(): # answer is none answer = None assert dataset.score_answer(answer, entry) == 0.0 + + +def test_manipulate_matrix_curriculum(): + curriculum = ManipulateMatrixCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: ManipulateMatrixConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_rows == 10 and base_cfg.max_rows == 10 + assert base_cfg.min_cols == 10 and base_cfg.max_cols == 10 + assert base_cfg.min_transforms == 5 and base_cfg.max_transforms == 5 + + # test incrementing attribute levels + curriculum.increment_attr_level("rows") + curriculum.increment_attr_level("cols") + curriculum.increment_attr_level("num_transforms") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_rows == 10 and increased_cfg.max_rows == 25 + assert increased_cfg.min_cols == 10 and increased_cfg.max_cols == 25 + assert increased_cfg.min_transforms == 5 and increased_cfg.max_transforms == 10 + + # test decrementing attribute level for rows again + curriculum.decrement_attr_level("rows") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_rows == 10 and partially_decreased_cfg.max_rows == 10 + assert partially_decreased_cfg.min_cols == 10 and partially_decreased_cfg.max_cols == 25 + assert increased_cfg.min_transforms == 5 and increased_cfg.max_transforms == 10