feat(env): Mahjong Puzzle Curriculum (#263)

* mahjong curriculum

* typo

* update levels
This commit is contained in:
Zafir Stojanovski 2025-03-05 22:28:02 +01:00 committed by GitHub
parent 19ca54da72
commit 3c544aba20
3 changed files with 45 additions and 3 deletions

View file

@ -4,7 +4,7 @@ import string
import pytest
from reasoning_gym.games.mahjong import MahjongPuzzleConfig, MahjongPuzzleDataset
from reasoning_gym.games.mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset
def test_mahjong_puzzle_config_validation():
@ -95,3 +95,24 @@ def test_mahjong_puzzle_answer():
for c in string.ascii_lowercase:
assert dataset._check_peng(cards, new_card=c) == False
assert dataset._check_chi(cards, new_card=c) == False
def test_mahjong_puzzle_curriculum():
curriculum = MahjongPuzzleCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: MahjongPuzzleConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_num_rounds == 10 and base_cfg.max_num_rounds == 10
# test incrementing attribute levels for num_rounds attribute
curriculum.increment_attr_level("num_rounds")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_num_rounds == 10 and increased_cfg.max_num_rounds == 50
# test incrementing again
curriculum.increment_attr_level("num_rounds")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_num_rounds == 10 and increased_cfg.max_num_rounds == 100