From 31818d3e0baf6e3376c4a0980db808ed9d6edd0e Mon Sep 17 00:00:00 2001 From: "Andreas Koepf (aider)" Date: Sun, 2 Feb 2025 22:15:49 +0100 Subject: [PATCH] test: Add unit test for score_answer method in N-Queens dataset --- tests/test_n_queens.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_n_queens.py b/tests/test_n_queens.py index f5b8108c..946685be 100644 --- a/tests/test_n_queens.py +++ b/tests/test_n_queens.py @@ -102,6 +102,27 @@ def test_nqueens_board_generation(): assert is_valid_solution(board) +def test_nqueens_score_answer(): + """Test the score_answer method""" + config = NQueensConfig(n=8, size=10, seed=42) + dataset = NQueensDataset(config) + + # Test a few items + for i in range(len(dataset)): + item = dataset[i] + + # Test correct answer gets score 1.0 + valid_answer = item["metadata"]["valid_answers"][0] + assert dataset.score_answer(valid_answer, item) == 1.0 + + # Test invalid answer gets score 0.01 + invalid_answer = "_ _ _ _\n_ _ _ _\n_ _ _ _\n_ _ _ _" + assert dataset.score_answer(invalid_answer, item) == 0.01 + + # Test None answer gets score 0.0 + assert dataset.score_answer(None, item) == 0.0 + + def is_valid_solution(board: list[list[str]]) -> bool: """Helper function to verify N Queens solution validity""" rows, cols, diags, off_diags = set(), set(), set(), set()