diff --git a/reasoning_gym/algorithmic/ab.py b/reasoning_gym/algorithmic/ab.py index c0070fb0..8970e372 100644 --- a/reasoning_gym/algorithmic/ab.py +++ b/reasoning_gym/algorithmic/ab.py @@ -58,7 +58,8 @@ class ABConfig: def validate(self) -> None: """Validate configuration parameters""" - assert self.length > 0, "difficulty must be greater than 0" + assert self.length > 0, "length must be greater than 0" + assert self.size > 0, "size must be greater than 0" class ABDataset(ProceduralDataset): diff --git a/tests/test_ab.py b/tests/test_ab.py index ced8ba85..489c4acf 100644 --- a/tests/test_ab.py +++ b/tests/test_ab.py @@ -1,5 +1,7 @@ import random + import pytest + from reasoning_gym.algorithmic.ab import ABConfig, ABDataset, compute_steps, generate_program @@ -28,17 +30,17 @@ def test_ab_program_generation(): """Test program generation and computation""" rng = random.Random(42) program = generate_program(5, rng) - + # Test program format assert len(program) == 5 assert all(token in ["A#", "#A", "B#", "#B"] for token in program) - + # Test computation steps, non_halting = compute_steps(program) assert isinstance(steps, list) assert isinstance(non_halting, bool) assert len(steps) > 0 - + # Test each step follows valid transformation rules for step in steps: assert all(token in ["A#", "#A", "B#", "#B"] for token in step) @@ -48,15 +50,15 @@ def test_ab_scoring(): """Test scoring functionality""" config = ABConfig(seed=42, size=10, length=5) dataset = ABDataset(config) - + for item in dataset: # Test correct answer assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - + # Test wrong answer wrong_answer = "A# B#" if item["answer"] != "A# B#" else "B# A#" assert dataset.score_answer(answer=wrong_answer, entry=item) == 0.01 - + # Test None answer assert dataset.score_answer(answer=None, entry=item) == 0.0 @@ -65,14 +67,14 @@ def test_ab_iteration(): """Test dataset iteration behavior""" config = ABConfig(size=5, seed=42) dataset = ABDataset(config) - + # Test length assert len(dataset) == config.size - + # Test iteration items = list(dataset) assert len(items) == config.size - + # Test multiple iterations yield same results items2 = list(dataset) assert items == items2 @@ -88,11 +90,11 @@ def test_ab_item_structure(): assert "question" in item assert "answer" in item assert "metadata" in item - + # Test question format assert "A::B is a system" in item["question"] assert "Return the final state" in item["question"] - + # Test answer format answer_tokens = item["answer"].split() assert all(token in ["A#", "#A", "B#", "#B"] for token in answer_tokens)