diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py index 99e5ea24..29149e6c 100644 --- a/reasoning_gym/logic/__init__.py +++ b/reasoning_gym/logic/__init__.py @@ -2,7 +2,7 @@ Logic tasks for training reasoning capabilities. """ -from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset +from .aiw import AliceInWonderlandConfig, AliceInWonderlandCurriculum, AliceInWonderlandDataset from .circuit_logic import CircuitLogicConfig, CircuitLogicDataset from .knights_knaves import KnightsKnavesConfig, KnightsKnavesDataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset @@ -12,6 +12,7 @@ from .zebra_puzzles import ZebraConfig, ZebraCurriculum, ZebraDataset __all__ = [ "AliceInWonderlandConfig", + "AliceInWonderlandCurriculum", "AliceInWonderlandDataset", "PropositionalLogicConfig", "PropositionalLogicDataset", diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py index 7cabbb85..5f982249 100644 --- a/reasoning_gym/logic/aiw.py +++ b/reasoning_gym/logic/aiw.py @@ -4,6 +4,7 @@ from random import Random from string import Template from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -61,6 +62,7 @@ class AliceInWonderlandConfig: task_types: list[TaskType] = field( default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues ) + task_type_weights: list[float] = field(default_factory=lambda: [1 / 3, 1 / 3, 1 / 3]) seed: Optional[int] = None size: int = 10 max_entities: int = 6 # Added max_entities @@ -142,7 +144,8 @@ class AliceInWonderlandDataset(ProceduralDataset): dict: A dictionary containing the generated question, the right answer and a description of the example. """ - task_type = rng.choice(self.config.task_types) + + task_type = rng.choices(self.config.task_types, weights=self.config.task_type_weights, k=1)[0] female_name = rng.choice(self.config.female_names) male_name = rng.choice(self.config.male_names) @@ -187,11 +190,58 @@ class AliceInWonderlandDataset(ProceduralDataset): num_female_colleagues_bob_circle=num_female_colleagues_bob_circle, ) - return {"question": question, "answer": str(answer), "metadata": {"task_type": task_type.value}} + return { + "question": question, + "answer": str(answer), + "metadata": { + "task_type": task_type.value, + "difficulty": { + "task_type_weight": self.config.task_type_weights, + "num_entities": self.config.max_entities, + }, + }, + } def __getitem__(self, idx: int) -> dict: rng = Random(self.seed + idx) return self._get_aiw(rng) -register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig) +class AliceInWonderlandCurriculum(BaseCurriculum): + """Curriculum for the Alice in Wonderland dataset.""" + + def __init__(self): + super().__init__(AliceInWonderlandCurriculum.__name__, AliceInWonderlandConfig) + self._define_attributes( + ScalarAttributeDefinition( + name="task_type_weight", + field_name="task_type_weights", + attr_type=AttributeType.STATIC, + description="The weight of the task type", + levels=[ + [1.0, 0.0, 0.0], + [0.9, 0.05, 0.05], + [0.7, 0.15, 0.15], + [0.6, 0.2, 0.2], + [0.5, 0.25, 0.25], + [0.4, 0.3, 0.3], + [0.3, 0.35, 0.35], + [0.2, 0.4, 0.4], + [0.1, 0.45, 0.45], + ], + min_value=[1.0, 0.0, 0.0], + default_level=0, + ), + ScalarAttributeDefinition( + name="num_entities", + field_name="max_entities", + attr_type=AttributeType.STATIC, + description="The number of entities in the question", + levels=list(range(4, 18, 2)), + min_value=4, + default_level=0, + ), + ) + + +register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig, AliceInWonderlandCurriculum) diff --git a/tests/test_aiw.py b/tests/test_aiw.py index 5a2fb454..9979f977 100644 --- a/tests/test_aiw.py +++ b/tests/test_aiw.py @@ -1,6 +1,11 @@ import pytest -from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType +from reasoning_gym.logic.aiw import ( + AliceInWonderlandConfig, + AliceInWonderlandCurriculum, + AliceInWonderlandDataset, + TaskType, +) def test_aiw_config_validation(): @@ -94,3 +99,40 @@ def test_aiw_random_ranges(): # Check all numbers are in reasonable range (1-6 as per implementation) assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}" + + +def test_aiw_curriculum(): + """Test the AIW curriculum functionality""" + curriculum = AliceInWonderlandCurriculum() + + base_value = {"size": 100, "seed": 42} + + # Test default configuration + base_cfg: AliceInWonderlandConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 42 + assert base_cfg.size == 100 + assert base_cfg.max_entities == 4 + assert base_cfg.task_type_weights == [1.0, 0.0, 0.0] # Default is siblings only + + # Test incrementing task_type_weight attribute + curriculum.increment_attr_level("task_type_weight") + task_weight_cfg = curriculum.generate_configuration(base_value) + assert task_weight_cfg.task_type_weights == [0.9, 0.05, 0.05] # Second level adds some friends/colleagues + + # Test incrementing num_entities attribute + curriculum.increment_attr_level("num_entities") + entities_cfg = curriculum.generate_configuration(base_value) + assert entities_cfg.max_entities == 6 # Increased max entities + assert entities_cfg.task_type_weights == [0.9, 0.05, 0.05] # Should preserve task weight level + + # Test decrementing task_type_weight attribute + curriculum.decrement_attr_level("task_type_weight") + updated_cfg = curriculum.generate_configuration(base_value) + assert updated_cfg.task_type_weights == [1.0, 0.0, 0.0] # Back to default weights + assert updated_cfg.max_entities == 6 # Should preserve entities level + + # Test global level setting + curriculum.set_global_level(2) # Set all attributes to level 2 + global_cfg = curriculum.generate_configuration(base_value) + assert global_cfg.task_type_weights == [0.7, 0.15, 0.15] # Level 2 of task weights + assert global_cfg.max_entities == 8 # Level 2 of num_entities