mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
feat(env): Rubiks Cube Curriculum (#357)
This commit is contained in:
parent
d603d8b72b
commit
099ea88402
3 changed files with 74 additions and 9 deletions
|
|
@ -6,6 +6,7 @@ from typing import Any, Optional
|
|||
from magiccube.cube import Cube, CubeMove, CubeMoveType
|
||||
from magiccube.solver.basic.basic_solver import BasicSolver
|
||||
|
||||
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
|
|
@ -13,7 +14,8 @@ from ..factory import ProceduralDataset, register_dataset
|
|||
class RubiksCubeConfig:
|
||||
"""Configuration for RubiksCube task generation"""
|
||||
|
||||
scramble_steps: int = 3 # Number of random steps from initial state
|
||||
min_scramble_steps: int = 3 # Minimum number of random steps from initial state
|
||||
max_scramble_steps: int = 10 # Maximum number of random steps from initial state
|
||||
cube_size: int = 3 # Default to a standard 3x3x3 cube
|
||||
remove_ansi: bool = True
|
||||
seed: Optional[int] = None
|
||||
|
|
@ -23,7 +25,10 @@ class RubiksCubeConfig:
|
|||
"""Validate configuration parameters"""
|
||||
assert self.cube_size > 1, "cube_size must be greater than 1"
|
||||
assert self.cube_size < 7, "cube_size must be less than 7"
|
||||
assert self.scramble_steps > 0, "scramble_steps must be greater than 0"
|
||||
assert self.min_scramble_steps > 0, "min_scramble_steps must be greater than 0"
|
||||
assert (
|
||||
self.max_scramble_steps >= self.min_scramble_steps
|
||||
), "max_scramble_steps must be greater than min_scramble_steps"
|
||||
|
||||
|
||||
class RubiksCubeDataset(ProceduralDataset):
|
||||
|
|
@ -77,7 +82,8 @@ class RubiksCubeDataset(ProceduralDataset):
|
|||
rng = Random(self.seed + idx)
|
||||
|
||||
cube = Cube(self.config.cube_size)
|
||||
scramble_moves = self._generate_random_moves(rng, cube, num_steps=self.config.scramble_steps)
|
||||
num_steps = rng.randint(self.config.min_scramble_steps, self.config.max_scramble_steps)
|
||||
scramble_moves = self._generate_random_moves(rng, cube, num_steps=num_steps)
|
||||
cube.rotate(scramble_moves)
|
||||
|
||||
# render cube
|
||||
|
|
@ -100,9 +106,13 @@ class RubiksCubeDataset(ProceduralDataset):
|
|||
"answer": None,
|
||||
"metadata": {
|
||||
"cube_size": self.config.cube_size,
|
||||
"scramble_steps": self.config.scramble_steps,
|
||||
"scramble_steps": num_steps,
|
||||
"scramble_moves": " ".join([str(move) for move in scramble_moves]),
|
||||
"example_correct_answer": actions_string,
|
||||
"difficulty": {
|
||||
"scramble_steps": num_steps,
|
||||
"cube_size": self.config.cube_size,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -154,5 +164,33 @@ class RubiksCubeDataset(ProceduralDataset):
|
|||
return move_str
|
||||
|
||||
|
||||
class RubiksCubeCurriculum(BaseCurriculum):
|
||||
def __init__(self):
|
||||
super().__init__(RubiksCubeCurriculum.__name__, RubiksCubeConfig)
|
||||
|
||||
# Define attributes
|
||||
self._define_attributes(
|
||||
ScalarAttributeDefinition(
|
||||
name="cube_size",
|
||||
field_name="cube_size",
|
||||
levels=[3, 4, 5, 6, 7],
|
||||
default_level=0,
|
||||
description="Board size",
|
||||
attr_type=AttributeType.STATIC,
|
||||
min_value=3,
|
||||
),
|
||||
RangeAttributeDefinition(
|
||||
name="scramble_steps",
|
||||
levels=[3, 10, 50, 100, 500, 1000],
|
||||
default_level=1,
|
||||
description="Number of random moves to scramble the cube",
|
||||
attr_type=AttributeType.APPEND,
|
||||
min_value=1,
|
||||
lower_field_name="min_scramble_steps",
|
||||
upper_field_name="max_scramble_steps",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("rubiks_cube", RubiksCubeDataset, RubiksCubeConfig)
|
||||
register_dataset("rubiks_cube", RubiksCubeDataset, RubiksCubeConfig, RubiksCubeCurriculum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue