diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py index dc946118..c46321c2 100644 --- a/reasoning_gym/logic/__init__.py +++ b/reasoning_gym/logic/__init__.py @@ -3,7 +3,7 @@ Logic tasks for training reasoning capabilities. """ from .aiw import AliceInWonderlandConfig, AliceInWonderlandCurriculum, AliceInWonderlandDataset -from .circuit_logic import CircuitLogicConfig, CircuitLogicDataset +from .circuit_logic import CircuitLogicConfig, CircuitLogicCurriculum, CircuitLogicDataset from .knights_knaves import KnightsKnavesConfig, KnightsKnavesDataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicCurriculum, PropositionalLogicDataset from .self_reference import SelfReferenceConfig, SelfReferenceCurriculum, SelfReferenceDataset @@ -28,6 +28,7 @@ __all__ = [ "SelfReferenceDataset", "CircuitLogicConfig", "CircuitLogicDataset", + "CircuitLogicCurriculum", "KnightsKnavesConfig", "KnightsKnavesDataset", ] diff --git a/reasoning_gym/logic/circuit_logic.py b/reasoning_gym/logic/circuit_logic.py index 7f8e69cb..798d1277 100644 --- a/reasoning_gym/logic/circuit_logic.py +++ b/reasoning_gym/logic/circuit_logic.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset VERT = "│" @@ -60,7 +61,8 @@ class CircuitLogicConfig: :param seed: Random seed """ - num_terms: int = 5 + min_terms: int = 3 + max_terms: int = 5 min_inputs: int = 2 max_inputs: int = 4 neg_prob: float = 0.3 @@ -70,7 +72,7 @@ class CircuitLogicConfig: def validate(self): assert 1 <= self.min_inputs <= self.max_inputs, "Invalid input range" - assert 1 <= self.num_terms, "Invalid number of terms" + assert 1 <= self.min_terms <= self.max_terms, "Invalid number of terms" assert 0.0 <= self.neg_prob <= 1.0, "neg_prob must be between 0 and 1" @@ -112,28 +114,15 @@ class CircuitLogicDataset(ProceduralDataset): ("AND", "&"), ] - def __len__(self) -> int: - return self.config.size - - def __iter__(self): - self._current_idx = 0 - return self - - def __next__(self) -> dict[str, Any]: - if self._current_idx >= self.config.size: - raise StopIteration - item = self[self._current_idx] - self._current_idx += 1 - return item - def __getitem__(self, idx: int) -> dict[str, Any]: """ Generate one random circuit logic item using ASCII drawing. """ rng = Random(self.seed + idx if self.seed is not None else None) + num_terms = rng.randint(self.config.min_terms, self.config.max_terms) return self._generate_circuit( rng=rng, - num_terms=self.config.num_terms, + num_terms=num_terms, min_inputs=self.config.min_inputs, max_inputs=self.config.max_inputs, neg_prob=self.config.neg_prob, @@ -397,6 +386,10 @@ class CircuitLogicDataset(ProceduralDataset): "term_strings": term_strings, "final_gate": final_gate_name, "inputs": inputs_list, + "difficulty": { + "terms": num_terms, + "inputs": (self.config.min_inputs, self.config.max_inputs), + }, }, } @@ -411,4 +404,33 @@ class CircuitLogicDataset(ProceduralDataset): return 0.0 -register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig) +class CircuitLogicCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(CircuitLogicCurriculum.__name__, CircuitLogicConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="terms", + levels=[3, 5, 10, 20, 30], + default_level=1, + description="Number of terms in the expression", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_terms", + upper_field_name="max_terms", + ), + RangeAttributeDefinition( + name="inputs", + levels=[2, 4, 6, 8, 10], + default_level=1, + description="Number of inputs per term", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_inputs", + upper_field_name="max_inputs", + ), + ) + + +register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig, CircuitLogicCurriculum) diff --git a/tests/test_circuit_logic.py b/tests/test_circuit_logic.py index 6b1f590a..138b46b8 100644 --- a/tests/test_circuit_logic.py +++ b/tests/test_circuit_logic.py @@ -1,16 +1,24 @@ import pytest -from reasoning_gym.logic import CircuitLogicConfig, CircuitLogicDataset +from reasoning_gym.logic import CircuitLogicConfig, CircuitLogicCurriculum, CircuitLogicDataset def test_circuit_logic_config_validation(): """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = CircuitLogicConfig(min_inputs=0) + config.validate() + with pytest.raises(AssertionError): config = CircuitLogicConfig(min_inputs=3, max_inputs=2) config.validate() with pytest.raises(AssertionError): - config = CircuitLogicConfig(num_terms=0) + config = CircuitLogicConfig(min_terms=0) + config.validate() + + with pytest.raises(AssertionError): + config = CircuitLogicConfig(min_terms=5, max_terms=4) config.validate() with pytest.raises(AssertionError): @@ -34,7 +42,7 @@ def test_circuit_logic_deterministic(): def test_circuit_logic_items(): """Test basic properties of generated items""" - config = CircuitLogicConfig(num_terms=3, min_inputs=2, max_inputs=3, neg_prob=0.3, size=50, seed=42) + config = CircuitLogicConfig(min_terms=3, max_terms=3, min_inputs=2, max_inputs=3, neg_prob=0.3, size=50, seed=42) dataset = CircuitLogicDataset(config) for i in range(len(dataset)): @@ -68,7 +76,13 @@ def test_circuit_logic_items(): def test_circuit_logic_expression_validity(): """Test that generated expressions follow logical circuit rules""" config = CircuitLogicConfig( - num_terms=2, min_inputs=2, max_inputs=2, neg_prob=0.0, size=20, seed=42 # Disable negation for simpler testing + min_terms=2, + max_terms=2, + min_inputs=2, + max_inputs=2, + neg_prob=0.0, + size=20, + seed=42, # Disable negation for simpler testing ) dataset = CircuitLogicDataset(config) @@ -88,7 +102,7 @@ def test_circuit_logic_expression_validity(): def test_circuit_logic_answer_verification(): """Test that answers match logical evaluation of circuits""" - config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=20, seed=42) + config = CircuitLogicConfig(min_terms=2, max_terms=2, min_inputs=2, max_inputs=2, size=20, seed=42) dataset = CircuitLogicDataset(config) def evaluate_term(term: str, assignments: dict) -> int: @@ -158,7 +172,7 @@ def test_circuit_logic_answer_verification(): def test_circuit_logic_ascii_diagram(): """Test properties of the ASCII circuit diagram""" - config = CircuitLogicConfig(num_terms=2, min_inputs=2, max_inputs=2, size=10, seed=42) + config = CircuitLogicConfig(min_terms=2, max_terms=2, min_inputs=2, max_inputs=2, size=10, seed=42) dataset = CircuitLogicDataset(config) for i in range(len(dataset)): @@ -222,3 +236,28 @@ def test_circuit_logic_iteration(): first_items = list(dataset) second_items = list(dataset) assert first_items == second_items + + +def test_circuit_logic_curriculum(): + curriculum = CircuitLogicCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: CircuitLogicConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_terms == 3 and base_cfg.max_terms == 5 + assert base_cfg.min_inputs == 2 and base_cfg.max_inputs == 4 + + # test incrementing attribute levels + curriculum.increment_attr_level("terms") + curriculum.increment_attr_level("inputs") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_terms == 3 and increased_cfg.max_terms == 10 + assert increased_cfg.min_inputs == 2 and increased_cfg.max_inputs == 6 + + # test decrementing attribute level for terms again + curriculum.decrement_attr_level("terms") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_terms == 3 and partially_decreased_cfg.max_terms == 5 + assert partially_decreased_cfg.min_inputs == 2 and partially_decreased_cfg.max_inputs == 6