diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 9d15ec3c..9752e9a1 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -19,7 +19,7 @@ from .rush_hour import RushHourConfig, RushHourDataset from .sokoban import SokobanConfig, SokobanDataset from .sudoku import SudokuConfig, SudokuDataset from .tower_of_hanoi import HanoiConfig, HanoiDataset -from .tsumego import TsumegoConfig, TsumegoDataset +from .tsumego import TsumegoConfig, TsumegoCurriculum, TsumegoDataset __all__ = [ "CountdownConfig", @@ -49,6 +49,7 @@ __all__ = [ "NQueensConfig", "NQueensCurriculum", "TsumegoConfig", + "TsumegoCurriculum", "TsumegoDataset", "KnightSwapConfig", "KnightSwapDataset", diff --git a/reasoning_gym/games/tsumego.py b/reasoning_gym/games/tsumego.py index f2761fe9..3b77817d 100644 --- a/reasoning_gym/games/tsumego.py +++ b/reasoning_gym/games/tsumego.py @@ -21,6 +21,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset # Added constant to avoid repetition of adjacent directions @@ -290,5 +291,22 @@ class TsumegoDataset(ProceduralDataset): return reward +class TsumegoCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(TsumegoCurriculum.__name__, TsumegoConfig) + self._define_attributes( + RangeAttributeDefinition( + name="board_size", + levels=[9, 10, 11, 12], + default_level=0, + min_value=9, + attr_type=AttributeType.APPEND, + lower_field_name="min_board_size", + upper_field_name="max_board_size", + description="The size of the board", + ) + ) + + # Register the dataset -register_dataset("tsumego", TsumegoDataset, TsumegoConfig) +register_dataset("tsumego", TsumegoDataset, TsumegoConfig, TsumegoCurriculum) diff --git a/tests/test_tsumego.py b/tests/test_tsumego.py index 43460533..86383701 100644 --- a/tests/test_tsumego.py +++ b/tests/test_tsumego.py @@ -260,3 +260,62 @@ def test_capture_verification(): final_white = sum(row.count("O") for row in board_after) assert final_white < initial_white, "The solution move should capture at least one opponent stone." + + +def test_tsumego_curriculum(): + """Test the TsumegoCurriculum functionality""" + from reasoning_gym.games.tsumego import TsumegoCurriculum + + curriculum = TsumegoCurriculum() + + base_value = {"size": 150, "seed": 1} + + # Test initial configuration + base_cfg = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_board_size == 9 and base_cfg.max_board_size == 9 + assert base_cfg.max_stones == 15 # Default value from TsumegoConfig + + # Test incrementing attribute level + curriculum.increment_attr_level("board_size") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_board_size == 9 and increased_cfg.max_board_size == 10 + assert increased_cfg.max_stones == 15 # Unchanged + + # Test incrementing attribute level again + curriculum.increment_attr_level("board_size") + increased_cfg_2 = curriculum.generate_configuration(base_value) + assert increased_cfg_2.min_board_size == 9 and increased_cfg_2.max_board_size == 11 + assert increased_cfg_2.max_stones == 15 # Unchanged + + # Test decrementing attribute level + curriculum.decrement_attr_level("board_size") + decreased_cfg = curriculum.generate_configuration(base_value) + assert decreased_cfg.min_board_size == 9 and decreased_cfg.max_board_size == 10 + assert decreased_cfg.max_stones == 15 # Unchanged + + # Test global level adjustments + curriculum = TsumegoCurriculum() # Reset curriculum + assert curriculum.get_attr_level("board_size") == 0 + + # Increase global level + curriculum.increment_global_level() + assert curriculum.get_attr_level("board_size") == 1 + + global_level_cfg = curriculum.generate_configuration(base_value) + assert global_level_cfg.min_board_size == 9 and global_level_cfg.max_board_size == 10 + + # Increase global level again + curriculum.increment_global_level() + assert curriculum.get_attr_level("board_size") == 2 + + global_level_cfg_2 = curriculum.generate_configuration(base_value) + assert global_level_cfg_2.min_board_size == 9 and global_level_cfg_2.max_board_size == 11 + + # Decrease global level + curriculum.decrement_global_level() + assert curriculum.get_attr_level("board_size") == 1 + + global_level_cfg_3 = curriculum.generate_configuration(base_value) + assert global_level_cfg_3.min_board_size == 9 and global_level_cfg_3.max_board_size == 10