diff --git a/tests/test_propositional_logic.py b/tests/test_propositional_logic.py index aff2ebc3..c73ab1cc 100644 --- a/tests/test_propositional_logic.py +++ b/tests/test_propositional_logic.py @@ -87,3 +87,17 @@ def test_propositional_logic_dataset_iteration(): # Test multiple iterations yield same items assert items == list(dataset) + + +def test_propositional_logic_dataset_score_answer_correct(): + dataset = PropositionalLogicDataset(PropositionalLogicConfig(size=50, seed=101)) + for i, item in enumerate(dataset): + score = dataset.score_answer(item["metadata"]["example_answer"], item) + assert score == 1.0 + + +def test_propositional_logic_dataset_score_answer_incorrect(): + dataset = PropositionalLogicDataset(PropositionalLogicConfig(size=100, seed=101)) + for i, item in enumerate(dataset): + score = dataset.score_answer("Wrong", item) + assert score == 0.01