[fix] normalize to <answer></answer>

This commit is contained in:
theblackcat102 2025-02-19 08:40:31 +08:00
parent bbc31e4291
commit 9a2e9e949e
2 changed files with 20 additions and 13 deletions

View file

@ -43,8 +43,7 @@ EXAMPLE_CASE = """
* Verify: BASE (7483) + BALL (7455) = GAMES (14938).
ANSWER:
B=7, A=4, S=8, E=3, L=5, M=9, G=1"""
<answer>B=7, A=4, S=8, E=3, L=5, M=9, G=1</answer>"""
@dataclass
@ -196,7 +195,7 @@ class CryptarithmDataset(ProceduralDataset):
if self.config.allow_leading_zero
else "No leading letter can be zero.\n"
)
+ "Provide a mapping from letters to digits that satisfies the equation in your final answer:\nANSWER:\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n"
+ "Provide a mapping from letters to digits that satisfies the equation in your final answer:\n<answer>\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...</answer>\n"
)
if self.config.include_example:
question_str += "Here's an example:\n" + EXAMPLE_CASE
@ -237,12 +236,15 @@ class CryptarithmDataset(ProceduralDataset):
alphabet, number = pair.split("=")
correct_mapping[alphabet] = int(number)
if answer == None or "ANSWER:" not in answer:
if answer == None or "<answer>" not in answer:
return 0.0
number_mapping_line = ""
if "ANSWER:" in answer:
number_mapping_line = answer.split("ANSWER:\n")[-1]
if "<answer>" in answer:
number_mapping_line = answer.split("<answer>")[-1]
if "</answer>" not in number_mapping_line:
return 0.0
number_mapping_line = number_mapping_line.split("</answer>")[0].strip()
# case 1 : pairs are in a list format and the number of pairs matched up
if len(number_mapping_line.split(",")) != len(correct_mapping):