diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py
index 7075e9ec..a7b5236f 100644
--- a/reasoning_gym/algorithmic/cryptarithm.py
+++ b/reasoning_gym/algorithmic/cryptarithm.py
@@ -13,7 +13,7 @@ No leading letter can be zero (unless allow_leading_zero=True).
from dataclasses import dataclass
from random import Random
-from typing import Optional
+from typing import Dict, Optional
from ..factory import ProceduralDataset, register_dataset
@@ -23,10 +23,27 @@ EXAMPLE_CASE = """
------
GAMES
-Answer (one possible solution):
+* BASE + BALL = GAMES, two 4-digit numbers sum to 5 digits, so G = 1.
-B=7, A=8, S=2, E=9, L=1, G=1, M=0
-Summation: 7829 + 7811 = 15640 (the puzzle might produce a different arrangement, but the principle is the same)."""
+* Units: E + L = S (no carry).
+
+* Tens: S + L = E + 10 (carry 1). Substitute S = E + L to get E + 2L = E + 10, so L = 5.
+
+* Since S = E + 5 and S is one digit, E < 5.
+
+* Hundreds: 2A + 1 = M (with carry).
+
+* Thousands: 2B = A + 10 (carry makes G = 1). So A = 2B - 10.
+
+* Try B = 7: Then A = 4 and M = 2(4) + 1 = 9.
+
+* With E < 5, try E = 3: Then S = 8.
+
+* Solution: B = 7, A = 4, S = 8, E = 3, L = 5, M = 9, G = 1
+
+* Verify: BASE (7483) + BALL (7455) = GAMES (14938).
+
+B=7, A=4, S=8, E=3, L=5, M=9, G=1"""
@dataclass
@@ -178,7 +195,7 @@ class CryptarithmDataset(ProceduralDataset):
if self.config.allow_leading_zero
else "No leading letter can be zero.\n"
)
- + "Provide a mapping from letters to digits that satisfies the equation.\n"
+ + "Provide a mapping from letters to digits that satisfies the equation in your final answer:\n\nALPHABET_1=NUMBER_1, ALPHABET_2=NUMBER_2, ...\n"
)
if self.config.include_example:
question_str += "Here's an example:\n" + EXAMPLE_CASE
@@ -202,5 +219,49 @@ class CryptarithmDataset(ProceduralDataset):
},
}
+ def score_answer(self, answer: Optional[str], answer_str: Dict[str, any]) -> float:
+ """Determine if the solution provided solves the Cryptarithm task.
+
+ The function awards 1.0 for a correct format and answers for all alphabet pairs.
+
+ Args:
+ answer (Optional[str]): The user's answer already parsed by `extract_answer`
+ answer_str (Dict[str, any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..."
+
+ Returns:
+ float: The computed score between 0.0 and 1.0.
+ """
+ correct_mapping = {}
+ for pair in answer_str.split(","):
+ alphabet, number = pair.split("=")
+ correct_mapping[alphabet] = int(number)
+
+ # case 1 : pairs are in a list format and the number of pairs matched up
+ if len(answer.split(",")) != len(correct_mapping):
+ return 0.1
+
+ predict_mapping = {}
+ for pair in answer.split(","):
+ try:
+ alphabet, number = pair.strip().split("=")
+ # as the unique alphabet grows we may want this to scale linearly with the number alphabet
+ predict_mapping[alphabet] = int(number)
+ except ValueError:
+ return 0.15
+ # case 2 : all alphabet has correct format ALPHABET=NUMBER format
+ if len(predict_mapping) != len(correct_mapping):
+ return 0.3
+
+ # case 3 : partial score for the number of correct mapping answer
+ total_correct, total = 0, 0
+ for alphabet, number in correct_mapping.items():
+ total += 1
+ if alphabet in predict_mapping:
+ if predict_mapping[alphabet] == number:
+ total_correct += 1
+
+ # note: linear relationship is probably not good?
+ return (total_correct / total) * 0.7 + 0.3
+
register_dataset("cryptarithm", CryptarithmDataset, CryptarithmConfig)
diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py
index 0ae3ea7f..b704e3b4 100644
--- a/tests/test_cryptarithm.py
+++ b/tests/test_cryptarithm.py
@@ -103,3 +103,67 @@ def test_max_letters_constraint():
# Check total unique letters doesn't exceed 10 (digits 0-9)
assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10"
+
+
+def test_cryptarithm_score_answer():
+ """Test the CryptarithmDataset.score_answer method for various correctness levels."""
+ dataset = create_dataset("cryptarithm", seed=42, size=1)
+ puzzle = dataset[0]
+ correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..."
+
+ # 1) Missing '' => score should be 0.0
+ # score = dataset.score_answer(answer=None, answer_str=correct_answer_str)
+ # assert score == 0.0, f"Expected 0.0 when missing '' prefix, got {score}"
+
+ # 2) Correct mapping => expecting 1.0
+ score = dataset.score_answer(answer=correct_answer_str, answer_str=correct_answer_str)
+ assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}"
+
+ # 3) Mismatch number of pairs => score should be 0.1
+ # For instance, drop the last pair
+ splitted = correct_answer_str.split(",")
+ mismatch_str = ",".join(splitted[:-1])
+ score = dataset.score_answer(answer=mismatch_str, answer_str=correct_answer_str)
+ assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}"
+
+ # 4) Parse error => 0.15 (e.g. remove '=' from the first pair)
+ splitted = correct_answer_str.split(",")
+ splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair
+ parse_error_str = ",".join(splitted)
+ score = dataset.score_answer(answer=parse_error_str, answer_str=correct_answer_str)
+ assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}"
+
+ # 5) Correct number of pairs, but duplicate alphabets => 0.3
+ # This makes the dictionary have fewer unique keys than expected
+ splitted = correct_answer_str.split(",")
+ if len(splitted) > 1:
+ splitted[0] = splitted[1] # Duplicate the second pair in the first position
+ duplicates_str = ",".join(splitted)
+ score = dataset.score_answer(answer=duplicates_str, answer_str=correct_answer_str)
+ assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}"
+
+ # 6) Partial correctness => some correct, some incorrect
+ splitted = correct_answer_str.split(",")
+ correct_mapping = {}
+ for pair in splitted:
+ alpha, num_str = pair.split("=")
+ correct_mapping[alpha] = int(num_str)
+
+ # Make exactly half of them correct, half incorrect
+ total = len(correct_mapping)
+ half = total // 2
+ new_pairs = []
+ i = 0
+ for alpha, num in correct_mapping.items():
+ if i < half:
+ new_pairs.append(f"{alpha}={num}") # keep correct
+ else:
+ new_pairs.append(f"{alpha}={(num+1) % 10}") # make incorrect
+ i += 1
+
+ partial_answer_str = ",".join(new_pairs)
+ score = dataset.score_answer(answer=partial_answer_str, answer_str=correct_answer_str)
+
+ # The formula is (num_correct / total) * 0.7 + 0.3
+ expected_score = (half / total) * 0.7 + 0.3
+ assert abs(score - expected_score) < 1e-9, f"Partial correctness: expected {expected_score}, got {score}"