mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
parent
93793f5416
commit
d603d8b72b
3 changed files with 98 additions and 5 deletions
|
|
@ -1,6 +1,11 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType
|
||||
from reasoning_gym.logic.aiw import (
|
||||
AliceInWonderlandConfig,
|
||||
AliceInWonderlandCurriculum,
|
||||
AliceInWonderlandDataset,
|
||||
TaskType,
|
||||
)
|
||||
|
||||
|
||||
def test_aiw_config_validation():
|
||||
|
|
@ -94,3 +99,40 @@ def test_aiw_random_ranges():
|
|||
|
||||
# Check all numbers are in reasonable range (1-6 as per implementation)
|
||||
assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"
|
||||
|
||||
|
||||
def test_aiw_curriculum():
|
||||
"""Test the AIW curriculum functionality"""
|
||||
curriculum = AliceInWonderlandCurriculum()
|
||||
|
||||
base_value = {"size": 100, "seed": 42}
|
||||
|
||||
# Test default configuration
|
||||
base_cfg: AliceInWonderlandConfig = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 42
|
||||
assert base_cfg.size == 100
|
||||
assert base_cfg.max_entities == 4
|
||||
assert base_cfg.task_type_weights == [1.0, 0.0, 0.0] # Default is siblings only
|
||||
|
||||
# Test incrementing task_type_weight attribute
|
||||
curriculum.increment_attr_level("task_type_weight")
|
||||
task_weight_cfg = curriculum.generate_configuration(base_value)
|
||||
assert task_weight_cfg.task_type_weights == [0.9, 0.05, 0.05] # Second level adds some friends/colleagues
|
||||
|
||||
# Test incrementing num_entities attribute
|
||||
curriculum.increment_attr_level("num_entities")
|
||||
entities_cfg = curriculum.generate_configuration(base_value)
|
||||
assert entities_cfg.max_entities == 6 # Increased max entities
|
||||
assert entities_cfg.task_type_weights == [0.9, 0.05, 0.05] # Should preserve task weight level
|
||||
|
||||
# Test decrementing task_type_weight attribute
|
||||
curriculum.decrement_attr_level("task_type_weight")
|
||||
updated_cfg = curriculum.generate_configuration(base_value)
|
||||
assert updated_cfg.task_type_weights == [1.0, 0.0, 0.0] # Back to default weights
|
||||
assert updated_cfg.max_entities == 6 # Should preserve entities level
|
||||
|
||||
# Test global level setting
|
||||
curriculum.set_global_level(2) # Set all attributes to level 2
|
||||
global_cfg = curriculum.generate_configuration(base_value)
|
||||
assert global_cfg.task_type_weights == [0.7, 0.15, 0.15] # Level 2 of task weights
|
||||
assert global_cfg.max_entities == 8 # Level 2 of num_entities
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue