update prompt and score answer

This commit is contained in:
Zafir Stojanovski 2025-02-16 15:18:45 +01:00
parent 5803a2962e
commit 1a3e4372ef
2 changed files with 59 additions and 19 deletions

View file

@ -15,6 +15,10 @@ def test_spiral_matrix_config_validation():
config = SpiralMatrixConfig(max_n=0) # Zero not allowed
config.validate()
with pytest.raises(AssertionError):
config = SpiralMatrixConfig(max_n=1) # One not allowed
config.validate()
def test_spiral_matrix_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
@ -69,18 +73,26 @@ def test_spiral_matrix_answer():
config = SpiralMatrixConfig(seed=42)
dataset = SpiralMatrixDataset(config)
# One element
matrix = [[0]]
assert dataset._get_spiral(matrix) == [0]
# One row
matrix = [[0, 1, 2]]
assert dataset._get_spiral(matrix) == [0, 1, 2]
# One column
matrix = [[0], [1], [2]]
assert dataset._get_spiral(matrix) == [0, 1, 2]
# 2D grid
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
assert dataset._get_spiral(matrix) == [1, 2, 3, 6, 9, 8, 7, 4, 5]
# Answer is identical (up to trimming)
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = "\n\n1 2 3 6 9 8 7 4 5\n"
assert dataset.score_answer(answer, entry) == 1.0
# Score answer in list format (partially correct)
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = "[1, 2, 3, 6, 9, 8, 7, 4, 5]"
assert dataset.score_answer(answer, entry) == 0.5
# Answer is incorrect
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = "1 2 3"
assert dataset.score_answer(answer, entry) == 0.01
# Answer is none
entry = {"answer": "1 2 3 6 9 8 7 4 5"}
answer = None
assert dataset.score_answer(answer, entry) == 0.0