diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py index 29149e6c..dc946118 100644 --- a/reasoning_gym/logic/__init__.py +++ b/reasoning_gym/logic/__init__.py @@ -5,7 +5,7 @@ Logic tasks for training reasoning capabilities. from .aiw import AliceInWonderlandConfig, AliceInWonderlandCurriculum, AliceInWonderlandDataset from .circuit_logic import CircuitLogicConfig, CircuitLogicDataset from .knights_knaves import KnightsKnavesConfig, KnightsKnavesDataset -from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset +from .propositional_logic import PropositionalLogicConfig, PropositionalLogicCurriculum, PropositionalLogicDataset from .self_reference import SelfReferenceConfig, SelfReferenceCurriculum, SelfReferenceDataset from .syllogisms import SyllogismConfig, SyllogismDataset from .zebra_puzzles import ZebraConfig, ZebraCurriculum, ZebraDataset @@ -16,6 +16,7 @@ __all__ = [ "AliceInWonderlandDataset", "PropositionalLogicConfig", "PropositionalLogicDataset", + "PropositionalLogicCurriculum", "SyllogismConfig", "SyllogismDataset", "syllogism_dataset", diff --git a/reasoning_gym/logic/propositional_logic.py b/reasoning_gym/logic/propositional_logic.py index b9387dfa..b7033f4a 100644 --- a/reasoning_gym/logic/propositional_logic.py +++ b/reasoning_gym/logic/propositional_logic.py @@ -6,6 +6,7 @@ from enum import StrEnum from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -87,6 +88,7 @@ class PropositionalLogicConfig: max_vars: int = 4 # Maximum number of variables min_statements: int = 2 # Minimum number of given statements max_statements: int = 4 # Maximum number of statements + min_complexity: int = 1 # Minimum operator depth max_complexity: int = 3 # Maximum operator depth seed: Optional[int] = None size: int = 500 # Virtual dataset size @@ -96,8 +98,9 @@ class PropositionalLogicConfig: assert self.min_vars > 0, "min_vars must be positive" assert self.max_vars >= self.min_vars, "max_vars must be >= min_vars" assert self.min_statements > 0, "min_statements must be positive" - assert self.max_statements >= self.min_statements - assert self.max_complexity > 0, "max_complexity must be positive" + assert self.max_statements >= self.min_statements, "max_statements must be >= min_statements" + assert self.min_complexity > 0, "min_complexity must be positive" + assert self.max_complexity >= self.min_complexity, "max_complexity must be >= min_complexity" class Expression: @@ -217,6 +220,11 @@ class PropositionalLogicDataset(ProceduralDataset): "variables": variables, "complexity": self._measure_complexity(conclusion), "example_answer": str(conclusion), + "difficulty": { + "vars": num_vars, + "statements": num_statements, + "complexity": (self.config.min_complexity, self.config.max_complexity), + }, }, } @@ -224,7 +232,7 @@ class PropositionalLogicDataset(ProceduralDataset): """Generate a list of premise statements""" premises = [] for _ in range(num_statements): - depth = rng.randint(1, self.config.max_complexity) + depth = rng.randint(self.config.min_complexity, self.config.max_complexity) premises.append(self._generate_expression(rng, variables, depth)) return premises @@ -329,4 +337,45 @@ class PropositionalLogicDataset(ProceduralDataset): return True -register_dataset("propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig) +class PropositionalLogicCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(PropositionalLogicCurriculum.__name__, PropositionalLogicConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="vars", + levels=[2, 4, 6, 8, 10], + default_level=0, + description="Number of variables in the logical expressions", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_vars", + upper_field_name="max_vars", + ), + RangeAttributeDefinition( + name="statements", + levels=[2, 4, 6, 8, 10], + default_level=0, + description="Number of premises in the logical expressions", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_statements", + upper_field_name="max_statements", + ), + RangeAttributeDefinition( + name="complexity", + levels=[1, 2, 3, 4, 5], + default_level=0, + description="Complexity of the logical expressions", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_complexity", + upper_field_name="max_complexity", + ), + ) + + +register_dataset( + "propositional_logic", PropositionalLogicDataset, PropositionalLogicConfig, PropositionalLogicCurriculum +) diff --git a/tests/test_propositional_logic.py b/tests/test_propositional_logic.py index d708cc4a..9d1f205c 100644 --- a/tests/test_propositional_logic.py +++ b/tests/test_propositional_logic.py @@ -6,6 +6,7 @@ from reasoning_gym.logic.propositional_logic import ( Expression, Operator, PropositionalLogicConfig, + PropositionalLogicCurriculum, PropositionalLogicDataset, ) @@ -101,3 +102,32 @@ def test_propositional_logic_dataset_score_answer_incorrect(): for i, item in enumerate(dataset): score = dataset.score_answer("Wrong", item) assert score == 0.0 + + +def test_propositional_logic_curriculum(): + curriculum = PropositionalLogicCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: PropositionalLogicConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_vars == 2 and base_cfg.max_vars == 2 + assert base_cfg.min_statements == 2 and base_cfg.max_statements == 2 + assert base_cfg.min_complexity == 1 and base_cfg.max_complexity == 1 + + # test incrementing attribute levels + curriculum.increment_attr_level("vars") + curriculum.increment_attr_level("statements") + curriculum.increment_attr_level("complexity") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_vars == 2 and increased_cfg.max_vars == 4 + assert increased_cfg.min_statements == 2 and increased_cfg.max_statements == 4 + assert increased_cfg.min_complexity == 1 and increased_cfg.max_complexity == 2 + + # test decrementing attribute level for vars again + curriculum.decrement_attr_level("vars") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_vars == 2 and partially_decreased_cfg.max_vars == 2 + assert partially_decreased_cfg.min_statements == 2 and partially_decreased_cfg.max_statements == 4 + assert partially_decreased_cfg.min_complexity == 1 and partially_decreased_cfg.max_complexity == 2