mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
make sure rubiks cube is scrambled deterministically based on seed+idx
This commit is contained in:
parent
9c6d097c38
commit
2653a28903
2 changed files with 50 additions and 14 deletions
|
|
@ -1,12 +1,13 @@
|
|||
from dataclasses import dataclass
|
||||
import random
|
||||
from random import Random
|
||||
import re
|
||||
from magiccube.cube import Cube
|
||||
from magiccube.cube import Cube, CubeMove, CubeMoveType
|
||||
from magiccube.solver.basic.basic_solver import BasicSolver
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class RubiksCubeConfig:
|
||||
"""Configuration for RubiksCube task generation"""
|
||||
|
|
@ -20,6 +21,7 @@ class RubiksCubeConfig:
|
|||
assert self.cube_size < 7, "cube_size must be less than 7"
|
||||
assert self.scramble_steps > 0, "scramble_steps must be greater than 0"
|
||||
|
||||
|
||||
class RubiksCubeDataset(ProceduralDataset):
|
||||
"""Generates RubiksCube tasks"""
|
||||
|
||||
|
|
@ -30,6 +32,35 @@ class RubiksCubeDataset(ProceduralDataset):
|
|||
]
|
||||
super().__init__(config=config)
|
||||
|
||||
def _generate_random_moves(self, rng: Random, cube: Cube, num_steps: int = 50, wide=None) -> List[CubeMove]:
|
||||
"""Generate a list of random moves (but don't apply them).
|
||||
By default scramble only uses wide moves to cubes with size >=4."""
|
||||
|
||||
if wide is None and cube.size <= 3:
|
||||
wide = False
|
||||
elif wide is None and cube.size > 3:
|
||||
wide = True
|
||||
|
||||
possible_moves = [
|
||||
CubeMoveType.L,
|
||||
CubeMoveType.R, # CubeMoveType.M,
|
||||
CubeMoveType.D,
|
||||
CubeMoveType.U, # CubeMoveType.E,
|
||||
CubeMoveType.B,
|
||||
CubeMoveType.F, # CubeMoveType.S,
|
||||
]
|
||||
movements = [
|
||||
CubeMove(
|
||||
rng.choice(possible_moves),
|
||||
rng.choice([False, True]), # reversed
|
||||
rng.choice([False, True]) if wide else False, # wide
|
||||
rng.randint(1, cube.size // 2) if wide else 1, # layer
|
||||
)
|
||||
for _ in range(num_steps)
|
||||
]
|
||||
|
||||
return movements
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single RubiksCube task
|
||||
|
||||
|
|
@ -39,31 +70,35 @@ class RubiksCubeDataset(ProceduralDataset):
|
|||
- answer: None, indicating to use the dynamic evaluator
|
||||
- metadata: dict with generation parameters and example solution
|
||||
"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
cube = Cube(self.config.cube_size)
|
||||
scramble_moves = cube.scramble(num_steps=self.config.scramble_steps)
|
||||
scramble_moves = self._generate_random_moves(rng, cube, num_steps=self.config.scramble_steps)
|
||||
cube.rotate(scramble_moves)
|
||||
cube_render = self.remove_ansi(str(cube))
|
||||
|
||||
if self.config.cube_size == 3:
|
||||
solver = BasicSolver(cube)
|
||||
actions = solver.solve()
|
||||
actions_string = ' '.join([str(move) for move in actions])
|
||||
actions_string = " ".join([str(move) for move in actions])
|
||||
else:
|
||||
actions = None
|
||||
|
||||
return {
|
||||
"question": random.choice(self._prompt_templates).format(cube_size=self.config.cube_size, cube_render=cube_render),
|
||||
"question": rng.choice(self._prompt_templates).format(
|
||||
cube_size=self.config.cube_size, cube_render=cube_render
|
||||
),
|
||||
"answer": None,
|
||||
"metadata": {
|
||||
"cube_size": self.config.cube_size,
|
||||
"scramble_steps": self.config.scramble_steps,
|
||||
"scramble_moves": ' '.join([str(move) for move in scramble_moves]),
|
||||
"scramble_moves": " ".join([str(move) for move in scramble_moves]),
|
||||
"example_correct_answer": actions_string,
|
||||
},
|
||||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
""" Determine if the solution provided solves the cube """
|
||||
"""Determine if the solution provided solves the cube"""
|
||||
|
||||
reward = 0.0
|
||||
if answer is not None:
|
||||
|
|
@ -71,7 +106,7 @@ class RubiksCubeDataset(ProceduralDataset):
|
|||
# Reconstruct the test cube
|
||||
eval_cube = Cube(entry["metadata"]["cube_size"])
|
||||
eval_cube.rotate(entry["metadata"]["scramble_moves"])
|
||||
|
||||
|
||||
# Test the solution
|
||||
eval_cube.rotate(answer)
|
||||
solved = eval_cube.is_done()
|
||||
|
|
@ -79,15 +114,15 @@ class RubiksCubeDataset(ProceduralDataset):
|
|||
if solved:
|
||||
reward = 1.0
|
||||
else:
|
||||
reward = 0.01 # At least you tried
|
||||
reward = 0.01 # At least you tried
|
||||
|
||||
return reward
|
||||
|
||||
def remove_ansi(self, line):
|
||||
""" Remove terminal colors from magiccube rendering"""
|
||||
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
|
||||
return ansi_escape.sub('', line)
|
||||
"""Remove terminal colors from magiccube rendering"""
|
||||
ansi_escape = re.compile(r"(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]")
|
||||
return ansi_escape.sub("", line)
|
||||
|
||||
|
||||
# Register the dataset
|
||||
register_dataset("RubiksCube", RubiksCubeDataset, RubiksCubeConfig)
|
||||
register_dataset("rubiks_cube", RubiksCubeDataset, RubiksCubeConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue