mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
Add curriculum to ab dataset (#345)
* Add curriculum to ab dataset * Add difficulty to metadata
This commit is contained in:
parent
4f45c8d655
commit
454250a4ea
3 changed files with 66 additions and 4 deletions
|
|
@ -2,7 +2,7 @@ import random
|
|||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.algorithmic.ab import ABConfig, ABDataset, compute_steps, generate_program
|
||||
from reasoning_gym.algorithmic.ab import ABConfig, ABCurriculum, ABDataset, compute_steps, generate_program
|
||||
|
||||
|
||||
def test_ab_config_validation():
|
||||
|
|
@ -98,3 +98,39 @@ def test_ab_item_structure():
|
|||
# Test answer format
|
||||
answer_tokens = item["answer"].split()
|
||||
assert all(token in ["A#", "#A", "B#", "#B"] for token in answer_tokens)
|
||||
|
||||
|
||||
def test_ab_curriculum():
|
||||
"""Test the curriculum ab dataset."""
|
||||
curriculum = ABCurriculum()
|
||||
base_value = {"size": 150, "seed": 1}
|
||||
|
||||
base_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
|
||||
|
||||
assert base_cfg.seed == 1
|
||||
assert base_cfg.size == 150
|
||||
assert base_cfg.length == 1
|
||||
|
||||
# Test and validate increase in levels
|
||||
curriculum.increment_attr_level("length")
|
||||
increase_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
|
||||
|
||||
assert increase_cfg.length == 10
|
||||
|
||||
# Test and validate decrease in levels
|
||||
curriculum.decrement_attr_level("length")
|
||||
decrease_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
|
||||
|
||||
assert decrease_cfg.length == 1
|
||||
|
||||
# Test upper bound boundary condition
|
||||
for _ in range(10):
|
||||
curriculum.increment_attr_level("length")
|
||||
upper_bound_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
|
||||
assert upper_bound_cfg.length == 100
|
||||
|
||||
# Test lower bound boundary condition
|
||||
for _ in range(10):
|
||||
curriculum.decrement_attr_level("length")
|
||||
lower_bound_cfg: ABCurriculum = curriculum.generate_configuration(base_value)
|
||||
assert lower_bound_cfg.length == 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue