diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 7d6d9f02..f7f8d161 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -29,7 +29,7 @@ from .palindrome_partitioning import PalindromePartitioningConfig, PalindromePar from .pool_matrix import PoolMatrixConfig, PoolMatrixDataset from .ransom_note import RansomNoteConfig, RansomNoteDataset from .rotate_matrix import RotateMatrixConfig, RotateMatrixCurriculum, RotateMatrixDataset -from .rotten_oranges import RottenOrangesConfig, RottenOrangesDataset +from .rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixCurriculum, SpiralMatrixDataset @@ -116,6 +116,7 @@ __all__ = [ "StringSynthesisDataset", "RottenOrangesConfig", "RottenOrangesDataset", + "RottenOrangesCurriculum", "JugsConfig", "JugsDataset", "BinaryAlternationConfig", diff --git a/reasoning_gym/algorithmic/rotten_oranges.py b/reasoning_gym/algorithmic/rotten_oranges.py index 3b849c4d..3995fcc4 100644 --- a/reasoning_gym/algorithmic/rotten_oranges.py +++ b/reasoning_gym/algorithmic/rotten_oranges.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """You are given an n x n grid where each cell can have one of three values: @@ -40,7 +41,7 @@ class RottenOrangesConfig: def validate(self): """Validate configuration parameters""" - assert 1 <= self.min_n, "min_n must be at least 1" + assert 2 <= self.min_n, "min_n must be at least 2" assert self.min_n <= self.max_n, "min_n must be less than or equal to max_n" assert 0 < self.p_oranges <= 1, "p_oranges must be between 0 and 1" assert 0 < self.p_rotten <= 1, "p_rotten must be between 0 and 1" @@ -56,9 +57,8 @@ class RottenOrangesDataset(ProceduralDataset): """Get a string representation of the matrix""" return "\n".join(" ".join(str(x) for x in row) for row in matrix) - def _get_initial_matrix(self, rng: Random) -> list[list[int]]: + def _get_initial_matrix(self, rng: Random, n: int) -> list[list[int]]: """Generate a random matrix with oranges""" - n = rng.randint(self.config.min_n, self.config.max_n) matrix = [[0] * n for _ in range(n)] for i in range(n): for j in range(n): @@ -111,15 +111,39 @@ class RottenOrangesDataset(ProceduralDataset): """Generate a single Rotten Oranges question""" rng = Random(self.seed + idx) - matrix = self._get_initial_matrix(rng) + n = rng.randint(self.config.min_n, self.config.max_n) + matrix = self._get_initial_matrix(rng, n) matrix_str = self._matrix_to_str(matrix) answer = self._get_answer(matrix) return { "question": QUESTION_TEMPLATE.format(matrix=matrix_str), "answer": str(answer), - "metadata": {"matrix": matrix, "solution": answer}, + "metadata": { + "matrix": matrix, + "solution": answer, + "difficulty": {"n": n}, + }, } -register_dataset("rotten_oranges", RottenOrangesDataset, RottenOrangesConfig) +class RottenOrangesCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(RottenOrangesCurriculum.__name__, RottenOrangesConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="n", + levels=[10, 25, 50, 100], + default_level=0, + description="Size of the square matrix", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_n", + upper_field_name="max_n", + ) + ) + + +register_dataset("rotten_oranges", RottenOrangesDataset, RottenOrangesConfig, RottenOrangesCurriculum) diff --git a/tests/test_rotting_oranges.py b/tests/test_rotten_oranges.py similarity index 80% rename from tests/test_rotting_oranges.py rename to tests/test_rotten_oranges.py index 070fb319..613de36d 100644 --- a/tests/test_rotting_oranges.py +++ b/tests/test_rotten_oranges.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.algorithmic.rotten_oranges import RottenOrangesConfig, RottenOrangesDataset +from reasoning_gym.algorithmic.rotten_oranges import RottenOrangesConfig, RottenOrangesCurriculum, RottenOrangesDataset def test_rotten_oranges_config_validation(): @@ -117,3 +117,24 @@ def test_rotten_oranges_answer(): [0, 1, 1], ] assert dataset._get_answer(matrix) == 4 + + +def test_rotten_oranges_curriculum(): + curriculum = RottenOrangesCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: RottenOrangesConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_n == 10 and base_cfg.max_n == 10 + + # test incrementing attribute levels + curriculum.increment_attr_level("n") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_n == 10 and increased_cfg.max_n == 25 + + # test decrementing attribute level for n again + curriculum.decrement_attr_level("n") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_n == 10 and partially_decreased_cfg.max_n == 10