added tsumego curric (#323)

This commit is contained in:
joesharratt1229 2025-03-11 00:19:55 +01:00 committed by GitHub
parent 9aeef4ebb0
commit e9944149bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 80 additions and 2 deletions

View file

@ -19,7 +19,7 @@ from .rush_hour import RushHourConfig, RushHourDataset
from .sokoban import SokobanConfig, SokobanDataset
from .sudoku import SudokuConfig, SudokuDataset
from .tower_of_hanoi import HanoiConfig, HanoiDataset
from .tsumego import TsumegoConfig, TsumegoDataset
from .tsumego import TsumegoConfig, TsumegoCurriculum, TsumegoDataset
__all__ = [
"CountdownConfig",
@ -49,6 +49,7 @@ __all__ = [
"NQueensConfig",
"NQueensCurriculum",
"TsumegoConfig",
"TsumegoCurriculum",
"TsumegoDataset",
"KnightSwapConfig",
"KnightSwapDataset",

View file

@ -21,6 +21,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
# Added constant to avoid repetition of adjacent directions
@ -290,5 +291,22 @@ class TsumegoDataset(ProceduralDataset):
return reward
class TsumegoCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(TsumegoCurriculum.__name__, TsumegoConfig)
self._define_attributes(
RangeAttributeDefinition(
name="board_size",
levels=[9, 10, 11, 12],
default_level=0,
min_value=9,
attr_type=AttributeType.APPEND,
lower_field_name="min_board_size",
upper_field_name="max_board_size",
description="The size of the board",
)
)
# Register the dataset
register_dataset("tsumego", TsumegoDataset, TsumegoConfig)
register_dataset("tsumego", TsumegoDataset, TsumegoConfig, TsumegoCurriculum)

View file

@ -260,3 +260,62 @@ def test_capture_verification():
final_white = sum(row.count("O") for row in board_after)
assert final_white < initial_white, "The solution move should capture at least one opponent stone."
def test_tsumego_curriculum():
"""Test the TsumegoCurriculum functionality"""
from reasoning_gym.games.tsumego import TsumegoCurriculum
curriculum = TsumegoCurriculum()
base_value = {"size": 150, "seed": 1}
# Test initial configuration
base_cfg = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_board_size == 9 and base_cfg.max_board_size == 9
assert base_cfg.max_stones == 15 # Default value from TsumegoConfig
# Test incrementing attribute level
curriculum.increment_attr_level("board_size")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_board_size == 9 and increased_cfg.max_board_size == 10
assert increased_cfg.max_stones == 15 # Unchanged
# Test incrementing attribute level again
curriculum.increment_attr_level("board_size")
increased_cfg_2 = curriculum.generate_configuration(base_value)
assert increased_cfg_2.min_board_size == 9 and increased_cfg_2.max_board_size == 11
assert increased_cfg_2.max_stones == 15 # Unchanged
# Test decrementing attribute level
curriculum.decrement_attr_level("board_size")
decreased_cfg = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_board_size == 9 and decreased_cfg.max_board_size == 10
assert decreased_cfg.max_stones == 15 # Unchanged
# Test global level adjustments
curriculum = TsumegoCurriculum() # Reset curriculum
assert curriculum.get_attr_level("board_size") == 0
# Increase global level
curriculum.increment_global_level()
assert curriculum.get_attr_level("board_size") == 1
global_level_cfg = curriculum.generate_configuration(base_value)
assert global_level_cfg.min_board_size == 9 and global_level_cfg.max_board_size == 10
# Increase global level again
curriculum.increment_global_level()
assert curriculum.get_attr_level("board_size") == 2
global_level_cfg_2 = curriculum.generate_configuration(base_value)
assert global_level_cfg_2.min_board_size == 9 and global_level_cfg_2.max_board_size == 11
# Decrease global level
curriculum.decrement_global_level()
assert curriculum.get_attr_level("board_size") == 1
global_level_cfg_3 = curriculum.generate_configuration(base_value)
assert global_level_cfg_3.min_board_size == 9 and global_level_cfg_3.max_board_size == 10