mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-26 17:13:17 +00:00
number formatting curriculum (#341)
This commit is contained in:
parent
3984d7cdfb
commit
7a4b8fc5a8
3 changed files with 90 additions and 9 deletions
|
|
@ -2,17 +2,21 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic.number_format import NumberFormatConfig, NumberFormatDataset
|
||||
from reasoning_gym.arithmetic.number_format import NumberFormatConfig, NumberFormatCurriculum, NumberFormatDataset
|
||||
|
||||
|
||||
def test_number_format_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberFormatConfig(max_num_candidates=0) # Zero not allowed
|
||||
config = NumberFormatConfig(min_num_candidates=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberFormatConfig(max_num_candidates=1) # One not allowed
|
||||
config = NumberFormatConfig(min_num_candidates=1) # One not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = NumberFormatConfig(max_num_candidates=5, min_num_candidates=6) # min > max
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
@ -119,3 +123,32 @@ def test_number_format_answer():
|
|||
# Answer is unparsable
|
||||
model_answer = "test"
|
||||
assert dataset.score_answer(model_answer, entry) == 0.0
|
||||
|
||||
|
||||
def test_number_format_curriculum():
|
||||
curriculum = NumberFormatCurriculum()
|
||||
|
||||
base_value = {"size": 150, "seed": 1}
|
||||
|
||||
base_cfg: NumberFormatConfig = curriculum.generate_configuration(base_value)
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 150
|
||||
assert base_cfg.min_num_candidates == 5 and base_cfg.max_num_candidates == 25
|
||||
assert base_cfg.min_n == 10 and base_cfg.max_n == 1_000
|
||||
assert base_cfg.max_delta == 1e1
|
||||
|
||||
# test incrementing attribute levels
|
||||
curriculum.increment_attr_level("num_candidates")
|
||||
curriculum.increment_attr_level("n")
|
||||
curriculum.increment_attr_level("max_delta")
|
||||
increased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert increased_cfg.min_num_candidates == 5 and increased_cfg.max_num_candidates == 100
|
||||
assert increased_cfg.min_n == 10 and increased_cfg.max_n == 1_000_000
|
||||
assert increased_cfg.max_delta == 1e0
|
||||
|
||||
# test decrementing attribute level
|
||||
curriculum.decrement_attr_level("num_candidates")
|
||||
partially_decreased_cfg = curriculum.generate_configuration(base_value)
|
||||
assert partially_decreased_cfg.min_num_candidates == 5 and partially_decreased_cfg.max_num_candidates == 25
|
||||
assert partially_decreased_cfg.min_n == 10 and partially_decreased_cfg.max_n == 1_000_000
|
||||
assert partially_decreased_cfg.max_delta == 1e0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue