fix: Add validation for size parameter in ABConfig

This commit is contained in:
Andreas Koepf (aider) 2025-02-11 23:39:02 +01:00 committed by Andreas Koepf
parent 38922c7e6e
commit 59461aaec8
2 changed files with 15 additions and 12 deletions

View file

@ -1,5 +1,7 @@
import random
import pytest
from reasoning_gym.algorithmic.ab import ABConfig, ABDataset, compute_steps, generate_program
@ -28,17 +30,17 @@ def test_ab_program_generation():
"""Test program generation and computation"""
rng = random.Random(42)
program = generate_program(5, rng)
# Test program format
assert len(program) == 5
assert all(token in ["A#", "#A", "B#", "#B"] for token in program)
# Test computation
steps, non_halting = compute_steps(program)
assert isinstance(steps, list)
assert isinstance(non_halting, bool)
assert len(steps) > 0
# Test each step follows valid transformation rules
for step in steps:
assert all(token in ["A#", "#A", "B#", "#B"] for token in step)
@ -48,15 +50,15 @@ def test_ab_scoring():
"""Test scoring functionality"""
config = ABConfig(seed=42, size=10, length=5)
dataset = ABDataset(config)
for item in dataset:
# Test correct answer
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
# Test wrong answer
wrong_answer = "A# B#" if item["answer"] != "A# B#" else "B# A#"
assert dataset.score_answer(answer=wrong_answer, entry=item) == 0.01
# Test None answer
assert dataset.score_answer(answer=None, entry=item) == 0.0
@ -65,14 +67,14 @@ def test_ab_iteration():
"""Test dataset iteration behavior"""
config = ABConfig(size=5, seed=42)
dataset = ABDataset(config)
# Test length
assert len(dataset) == config.size
# Test iteration
items = list(dataset)
assert len(items) == config.size
# Test multiple iterations yield same results
items2 = list(dataset)
assert items == items2
@ -88,11 +90,11 @@ def test_ab_item_structure():
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Test question format
assert "A::B is a system" in item["question"]
assert "Return the final state" in item["question"]
# Test answer format
answer_tokens = item["answer"].split()
assert all(token in ["A#", "#A", "B#", "#B"] for token in answer_tokens)