feat(env): Rubiks Cube Curriculum (#357)

This commit is contained in:
Zafir Stojanovski 2025-03-13 21:12:32 +01:00 committed by GitHub
parent d603d8b72b
commit 099ea88402
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 74 additions and 9 deletions

View file

@ -6,6 +6,7 @@ from typing import Any, Optional
from magiccube.cube import Cube, CubeMove, CubeMoveType
from magiccube.solver.basic.basic_solver import BasicSolver
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -13,7 +14,8 @@ from ..factory import ProceduralDataset, register_dataset
class RubiksCubeConfig:
"""Configuration for RubiksCube task generation"""
scramble_steps: int = 3 # Number of random steps from initial state
min_scramble_steps: int = 3 # Minimum number of random steps from initial state
max_scramble_steps: int = 10 # Maximum number of random steps from initial state
cube_size: int = 3 # Default to a standard 3x3x3 cube
remove_ansi: bool = True
seed: Optional[int] = None
@ -23,7 +25,10 @@ class RubiksCubeConfig:
"""Validate configuration parameters"""
assert self.cube_size > 1, "cube_size must be greater than 1"
assert self.cube_size < 7, "cube_size must be less than 7"
assert self.scramble_steps > 0, "scramble_steps must be greater than 0"
assert self.min_scramble_steps > 0, "min_scramble_steps must be greater than 0"
assert (
self.max_scramble_steps >= self.min_scramble_steps
), "max_scramble_steps must be greater than min_scramble_steps"
class RubiksCubeDataset(ProceduralDataset):
@ -77,7 +82,8 @@ class RubiksCubeDataset(ProceduralDataset):
rng = Random(self.seed + idx)
cube = Cube(self.config.cube_size)
scramble_moves = self._generate_random_moves(rng, cube, num_steps=self.config.scramble_steps)
num_steps = rng.randint(self.config.min_scramble_steps, self.config.max_scramble_steps)
scramble_moves = self._generate_random_moves(rng, cube, num_steps=num_steps)
cube.rotate(scramble_moves)
# render cube
@ -100,9 +106,13 @@ class RubiksCubeDataset(ProceduralDataset):
"answer": None,
"metadata": {
"cube_size": self.config.cube_size,
"scramble_steps": self.config.scramble_steps,
"scramble_steps": num_steps,
"scramble_moves": " ".join([str(move) for move in scramble_moves]),
"example_correct_answer": actions_string,
"difficulty": {
"scramble_steps": num_steps,
"cube_size": self.config.cube_size,
},
},
}
@ -154,5 +164,33 @@ class RubiksCubeDataset(ProceduralDataset):
return move_str
class RubiksCubeCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(RubiksCubeCurriculum.__name__, RubiksCubeConfig)
# Define attributes
self._define_attributes(
ScalarAttributeDefinition(
name="cube_size",
field_name="cube_size",
levels=[3, 4, 5, 6, 7],
default_level=0,
description="Board size",
attr_type=AttributeType.STATIC,
min_value=3,
),
RangeAttributeDefinition(
name="scramble_steps",
levels=[3, 10, 50, 100, 500, 1000],
default_level=1,
description="Number of random moves to scramble the cube",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_scramble_steps",
upper_field_name="max_scramble_steps",
),
)
# Register the dataset
register_dataset("rubiks_cube", RubiksCubeDataset, RubiksCubeConfig)
register_dataset("rubiks_cube", RubiksCubeDataset, RubiksCubeConfig, RubiksCubeCurriculum)