feat(env): Mahjong Puzzle Curriculum (#263)

* mahjong curriculum

* typo

* update levels
This commit is contained in:
Zafir Stojanovski 2025-03-05 22:28:02 +01:00 committed by GitHub
parent 8ecc723607
commit d0a42116fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 45 additions and 3 deletions

View file

@ -10,7 +10,7 @@ from .countdown import CountdownConfig, CountdownDataset
from .emoji_mystery import EmojiMysteryConfig, EmojiMysteryDataset from .emoji_mystery import EmojiMysteryConfig, EmojiMysteryDataset
from .futoshiki import FutoshikiConfig, FutoshikiDataset from .futoshiki import FutoshikiConfig, FutoshikiDataset
from .knight_swap import KnightSwapConfig, KnightSwapDataset from .knight_swap import KnightSwapConfig, KnightSwapDataset
from .mahjong import MahjongPuzzleConfig, MahjongPuzzleDataset from .mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset
from .maze import MazeConfig, MazeDataset from .maze import MazeConfig, MazeDataset
from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
from .n_queens import NQueensConfig, NQueensCurriculum, NQueensDataset from .n_queens import NQueensConfig, NQueensCurriculum, NQueensDataset
@ -48,4 +48,5 @@ __all__ = [
"KnightSwapDataset", "KnightSwapDataset",
"MahjongPuzzleConfig", "MahjongPuzzleConfig",
"MahjongPuzzleDataset", "MahjongPuzzleDataset",
"MahjongPuzzleCurriculum",
] ]

View file

@ -8,6 +8,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Optional from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """There are several letter cards, and the game rules are as follows: QUESTION_TEMPLATE = """There are several letter cards, and the game rules are as follows:
@ -38,7 +39,7 @@ class MahjongPuzzleConfig:
def validate(self): def validate(self):
"""Validate configuration parameters""" """Validate configuration parameters"""
assert 1 <= self.min_num_rounds, "min_num_rounds must be reater than 0" assert 1 <= self.min_num_rounds, "min_num_rounds must be greater than 0"
assert self.min_num_rounds <= self.max_num_rounds, "min_num_rounds must be less than max_num_rounds" assert self.min_num_rounds <= self.max_num_rounds, "min_num_rounds must be less than max_num_rounds"
@ -122,4 +123,23 @@ class MahjongPuzzleDataset(ProceduralDataset):
} }
class MahjongPuzzleCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(MahjongPuzzleCurriculum.__name__, MahjongPuzzleConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="num_rounds",
levels=[10, 50, 100, 500],
default_level=0,
description="Number of rounds in the game",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_num_rounds",
upper_field_name="max_num_rounds",
)
)
register_dataset("mahjong_puzzle", MahjongPuzzleDataset, MahjongPuzzleConfig) register_dataset("mahjong_puzzle", MahjongPuzzleDataset, MahjongPuzzleConfig)

View file

@ -4,7 +4,7 @@ import string
import pytest import pytest
from reasoning_gym.games.mahjong import MahjongPuzzleConfig, MahjongPuzzleDataset from reasoning_gym.games.mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset
def test_mahjong_puzzle_config_validation(): def test_mahjong_puzzle_config_validation():
@ -95,3 +95,24 @@ def test_mahjong_puzzle_answer():
for c in string.ascii_lowercase: for c in string.ascii_lowercase:
assert dataset._check_peng(cards, new_card=c) == False assert dataset._check_peng(cards, new_card=c) == False
assert dataset._check_chi(cards, new_card=c) == False assert dataset._check_chi(cards, new_card=c) == False
def test_mahjong_puzzle_curriculum():
curriculum = MahjongPuzzleCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: MahjongPuzzleConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_num_rounds == 10 and base_cfg.max_num_rounds == 10
# test incrementing attribute levels for num_rounds attribute
curriculum.increment_attr_level("num_rounds")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_num_rounds == 10 and increased_cfg.max_num_rounds == 50
# test incrementing again
curriculum.increment_attr_level("num_rounds")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_num_rounds == 10 and increased_cfg.max_num_rounds == 100