diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index f4d718b7..2e8e79ef 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -10,7 +10,7 @@ from .countdown import CountdownConfig, CountdownDataset from .emoji_mystery import EmojiMysteryConfig, EmojiMysteryDataset from .futoshiki import FutoshikiConfig, FutoshikiDataset from .knight_swap import KnightSwapConfig, KnightSwapDataset -from .mahjong import MahjongPuzzleConfig, MahjongPuzzleDataset +from .mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset from .maze import MazeConfig, MazeDataset from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset from .n_queens import NQueensConfig, NQueensCurriculum, NQueensDataset @@ -48,4 +48,5 @@ __all__ = [ "KnightSwapDataset", "MahjongPuzzleConfig", "MahjongPuzzleDataset", + "MahjongPuzzleCurriculum", ] diff --git a/reasoning_gym/games/mahjong.py b/reasoning_gym/games/mahjong.py index 77071057..cbd653db 100644 --- a/reasoning_gym/games/mahjong.py +++ b/reasoning_gym/games/mahjong.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """There are several letter cards, and the game rules are as follows: @@ -38,7 +39,7 @@ class MahjongPuzzleConfig: def validate(self): """Validate configuration parameters""" - assert 1 <= self.min_num_rounds, "min_num_rounds must be reater than 0" + assert 1 <= self.min_num_rounds, "min_num_rounds must be greater than 0" assert self.min_num_rounds <= self.max_num_rounds, "min_num_rounds must be less than max_num_rounds" @@ -122,4 +123,23 @@ class MahjongPuzzleDataset(ProceduralDataset): } +class MahjongPuzzleCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(MahjongPuzzleCurriculum.__name__, MahjongPuzzleConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="num_rounds", + levels=[10, 50, 100, 500], + default_level=0, + description="Number of rounds in the game", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_num_rounds", + upper_field_name="max_num_rounds", + ) + ) + + register_dataset("mahjong_puzzle", MahjongPuzzleDataset, MahjongPuzzleConfig) diff --git a/tests/test_mahjong_puzzle.py b/tests/test_mahjong_puzzle.py index f2f4f0b9..eaa0d9c0 100644 --- a/tests/test_mahjong_puzzle.py +++ b/tests/test_mahjong_puzzle.py @@ -4,7 +4,7 @@ import string import pytest -from reasoning_gym.games.mahjong import MahjongPuzzleConfig, MahjongPuzzleDataset +from reasoning_gym.games.mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset def test_mahjong_puzzle_config_validation(): @@ -95,3 +95,24 @@ def test_mahjong_puzzle_answer(): for c in string.ascii_lowercase: assert dataset._check_peng(cards, new_card=c) == False assert dataset._check_chi(cards, new_card=c) == False + + +def test_mahjong_puzzle_curriculum(): + curriculum = MahjongPuzzleCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: MahjongPuzzleConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_num_rounds == 10 and base_cfg.max_num_rounds == 10 + + # test incrementing attribute levels for num_rounds attribute + curriculum.increment_attr_level("num_rounds") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_num_rounds == 10 and increased_cfg.max_num_rounds == 50 + + # test incrementing again + curriculum.increment_attr_level("num_rounds") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_num_rounds == 10 and increased_cfg.max_num_rounds == 100