diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py
index a7b5236f..52c492d1 100644
--- a/reasoning_gym/algorithmic/cryptarithm.py
+++ b/reasoning_gym/algorithmic/cryptarithm.py
@@ -13,37 +13,29 @@ No leading letter can be zero (unless allow_leading_zero=True).
from dataclasses import dataclass
from random import Random
-from typing import Dict, Optional
+from typing import Any, Dict, Optional
from ..factory import ProceduralDataset, register_dataset
-EXAMPLE_CASE = """
+EXAMPLE_CASE = """- Input:
BASE
+ BALL
------
GAMES
-* BASE + BALL = GAMES, two 4-digit numbers sum to 5 digits, so G = 1.
-
-* Units: E + L = S (no carry).
-
-* Tens: S + L = E + 10 (carry 1). Substitute S = E + L to get E + 2L = E + 10, so L = 5.
-
-* Since S = E + 5 and S is one digit, E < 5.
-
-* Hundreds: 2A + 1 = M (with carry).
-
-* Thousands: 2B = A + 10 (carry makes G = 1). So A = 2B - 10.
-
-* Try B = 7: Then A = 4 and M = 2(4) + 1 = 9.
-
-* With E < 5, try E = 3: Then S = 8.
-
-* Solution: B = 7, A = 4, S = 8, E = 3, L = 5, M = 9, G = 1
-
-* Verify: BASE (7483) + BALL (7455) = GAMES (14938).
-
-B=7, A=4, S=8, E=3, L=5, M=9, G=1"""
+- Output: B=7, A=4, S=8, E=3, L=5, M=9, G=1
+- Explanation:
+ * BASE + BALL = GAMES, two 4-digit numbers sum to 5 digits, so G = 1.
+ * Units: E + L = S (no carry).
+ * Tens: S + L = E + 10 (carry 1). Substitute S = E + L to get E + 2L = E + 10, so L = 5.
+ * Since S = E + 5 and S is one digit, E < 5.
+ * Hundreds: 2A + 1 = M (with carry).
+ * Thousands: 2B = A + 10 (carry makes G = 1). So A = 2B - 10.
+ * Try B = 7: Then A = 4 and M = 2(4) + 1 = 9.
+ * With E < 5, try E = 3: Then S = 8.
+ * Solution: B = 7, A = 4, S = 8, E = 3, L = 5, M = 9, G = 1
+ * Verify: BASE (7483) + BALL (7455) = GAMES (14938).
+"""
@dataclass
@@ -195,10 +187,10 @@ class CryptarithmDataset(ProceduralDataset):
if self.config.allow_leading_zero
else "No leading letter can be zero.\n"
)
- + "Provide a mapping from letters to digits that satisfies the equation in your final answer:\n\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n"
+ + 'Provide a comma separated mapping from letters to digits that satisfies the equation in your final answer. Output format: "A=1,B=2,C=3" (without quotes)\n'
)
if self.config.include_example:
- question_str += "Here's an example:\n" + EXAMPLE_CASE
+ question_str += "\nHere's an example:\n" + EXAMPLE_CASE
# 8) Create a human-readable answer, e.g. "A=1,B=0,C=9,..."
sorted_letter_keys = sorted(letter_to_digit.keys())
@@ -219,7 +211,7 @@ class CryptarithmDataset(ProceduralDataset):
},
}
- def score_answer(self, answer: Optional[str], answer_str: Dict[str, any]) -> float:
+ def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
"""Determine if the solution provided solves the Cryptarithm task.
The function awards 1.0 for a correct format and answers for all alphabet pairs.
@@ -232,7 +224,8 @@ class CryptarithmDataset(ProceduralDataset):
float: The computed score between 0.0 and 1.0.
"""
correct_mapping = {}
- for pair in answer_str.split(","):
+ correct_answer_str = entry["answer"] # e.g. "A=1,B=7,..."
+ for pair in correct_answer_str.split(","):
alphabet, number = pair.split("=")
correct_mapping[alphabet] = int(number)
diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py
index b704e3b4..64b7b5e4 100644
--- a/tests/test_cryptarithm.py
+++ b/tests/test_cryptarithm.py
@@ -116,21 +116,21 @@ def test_cryptarithm_score_answer():
# assert score == 0.0, f"Expected 0.0 when missing '' prefix, got {score}"
# 2) Correct mapping => expecting 1.0
- score = dataset.score_answer(answer=correct_answer_str, answer_str=correct_answer_str)
+ score = dataset.score_answer(answer=correct_answer_str, entry=puzzle)
assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}"
# 3) Mismatch number of pairs => score should be 0.1
# For instance, drop the last pair
splitted = correct_answer_str.split(",")
mismatch_str = ",".join(splitted[:-1])
- score = dataset.score_answer(answer=mismatch_str, answer_str=correct_answer_str)
+ score = dataset.score_answer(answer=mismatch_str, entry=puzzle)
assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}"
# 4) Parse error => 0.15 (e.g. remove '=' from the first pair)
splitted = correct_answer_str.split(",")
splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair
parse_error_str = ",".join(splitted)
- score = dataset.score_answer(answer=parse_error_str, answer_str=correct_answer_str)
+ score = dataset.score_answer(answer=parse_error_str, entry=puzzle)
assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}"
# 5) Correct number of pairs, but duplicate alphabets => 0.3
@@ -139,7 +139,7 @@ def test_cryptarithm_score_answer():
if len(splitted) > 1:
splitted[0] = splitted[1] # Duplicate the second pair in the first position
duplicates_str = ",".join(splitted)
- score = dataset.score_answer(answer=duplicates_str, answer_str=correct_answer_str)
+ score = dataset.score_answer(answer=duplicates_str, entry=puzzle)
assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}"
# 6) Partial correctness => some correct, some incorrect
@@ -162,7 +162,7 @@ def test_cryptarithm_score_answer():
i += 1
partial_answer_str = ",".join(new_pairs)
- score = dataset.score_answer(answer=partial_answer_str, answer_str=correct_answer_str)
+ score = dataset.score_answer(answer=partial_answer_str, entry=puzzle)
# The formula is (num_correct / total) * 0.7 + 0.3
expected_score = (half / total) * 0.7 + 0.3