mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
feat: Add color cube rotation dataset with cube rotation logic
This commit is contained in:
parent
45da09afe8
commit
9fddc73842
1 changed files with 185 additions and 0 deletions
185
reasoning_gym/cognition/color_cube_rotation.py
Normal file
185
reasoning_gym/cognition/color_cube_rotation.py
Normal file
|
|
@ -0,0 +1,185 @@
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from ..dataset import ProceduralDataset
|
||||||
|
|
||||||
|
|
||||||
|
class Color(StrEnum):
|
||||||
|
RED = "red"
|
||||||
|
GREEN = "green"
|
||||||
|
BLUE = "blue"
|
||||||
|
YELLOW = "yellow"
|
||||||
|
WHITE = "white"
|
||||||
|
ORANGE = "orange"
|
||||||
|
|
||||||
|
|
||||||
|
class Side(StrEnum):
|
||||||
|
TOP = "top"
|
||||||
|
RIGHT = "right"
|
||||||
|
FRONT = "front"
|
||||||
|
LEFT = "left"
|
||||||
|
BACK = "back"
|
||||||
|
BOTTOM = "bottom"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Cube:
|
||||||
|
"""Represents a cube with colored sides"""
|
||||||
|
colors: Dict[Side, Color]
|
||||||
|
|
||||||
|
def rotate_front_to_top(self) -> None:
|
||||||
|
"""Rotate cube so front face becomes top"""
|
||||||
|
old = self.colors.copy()
|
||||||
|
self.colors[Side.TOP] = old[Side.FRONT]
|
||||||
|
self.colors[Side.FRONT] = old[Side.BOTTOM]
|
||||||
|
self.colors[Side.BOTTOM] = old[Side.BACK]
|
||||||
|
self.colors[Side.BACK] = old[Side.TOP]
|
||||||
|
# Right and left stay in place
|
||||||
|
|
||||||
|
def rotate_right_to_top(self) -> None:
|
||||||
|
"""Rotate cube so right face becomes top"""
|
||||||
|
old = self.colors.copy()
|
||||||
|
self.colors[Side.TOP] = old[Side.RIGHT]
|
||||||
|
self.colors[Side.RIGHT] = old[Side.BOTTOM]
|
||||||
|
self.colors[Side.BOTTOM] = old[Side.LEFT]
|
||||||
|
self.colors[Side.LEFT] = old[Side.TOP]
|
||||||
|
# Front and back stay in place
|
||||||
|
|
||||||
|
def rotate_back_to_top(self) -> None:
|
||||||
|
"""Rotate cube so back face becomes top"""
|
||||||
|
old = self.colors.copy()
|
||||||
|
self.colors[Side.TOP] = old[Side.BACK]
|
||||||
|
self.colors[Side.BACK] = old[Side.BOTTOM]
|
||||||
|
self.colors[Side.BOTTOM] = old[Side.FRONT]
|
||||||
|
self.colors[Side.FRONT] = old[Side.TOP]
|
||||||
|
# Right and left stay in place
|
||||||
|
|
||||||
|
def rotate_left_to_top(self) -> None:
|
||||||
|
"""Rotate cube so left face becomes top"""
|
||||||
|
old = self.colors.copy()
|
||||||
|
self.colors[Side.TOP] = old[Side.LEFT]
|
||||||
|
self.colors[Side.LEFT] = old[Side.BOTTOM]
|
||||||
|
self.colors[Side.BOTTOM] = old[Side.RIGHT]
|
||||||
|
self.colors[Side.RIGHT] = old[Side.TOP]
|
||||||
|
# Front and back stay in place
|
||||||
|
|
||||||
|
def rotate_bottom_to_top(self) -> None:
|
||||||
|
"""Rotate cube so bottom face becomes top"""
|
||||||
|
old = self.colors.copy()
|
||||||
|
self.colors[Side.TOP] = old[Side.BOTTOM]
|
||||||
|
self.colors[Side.BOTTOM] = old[Side.TOP]
|
||||||
|
self.colors[Side.FRONT] = old[Side.BACK]
|
||||||
|
self.colors[Side.BACK] = old[Side.FRONT]
|
||||||
|
# Right and left stay in place
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ColorCubeRotationConfig:
|
||||||
|
"""Configuration for color cube rotation task generation"""
|
||||||
|
min_rotations: int = 1
|
||||||
|
max_rotations: int = 3
|
||||||
|
seed: Optional[int] = None
|
||||||
|
size: int = 500
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validate configuration parameters"""
|
||||||
|
assert self.min_rotations > 0, "min_rotations must be positive"
|
||||||
|
assert self.max_rotations >= self.min_rotations, "max_rotations must be >= min_rotations"
|
||||||
|
|
||||||
|
|
||||||
|
class ColorCubeRotationDataset(ProceduralDataset):
|
||||||
|
"""Generates color cube rotation reasoning tasks"""
|
||||||
|
|
||||||
|
def __init__(self, config: ColorCubeRotationConfig):
|
||||||
|
self.config = config
|
||||||
|
self.config.validate()
|
||||||
|
super().__init__(seed=config.seed, size=config.size)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict:
|
||||||
|
rng = random.Random(self.seed + idx)
|
||||||
|
|
||||||
|
# Generate initial cube state
|
||||||
|
cube = self._generate_cube(rng)
|
||||||
|
initial_state = cube.colors.copy()
|
||||||
|
|
||||||
|
# Generate sequence of rotations
|
||||||
|
num_rotations = rng.randint(self.config.min_rotations, self.config.max_rotations)
|
||||||
|
rotations = []
|
||||||
|
for _ in range(num_rotations):
|
||||||
|
from_side = rng.choice(list(Side))
|
||||||
|
if from_side != Side.TOP: # Skip meaningless top-to-top rotation
|
||||||
|
rotations.append(from_side)
|
||||||
|
self._rotate_to_top(cube, from_side)
|
||||||
|
|
||||||
|
# Select target side for question
|
||||||
|
target_side = rng.choice(list(Side))
|
||||||
|
|
||||||
|
# Generate story
|
||||||
|
story = self._generate_story(initial_state, rotations, target_side)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"question": story,
|
||||||
|
"answer": cube.colors[target_side],
|
||||||
|
"metadata": {
|
||||||
|
"initial_state": {k.value: v.value for k,v in initial_state.items()},
|
||||||
|
"rotations": [r.value for r in rotations],
|
||||||
|
"target_side": target_side.value,
|
||||||
|
"num_rotations": num_rotations,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_cube(self, rng: random.Random) -> Cube:
|
||||||
|
"""Generate a cube with random colors"""
|
||||||
|
colors = list(Color)
|
||||||
|
rng.shuffle(colors) # Randomize color order
|
||||||
|
return Cube({side: color for side, color in zip(Side, colors)})
|
||||||
|
|
||||||
|
def _rotate_to_top(self, cube: Cube, from_side: Side) -> None:
|
||||||
|
"""Rotate cube so that given side becomes top"""
|
||||||
|
rotation_map = {
|
||||||
|
Side.FRONT: cube.rotate_front_to_top,
|
||||||
|
Side.RIGHT: cube.rotate_right_to_top,
|
||||||
|
Side.BACK: cube.rotate_back_to_top,
|
||||||
|
Side.LEFT: cube.rotate_left_to_top,
|
||||||
|
Side.BOTTOM: cube.rotate_bottom_to_top,
|
||||||
|
}
|
||||||
|
if from_side in rotation_map:
|
||||||
|
rotation_map[from_side]()
|
||||||
|
|
||||||
|
def _generate_story(self, initial_state: Dict[Side, Color],
|
||||||
|
rotations: List[Side], target_side: Side) -> str:
|
||||||
|
"""Generate story describing cube state and rotations"""
|
||||||
|
# Describe initial state
|
||||||
|
story_parts = ["A cube has:"]
|
||||||
|
for side in Side:
|
||||||
|
story_parts.append(f"- a {initial_state[side].value} {side.value} side")
|
||||||
|
|
||||||
|
# Describe rotations
|
||||||
|
for from_side in rotations:
|
||||||
|
story_parts.append(
|
||||||
|
f"\nThe cube is rotated so that the side which was before at the {from_side.value} "
|
||||||
|
"is now at the top."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ask question
|
||||||
|
story_parts.append(f"\nWhat is now the color of the {target_side.value} side of the cube?")
|
||||||
|
|
||||||
|
return "\n".join(story_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def color_cube_rotation_dataset(
|
||||||
|
min_rotations: int = 1,
|
||||||
|
max_rotations: int = 3,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
size: int = 500,
|
||||||
|
) -> ColorCubeRotationDataset:
|
||||||
|
"""Create a ColorCubeRotationDataset with the given configuration"""
|
||||||
|
config = ColorCubeRotationConfig(
|
||||||
|
min_rotations=min_rotations,
|
||||||
|
max_rotations=max_rotations,
|
||||||
|
seed=seed,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
|
return ColorCubeRotationDataset(config)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue