diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 46869a31..69dc438a 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -9,7 +9,7 @@ Game tasks for training reasoning capabilities: from .boxnet import BoxnetConfig, BoxnetCurriculum, BoxnetDataset from .countdown import CountdownConfig, CountdownDataset from .emoji_mystery import EmojiMysteryConfig, EmojiMysteryCurriculum, EmojiMysteryDataset -from .futoshiki import FutoshikiConfig, FutoshikiDataset +from .futoshiki import FutoshikiConfig, FutoshikiCurriculum, FutoshikiDataset from .knight_swap import KnightSwapConfig, KnightSwapDataset from .mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset from .maze import MazeConfig, MazeCurriculum, MazeDataset @@ -18,7 +18,7 @@ from .n_queens import NQueensConfig, NQueensCurriculum, NQueensDataset from .puzzle24 import Puzzle24Config, Puzzle24Dataset from .rush_hour import RushHourConfig, RushHourDataset from .sokoban import SokobanConfig, SokobanCurriculum, SokobanDataset -from .sudoku import SudokuConfig, SudokuDataset +from .sudoku import SudokuConfig, SudokuCurriculum, SudokuDataset from .tower_of_hanoi import HanoiConfig, HanoiDataset from .tsumego import TsumegoConfig, TsumegoCurriculum, TsumegoDataset @@ -40,6 +40,7 @@ __all__ = [ "Puzzle24Config", "Puzzle24Dataset", "SudokuConfig", + "SudokuCurriculum", "SudokuDataset", "SokobanConfig", "SokobanCurriculum", diff --git a/reasoning_gym/games/sudoku.py b/reasoning_gym/games/sudoku.py index d7ecf8d2..83016cb7 100644 --- a/reasoning_gym/games/sudoku.py +++ b/reasoning_gym/games/sudoku.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -210,7 +211,14 @@ class SudokuDataset(ProceduralDataset): return { "question": question, "answer": solution_str, - "metadata": {"puzzle": puzzle, "solution": solved_board, "num_empty": num_empty}, + "metadata": { + "puzzle": puzzle, + "solution": solved_board, + "num_empty": num_empty, + "difficulty": { + "num_empty": num_empty, + }, + }, } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -251,4 +259,23 @@ class SudokuDataset(ProceduralDataset): return reward -register_dataset("sudoku", SudokuDataset, SudokuConfig) +class SudokuCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(SudokuCurriculum.__name__, SudokuConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="empty", + levels=[20, 30, 40, 50], + default_level=1, + description="Number of empty cells in the puzzle", + attr_type=AttributeType.APPEND, + min_value=0, + lower_field_name="min_empty", + upper_field_name="max_empty", + ) + ) + + +register_dataset("sudoku", SudokuDataset, SudokuConfig, SudokuCurriculum) diff --git a/tests/test_sudoku.py b/tests/test_sudoku.py index ce41ce3a..27cc9a93 100644 --- a/tests/test_sudoku.py +++ b/tests/test_sudoku.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.games.sudoku import SudokuConfig, SudokuDataset +from reasoning_gym.games.sudoku import SudokuConfig, SudokuCurriculum, SudokuDataset def test_sudoku_config_validation(): @@ -120,3 +120,24 @@ def is_valid_solution(board: list[list[int]]) -> bool: return False return True + + +def test_sudoku_curriculum(): + curriculum = SudokuCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: SudokuConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_empty == 20 and base_cfg.max_empty == 30 + + # test incrementing attribute levels + curriculum.increment_attr_level("empty") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_empty == 20 and increased_cfg.max_empty == 40 + + # test decrementing attribute level for empty again + curriculum.decrement_attr_level("empty") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_empty == 20 and partially_decreased_cfg.max_empty == 30