diff --git a/reasoning_gym/algorithmic/game_of_life_halting.py b/reasoning_gym/algorithmic/game_of_life_halting.py index b3ade503..99b9661e 100644 --- a/reasoning_gym/algorithmic/game_of_life_halting.py +++ b/reasoning_gym/algorithmic/game_of_life_halting.py @@ -263,6 +263,9 @@ class GameOfLifeHaltingDataset(ProceduralDataset): }, ] + def __init__(self, config: GameOfLifeHaltingConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + def __getitem__(self, idx: int) -> dict: """Generate a single GameOfLife task diff --git a/tests/test_game_of_life.py b/tests/test_game_of_life.py index bc3c0c61..a0b8798e 100644 --- a/tests/test_game_of_life.py +++ b/tests/test_game_of_life.py @@ -23,11 +23,14 @@ def test_game_of_life_config_validation(): def test_game_of_life_deterministic(): """Test that dataset generates same items with same seed""" config = GameOfLifeConfig(seed=42, size=10) + config2 = GameOfLifeConfig(seed=43, size=10) dataset1 = GameOfLifeDataset(config) dataset2 = GameOfLifeDataset(config) + dataset3 = GameOfLifeDataset(config2) for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] + assert dataset1[i] != dataset3[i] def test_game_of_life_basic_properties(): diff --git a/tests/test_game_of_life_halting.py b/tests/test_game_of_life_halting.py index 503808e3..b09d0989 100644 --- a/tests/test_game_of_life_halting.py +++ b/tests/test_game_of_life_halting.py @@ -25,3 +25,16 @@ def test_game_of_life(): # # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 + + +def test_game_of_life_halting_deterministic(): + """Test that dataset generates same items with same seed""" + config = GameOfLifeHaltingConfig(seed=42, size=10) + config2 = GameOfLifeHaltingConfig(seed=43, size=10) + dataset1 = GameOfLifeHaltingDataset(config) + dataset2 = GameOfLifeHaltingDataset(config) + dataset3 = GameOfLifeHaltingDataset(config2) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + assert dataset1[i] != dataset3[i]