spiral matrix curriculum (#296)

This commit is contained in:
Zafir Stojanovski 2025-03-08 20:56:08 +01:00 committed by GitHub
parent d82c73b6f8
commit e4e516a949
3 changed files with 55 additions and 8 deletions

View file

@ -32,7 +32,7 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMat
from .rotten_oranges import RottenOrangesConfig, RottenOrangesDataset from .rotten_oranges import RottenOrangesConfig, RottenOrangesDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset
from .string_insertion import StringInsertionConfig, StringInsertionDataset from .string_insertion import StringInsertionConfig, StringInsertionDataset
from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .string_manipulation import StringManipulationConfig, StringManipulationDataset
from .string_splitting import StringSplittingConfig, StringSplittingDataset from .string_splitting import StringSplittingConfig, StringSplittingDataset
@ -82,6 +82,7 @@ __all__ = [
"PalindromePartitioningDataset", "PalindromePartitioningDataset",
"SpiralMatrixConfig", "SpiralMatrixConfig",
"SpiralMatrixDataset", "SpiralMatrixDataset",
"SpiralMatrixCurriculum",
"RansomNoteConfig", "RansomNoteConfig",
"RansomNoteDataset", "RansomNoteDataset",
"IsomorphicStringsConfig", "IsomorphicStringsConfig",

View file

@ -8,6 +8,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a matrix, your job is to generate a list of elements in spiral order, starting from the top-left element. QUESTION_TEMPLATE = """Given a matrix, your job is to generate a list of elements in spiral order, starting from the top-left element.
@ -30,6 +31,7 @@ For the matrix below, what is the list of elements in spiral order?
class SpiralMatrixConfig: class SpiralMatrixConfig:
"""Configuration for Spiral Matrix dataset generation""" """Configuration for Spiral Matrix dataset generation"""
min_n: int = 2 # Minimum number of rows/cols in the matrix
max_n: int = 10 # Maximum number of rows/cols in the matrix max_n: int = 10 # Maximum number of rows/cols in the matrix
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
@ -37,7 +39,7 @@ class SpiralMatrixConfig:
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
assert 2 <= self.max_n, "max_n must be at least 2" assert 2 <= self.min_n <= self.max_n, "min_n must be between 2 and max_n"
class SpiralMatrixDataset(ProceduralDataset): class SpiralMatrixDataset(ProceduralDataset):
@ -46,9 +48,8 @@ class SpiralMatrixDataset(ProceduralDataset):
def __init__(self, config: SpiralMatrixConfig): def __init__(self, config: SpiralMatrixConfig):
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
def _get_matrix(self, rng: Random) -> list[list[int]]: def _get_matrix(self, rng: Random, n: int) -> list[list[int]]:
"""Generate a random matrix""" """Generate a random matrix"""
n = rng.randint(2, self.config.max_n)
numbers = [rng.randint(0, 9) for _ in range(n**2)] numbers = [rng.randint(0, 9) for _ in range(n**2)]
rng.shuffle(numbers) rng.shuffle(numbers)
matrix = [numbers[i * n : (i + 1) * n] for i in range(n)] matrix = [numbers[i * n : (i + 1) * n] for i in range(n)]
@ -100,7 +101,8 @@ class SpiralMatrixDataset(ProceduralDataset):
"""Generate a single Spiral Matrix question""" """Generate a single Spiral Matrix question"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
matrix = self._get_matrix(rng) n = rng.randint(2, self.config.max_n)
matrix = self._get_matrix(rng, n)
matrix_str = self._matrix_to_str(matrix) matrix_str = self._matrix_to_str(matrix)
answer = self._get_spiral(matrix) answer = self._get_spiral(matrix)
answer_str = self._list_to_str(answer) answer_str = self._list_to_str(answer)
@ -108,7 +110,11 @@ class SpiralMatrixDataset(ProceduralDataset):
return { return {
"question": QUESTION_TEMPLATE.format(matrix=matrix_str), "question": QUESTION_TEMPLATE.format(matrix=matrix_str),
"answer": answer_str, "answer": answer_str,
"metadata": {"matrix": matrix, "solution": answer}, "metadata": {
"matrix": matrix,
"solution": answer,
"difficulty": {"n": n},
},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -133,4 +139,23 @@ class SpiralMatrixDataset(ProceduralDataset):
return 0.0 return 0.0
register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig) class SpiralMatrixCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(SpiralMatrixCurriculum.__name__, SpiralMatrixConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="n",
levels=[10, 25, 50, 100],
default_level=0,
description="Number of rows/cols in the matrix",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_n",
upper_field_name="max_n",
)
)
register_dataset("spiral_matrix", SpiralMatrixDataset, SpiralMatrixConfig, SpiralMatrixCurriculum)

View file

@ -2,7 +2,7 @@
import pytest import pytest
from reasoning_gym.algorithmic.spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from reasoning_gym.algorithmic.spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset
def test_spiral_matrix_config_validation(): def test_spiral_matrix_config_validation():
@ -96,3 +96,24 @@ def test_spiral_matrix_answer():
entry = {"answer": "1 2 3 6 9 8 7 4 5"} entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = None answer = None
assert dataset.score_answer(answer, entry) == 0.0 assert dataset.score_answer(answer, entry) == 0.0
def test_spiral_matrix_curriculum():
curriculum = SpiralMatrixCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: SpiralMatrixConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_n == 10 and base_cfg.max_n == 10
# test incrementing attribute levels
curriculum.increment_attr_level("n")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_n == 10 and increased_cfg.max_n == 25
# test decrementing attribute levels
curriculum.decrement_attr_level("n")
decreased_cfg = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_n == 10 and decreased_cfg.max_n == 10