added self reference curr (#329)

This commit is contained in:
joesharratt1229 2025-03-11 00:23:26 +01:00 committed by GitHub
parent 54074b17ef
commit b497e35fb8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 95 additions and 6 deletions

View file

@ -6,7 +6,7 @@ from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset
from .circuit_logic import CircuitLogicConfig, CircuitLogicDataset from .circuit_logic import CircuitLogicConfig, CircuitLogicDataset
from .knights_knaves import KnightsKnavesConfig, KnightsKnavesDataset from .knights_knaves import KnightsKnavesConfig, KnightsKnavesDataset
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
from .self_reference import SelfReferenceConfig, SelfReferenceDataset from .self_reference import SelfReferenceConfig, SelfReferenceCurriculum, SelfReferenceDataset
from .syllogisms import SyllogismConfig, SyllogismDataset from .syllogisms import SyllogismConfig, SyllogismDataset
from .zebra_puzzles import ZebraConfig, ZebraCurriculum, ZebraDataset from .zebra_puzzles import ZebraConfig, ZebraCurriculum, ZebraDataset
@ -21,7 +21,7 @@ __all__ = [
"ZebraConfig", "ZebraConfig",
"ZebraCurriculum", "ZebraCurriculum",
"ZebraDataset", "ZebraDataset",
"SelfReference", "SelfReferenceCurriculum",
"SelfReferenceConfig", "SelfReferenceConfig",
"SelfReferenceDataset", "SelfReferenceDataset",
"CircuitLogicConfig", "CircuitLogicConfig",

View file

@ -2,6 +2,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -329,9 +330,10 @@ class SelfReferenceDataset(ProceduralDataset):
- metadata: dict with generation parameters - metadata: dict with generation parameters
""" """
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
difficulty = self.config.difficulty
# Generate puzzle # Generate puzzle
puzzle = generate_dynamic_puzzle(self.config.difficulty, rng) puzzle = generate_dynamic_puzzle(difficulty, rng)
puzz_s = ( puzz_s = (
"Given the truthfulness of these statements, please tell me the number of possible solutions: \n" "Given the truthfulness of these statements, please tell me the number of possible solutions: \n"
+ print_puzzle_dynamic(puzzle) + print_puzzle_dynamic(puzzle)
@ -344,7 +346,7 @@ class SelfReferenceDataset(ProceduralDataset):
return { return {
"question": puzz_s, "question": puzz_s,
"answer": answer, "answer": answer,
"metadata": {}, "metadata": {"difficulty": difficulty},
} }
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
@ -366,4 +368,20 @@ class SelfReferenceDataset(ProceduralDataset):
return 0.0 return 0.0
register_dataset("self_reference", SelfReferenceDataset, SelfReferenceConfig) class SelfReferenceCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(SelfReferenceCurriculum.__name__, SelfReferenceConfig)
self._define_attributes(
ScalarAttributeDefinition(
name="difficulty",
field_name="difficulty",
levels=list(range(1, 11)),
default_level=0,
description="The difficulty of the puzzle",
attr_type=AttributeType.STATIC,
min_value=1,
)
)
register_dataset("self_reference", SelfReferenceDataset, SelfReferenceConfig, SelfReferenceCurriculum)

View file

@ -1,6 +1,6 @@
import pytest import pytest
from reasoning_gym.logic.self_reference import SelfReferenceConfig, SelfReferenceDataset from reasoning_gym.logic.self_reference import SelfReferenceConfig, SelfReferenceCurriculum, SelfReferenceDataset
def test_self_reference(): def test_self_reference():
@ -53,3 +53,74 @@ def test_self_reference():
assert dataset.score_answer(answer=99, entry=item) == 0.0 assert dataset.score_answer(answer=99, entry=item) == 0.0
assert dataset.score_answer(answer="99", entry=item) == 0.0 assert dataset.score_answer(answer="99", entry=item) == 0.0
assert dataset.score_answer(answer=None, entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0
def test_self_reference_curriculum():
"""Test the SelfReferenceCurriculum functionality"""
curriculum = SelfReferenceCurriculum()
base_value = {"size": 150, "seed": 1}
# Test initial configuration
base_cfg = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.difficulty == 1 # Default level 0 maps to difficulty=1
# Test incrementing difficulty attribute
curriculum.increment_attr_level("difficulty")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.difficulty == 2
assert increased_cfg.seed == 1 # Unchanged
assert increased_cfg.size == 150 # Unchanged
# Test incrementing difficulty attribute again
curriculum.increment_attr_level("difficulty")
increased_cfg_2 = curriculum.generate_configuration(base_value)
assert increased_cfg_2.difficulty == 3
# Test decrementing difficulty attribute
curriculum.decrement_attr_level("difficulty")
decreased_cfg = curriculum.generate_configuration(base_value)
assert decreased_cfg.difficulty == 2
# Test global level adjustments
curriculum = SelfReferenceCurriculum() # Reset curriculum
assert curriculum.get_attr_level("difficulty") == 0 # Default level is 0, maps to difficulty=1
# Increase global level
curriculum.increment_global_level()
assert curriculum.get_attr_level("difficulty") == 1
global_level_cfg = curriculum.generate_configuration(base_value)
assert global_level_cfg.difficulty == 2
# Increase global level again
curriculum.increment_global_level()
assert curriculum.get_attr_level("difficulty") == 2
global_level_cfg_2 = curriculum.generate_configuration(base_value)
assert global_level_cfg_2.difficulty == 3
# Decrease global level
curriculum.decrement_global_level()
assert curriculum.get_attr_level("difficulty") == 1
global_level_cfg_3 = curriculum.generate_configuration(base_value)
assert global_level_cfg_3.difficulty == 2
# Test upper bound
curriculum = SelfReferenceCurriculum() # Reset curriculum
for _ in range(15): # Try going beyond max level (10)
curriculum.increment_attr_level("difficulty")
max_cfg = curriculum.generate_configuration(base_value)
assert max_cfg.difficulty == 10 # Should be capped at 10 (the highest level)
# Test lower bound
curriculum = SelfReferenceCurriculum() # Reset curriculum
curriculum.decrement_attr_level("difficulty") # Try going below min level
min_cfg = curriculum.generate_configuration(base_value)
assert min_cfg.difficulty == 1 # Should be capped at 1 (the lowest level)