diff --git a/tests/test_n_queens.py b/tests/test_n_queens.py index f5b8108c..946685be 100644 --- a/tests/test_n_queens.py +++ b/tests/test_n_queens.py @@ -102,6 +102,27 @@ def test_nqueens_board_generation(): assert is_valid_solution(board) +def test_nqueens_score_answer(): + """Test the score_answer method""" + config = NQueensConfig(n=8, size=10, seed=42) + dataset = NQueensDataset(config) + + # Test a few items + for i in range(len(dataset)): + item = dataset[i] + + # Test correct answer gets score 1.0 + valid_answer = item["metadata"]["valid_answers"][0] + assert dataset.score_answer(valid_answer, item) == 1.0 + + # Test invalid answer gets score 0.01 + invalid_answer = "_ _ _ _\n_ _ _ _\n_ _ _ _\n_ _ _ _" + assert dataset.score_answer(invalid_answer, item) == 0.01 + + # Test None answer gets score 0.0 + assert dataset.score_answer(None, item) == 0.0 + + def is_valid_solution(board: list[list[str]]) -> bool: """Helper function to verify N Queens solution validity""" rows, cols, diags, off_diags = set(), set(), set(), set()