diff --git a/reasoning_gym/cognition/__init__.py b/reasoning_gym/cognition/__init__.py index b4e23932..4cdb9164 100644 --- a/reasoning_gym/cognition/__init__.py +++ b/reasoning_gym/cognition/__init__.py @@ -8,7 +8,7 @@ from .modulo_grid import ModuloGridConfig, ModuloGridDataset from .needle_haystack import NeedleHaystackConfig, NeedleHaystackCurriculum, NeedleHaystackDataset from .number_sequences import NumberSequenceConfig, NumberSequenceCurriculum, NumberSequenceDataset from .rectangle_count import RectangleCountConfig, RectangleCountCurriculum, RectangleCountDataset -from .rubiks_cube import RubiksCubeConfig, RubiksCubeDataset +from .rubiks_cube import RubiksCubeConfig, RubiksCubeCurriculum, RubiksCubeDataset __all__ = [ "ColorCubeRotationConfig", @@ -21,6 +21,7 @@ __all__ = [ "NumberSequenceCurriculum", "RubiksCubeConfig", "RubiksCubeDataset", + "RubiksCubeCurriculum", "RectangleCountConfig", "RectangleCountCurriculum", "RectangleCountDataset", diff --git a/reasoning_gym/cognition/rubiks_cube.py b/reasoning_gym/cognition/rubiks_cube.py index 072bd0de..0f69135f 100644 --- a/reasoning_gym/cognition/rubiks_cube.py +++ b/reasoning_gym/cognition/rubiks_cube.py @@ -6,6 +6,7 @@ from typing import Any, Optional from magiccube.cube import Cube, CubeMove, CubeMoveType from magiccube.solver.basic.basic_solver import BasicSolver +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -13,7 +14,8 @@ from ..factory import ProceduralDataset, register_dataset class RubiksCubeConfig: """Configuration for RubiksCube task generation""" - scramble_steps: int = 3 # Number of random steps from initial state + min_scramble_steps: int = 3 # Minimum number of random steps from initial state + max_scramble_steps: int = 10 # Maximum number of random steps from initial state cube_size: int = 3 # Default to a standard 3x3x3 cube remove_ansi: bool = True seed: Optional[int] = None @@ -23,7 +25,10 @@ class RubiksCubeConfig: """Validate configuration parameters""" assert self.cube_size > 1, "cube_size must be greater than 1" assert self.cube_size < 7, "cube_size must be less than 7" - assert self.scramble_steps > 0, "scramble_steps must be greater than 0" + assert self.min_scramble_steps > 0, "min_scramble_steps must be greater than 0" + assert ( + self.max_scramble_steps >= self.min_scramble_steps + ), "max_scramble_steps must be greater than min_scramble_steps" class RubiksCubeDataset(ProceduralDataset): @@ -77,7 +82,8 @@ class RubiksCubeDataset(ProceduralDataset): rng = Random(self.seed + idx) cube = Cube(self.config.cube_size) - scramble_moves = self._generate_random_moves(rng, cube, num_steps=self.config.scramble_steps) + num_steps = rng.randint(self.config.min_scramble_steps, self.config.max_scramble_steps) + scramble_moves = self._generate_random_moves(rng, cube, num_steps=num_steps) cube.rotate(scramble_moves) # render cube @@ -100,9 +106,13 @@ class RubiksCubeDataset(ProceduralDataset): "answer": None, "metadata": { "cube_size": self.config.cube_size, - "scramble_steps": self.config.scramble_steps, + "scramble_steps": num_steps, "scramble_moves": " ".join([str(move) for move in scramble_moves]), "example_correct_answer": actions_string, + "difficulty": { + "scramble_steps": num_steps, + "cube_size": self.config.cube_size, + }, }, } @@ -154,5 +164,33 @@ class RubiksCubeDataset(ProceduralDataset): return move_str +class RubiksCubeCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(RubiksCubeCurriculum.__name__, RubiksCubeConfig) + + # Define attributes + self._define_attributes( + ScalarAttributeDefinition( + name="cube_size", + field_name="cube_size", + levels=[3, 4, 5, 6, 7], + default_level=0, + description="Board size", + attr_type=AttributeType.STATIC, + min_value=3, + ), + RangeAttributeDefinition( + name="scramble_steps", + levels=[3, 10, 50, 100, 500, 1000], + default_level=1, + description="Number of random moves to scramble the cube", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_scramble_steps", + upper_field_name="max_scramble_steps", + ), + ) + + # Register the dataset -register_dataset("rubiks_cube", RubiksCubeDataset, RubiksCubeConfig) +register_dataset("rubiks_cube", RubiksCubeDataset, RubiksCubeConfig, RubiksCubeCurriculum) diff --git a/tests/test_rubiks_cube.py b/tests/test_rubiks_cube.py index 3f5f39f9..50e87ee6 100644 --- a/tests/test_rubiks_cube.py +++ b/tests/test_rubiks_cube.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.cognition.rubiks_cube import RubiksCubeConfig, RubiksCubeDataset +from reasoning_gym.cognition.rubiks_cube import RubiksCubeConfig, RubiksCubeCurriculum, RubiksCubeDataset def test_rubikscube_config_validation(): @@ -10,7 +10,7 @@ def test_rubikscube_config_validation(): config.validate() with pytest.raises(AssertionError): - config = RubiksCubeConfig(scramble_steps=0) # Don't give an unscrambled cube + config = RubiksCubeConfig(max_scramble_steps=0) # Don't give an unscrambled cube config.validate() @@ -30,7 +30,8 @@ def test_rubikscube_items(): """Test basic properties and solution of generated items""" config = RubiksCubeConfig( cube_size=3, - scramble_steps=4, + min_scramble_steps=4, + max_scramble_steps=4, seed=42, size=100, ) @@ -60,3 +61,28 @@ def test_rubikscube_items(): if len(item["metadata"]["example_correct_answer"]) > 0: assert dataset.score_answer(answer="", entry=item) == 0.01 + + +def test_rubiks_cube_curriculum(): + curriculum = RubiksCubeCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: RubiksCubeConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.cube_size == 3 + assert base_cfg.min_scramble_steps == 3 and base_cfg.max_scramble_steps == 10 + + # test incrementing attribute levels for cube_size & scramble_stepsd attributes + curriculum.increment_attr_level("cube_size") + curriculum.increment_attr_level("scramble_steps") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.cube_size == 4 + assert increased_cfg.min_scramble_steps == 3 and increased_cfg.max_scramble_steps == 50 + + # test decrementing attribute level for cube_size again + curriculum.decrement_attr_level("cube_size") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.cube_size == 3 + assert partially_decreased_cfg.min_scramble_steps == 3 and partially_decreased_cfg.max_scramble_steps == 50