diff --git a/pyproject.toml b/pyproject.toml index eeb53cec..3181bd02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ description = "A library of procedural dataset generators for training reasoning models" readme = "README.md" requires-python = ">=3.11" -dependencies = ["sympy>=1.13.1"] +dependencies = ["sympy>=1.13.1", "magiccube==0.3.0"] classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", diff --git a/reasoning_gym/cognition/rubiks_cube.py b/reasoning_gym/cognition/rubiks_cube.py new file mode 100644 index 00000000..471a714a --- /dev/null +++ b/reasoning_gym/cognition/rubiks_cube.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +import random +import re + +from magiccube.cube import Cube +from magiccube.solver.basic.basic_solver import BasicSolver + +from typing import List, Optional, Tuple, Dict + + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class RubiksCubeConfig: + """Configuration for RubiksCube task generation""" + + scramble_steps: int = 3 # Number of random steps from initial state + cube_size: int = 3 # Default to a standard 3x3x3 cube + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.cube_size > 1, "cube_size must be greater than 1" + assert self.cube_size < 7, "cube_size must be less than 7" + assert self.scramble_steps > 0, "scramble_steps must be greater than 0" + +class RubiksCubeDataset(ProceduralDataset): + """Generates RubiksCube tasks""" + + def __init__(self, config: RubiksCubeConfig): + self._prompt_templates = [ + "You are given a {cube_size}x{cube_size}x{cube_size} Rubik's cube. It looks like this:\n\n{cube_render} \n\nPlease provide a solution to solve this cube using Singmaster notation.", + "You see a size {cube_size} Rubik's cube. It is arranged this:\n\n{cube_render} \n\nPlease provide a solution to solve this cube.", + ] + super().__init__(config=config) + + def __getitem__(self, idx: int) -> dict: + """Generate a single RubiksCube task + + Returns: + dict with keys: + - question: str, the task description with cube string + - answer: None, indicating to use the dynamic evaluator + - metadata: dict with generation parameters and example solution + """ + + cube = Cube(self.config.cube_size) + scramble_moves = cube.scramble(num_steps=self.config.scramble_steps) + cube_string = cube.cube + cube_render = self.remove_ansi(str(cube)) + + if self.config.cube_size == 3: + solver = BasicSolver(cube) + actions = solver.solve() + actions_string = ' '.join([str(move) for move in actions]) + else: + actions = None + + return { + "question": random.choice(self._prompt_templates).format(cube_size=self.config.cube_size, cube_string=cube_string, cube_render=cube_render), + "answer": None, + "metadata": { + "cube_size": self.config.cube_size, + "scramble_steps": self.config.scramble_steps, + "scramble_moves": scramble_moves, + "cube_string": str(cube), + "example_correct_answer": actions_string, + }, + } + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """ Determine if the solution provided solves the cube """ + + reward = 0.0 + if answer is not None: + + # Reconstruct the test cube + eval_cube = Cube(entry["metadata"]["cube_size"]) + eval_cube.rotate(entry["metadata"]["scramble_moves"]) + + # Test the solution + eval_cube.rotate(answer) + solved = eval_cube.is_done() + + if solved: + reward = 1.0 + else: + reward = 0.01 # At least you tried + + return reward + + def remove_ansi(self, line): + """ Remove terminal colors from magiccube rendering""" + ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') + return ansi_escape.sub('', line) + + +# Register the dataset +register_dataset("RubiksCube", RubiksCubeDataset, RubiksCubeConfig) diff --git a/tests/test_rubikscube.py b/tests/test_rubikscube.py new file mode 100644 index 00000000..379b7f09 --- /dev/null +++ b/tests/test_rubikscube.py @@ -0,0 +1,46 @@ +import pytest + +from magiccube.cube import Cube +from reasoning_gym.cognition.rubiks_cube import RubiksCubeConfig, RubiksCubeDataset + + +def test_rubikscube_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = RubiksCubeConfig(cube_size=1) # Too small + config.validate() + + with pytest.raises(AssertionError): + config = RubiksCubeConfig(scramble_steps=0) # Don't give an unscrambled cube + config.validate() + + +def test_rubikscube_items(): + """Test basic properties and solution of generated items""" + config = RubiksCubeConfig( + cube_size=3, + scramble_steps=4 + ) + dataset = RubiksCubeDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata contains required fields + assert "cube_size" in item["metadata"] + assert "cube_string" in item["metadata"] + assert "cube_size" in item["metadata"] + assert "scramble_steps" in item["metadata"] + assert "scramble_moves" in item["metadata"] + assert "example_correct_answer" in item["metadata"] + + assert dataset.score_answer(answer=item['metadata']['example_correct_answer'], entry=item) == 1.0 + assert dataset.score_answer(answer='R', entry=item) == 0.01 + assert dataset.score_answer(answer=None, entry=item) == 0.0 + + + +