diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 15bd96d1..a17d0ea6 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -12,7 +12,7 @@ from .futoshiki import FutoshikiConfig, FutoshikiDataset from .knight_swap import KnightSwapConfig, KnightSwapDataset from .mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset from .maze import MazeConfig, MazeDataset -from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset +from .mini_sudoku import MiniSudokuConfig, MiniSudokuCurriculum, MiniSudokuDataset from .n_queens import NQueensConfig, NQueensCurriculum, NQueensDataset from .puzzle24 import Puzzle24Config, Puzzle24Dataset from .rush_hour import RushHourConfig, RushHourDataset @@ -30,6 +30,7 @@ __all__ = [ "FutoshikiDataset", "MiniSudokuConfig", "MiniSudokuDataset", + "MiniSudokuCurriculum", "Puzzle24Config", "Puzzle24Dataset", "SudokuConfig", diff --git a/reasoning_gym/games/mini_sudoku.py b/reasoning_gym/games/mini_sudoku.py index 46df6aa2..16a7032f 100644 --- a/reasoning_gym/games/mini_sudoku.py +++ b/reasoning_gym/games/mini_sudoku.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -191,7 +192,14 @@ class MiniSudokuDataset(ProceduralDataset): return { "question": question, "answer": solution_str, - "metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty}, + "metadata": { + "puzzle": puzzle, + "solution": solved_board, + "num_empty": num_empty, + "difficulty": { + "empty": num_empty, + }, + }, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -232,4 +240,23 @@ class MiniSudokuDataset(ProceduralDataset): return reward -register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig) +class MiniSudokuCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(MiniSudokuCurriculum.__name__, MiniSudokuConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="empty", + levels=[4, 6, 8, 10], + default_level=1, + description="Number of empty cells in the puzzle", + attr_type=AttributeType.APPEND, + min_value=0, + lower_field_name="min_empty", + upper_field_name="max_empty", + ) + ) + + +register_dataset("mini_sudoku", MiniSudokuDataset, MiniSudokuConfig, MiniSudokuCurriculum) diff --git a/tests/test_mini_sudoku.py b/tests/test_mini_sudoku.py index 606a544a..ac892f4f 100644 --- a/tests/test_mini_sudoku.py +++ b/tests/test_mini_sudoku.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.games.mini_sudoku import MiniSudokuConfig, MiniSudokuDataset +from reasoning_gym.games.mini_sudoku import MiniSudokuConfig, MiniSudokuCurriculum, MiniSudokuDataset def test_mini_sudoku_config_validation(): @@ -120,3 +120,24 @@ def is_valid_solution(board: list[list[int]]) -> bool: return False return True + + +def test_mini_sudoku_curriculum(): + curriculum = MiniSudokuCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: MiniSudokuConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_empty == 4 and base_cfg.max_empty == 6 + + # test incrementing attribute levels + curriculum.increment_attr_level("empty") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_empty == 4 and increased_cfg.max_empty == 8 + + # test decrementing attribute level for empty again + curriculum.decrement_attr_level("empty") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_empty == 4 and partially_decreased_cfg.max_empty == 6