number formatting curriculum (#341)

This commit is contained in:
Zafir Stojanovski 2025-03-13 20:57:43 +01:00 committed by GitHub
parent 3984d7cdfb
commit 7a4b8fc5a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 90 additions and 9 deletions

View file

@ -2,17 +2,21 @@
import pytest
from reasoning_gym.arithmetic.number_format import NumberFormatConfig, NumberFormatDataset
from reasoning_gym.arithmetic.number_format import NumberFormatConfig, NumberFormatCurriculum, NumberFormatDataset
def test_number_format_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = NumberFormatConfig(max_num_candidates=0) # Zero not allowed
config = NumberFormatConfig(min_num_candidates=0) # Zero not allowed
config.validate()
with pytest.raises(AssertionError):
config = NumberFormatConfig(max_num_candidates=1) # One not allowed
config = NumberFormatConfig(min_num_candidates=1) # One not allowed
config.validate()
with pytest.raises(AssertionError):
config = NumberFormatConfig(max_num_candidates=5, min_num_candidates=6) # min > max
config.validate()
with pytest.raises(AssertionError):
@ -119,3 +123,32 @@ def test_number_format_answer():
# Answer is unparsable
model_answer = "test"
assert dataset.score_answer(model_answer, entry) == 0.0
def test_number_format_curriculum():
curriculum = NumberFormatCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: NumberFormatConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_num_candidates == 5 and base_cfg.max_num_candidates == 25
assert base_cfg.min_n == 10 and base_cfg.max_n == 1_000
assert base_cfg.max_delta == 1e1
# test incrementing attribute levels
curriculum.increment_attr_level("num_candidates")
curriculum.increment_attr_level("n")
curriculum.increment_attr_level("max_delta")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_num_candidates == 5 and increased_cfg.max_num_candidates == 100
assert increased_cfg.min_n == 10 and increased_cfg.max_n == 1_000_000
assert increased_cfg.max_delta == 1e0
# test decrementing attribute level
curriculum.decrement_attr_level("num_candidates")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_num_candidates == 5 and partially_decreased_cfg.max_num_candidates == 25
assert partially_decreased_cfg.min_n == 10 and partially_decreased_cfg.max_n == 1_000_000
assert partially_decreased_cfg.max_delta == 1e0