diff --git a/reasoning_gym/cognition/__init__.py b/reasoning_gym/cognition/__init__.py index 79185467..3b708d47 100644 --- a/reasoning_gym/cognition/__init__.py +++ b/reasoning_gym/cognition/__init__.py @@ -2,7 +2,7 @@ Cognition tasks for training reasoning capabilities. """ -from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationDataset +from .color_cube_rotation import ColorCubeRotationConfig, ColorCubeRotationCurriculum, ColorCubeRotationDataset from .figlet_fonts import FigletFontConfig, FigletFontDataset from .modulo_grid import ModuloGridConfig, ModuloGridDataset from .needle_haystack import NeedleHaystackConfig, NeedleHaystackDataset @@ -13,6 +13,7 @@ from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset __all__ = [ "ColorCubeRotationConfig", "ColorCubeRotationDataset", + "ColorCubeRotationCurriculum", "FigletFontConfig", "FigletFontDataset", "NumberSequenceConfig", diff --git a/reasoning_gym/cognition/color_cube_rotation.py b/reasoning_gym/cognition/color_cube_rotation.py index 59e3e3a1..997e6a05 100644 --- a/reasoning_gym/cognition/color_cube_rotation.py +++ b/reasoning_gym/cognition/color_cube_rotation.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from enum import StrEnum from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -140,6 +141,7 @@ class ColorCubeRotationDataset(ProceduralDataset): "rotations": [r.value for r in rotations], "target_side": target_side.value, "num_rotations": num_rotations, + "difficulty": {"rotations": num_rotations}, }, } @@ -204,4 +206,23 @@ class ColorCubeRotationDataset(ProceduralDataset): return reward -register_dataset("color_cube_rotation", ColorCubeRotationDataset, ColorCubeRotationConfig) +class ColorCubeRotationCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(ColorCubeRotationCurriculum.__name__, ColorCubeRotationConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="rotations", + levels=[1, 5, 10, 50, 100], + default_level=1, + description="Number of rotations to perform on the cube", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_rotations", + upper_field_name="max_rotations", + ) + ) + + +register_dataset("color_cube_rotation", ColorCubeRotationDataset, ColorCubeRotationConfig, ColorCubeRotationCurriculum) diff --git a/tests/test_color_cube_rotation.py b/tests/test_color_cube_rotation.py index 87ecd8c6..e024e45a 100644 --- a/tests/test_color_cube_rotation.py +++ b/tests/test_color_cube_rotation.py @@ -1,7 +1,14 @@ import pytest from reasoning_gym import create_dataset -from reasoning_gym.cognition.color_cube_rotation import Color, ColorCubeRotationDataset, Cube, Side +from reasoning_gym.cognition.color_cube_rotation import ( + Color, + ColorCubeRotationConfig, + ColorCubeRotationCurriculum, + ColorCubeRotationDataset, + Cube, + Side, +) def test_color_cube_rotation_generation(): @@ -82,3 +89,24 @@ def test_cube_rotations(): assert cube.colors[Side.BACK] == original[Side.TOP] assert cube.colors[Side.RIGHT] == original[Side.RIGHT] # Unchanged assert cube.colors[Side.LEFT] == original[Side.LEFT] # Unchanged + + +def test_shortest_path_curriculum(): + curriculum = ColorCubeRotationCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: ColorCubeRotationConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_rotations == 1 and base_cfg.max_rotations == 5 + + # test incrementing attribute levels + curriculum.increment_attr_level("rotations") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_rotations == 1 and increased_cfg.max_rotations == 10 + + # test decrementing attribute level + curriculum.decrement_attr_level("rotations") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_rotations == 1 and partially_decreased_cfg.max_rotations == 5