diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py
index 4eccc9e4..81ea5bd3 100644
--- a/reasoning_gym/algorithmic/cryptarithm.py
+++ b/reasoning_gym/algorithmic/cryptarithm.py
@@ -43,8 +43,7 @@ EXAMPLE_CASE = """
* Verify: BASE (7483) + BALL (7455) = GAMES (14938).
-ANSWER:
-B=7, A=4, S=8, E=3, L=5, M=9, G=1"""
+B=7, A=4, S=8, E=3, L=5, M=9, G=1"""
@dataclass
@@ -196,7 +195,7 @@ class CryptarithmDataset(ProceduralDataset):
if self.config.allow_leading_zero
else "No leading letter can be zero.\n"
)
- + "Provide a mapping from letters to digits that satisfies the equation in your final answer:\nANSWER:\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n"
+ + "Provide a mapping from letters to digits that satisfies the equation in your final answer:\n\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n"
)
if self.config.include_example:
question_str += "Here's an example:\n" + EXAMPLE_CASE
@@ -237,12 +236,15 @@ class CryptarithmDataset(ProceduralDataset):
alphabet, number = pair.split("=")
correct_mapping[alphabet] = int(number)
- if answer == None or "ANSWER:" not in answer:
+ if answer == None or "" not in answer:
return 0.0
number_mapping_line = ""
- if "ANSWER:" in answer:
- number_mapping_line = answer.split("ANSWER:\n")[-1]
+ if "" in answer:
+ number_mapping_line = answer.split("")[-1]
+ if "" not in number_mapping_line:
+ return 0.0
+ number_mapping_line = number_mapping_line.split("")[0].strip()
# case 1 : pairs are in a list format and the number of pairs matched up
if len(number_mapping_line.split(",")) != len(correct_mapping):
diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py
index 942f180e..74563ef0 100644
--- a/tests/test_cryptarithm.py
+++ b/tests/test_cryptarithm.py
@@ -111,20 +111,25 @@ def test_cryptarithm_score_answer():
puzzle = dataset[0]
correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..."
- # 1) Missing 'ANSWER:' => score should be 0.0
+ # 1) Missing '' => score should be 0.0
score = dataset.score_answer(answer=None, answer_str=correct_answer_str)
- assert score == 0.0, f"Expected 0.0 when missing 'ANSWER:' prefix, got {score}"
+ assert score == 0.0, f"Expected 0.0 when missing '' prefix, got {score}"
# 2) Correct mapping => expecting 1.0
- user_answer = f"ANSWER:\n{correct_answer_str}"
+ user_answer = f"{correct_answer_str}"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}"
+ # 2.1) Missing end tag => expecting 1.0
+ user_answer = f"{correct_answer_str}"
+ score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
+ assert score == 0.0, f"Expected 0.0 for missing end answer tag, got {score}"
+
# 3) Mismatch number of pairs => score should be 0.1
# For instance, drop the last pair
splitted = correct_answer_str.split(",")
mismatch_str = ",".join(splitted[:-1])
- user_answer = f"ANSWER:\n{mismatch_str}"
+ user_answer = f"{mismatch_str}"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}"
@@ -132,7 +137,7 @@ def test_cryptarithm_score_answer():
splitted = correct_answer_str.split(",")
splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair
parse_error_str = ",".join(splitted)
- user_answer = f"ANSWER:\n{parse_error_str}"
+ user_answer = f"{parse_error_str}"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}"
@@ -142,7 +147,7 @@ def test_cryptarithm_score_answer():
if len(splitted) > 1:
splitted[0] = splitted[1] # Duplicate the second pair in the first position
duplicates_str = ",".join(splitted)
- user_answer = f"ANSWER:\n{duplicates_str}"
+ user_answer = f"{duplicates_str}"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}"
@@ -166,7 +171,7 @@ def test_cryptarithm_score_answer():
i += 1
partial_answer_str = ",".join(new_pairs)
- user_answer = f"ANSWER:\n{partial_answer_str}"
+ user_answer = f"{partial_answer_str}"
score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str)
# The formula is (num_correct / total) * 0.7 + 0.3