color cube rotation curriculum (#347)

This commit is contained in:
Zafir Stojanovski 2025-03-13 21:04:34 +01:00 committed by GitHub
parent ec3e414a8c
commit 9fcc277101
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 53 additions and 3 deletions

View file

@ -1,7 +1,14 @@
import pytest
from reasoning_gym import create_dataset
from reasoning_gym.cognition.color_cube_rotation import Color, ColorCubeRotationDataset, Cube, Side
from reasoning_gym.cognition.color_cube_rotation import (
Color,
ColorCubeRotationConfig,
ColorCubeRotationCurriculum,
ColorCubeRotationDataset,
Cube,
Side,
)
def test_color_cube_rotation_generation():
@ -82,3 +89,24 @@ def test_cube_rotations():
assert cube.colors[Side.BACK] == original[Side.TOP]
assert cube.colors[Side.RIGHT] == original[Side.RIGHT] # Unchanged
assert cube.colors[Side.LEFT] == original[Side.LEFT] # Unchanged
def test_shortest_path_curriculum():
curriculum = ColorCubeRotationCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: ColorCubeRotationConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_rotations == 1 and base_cfg.max_rotations == 5
# test incrementing attribute levels
curriculum.increment_attr_level("rotations")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_rotations == 1 and increased_cfg.max_rotations == 10
# test decrementing attribute level
curriculum.decrement_attr_level("rotations")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_rotations == 1 and partially_decreased_cfg.max_rotations == 5