[fix] normalize to <answer></answer>

This commit is contained in:
theblackcat102 2025-02-19 08:40:31 +08:00
parent bbc31e4291
commit 9a2e9e949e
2 changed files with 20 additions and 13 deletions

View file

@ -111,20 +111,25 @@ def test_cryptarithm_score_answer():
puzzle = dataset[0]
correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..."
# 1) Missing 'ANSWER:' => score should be 0.0
# 1) Missing '<answer>' => score should be 0.0
score = dataset.score_answer(answer=None, answer_str=correct_answer_str)
assert score == 0.0, f"Expected 0.0 when missing 'ANSWER:' prefix, got {score}"
assert score == 0.0, f"Expected 0.0 when missing '<answer>' prefix, got {score}"
# 2) Correct mapping => expecting 1.0
user_answer = f"ANSWER:\n{correct_answer_str}"
user_answer = f"<answer>{correct_answer_str}</answer>"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}"
# 2.1) Missing end tag => expecting 1.0
user_answer = f"<answer>{correct_answer_str}"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
assert score == 0.0, f"Expected 0.0 for missing end answer tag, got {score}"
# 3) Mismatch number of pairs => score should be 0.1
# For instance, drop the last pair
splitted = correct_answer_str.split(",")
mismatch_str = ",".join(splitted[:-1])
user_answer = f"ANSWER:\n{mismatch_str}"
user_answer = f"<answer>{mismatch_str}</answer>"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}"
@ -132,7 +137,7 @@ def test_cryptarithm_score_answer():
splitted = correct_answer_str.split(",")
splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair
parse_error_str = ",".join(splitted)
user_answer = f"ANSWER:\n{parse_error_str}"
user_answer = f"<answer>{parse_error_str}</answer>"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}"
@ -142,7 +147,7 @@ def test_cryptarithm_score_answer():
if len(splitted) > 1:
splitted[0] = splitted[1] # Duplicate the second pair in the first position
duplicates_str = ",".join(splitted)
user_answer = f"ANSWER:\n{duplicates_str}"
user_answer = f"<answer>{duplicates_str}</answer>"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}"
@ -166,7 +171,7 @@ def test_cryptarithm_score_answer():
i += 1
partial_answer_str = ",".join(new_pairs)
user_answer = f"ANSWER:\n{partial_answer_str}"
user_answer = f"<answer>{partial_answer_str}</answer>"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
# The formula is (num_correct / total) * 0.7 + 0.3