From 54074b17ef7c5d9b2521ae7a316e91de232cfc2d Mon Sep 17 00:00:00 2001 From: joesharratt1229 <118444587+joesharratt1229@users.noreply.github.com> Date: Tue, 11 Mar 2025 00:22:54 +0100 Subject: [PATCH] Added zebra curriculum (#328) * added zebra curriculum * added metadata --- reasoning_gym/logic/__init__.py | 3 +- reasoning_gym/logic/zebra_puzzles.py | 31 +++++++++- tests/test_zebra.py | 86 +++++++++++++++++++++++++++- 3 files changed, 115 insertions(+), 5 deletions(-) diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py index 422e0235..a0d7b482 100644 --- a/reasoning_gym/logic/__init__.py +++ b/reasoning_gym/logic/__init__.py @@ -8,7 +8,7 @@ from .knights_knaves import KnightsKnavesConfig, KnightsKnavesDataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset from .self_reference import SelfReferenceConfig, SelfReferenceDataset from .syllogisms import SyllogismConfig, SyllogismDataset -from .zebra_puzzles import ZebraConfig, ZebraDataset +from .zebra_puzzles import ZebraConfig, ZebraCurriculum, ZebraDataset __all__ = [ "AliceInWonderlandConfig", @@ -19,6 +19,7 @@ __all__ = [ "SyllogismDataset", "syllogism_dataset", "ZebraConfig", + "ZebraCurriculum", "ZebraDataset", "SelfReference", "SelfReferenceConfig", diff --git a/reasoning_gym/logic/zebra_puzzles.py b/reasoning_gym/logic/zebra_puzzles.py index 7e424b13..e891b760 100644 --- a/reasoning_gym/logic/zebra_puzzles.py +++ b/reasoning_gym/logic/zebra_puzzles.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset from .contrib.logic_puzzle.generate import generate_puzzle @@ -50,8 +51,7 @@ class ZebraDataset(ProceduralDataset): "question": question, "answer": answer, "metadata": { - "num_people": K, - "num_characteristics": M, + "difficulty": {"num_people": K, "num_characteristics": M}, }, } @@ -74,4 +74,29 @@ class ZebraDataset(ProceduralDataset): return 0.0 -register_dataset("zebra_puzzles", ZebraDataset, ZebraConfig) +class ZebraCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(ZebraCurriculum.__name__, ZebraConfig) + self._define_attributes( + ScalarAttributeDefinition( + name="num_people", + levels=list(range(2, 8)), + default_level=0, + description="The number of people in the Zebra puzzle", + attr_type=AttributeType.STATIC, + min_value=2, + field_name="num_people", + ), + ScalarAttributeDefinition( + name="num_characteristics", + levels=list(range(2, 8)), + default_level=0, + description="The number of characteristics in the Zebra puzzle", + attr_type=AttributeType.STATIC, + min_value=2, + field_name="num_characteristics", + ), + ) + + +register_dataset("zebra_puzzles", ZebraDataset, ZebraConfig, ZebraCurriculum) diff --git a/tests/test_zebra.py b/tests/test_zebra.py index 054fad22..21ca9bb0 100644 --- a/tests/test_zebra.py +++ b/tests/test_zebra.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.logic.zebra_puzzles import ZebraConfig, ZebraDataset +from reasoning_gym.logic.zebra_puzzles import ZebraConfig, ZebraCurriculum, ZebraDataset def test_zebra_deterministic(): @@ -27,3 +27,87 @@ def test_zebra_puzzles(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 + + +def test_zebra_curriculum(): + """Test the ZebraCurriculum functionality""" + + curriculum = ZebraCurriculum() + + base_value = {"size": 150, "seed": 1} + + # Test initial configuration + base_cfg = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.num_people == 2 # Default level 0 maps to 2 people + assert base_cfg.num_characteristics == 2 # Default level 0 maps to 2 characteristics + + # Test incrementing num_people attribute + curriculum.increment_attr_level("num_people") + people_cfg = curriculum.generate_configuration(base_value) + assert people_cfg.num_people == 3 # Level 1 maps to 3 people + assert people_cfg.num_characteristics == 2 # Unchanged + + # Test incrementing num_characteristics attribute + curriculum.increment_attr_level("num_characteristics") + both_cfg = curriculum.generate_configuration(base_value) + assert both_cfg.num_people == 3 # Preserved + assert both_cfg.num_characteristics == 3 # Level 1 maps to 3 characteristics + + # Test decrementing num_people attribute + curriculum.decrement_attr_level("num_people") + char_only_cfg = curriculum.generate_configuration(base_value) + assert char_only_cfg.num_people == 2 # Back to level 0 + assert char_only_cfg.num_characteristics == 3 # Preserved + + # Test global level adjustments + curriculum = ZebraCurriculum() # Reset curriculum + assert curriculum.get_attr_level("num_people") == 0 + assert curriculum.get_attr_level("num_characteristics") == 0 + + # Increase global level + curriculum.increment_global_level() + assert curriculum.get_attr_level("num_people") == 1 + assert curriculum.get_attr_level("num_characteristics") == 1 + + global_level_cfg = curriculum.generate_configuration(base_value) + assert global_level_cfg.num_people == 3 + assert global_level_cfg.num_characteristics == 3 + + # Increase global level again + curriculum.increment_global_level() + assert curriculum.get_attr_level("num_people") == 2 + assert curriculum.get_attr_level("num_characteristics") == 2 + + global_level_cfg_2 = curriculum.generate_configuration(base_value) + assert global_level_cfg_2.num_people == 4 + assert global_level_cfg_2.num_characteristics == 4 + + # Decrease global level + curriculum.decrement_global_level() + assert curriculum.get_attr_level("num_people") == 1 + assert curriculum.get_attr_level("num_characteristics") == 1 + + global_level_cfg_3 = curriculum.generate_configuration(base_value) + assert global_level_cfg_3.num_people == 3 + assert global_level_cfg_3.num_characteristics == 3 + + # Test upper bound + curriculum = ZebraCurriculum() # Reset curriculum + for _ in range(10): # Try going beyond max level + curriculum.increment_attr_level("num_people") + curriculum.increment_attr_level("num_characteristics") + + max_cfg = curriculum.generate_configuration(base_value) + assert max_cfg.num_people == 7 # Capped at 7 + assert max_cfg.num_characteristics == 7 # Capped at 7 + + # Test lower bound + curriculum = ZebraCurriculum() # Reset curriculum + curriculum.decrement_attr_level("num_people") # Try going below min level + curriculum.decrement_attr_level("num_characteristics") # Try going below min level + + min_cfg = curriculum.generate_configuration(base_value) + assert min_cfg.num_people == 2 # Stays at 2 + assert min_cfg.num_characteristics == 2 # Stays at 2