reasoning-gym/tests/test_mahjong_puzzle.py
Zafir Stojanovski d0a42116fb
feat(env): Mahjong Puzzle Curriculum (#263)
* mahjong curriculum

* typo

* update levels
2025-03-05 22:28:02 +01:00

118 lines
3.9 KiB
Python

"""Tests for Mahjong Puzzle questions generation"""
import string
import pytest
from reasoning_gym.games.mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset
def test_mahjong_puzzle_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = MahjongPuzzleConfig(min_num_rounds=-1) # Negative not allowed
config.validate()
with pytest.raises(AssertionError):
config = MahjongPuzzleConfig(min_num_rounds=0) # Zero not allowed
config.validate()
with pytest.raises(AssertionError):
config = MahjongPuzzleConfig(min_num_rounds=3, max_num_rounds=2) # Min > Max
config.validate()
def test_mahjong_puzzle_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
config = MahjongPuzzleConfig(seed=42, size=10)
dataset1 = MahjongPuzzleDataset(config)
dataset2 = MahjongPuzzleDataset(config)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_mahjong_puzzle_dataset_items():
"""Test basic properties of generated items"""
config = MahjongPuzzleConfig(min_num_rounds=3, max_num_rounds=5, size=10, seed=42)
dataset = MahjongPuzzleDataset(config)
for i in range(len(dataset)):
item = dataset[i]
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Check metadata
assert "rounds" in item["metadata"]
assert "solution" in item["metadata"]
rounds = item["metadata"]["rounds"]
solution = item["metadata"]["solution"]
# Verify values
assert solution in ["Peng", "Chi", "Pass"]
assert 3 <= len(rounds) <= 5
assert all(isinstance(r, dict) for r in rounds)
assert all("add" in r for r in rounds)
assert all("remove" in r for r in rounds)
assert all(len(r["cards"]) == 13 for r in rounds)
assert all(r["result"] in ["Peng", "Chi", "Pass"] for r in rounds)
def test_mahjong_puzzle_dataset_iteration():
"""Test that iteration respects dataset size"""
config = MahjongPuzzleConfig(size=5, seed=42)
dataset = MahjongPuzzleDataset(config)
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same items
assert items == list(dataset)
def test_mahjong_puzzle_answer():
"""Test the _get_answer method"""
config = MahjongPuzzleConfig(seed=42)
dataset = MahjongPuzzleDataset(config)
# Peng
cards = "ABBCCDDEEFFGH"
assert dataset._check_peng(cards, new_card="B") == True # B, B, B
assert dataset._check_peng(cards, new_card="A") == False
# Chi
cards = "ABDGIKMOQSUWY"
assert dataset._check_chi(cards, new_card="C") == True # A, B, C
assert dataset._check_chi(cards, new_card="A") == False
# Pass
cards = "ACEGIKMOQSUWY"
for c in string.ascii_lowercase:
assert dataset._check_peng(cards, new_card=c) == False
assert dataset._check_chi(cards, new_card=c) == False
def test_mahjong_puzzle_curriculum():
curriculum = MahjongPuzzleCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: MahjongPuzzleConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_num_rounds == 10 and base_cfg.max_num_rounds == 10
# test incrementing attribute levels for num_rounds attribute
curriculum.increment_attr_level("num_rounds")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_num_rounds == 10 and increased_cfg.max_num_rounds == 50
# test incrementing again
curriculum.increment_attr_level("num_rounds")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_num_rounds == 10 and increased_cfg.max_num_rounds == 100