diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index c2d39fef..8ed7afbd 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -11,7 +11,7 @@ from .emoji_mystery import EmojiMysteryConfig, EmojiMysteryCurriculum, EmojiMyst from .futoshiki import FutoshikiConfig, FutoshikiDataset from .knight_swap import KnightSwapConfig, KnightSwapDataset from .mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset -from .maze import MazeConfig, MazeDataset +from .maze import MazeConfig, MazeCurriculum, MazeDataset from .mini_sudoku import MiniSudokuConfig, MiniSudokuCurriculum, MiniSudokuDataset from .n_queens import NQueensConfig, NQueensCurriculum, NQueensDataset from .puzzle24 import Puzzle24Config, Puzzle24Dataset @@ -44,6 +44,7 @@ __all__ = [ "RushHourDataset", "MazeConfig", "MazeDataset", + "MazeCurriculum", "HanoiConfig", "HanoiDataset", "NQueensDataset", diff --git a/reasoning_gym/games/maze.py b/reasoning_gym/games/maze.py index c656755e..39bf9e9e 100644 --- a/reasoning_gym/games/maze.py +++ b/reasoning_gym/games/maze.py @@ -3,6 +3,7 @@ import string from dataclasses import dataclass from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -110,6 +111,10 @@ class MazeDataset(ProceduralDataset): "goal": self.goal_char, "wall": self.wall_char, "path": self.path_char, + "difficulty": { + "dist": dist, + "grid_size": size, + }, }, } @@ -184,4 +189,33 @@ class MazeDataset(ProceduralDataset): return "\n".join("".join(row) for row in grid) -register_dataset("maze", MazeDataset, MazeConfig) +class MazeCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(MazeCurriculum.__name__, MazeConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="dist", + levels=[10, 25, 50, 100], + default_level=1, + description="Distance from start to goal", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_dist", + upper_field_name="max_dist", + ), + RangeAttributeDefinition( + name="grid_size", + levels=[10, 25, 50, 100], + default_level=1, + description="Size of the square grid", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_grid_size", + upper_field_name="max_grid_size", + ), + ) + + +register_dataset("maze", MazeDataset, MazeConfig, MazeCurriculum) diff --git a/tests/test_maze.py b/tests/test_maze.py index fda8ed14..ad2c8158 100644 --- a/tests/test_maze.py +++ b/tests/test_maze.py @@ -1,7 +1,7 @@ import pytest from reasoning_gym import create_dataset -from reasoning_gym.games.maze import MazeConfig, MazeDataset +from reasoning_gym.games.maze import MazeConfig, MazeCurriculum, MazeDataset def test_maze_config_validation(): @@ -125,3 +125,28 @@ def _bfs_distance(maze, start, goal, wall_char): queue.append((nr, nc, dist + 1)) return None # no path found + + +def test_maze_curriculum(): + curriculum = MazeCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: MazeConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_dist == 10 and base_cfg.max_dist == 25 + assert base_cfg.min_grid_size == 10 and base_cfg.max_grid_size == 25 + + # test incrementing attribute levels + curriculum.increment_attr_level("dist") + curriculum.increment_attr_level("grid_size") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_dist == 10 and increased_cfg.max_dist == 50 + assert increased_cfg.min_grid_size == 10 and increased_cfg.max_grid_size == 50 + + # test decrementing attribute level for dist again + curriculum.decrement_attr_level("dist") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_dist == 10 and partially_decreased_cfg.max_dist == 25 + assert partially_decreased_cfg.min_grid_size == 10 and partially_decreased_cfg.max_grid_size == 50