mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
118 lines
3.9 KiB
Python
118 lines
3.9 KiB
Python
"""Tests for Mahjong Puzzle questions generation"""
|
|
|
|
import string
|
|
|
|
import pytest
|
|
|
|
from reasoning_gym.games.mahjong import MahjongPuzzleConfig, MahjongPuzzleCurriculum, MahjongPuzzleDataset
|
|
|
|
|
|
def test_mahjong_puzzle_config_validation():
|
|
"""Test that invalid configs raise appropriate errors"""
|
|
|
|
with pytest.raises(AssertionError):
|
|
config = MahjongPuzzleConfig(min_num_rounds=-1) # Negative not allowed
|
|
config.validate()
|
|
|
|
with pytest.raises(AssertionError):
|
|
config = MahjongPuzzleConfig(min_num_rounds=0) # Zero not allowed
|
|
config.validate()
|
|
|
|
with pytest.raises(AssertionError):
|
|
config = MahjongPuzzleConfig(min_num_rounds=3, max_num_rounds=2) # Min > Max
|
|
config.validate()
|
|
|
|
|
|
def test_mahjong_puzzle_dataset_deterministic():
|
|
"""Test that dataset generates same items with same seed"""
|
|
config = MahjongPuzzleConfig(seed=42, size=10)
|
|
dataset1 = MahjongPuzzleDataset(config)
|
|
dataset2 = MahjongPuzzleDataset(config)
|
|
|
|
for i in range(len(dataset1)):
|
|
assert dataset1[i] == dataset2[i]
|
|
|
|
|
|
def test_mahjong_puzzle_dataset_items():
|
|
"""Test basic properties of generated items"""
|
|
config = MahjongPuzzleConfig(min_num_rounds=3, max_num_rounds=5, size=10, seed=42)
|
|
dataset = MahjongPuzzleDataset(config)
|
|
|
|
for i in range(len(dataset)):
|
|
item = dataset[i]
|
|
# Check item structure
|
|
assert isinstance(item, dict)
|
|
assert "question" in item
|
|
assert "answer" in item
|
|
assert "metadata" in item
|
|
|
|
# Check metadata
|
|
assert "rounds" in item["metadata"]
|
|
assert "solution" in item["metadata"]
|
|
|
|
rounds = item["metadata"]["rounds"]
|
|
solution = item["metadata"]["solution"]
|
|
|
|
# Verify values
|
|
assert solution in ["Peng", "Chi", "Pass"]
|
|
assert 3 <= len(rounds) <= 5
|
|
assert all(isinstance(r, dict) for r in rounds)
|
|
assert all("add" in r for r in rounds)
|
|
assert all("remove" in r for r in rounds)
|
|
assert all(len(r["cards"]) == 13 for r in rounds)
|
|
assert all(r["result"] in ["Peng", "Chi", "Pass"] for r in rounds)
|
|
|
|
|
|
def test_mahjong_puzzle_dataset_iteration():
|
|
"""Test that iteration respects dataset size"""
|
|
config = MahjongPuzzleConfig(size=5, seed=42)
|
|
dataset = MahjongPuzzleDataset(config)
|
|
|
|
items = list(dataset)
|
|
assert len(items) == config.size
|
|
|
|
# Test multiple iterations yield same items
|
|
assert items == list(dataset)
|
|
|
|
|
|
def test_mahjong_puzzle_answer():
|
|
"""Test the _get_answer method"""
|
|
config = MahjongPuzzleConfig(seed=42)
|
|
dataset = MahjongPuzzleDataset(config)
|
|
|
|
# Peng
|
|
cards = "ABBCCDDEEFFGH"
|
|
assert dataset._check_peng(cards, new_card="B") == True # B, B, B
|
|
assert dataset._check_peng(cards, new_card="A") == False
|
|
|
|
# Chi
|
|
cards = "ABDGIKMOQSUWY"
|
|
assert dataset._check_chi(cards, new_card="C") == True # A, B, C
|
|
assert dataset._check_chi(cards, new_card="A") == False
|
|
|
|
# Pass
|
|
cards = "ACEGIKMOQSUWY"
|
|
for c in string.ascii_lowercase:
|
|
assert dataset._check_peng(cards, new_card=c) == False
|
|
assert dataset._check_chi(cards, new_card=c) == False
|
|
|
|
|
|
def test_mahjong_puzzle_curriculum():
|
|
curriculum = MahjongPuzzleCurriculum()
|
|
|
|
base_value = {"size": 150, "seed": 1}
|
|
|
|
base_cfg: MahjongPuzzleConfig = curriculum.generate_configuration(base_value)
|
|
assert base_cfg.seed == 1
|
|
assert base_cfg.size == 150
|
|
assert base_cfg.min_num_rounds == 10 and base_cfg.max_num_rounds == 10
|
|
|
|
# test incrementing attribute levels for num_rounds attribute
|
|
curriculum.increment_attr_level("num_rounds")
|
|
increased_cfg = curriculum.generate_configuration(base_value)
|
|
assert increased_cfg.min_num_rounds == 10 and increased_cfg.max_num_rounds == 50
|
|
|
|
# test incrementing again
|
|
curriculum.increment_attr_level("num_rounds")
|
|
increased_cfg = curriculum.generate_configuration(base_value)
|
|
assert increased_cfg.min_num_rounds == 10 and increased_cfg.max_num_rounds == 100
|