diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py index c46321c2..99fb75cd 100644 --- a/reasoning_gym/logic/__init__.py +++ b/reasoning_gym/logic/__init__.py @@ -4,7 +4,7 @@ Logic tasks for training reasoning capabilities. from .aiw import AliceInWonderlandConfig, AliceInWonderlandCurriculum, AliceInWonderlandDataset from .circuit_logic import CircuitLogicConfig, CircuitLogicCurriculum, CircuitLogicDataset -from .knights_knaves import KnightsKnavesConfig, KnightsKnavesDataset +from .knights_knaves import KnightsKnavesConfig, KnightsKnavesCurriculum, KnightsKnavesDataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicCurriculum, PropositionalLogicDataset from .self_reference import SelfReferenceConfig, SelfReferenceCurriculum, SelfReferenceDataset from .syllogisms import SyllogismConfig, SyllogismDataset @@ -31,4 +31,5 @@ __all__ = [ "CircuitLogicCurriculum", "KnightsKnavesConfig", "KnightsKnavesDataset", + "KnightsKnavesCurriculum", ] diff --git a/reasoning_gym/logic/knights_knaves.py b/reasoning_gym/logic/knights_knaves.py index fe4f503f..09e3a4b7 100644 --- a/reasoning_gym/logic/knights_knaves.py +++ b/reasoning_gym/logic/knights_knaves.py @@ -8,6 +8,8 @@ import numpy as np from reasoning_gym.factory import ProceduralDataset, register_dataset +from ..coaching import BaseCurriculum, ScalarAttributeDefinition + DATASET_NAME = "knights_knaves" COMMON_NAMES = [ @@ -462,6 +464,11 @@ class KnightsKnavesDataset(ProceduralDataset): "solution": problem["solution"], "names": formatted["names"], "knight_knave_terms": formatted["knight_knave"], + "difficulty": { + "n_people": self.config.n_people, + "depth_constraint": self.config.depth_constraint, + "width_constraint": self.config.width_constraint, + }, } return {"question": question, "answer": answer, "metadata": metadata} @@ -515,4 +522,30 @@ class KnightsKnavesDataset(ProceduralDataset): return 0.0 -register_dataset(DATASET_NAME, KnightsKnavesDataset, KnightsKnavesConfig) +class KnightsKnavesCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(KnightsKnavesCurriculum.__name__, KnightsKnavesConfig) + + self._define_attributes( + ScalarAttributeDefinition( + name="n_people", + levels=[2, 3, 4, 5], + description="Number of people in the problem", + field_name="n_people", + ), + ScalarAttributeDefinition( + name="depth_constraint", + levels=[2, 3, 4, 5], + description="Depth of the problem", + field_name="depth_constraint", + ), + ScalarAttributeDefinition( + name="width_constraint", + levels=[2, 3, 4, 5], + description="Width of the problem", + field_name="width_constraint", + ), + ) + + +register_dataset(DATASET_NAME, KnightsKnavesDataset, KnightsKnavesConfig, KnightsKnavesCurriculum) diff --git a/tests/test_knights_knaves.py b/tests/test_knights_knaves.py index bcaf1fe7..1f8d2c1c 100644 --- a/tests/test_knights_knaves.py +++ b/tests/test_knights_knaves.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.logic.knights_knaves import KnightsKnavesConfig, KnightsKnavesDataset +from reasoning_gym.logic.knights_knaves import KnightsKnavesConfig, KnightsKnavesCurriculum, KnightsKnavesDataset def test_config_validation(): @@ -234,3 +234,42 @@ def test_depth_constraint_specific_problem(): solutions = KnightsKnavesDataset.find_solution(test_statements) assert len(solutions) == 1, "Should have exactly one solution" assert solutions[0] == (True, False, False) + + +def test_curriculum(): + curriculum = KnightsKnavesCurriculum() + + assert len(curriculum.attributes) == 3 + + base_value = {"size": 150, "seed": 1} + + base_cfg = curriculum.generate_configuration(base_value) + + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.n_people == 2 + assert base_cfg.depth_constraint == 2 + + # test incrementing attribute levels + curriculum.increment_attr_level("n_people") + curriculum.increment_attr_level("depth_constraint") + curriculum.increment_attr_level("width_constraint") + + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.n_people == 3 + assert increased_cfg.depth_constraint == 3 + assert increased_cfg.width_constraint == 3 + # test decrementing attribute level + curriculum.decrement_attr_level("n_people") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.n_people == 2 + assert partially_decreased_cfg.depth_constraint == 3 + assert partially_decreased_cfg.width_constraint == 3 + + curriculum.increment_attr_level("n_people") + curriculum.increment_attr_level("depth_constraint") + curriculum.increment_attr_level("width_constraint") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.n_people == 3 + assert increased_cfg.depth_constraint == 4 + assert increased_cfg.width_constraint == 4