mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
ransom note curriculum (#300)
Co-authored-by: Andreas Köpf <andreas.koepf@xamla.com>
This commit is contained in:
parent
501f2d8322
commit
b58371c533
3 changed files with 108 additions and 56 deletions
|
|
@ -4,34 +4,34 @@ import json
|
|||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.ransom_note import RansomNoteConfig, RansomNoteDataset
|
||||
from reasoning_gym.algorithmic.ransom_note import RansomNoteConfig, RansomNoteCurriculum, RansomNoteDataset
|
||||
|
||||
|
||||
def test_ransom_note_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(max_note_length=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(max_note_length=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(max_magazine_length=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(max_magazine_length=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(max_magazine_length=1) # One not allowed
|
||||
config = RansomNoteConfig(min_note_length=0) # min_note_length must be at least 1
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(
|
||||
max_note_length=3, max_magazine_length=2
|
||||
min_note_length=5, max_note_length=4
|
||||
) # min_note_length must be less than or equal to max_note_length
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(min_magazine_length=1) # min_magazine_length must be at least 2
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(
|
||||
min_magazine_length=5, max_magazine_length=4
|
||||
) # min_magazine_length must be less than or equal to max_magazine_length
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = RansomNoteConfig(
|
||||
max_note_length=5, max_magazine_length=5
|
||||
) # max_note_length must be less than max_magazine_length
|
||||
config.validate()
|
||||
|
||||
|
|
@ -56,7 +56,9 @@ def test_ransom_note_dataset_deterministic():
|
|||
|
||||
def test_group_anagrams_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = RansomNoteConfig(max_note_length=10, max_magazine_length=30, size=10, seed=42)
|
||||
config = RansomNoteConfig(
|
||||
min_note_length=1, max_note_length=10, min_magazine_length=2, max_magazine_length=30, size=10, seed=42
|
||||
)
|
||||
dataset = RansomNoteDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
|
|
@ -114,3 +116,28 @@ def test_ransom_note_answer():
|
|||
# Inorrect solution
|
||||
ransom_note, magazine = "az", "badhergh"
|
||||
assert dataset._can_construct(ransom_note, magazine) == False
|
||||
|
||||
|
||||
def test_ransom_note_curriculum():
|
||||
curriculum = RansomNoteCurriculum()
|
||||
|
||||
base_value = {"size": 150, "seed": 1}
|
||||
|
||||
base_cfg: RansomNoteConfig = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 150
|
||||
assert base_cfg.min_note_length == 10 and base_cfg.max_note_length == 10
|
||||
assert base_cfg.min_magazine_length == 50 and base_cfg.max_magazine_length == 50
|
||||
|
||||
# test incrementing attribute levels
|
||||
curriculum.increment_attr_level("note_length")
|
||||
curriculum.increment_attr_level("magazine_length")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.min_note_length == 10 and increased_cfg.max_note_length == 50
|
||||
assert increased_cfg.min_magazine_length == 50 and increased_cfg.max_magazine_length == 100
|
||||
|
||||
# test decrementing attribute level for note_length again
|
||||
curriculum.decrement_attr_level("note_length")
|
||||
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert partially_decreased_cfg.min_note_length == 10 and partially_decreased_cfg.max_note_length == 10
|
||||
assert partially_decreased_cfg.min_magazine_length == 50 and partially_decreased_cfg.max_magazine_length == 100
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue