Curriculum/cognition (#314)

* added rectangle count curriculum

* added number sequences

* registered curriculum
This commit is contained in:
joesharratt1229 2025-03-11 00:10:28 +01:00 committed by GitHub
parent d0b49cfffd
commit 0dce7adbad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 117 additions and 7 deletions

View file

@ -1,6 +1,12 @@
import pytest
from reasoning_gym.cognition.number_sequences import NumberSequenceConfig, NumberSequenceDataset, Operation, PatternRule
from reasoning_gym.cognition.number_sequences import (
NumberSequenceConfig,
NumberSequenceCurriculum,
NumberSequenceDataset,
Operation,
PatternRule,
)
def test_sequence_config_validation():
@ -75,3 +81,36 @@ def test_sequence_dataset_iteration():
# Test multiple iterations yield same items
assert items == list(dataset)
def test_number_sequence_curriculum():
"""Test the number sequence curriculum functionality"""
curriculum = NumberSequenceCurriculum()
# Test with custom base values
base_value = {"size": 150, "seed": 42}
# Test basic configuration generation
base_cfg: NumberSequenceConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 42
assert base_cfg.size == 150
assert base_cfg.max_complexity == 1 # Default level (0) corresponds to complexity 1
# Test attribute level increment
curriculum.increment_attr_level("max_complexity")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.max_complexity == 2 # Level 1 corresponds to complexity 2
# Test attribute level increment again
curriculum.increment_attr_level("max_complexity")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.max_complexity == 3 # Level 2 corresponds to complexity 3
# Test that other parameters remain unchanged
assert increased_cfg.seed == 42
assert increased_cfg.size == 150
# Test attribute level decrement
curriculum.decrement_attr_level("max_complexity")
decreased_cfg = curriculum.generate_configuration(base_value)
assert decreased_cfg.max_complexity == 2 # Back to level 1, complexity 2