Added zebra curriculum (#328)

* added zebra curriculum

* added metadata
This commit is contained in:
joesharratt1229 2025-03-11 00:22:54 +01:00 committed by GitHub
parent f204a848d9
commit 54074b17ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 115 additions and 5 deletions

View file

@ -8,7 +8,7 @@ from .knights_knaves import KnightsKnavesConfig, KnightsKnavesDataset
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
from .self_reference import SelfReferenceConfig, SelfReferenceDataset from .self_reference import SelfReferenceConfig, SelfReferenceDataset
from .syllogisms import SyllogismConfig, SyllogismDataset from .syllogisms import SyllogismConfig, SyllogismDataset
from .zebra_puzzles import ZebraConfig, ZebraDataset from .zebra_puzzles import ZebraConfig, ZebraCurriculum, ZebraDataset
__all__ = [ __all__ = [
"AliceInWonderlandConfig", "AliceInWonderlandConfig",
@ -19,6 +19,7 @@ __all__ = [
"SyllogismDataset", "SyllogismDataset",
"syllogism_dataset", "syllogism_dataset",
"ZebraConfig", "ZebraConfig",
"ZebraCurriculum",
"ZebraDataset", "ZebraDataset",
"SelfReference", "SelfReference",
"SelfReferenceConfig", "SelfReferenceConfig",

View file

@ -2,6 +2,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
from .contrib.logic_puzzle.generate import generate_puzzle from .contrib.logic_puzzle.generate import generate_puzzle
@ -50,8 +51,7 @@ class ZebraDataset(ProceduralDataset):
"question": question, "question": question,
"answer": answer, "answer": answer,
"metadata": { "metadata": {
"num_people": K, "difficulty": {"num_people": K, "num_characteristics": M},
"num_characteristics": M,
}, },
} }
@ -74,4 +74,29 @@ class ZebraDataset(ProceduralDataset):
return 0.0 return 0.0
register_dataset("zebra_puzzles", ZebraDataset, ZebraConfig) class ZebraCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(ZebraCurriculum.__name__, ZebraConfig)
self._define_attributes(
ScalarAttributeDefinition(
name="num_people",
levels=list(range(2, 8)),
default_level=0,
description="The number of people in the Zebra puzzle",
attr_type=AttributeType.STATIC,
min_value=2,
field_name="num_people",
),
ScalarAttributeDefinition(
name="num_characteristics",
levels=list(range(2, 8)),
default_level=0,
description="The number of characteristics in the Zebra puzzle",
attr_type=AttributeType.STATIC,
min_value=2,
field_name="num_characteristics",
),
)
register_dataset("zebra_puzzles", ZebraDataset, ZebraConfig, ZebraCurriculum)

View file

@ -1,6 +1,6 @@
import pytest import pytest
from reasoning_gym.logic.zebra_puzzles import ZebraConfig, ZebraDataset from reasoning_gym.logic.zebra_puzzles import ZebraConfig, ZebraCurriculum, ZebraDataset
def test_zebra_deterministic(): def test_zebra_deterministic():
@ -27,3 +27,87 @@ def test_zebra_puzzles():
# Test the scoring # Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0
def test_zebra_curriculum():
"""Test the ZebraCurriculum functionality"""
curriculum = ZebraCurriculum()
base_value = {"size": 150, "seed": 1}
# Test initial configuration
base_cfg = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.num_people == 2 # Default level 0 maps to 2 people
assert base_cfg.num_characteristics == 2 # Default level 0 maps to 2 characteristics
# Test incrementing num_people attribute
curriculum.increment_attr_level("num_people")
people_cfg = curriculum.generate_configuration(base_value)
assert people_cfg.num_people == 3 # Level 1 maps to 3 people
assert people_cfg.num_characteristics == 2 # Unchanged
# Test incrementing num_characteristics attribute
curriculum.increment_attr_level("num_characteristics")
both_cfg = curriculum.generate_configuration(base_value)
assert both_cfg.num_people == 3 # Preserved
assert both_cfg.num_characteristics == 3 # Level 1 maps to 3 characteristics
# Test decrementing num_people attribute
curriculum.decrement_attr_level("num_people")
char_only_cfg = curriculum.generate_configuration(base_value)
assert char_only_cfg.num_people == 2 # Back to level 0
assert char_only_cfg.num_characteristics == 3 # Preserved
# Test global level adjustments
curriculum = ZebraCurriculum() # Reset curriculum
assert curriculum.get_attr_level("num_people") == 0
assert curriculum.get_attr_level("num_characteristics") == 0
# Increase global level
curriculum.increment_global_level()
assert curriculum.get_attr_level("num_people") == 1
assert curriculum.get_attr_level("num_characteristics") == 1
global_level_cfg = curriculum.generate_configuration(base_value)
assert global_level_cfg.num_people == 3
assert global_level_cfg.num_characteristics == 3
# Increase global level again
curriculum.increment_global_level()
assert curriculum.get_attr_level("num_people") == 2
assert curriculum.get_attr_level("num_characteristics") == 2
global_level_cfg_2 = curriculum.generate_configuration(base_value)
assert global_level_cfg_2.num_people == 4
assert global_level_cfg_2.num_characteristics == 4
# Decrease global level
curriculum.decrement_global_level()
assert curriculum.get_attr_level("num_people") == 1
assert curriculum.get_attr_level("num_characteristics") == 1
global_level_cfg_3 = curriculum.generate_configuration(base_value)
assert global_level_cfg_3.num_people == 3
assert global_level_cfg_3.num_characteristics == 3
# Test upper bound
curriculum = ZebraCurriculum() # Reset curriculum
for _ in range(10): # Try going beyond max level
curriculum.increment_attr_level("num_people")
curriculum.increment_attr_level("num_characteristics")
max_cfg = curriculum.generate_configuration(base_value)
assert max_cfg.num_people == 7 # Capped at 7
assert max_cfg.num_characteristics == 7 # Capped at 7
# Test lower bound
curriculum = ZebraCurriculum() # Reset curriculum
curriculum.decrement_attr_level("num_people") # Try going below min level
curriculum.decrement_attr_level("num_characteristics") # Try going below min level
min_cfg = curriculum.generate_configuration(base_value)
assert min_cfg.num_people == 2 # Stays at 2
assert min_cfg.num_characteristics == 2 # Stays at 2