diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index 4eccc9e4..81ea5bd3 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -43,8 +43,7 @@ EXAMPLE_CASE = """ * Verify: BASE (7483) + BALL (7455) = GAMES (14938). -ANSWER: -B=7, A=4, S=8, E=3, L=5, M=9, G=1""" +B=7, A=4, S=8, E=3, L=5, M=9, G=1""" @dataclass @@ -196,7 +195,7 @@ class CryptarithmDataset(ProceduralDataset): if self.config.allow_leading_zero else "No leading letter can be zero.\n" ) - + "Provide a mapping from letters to digits that satisfies the equation in your final answer:\nANSWER:\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n" + + "Provide a mapping from letters to digits that satisfies the equation in your final answer:\n\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n" ) if self.config.include_example: question_str += "Here's an example:\n" + EXAMPLE_CASE @@ -237,12 +236,15 @@ class CryptarithmDataset(ProceduralDataset): alphabet, number = pair.split("=") correct_mapping[alphabet] = int(number) - if answer == None or "ANSWER:" not in answer: + if answer == None or "" not in answer: return 0.0 number_mapping_line = "" - if "ANSWER:" in answer: - number_mapping_line = answer.split("ANSWER:\n")[-1] + if "" in answer: + number_mapping_line = answer.split("")[-1] + if "" not in number_mapping_line: + return 0.0 + number_mapping_line = number_mapping_line.split("")[0].strip() # case 1 : pairs are in a list format and the number of pairs matched up if len(number_mapping_line.split(",")) != len(correct_mapping): diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py index 942f180e..74563ef0 100644 --- a/tests/test_cryptarithm.py +++ b/tests/test_cryptarithm.py @@ -111,20 +111,25 @@ def test_cryptarithm_score_answer(): puzzle = dataset[0] correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..." - # 1) Missing 'ANSWER:' => score should be 0.0 + # 1) Missing '' => score should be 0.0 score = dataset.score_answer(answer=None, answer_str=correct_answer_str) - assert score == 0.0, f"Expected 0.0 when missing 'ANSWER:' prefix, got {score}" + assert score == 0.0, f"Expected 0.0 when missing '' prefix, got {score}" # 2) Correct mapping => expecting 1.0 - user_answer = f"ANSWER:\n{correct_answer_str}" + user_answer = f"{correct_answer_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}" + # 2.1) Missing end tag => expecting 1.0 + user_answer = f"{correct_answer_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + assert score == 0.0, f"Expected 0.0 for missing end answer tag, got {score}" + # 3) Mismatch number of pairs => score should be 0.1 # For instance, drop the last pair splitted = correct_answer_str.split(",") mismatch_str = ",".join(splitted[:-1]) - user_answer = f"ANSWER:\n{mismatch_str}" + user_answer = f"{mismatch_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}" @@ -132,7 +137,7 @@ def test_cryptarithm_score_answer(): splitted = correct_answer_str.split(",") splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair parse_error_str = ",".join(splitted) - user_answer = f"ANSWER:\n{parse_error_str}" + user_answer = f"{parse_error_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}" @@ -142,7 +147,7 @@ def test_cryptarithm_score_answer(): if len(splitted) > 1: splitted[0] = splitted[1] # Duplicate the second pair in the first position duplicates_str = ",".join(splitted) - user_answer = f"ANSWER:\n{duplicates_str}" + user_answer = f"{duplicates_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}" @@ -166,7 +171,7 @@ def test_cryptarithm_score_answer(): i += 1 partial_answer_str = ",".join(new_pairs) - user_answer = f"ANSWER:\n{partial_answer_str}" + user_answer = f"{partial_answer_str}" score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) # The formula is (num_correct / total) * 0.7 + 0.3