diff --git a/reasoning_gym/cognition/rubiks_cube.py b/reasoning_gym/cognition/rubiks_cube.py index 1ad88b1c..622c9ecb 100644 --- a/reasoning_gym/cognition/rubiks_cube.py +++ b/reasoning_gym/cognition/rubiks_cube.py @@ -14,6 +14,8 @@ class RubiksCubeConfig: scramble_steps: int = 3 # Number of random steps from initial state cube_size: int = 3 # Default to a standard 3x3x3 cube + seed: Optional[int] = None + size: int = 500 def validate(self) -> None: """Validate configuration parameters""" @@ -30,7 +32,7 @@ class RubiksCubeDataset(ProceduralDataset): "You are given a {cube_size}x{cube_size}x{cube_size} Rubik's cube. It looks like this:\n\n{cube_render} \n\nPlease provide a solution to solve this cube using Singmaster notation.", "You see a size {cube_size} Rubik's cube. It is arranged this:\n\n{cube_render} \n\nPlease provide a solution to solve this cube.", ] - super().__init__(config=config) + super().__init__(config=config, seed=config.seed, size=config.size) def _generate_random_moves(self, rng: Random, cube: Cube, num_steps: int = 50, wide=None) -> List[CubeMove]: """Generate a list of random moves (but don't apply them). diff --git a/tests/test_rubiks_cube.py b/tests/test_rubiks_cube.py index e193ceef..96e7d41d 100644 --- a/tests/test_rubiks_cube.py +++ b/tests/test_rubiks_cube.py @@ -1,6 +1,5 @@ import pytest -from magiccube.cube import Cube from reasoning_gym.cognition.rubiks_cube import RubiksCubeConfig, RubiksCubeDataset @@ -17,11 +16,13 @@ def test_rubikscube_config_validation(): def test_rubikscube_deterministic(): """Test that dataset generates same items with same seed""" - config = RubiksCubeConfig(seed=42, size=15) + config = RubiksCubeConfig(seed=42, size=15) # Only check first 15 entries for speed dataset1 = RubiksCubeDataset(config) dataset2 = RubiksCubeDataset(config) + assert len(dataset1) == 15 + assert len(dataset2) == 15 - for i in range(15): # Only check first 15 entries for speed + for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] @@ -29,7 +30,8 @@ def test_rubikscube_items(): """Test basic properties and solution of generated items""" config = RubiksCubeConfig( cube_size=3, - scramble_steps=4 + scramble_steps=4, + size=100, ) dataset = RubiksCubeDataset(config) @@ -46,7 +48,6 @@ def test_rubikscube_items(): assert "scramble_moves" in item["metadata"] assert "example_correct_answer" in item["metadata"] - assert dataset.score_answer(answer=item['metadata']['example_correct_answer'], entry=item) == 1.0 - assert dataset.score_answer(answer='R', entry=item) == 0.01 + assert dataset.score_answer(answer=item["metadata"]["example_correct_answer"], entry=item) == 1.0 + assert dataset.score_answer(answer="R", entry=item) == 0.01 assert dataset.score_answer(answer=None, entry=item) == 0.0 -