diff --git a/tests/test_zebra.py b/tests/test_zebra.py index d233c438..054fad22 100644 --- a/tests/test_zebra.py +++ b/tests/test_zebra.py @@ -3,9 +3,18 @@ import pytest from reasoning_gym.logic.zebra_puzzles import ZebraConfig, ZebraDataset +def test_zebra_deterministic(): + """Test that dataset generates same items with same seed""" + config = ZebraConfig(seed=42, size=10, num_people=4, num_characteristics=4) + dataset1 = ZebraDataset(config) + dataset2 = ZebraDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + def test_zebra_puzzles(): """Test basic properties and solution of generated items""" - config = ZebraConfig(seed=42, size=10, num_people=4, num_characteristics=4) dataset = ZebraDataset(config)