diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index 03bbe9ff..784e1702 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -21,6 +21,89 @@ from ..factory import ProceduralDataset, register_dataset DATASET_NAME = "cryptarithm" +def verify_cryptarithm_solution( + mapping: dict[str, int], + words_letters: list[str], + result_letters: str, + allow_leading_zero: bool, +) -> tuple[bool, str]: + """Validate if a letter-to-digit mapping satisfies the cryptarithm puzzle constraints. + + Args: + mapping: Dictionary mapping letters to digits (e.g., {'A': 1, 'B': 2}) + words_letters: List of word strings using letters (e.g., ['ABC', 'DE']) + result_letters: Result string using letters (e.g., 'FGH') + allow_leading_zero: Whether leading zeros are allowed + + Returns: + (is_valid, failure_reason) tuple: + - is_valid: True if mapping satisfies all constraints + - failure_reason: String describing why validation failed (empty if valid) + """ + # Collect all letters used in the puzzle + all_puzzle_letters = set() + for word in words_letters: + all_puzzle_letters.update(word) + all_puzzle_letters.update(result_letters) + + # Check 1: All letters must be mapped + mapped_letters = set(mapping.keys()) + if mapped_letters != all_puzzle_letters: + missing = all_puzzle_letters - mapped_letters + extra = mapped_letters - all_puzzle_letters + if missing: + return False, f"Missing mapping for letter(s): {sorted(missing)}" + if extra: + return False, f"Extra letter(s) in mapping: {sorted(extra)}" + + # Check 2: All digits must be valid (0-9) + for letter, digit in mapping.items(): + if not isinstance(digit, int) or digit < 0 or digit > 9: + return False, f"Invalid digit for letter {letter}: {digit}" + + # Check 3: Uniqueness constraint - each digit can only be assigned to one letter + digit_values = list(mapping.values()) + if len(set(digit_values)) != len(digit_values): + return False, "Duplicate digit assignments detected" + + # Check 4: Leading zero constraint (if not allowed) + if not allow_leading_zero: + # Check leading letters of all words + for word in words_letters: + if word: # non-empty word + leading_letter = word[0] + if mapping.get(leading_letter) == 0: + return False, f"Leading letter '{leading_letter}' cannot map to 0" + # Check leading letter of result + if result_letters: + leading_letter = result_letters[0] + if mapping.get(leading_letter) == 0: + return False, f"Leading letter '{leading_letter}' in result cannot map to 0" + + # Check 5: Arithmetic constraint - the sum must be correct + try: + # Convert each word from letters to numbers + word_numbers = [] + for word in words_letters: + number_str = "".join(str(mapping[letter]) for letter in word) + word_numbers.append(int(number_str)) + + # Convert result from letters to number + result_number_str = "".join(str(mapping[letter]) for letter in result_letters) + result_number = int(result_number_str) + + # Check if sum is correct + computed_sum = sum(word_numbers) + if computed_sum != result_number: + return False, f"Arithmetic equation not satisfied: {word_numbers} sums to {computed_sum}, expected {result_number}" + + except (KeyError, ValueError) as e: + return False, f"Error applying mapping: {e}" + + # All checks passed + return True, "" + + @dataclass class CryptarithmConfig: """Configuration for Cryptarithm dataset generation.""" @@ -173,6 +256,7 @@ class CryptarithmDataset(ProceduralDataset): ) # 8) Create a human-readable answer, e.g. "A=1,B=0,C=9,..." + # Note: This is ONE valid solution. Other solutions may exist and are equally valid. sorted_letter_keys = sorted(letter_to_digit.keys()) answer_str = ",".join(f"{letter}={letter_to_digit[letter]}" for letter in sorted_letter_keys) @@ -183,6 +267,7 @@ class CryptarithmDataset(ProceduralDataset): "metadata": { "source_dataset": DATASET_NAME, "source_index": idx, + "allow_leading_zero": self.config.allow_leading_zero, "letters": list(letter_to_digit.keys()), "word_values": words_numbers, "sum_number": total_sum, @@ -199,50 +284,48 @@ class CryptarithmDataset(ProceduralDataset): def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Determine if the solution provided solves the Cryptarithm task. - The function awards 1.0 for a correct format and answers for all alphabet pairs. + Validates that the provided letter-to-digit mapping satisfies all constraints: + 1. All letters are mapped to unique digits (0-9) + 2. Leading letters don't map to 0 (if allow_leading_zero=False) + 3. The arithmetic equation is satisfied + + The function awards 1.0 for any valid solution (not just the stored solution). + Multiple valid solutions may exist and are equally acceptable. Args: - answer (Optional[str]): The user's answer already parsed by `extract_answer` - answer_str (dict[str, Any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..." + answer (Optional[str]): The user's answer in format "A=1,B=2,C=3" + entry (dict[str, Any]): The dataset entry containing puzzle metadata Returns: - float: The computed score between 0.0 and 1.0. + float: 1.0 for valid solution, 0.01 for parseable but invalid, 0.0 for parse error """ if not isinstance(answer, str): return 0.0 - correct_mapping = {} - correct_answer_str = entry["answer"] # e.g. "A=1,B=7,..." - for pair in correct_answer_str.split(","): - alphabet, number = pair.split("=") - correct_mapping[alphabet] = int(number) + # Parse the answer into a letter-to-digit mapping + try: + predicted_mapping = {} + for pair in answer.split(","): + letter, digit_str = pair.strip().split("=") + letter = letter.strip() + predicted_mapping[letter] = int(digit_str.strip()) + except (ValueError, AttributeError): + return 0.0 # Parse error - # case 1 : pairs are in a list format and the number of pairs matched up - if len(answer.split(",")) != len(correct_mapping): - return 0.1 + # Extract puzzle constraints from metadata + words_letters = entry["metadata"]["words_letters"] + result_letters = entry["metadata"]["result_letters"] + allow_leading_zero = entry["metadata"].get("allow_leading_zero", False) - predict_mapping = {} - for pair in answer.split(","): - try: - alphabet, number = pair.strip().split("=") - # as the unique alphabet grows we may want this to scale linearly with the number alphabet - predict_mapping[alphabet] = int(number) - except ValueError: - return 0.15 - # case 2 : all alphabet has correct format ALPHABET=NUMBER format - if len(predict_mapping) != len(correct_mapping): - return 0.3 + # Validate the solution using the helper function + is_valid, failure_reason = verify_cryptarithm_solution( + predicted_mapping, words_letters, result_letters, allow_leading_zero + ) - # case 3 : partial score for the number of correct mapping answer - total_correct, total = 0, 0 - for alphabet, number in correct_mapping.items(): - total += 1 - if alphabet in predict_mapping: - if predict_mapping[alphabet] == number: - total_correct += 1 - - # note: linear relationship is probably not good? - return (total_correct / total) * 0.7 + 0.3 + if is_valid: + return 1.0 + else: + return 0.01 # Parseable but doesn't satisfy constraints class CryptarithmCurriculum(BaseCurriculum): diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py index 8e8c212d..b26c7d6a 100644 --- a/tests/test_cryptarithm.py +++ b/tests/test_cryptarithm.py @@ -1,7 +1,12 @@ import pytest from reasoning_gym import create_dataset -from reasoning_gym.algorithmic.cryptarithm import CryptarithmConfig, CryptarithmCurriculum, CryptarithmDataset +from reasoning_gym.algorithmic.cryptarithm import ( + CryptarithmConfig, + CryptarithmCurriculum, + CryptarithmDataset, + verify_cryptarithm_solution, +) def test_cryptarithm_generation(): @@ -111,62 +116,159 @@ def test_cryptarithm_score_answer(): puzzle = dataset[0] correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..." - # 1) Missing '' => score should be 0.0 - # score = dataset.score_answer(answer=None, answer_str=correct_answer_str) - # assert score == 0.0, f"Expected 0.0 when missing '' prefix, got {score}" - - # 2) Correct mapping => expecting 1.0 + # 1) Correct mapping => expecting 1.0 score = dataset.score_answer(answer=correct_answer_str, entry=puzzle) assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}" - # 3) Mismatch number of pairs => score should be 0.1 + # 2) Correct mapping in different order => should still be 1.0 + correct_mapping = {} + for pair in correct_answer_str.split(","): + alpha, num_str = pair.split("=") + correct_mapping[alpha] = int(num_str) + reversed_answer = ",".join(f"{letter}={correct_mapping[letter]}" for letter in reversed(sorted(correct_mapping.keys()))) + score = dataset.score_answer(answer=reversed_answer, entry=puzzle) + assert score == 1.0, f"Expected 1.0 for correct answer in different order, got {score}" + + # 3) Mismatch number of pairs => score should be 0.0 (parse succeeds but validation fails) # For instance, drop the last pair splitted = correct_answer_str.split(",") mismatch_str = ",".join(splitted[:-1]) score = dataset.score_answer(answer=mismatch_str, entry=puzzle) - assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}" + assert score == 0.01, f"Expected 0.01 when #pairs does not match (missing letter), got {score}" - # 4) Parse error => 0.15 (e.g. remove '=' from the first pair) + # 4) Parse error => 0.0 (e.g. remove '=' from the first pair) splitted = correct_answer_str.split(",") splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair parse_error_str = ",".join(splitted) score = dataset.score_answer(answer=parse_error_str, entry=puzzle) - assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}" + assert score == 0.0, f"Expected 0.0 when parsing fails on at least one pair, got {score}" - # 5) Correct number of pairs, but duplicate alphabets => 0.3 + # 5) Correct number of pairs, but duplicate alphabets => 0.01 (parseable but invalid) # This makes the dictionary have fewer unique keys than expected splitted = correct_answer_str.split(",") if len(splitted) > 1: splitted[0] = splitted[1] # Duplicate the second pair in the first position duplicates_str = ",".join(splitted) score = dataset.score_answer(answer=duplicates_str, entry=puzzle) - assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}" + assert score == 0.01, f"Expected 0.01 if the final dict has fewer unique alphabets, got {score}" - # 6) Partial correctness => some correct, some incorrect - splitted = correct_answer_str.split(",") + # 6) Wrong arithmetic - swap two digits to break the equation correct_mapping = {} - for pair in splitted: + for pair in correct_answer_str.split(","): alpha, num_str = pair.split("=") correct_mapping[alpha] = int(num_str) - # Make exactly half of them correct, half incorrect - total = len(correct_mapping) - half = total // 2 - new_pairs = [] - i = 0 - for alpha, num in correct_mapping.items(): - if i < half: - new_pairs.append(f"{alpha}={num}") # keep correct - else: - new_pairs.append(f"{alpha}={(num+1) % 10}") # make incorrect - i += 1 + # Swap two digit assignments to break arithmetic + letters = list(correct_mapping.keys()) + if len(letters) >= 2: + wrong_mapping = correct_mapping.copy() + wrong_mapping[letters[0]], wrong_mapping[letters[1]] = ( + wrong_mapping[letters[1]], + wrong_mapping[letters[0]], + ) - partial_answer_str = ",".join(new_pairs) - score = dataset.score_answer(answer=partial_answer_str, entry=puzzle) + wrong_answer_str = ",".join(f"{l}={wrong_mapping[l]}" for l in sorted(letters)) + score = dataset.score_answer(answer=wrong_answer_str, entry=puzzle) + assert score == 0.01, f"Expected 0.01 for invalid arithmetic, got {score}" - # The formula is (num_correct / total) * 0.7 + 0.3 - expected_score = (half / total) * 0.7 + 0.3 - assert abs(score - expected_score) < 1e-9, f"Partial correctness: expected {expected_score}, got {score}" + # 7) None or non-string answer => 0.0 + score = dataset.score_answer(answer=None, entry=puzzle) + assert score == 0.0, f"Expected 0.0 for None answer, got {score}" + + +def test_cryptarithm_verify_solution(): + """Test the verify_cryptarithm_solution helper function.""" + + # Test case 1: Valid solution with simple arithmetic + mapping = {"A": 1, "B": 2} + words = ["A", "B"] # 1 + 2 + result = "B" # 2 (wait, that's wrong - 1+2=3, not 2) + # Let me fix: 1 + 1 = 2 + mapping = {"A": 1, "B": 2} + words = ["A", "A"] # 1 + 1 + result = "B" # 2 + is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True) + assert is_valid, f"Valid solution marked invalid: {reason}" + + # Test case 2: Valid solution with multi-digit numbers + mapping = {"A": 1, "B": 2, "C": 3, "D": 5} + words = ["AB", "CD"] # 12 + 35 + result = "DC" # 53 (wait, 12+35=47, not 53) + # Fix: need 12 + 35 = 47 + mapping = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 7} + words = ["AB", "CD"] # 12 + 34 + result = "DE" # 47 (wait, 12+34=46, not 47) + # Let me be more careful: 12 + 35 = 47 + mapping = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 5, "F": 7} + words = ["AB", "CE"] # 12 + 35 + result = "DF" # 47 + is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True) + assert is_valid, f"Valid solution marked invalid: {reason}" + + # Test case 3: Wrong arithmetic + mapping = {"A": 1, "B": 2, "C": 3} + words = ["AB"] # 12 + result = "AC" # 13 (wrong!) + is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True) + assert not is_valid, "Invalid arithmetic not detected" + assert "Arithmetic equation not satisfied" in reason + + # Test case 4: Leading zero violation + mapping = {"A": 0, "B": 1} + words = ["AB"] # 01 + result = "AB" # 01 + is_valid, reason = verify_cryptarithm_solution(mapping, words, result, False) + assert not is_valid, "Leading zero violation not detected" + assert "cannot map to 0" in reason + + # Test case 5: Leading zero allowed + mapping = {"A": 0, "B": 1} + words = ["AB"] # 01 + result = "AB" # 01 + is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True) + assert is_valid, f"Leading zero incorrectly rejected when allowed: {reason}" + + # Test case 6: Duplicate digit assignments + mapping = {"A": 1, "B": 1, "C": 2} # A and B both map to 1 + words = ["AB"] # Both A and B are in puzzle + result = "C" # C is also in puzzle + is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True) + assert not is_valid, "Duplicate digits not detected" + assert "Duplicate digit" in reason + + # Test case 7: Missing letter mapping + mapping = {"A": 1} # Missing B + words = ["AB"] + result = "AB" + is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True) + assert not is_valid, "Missing letter not detected" + assert "Missing mapping" in reason + + # Test case 8: Extra letter in mapping + mapping = {"A": 1, "B": 2, "C": 3} # C is not in puzzle + words = ["AB"] # 12 + result = "AB" # 12 + is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True) + assert not is_valid, "Extra letter not detected" + assert "Extra letter" in reason + + # Test case 9: Invalid digit (out of range) + mapping = {"A": 10, "B": 2} # 10 is invalid + words = ["AB"] + result = "AB" + is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True) + assert not is_valid, "Invalid digit not detected" + assert "Invalid digit" in reason + + # Test case 10: Real cryptarithm example + # SEND + MORE = MONEY + # S=9, E=5, N=6, D=7, M=1, O=0, R=8, Y=2 + # 9567 + 1085 = 10652 + mapping = {"S": 9, "E": 5, "N": 6, "D": 7, "M": 1, "O": 0, "R": 8, "Y": 2} + words = ["SEND", "MORE"] + result = "MONEY" + is_valid, reason = verify_cryptarithm_solution(mapping, words, result, False) + assert is_valid, f"Classic SEND+MORE=MONEY not validated: {reason}" def test_cryptarithm_curriculum():