diff --git a/reasoning_gym/graphs/__init__.py b/reasoning_gym/graphs/__init__.py index a70816df..43cab6a7 100644 --- a/reasoning_gym/graphs/__init__.py +++ b/reasoning_gym/graphs/__init__.py @@ -1,7 +1,7 @@ from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset -from .quantum_lock import QuantumLockConfig, QuantumLockDataset +from .quantum_lock import QuantumLockConfig, QuantumLockCurriculum, QuantumLockDataset from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset __all__ = [ @@ -9,6 +9,7 @@ __all__ = [ "FamilyRelationshipsDataset", "QuantumLockConfig", "QuantumLockDataset", + "QuantumLockCurriculum", "LargestIslandDataset", "LargestIslandConfig", "LargestIslandCurriculum", diff --git a/reasoning_gym/graphs/quantum_lock.py b/reasoning_gym/graphs/quantum_lock.py index e2196906..7f0b5431 100644 --- a/reasoning_gym/graphs/quantum_lock.py +++ b/reasoning_gym/graphs/quantum_lock.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -47,14 +48,15 @@ Buttons: - metadata: dict with generation parameters """ rng = Random(self.seed + idx) + difficulty = rng.randint(1, self.config.difficulty) - puzzle_data = self.generate_quantum_puzzle(rng, self.config.difficulty) + puzzle_data = self.generate_quantum_puzzle(rng, difficulty) return { "question": self.format_puzzle(rng.choice(self._prompt_templates), puzzle=puzzle_data), "answer": " → ".join(puzzle_data["solution"]), "metadata": { - "difficulty": self.config.difficulty, + "metadata": {"difficulty": difficulty}, "solution_path": puzzle_data["solution"], "target_value": puzzle_data["target_value"], "buttons": puzzle_data["buttons"], @@ -233,5 +235,21 @@ Buttons: ) +class QuantumLockCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(QuantumLockCurriculum.__name__, QuantumLockConfig) + self._define_attributes( + ScalarAttributeDefinition( + name="difficulty", + field_name="difficulty", + levels=list(range(1, 11)), + default_level=0, + attr_type=AttributeType.STATIC, + description="The difficulty of the puzzle", + min_value=1, + ) + ) + + # Register the dataset -register_dataset("quantum_lock", QuantumLockDataset, QuantumLockConfig) +register_dataset("quantum_lock", QuantumLockDataset, QuantumLockConfig, QuantumLockCurriculum) diff --git a/tests/test_quantum_lock.py b/tests/test_quantum_lock.py index cf9693e6..6cac9ffa 100644 --- a/tests/test_quantum_lock.py +++ b/tests/test_quantum_lock.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.graphs.quantum_lock import QuantumLockConfig, QuantumLockDataset +from reasoning_gym.graphs.quantum_lock import QuantumLockConfig, QuantumLockCurriculum, QuantumLockDataset def test_quantumlock_config_validation(): @@ -37,7 +37,6 @@ def test_quantumlock_items(): # Check metadata contains required fields assert "solution_path" in item["metadata"] - assert "difficulty" in item["metadata"] assert "buttons" in item["metadata"] assert "initial_state" in item["metadata"] assert "target_value" in item["metadata"] @@ -115,3 +114,73 @@ def test_quantumlock_scoring(): if solution: lower_solution = "".join(solution).lower() assert dataset.score_answer(lower_solution, item) == 1.0 + + +def test_quantum_lock_curriculum(): + """Test the QuantumLockCurriculum functionality""" + curriculum = QuantumLockCurriculum() + + base_value = {"size": 150, "seed": 1} + + # Test initial configuration + base_cfg = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.difficulty == 1 # Default difficulty level + + # Test incrementing difficulty attribute + curriculum.increment_attr_level("difficulty") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.difficulty == 2 + assert increased_cfg.seed == 1 # Unchanged + assert increased_cfg.size == 150 # Unchanged + + # Test incrementing difficulty attribute again + curriculum.increment_attr_level("difficulty") + increased_cfg_2 = curriculum.generate_configuration(base_value) + assert increased_cfg_2.difficulty == 3 + + # Test decrementing difficulty attribute + curriculum.decrement_attr_level("difficulty") + decreased_cfg = curriculum.generate_configuration(base_value) + assert decreased_cfg.difficulty == 2 + + # Test global level adjustments + curriculum = QuantumLockCurriculum() # Reset curriculum + assert curriculum.get_attr_level("difficulty") == 0 # Default level is 0, maps to difficulty=1 + + # Increase global level + curriculum.increment_global_level() + assert curriculum.get_attr_level("difficulty") == 1 + + global_level_cfg = curriculum.generate_configuration(base_value) + assert global_level_cfg.difficulty == 2 + + # Increase global level again + curriculum.increment_global_level() + assert curriculum.get_attr_level("difficulty") == 2 + + global_level_cfg_2 = curriculum.generate_configuration(base_value) + assert global_level_cfg_2.difficulty == 3 + + # Decrease global level + curriculum.decrement_global_level() + assert curriculum.get_attr_level("difficulty") == 1 + + global_level_cfg_3 = curriculum.generate_configuration(base_value) + assert global_level_cfg_3.difficulty == 2 + + # Test upper bound + curriculum = QuantumLockCurriculum() # Reset curriculum + for _ in range(15): # Try going beyond max level (10) + curriculum.increment_attr_level("difficulty") + + max_cfg = curriculum.generate_configuration(base_value) + assert max_cfg.difficulty == 10 # Should be capped at 10 (the highest level) + + # Test lower bound + curriculum = QuantumLockCurriculum() # Reset curriculum + curriculum.decrement_attr_level("difficulty") # Try going below min level + + min_cfg = curriculum.generate_configuration(base_value) + assert min_cfg.difficulty == 1 # Should be capped at 1 (the lowest level)