added aiw curric (#356)

added metadata
This commit is contained in:
joesharratt1229 2025-03-13 20:10:52 +00:00 committed by GitHub
parent 93793f5416
commit d603d8b72b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 98 additions and 5 deletions

View file

@ -1,6 +1,11 @@
import pytest
from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType
from reasoning_gym.logic.aiw import (
AliceInWonderlandConfig,
AliceInWonderlandCurriculum,
AliceInWonderlandDataset,
TaskType,
)
def test_aiw_config_validation():
@ -94,3 +99,40 @@ def test_aiw_random_ranges():
# Check all numbers are in reasonable range (1-6 as per implementation)
assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"
def test_aiw_curriculum():
"""Test the AIW curriculum functionality"""
curriculum = AliceInWonderlandCurriculum()
base_value = {"size": 100, "seed": 42}
# Test default configuration
base_cfg: AliceInWonderlandConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 42
assert base_cfg.size == 100
assert base_cfg.max_entities == 4
assert base_cfg.task_type_weights == [1.0, 0.0, 0.0] # Default is siblings only
# Test incrementing task_type_weight attribute
curriculum.increment_attr_level("task_type_weight")
task_weight_cfg = curriculum.generate_configuration(base_value)
assert task_weight_cfg.task_type_weights == [0.9, 0.05, 0.05] # Second level adds some friends/colleagues
# Test incrementing num_entities attribute
curriculum.increment_attr_level("num_entities")
entities_cfg = curriculum.generate_configuration(base_value)
assert entities_cfg.max_entities == 6 # Increased max entities
assert entities_cfg.task_type_weights == [0.9, 0.05, 0.05] # Should preserve task weight level
# Test decrementing task_type_weight attribute
curriculum.decrement_attr_level("task_type_weight")
updated_cfg = curriculum.generate_configuration(base_value)
assert updated_cfg.task_type_weights == [1.0, 0.0, 0.0] # Back to default weights
assert updated_cfg.max_entities == 6 # Should preserve entities level
# Test global level setting
curriculum.set_global_level(2) # Set all attributes to level 2
global_cfg = curriculum.generate_configuration(base_value)
assert global_cfg.task_type_weights == [0.7, 0.15, 0.15] # Level 2 of task weights
assert global_cfg.max_entities == 8 # Level 2 of num_entities