[fix] issue #516 of cryptarithm validation issue

This commit is contained in:
theblackcat102 2026-03-06 00:43:56 +08:00
parent 5dcca08309
commit 467eb4da82
2 changed files with 249 additions and 64 deletions

View file

@ -21,6 +21,89 @@ from ..factory import ProceduralDataset, register_dataset
DATASET_NAME = "cryptarithm"
def verify_cryptarithm_solution(
mapping: dict[str, int],
words_letters: list[str],
result_letters: str,
allow_leading_zero: bool,
) -> tuple[bool, str]:
"""Validate if a letter-to-digit mapping satisfies the cryptarithm puzzle constraints.
Args:
mapping: Dictionary mapping letters to digits (e.g., {'A': 1, 'B': 2})
words_letters: List of word strings using letters (e.g., ['ABC', 'DE'])
result_letters: Result string using letters (e.g., 'FGH')
allow_leading_zero: Whether leading zeros are allowed
Returns:
(is_valid, failure_reason) tuple:
- is_valid: True if mapping satisfies all constraints
- failure_reason: String describing why validation failed (empty if valid)
"""
# Collect all letters used in the puzzle
all_puzzle_letters = set()
for word in words_letters:
all_puzzle_letters.update(word)
all_puzzle_letters.update(result_letters)
# Check 1: All letters must be mapped
mapped_letters = set(mapping.keys())
if mapped_letters != all_puzzle_letters:
missing = all_puzzle_letters - mapped_letters
extra = mapped_letters - all_puzzle_letters
if missing:
return False, f"Missing mapping for letter(s): {sorted(missing)}"
if extra:
return False, f"Extra letter(s) in mapping: {sorted(extra)}"
# Check 2: All digits must be valid (0-9)
for letter, digit in mapping.items():
if not isinstance(digit, int) or digit < 0 or digit > 9:
return False, f"Invalid digit for letter {letter}: {digit}"
# Check 3: Uniqueness constraint - each digit can only be assigned to one letter
digit_values = list(mapping.values())
if len(set(digit_values)) != len(digit_values):
return False, "Duplicate digit assignments detected"
# Check 4: Leading zero constraint (if not allowed)
if not allow_leading_zero:
# Check leading letters of all words
for word in words_letters:
if word: # non-empty word
leading_letter = word[0]
if mapping.get(leading_letter) == 0:
return False, f"Leading letter '{leading_letter}' cannot map to 0"
# Check leading letter of result
if result_letters:
leading_letter = result_letters[0]
if mapping.get(leading_letter) == 0:
return False, f"Leading letter '{leading_letter}' in result cannot map to 0"
# Check 5: Arithmetic constraint - the sum must be correct
try:
# Convert each word from letters to numbers
word_numbers = []
for word in words_letters:
number_str = "".join(str(mapping[letter]) for letter in word)
word_numbers.append(int(number_str))
# Convert result from letters to number
result_number_str = "".join(str(mapping[letter]) for letter in result_letters)
result_number = int(result_number_str)
# Check if sum is correct
computed_sum = sum(word_numbers)
if computed_sum != result_number:
return False, f"Arithmetic equation not satisfied: {word_numbers} sums to {computed_sum}, expected {result_number}"
except (KeyError, ValueError) as e:
return False, f"Error applying mapping: {e}"
# All checks passed
return True, ""
@dataclass
class CryptarithmConfig:
"""Configuration for Cryptarithm dataset generation."""
@ -173,6 +256,7 @@ class CryptarithmDataset(ProceduralDataset):
)
# 8) Create a human-readable answer, e.g. "A=1,B=0,C=9,..."
# Note: This is ONE valid solution. Other solutions may exist and are equally valid.
sorted_letter_keys = sorted(letter_to_digit.keys())
answer_str = ",".join(f"{letter}={letter_to_digit[letter]}" for letter in sorted_letter_keys)
@ -183,6 +267,7 @@ class CryptarithmDataset(ProceduralDataset):
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"allow_leading_zero": self.config.allow_leading_zero,
"letters": list(letter_to_digit.keys()),
"word_values": words_numbers,
"sum_number": total_sum,
@ -199,50 +284,48 @@ class CryptarithmDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the Cryptarithm task.
The function awards 1.0 for a correct format and answers for all alphabet pairs.
Validates that the provided letter-to-digit mapping satisfies all constraints:
1. All letters are mapped to unique digits (0-9)
2. Leading letters don't map to 0 (if allow_leading_zero=False)
3. The arithmetic equation is satisfied
The function awards 1.0 for any valid solution (not just the stored solution).
Multiple valid solutions may exist and are equally acceptable.
Args:
answer (Optional[str]): The user's answer already parsed by `extract_answer`
answer_str (dict[str, Any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..."
answer (Optional[str]): The user's answer in format "A=1,B=2,C=3"
entry (dict[str, Any]): The dataset entry containing puzzle metadata
Returns:
float: The computed score between 0.0 and 1.0.
float: 1.0 for valid solution, 0.01 for parseable but invalid, 0.0 for parse error
"""
if not isinstance(answer, str):
return 0.0
correct_mapping = {}
correct_answer_str = entry["answer"] # e.g. "A=1,B=7,..."
for pair in correct_answer_str.split(","):
alphabet, number = pair.split("=")
correct_mapping[alphabet] = int(number)
# Parse the answer into a letter-to-digit mapping
try:
predicted_mapping = {}
for pair in answer.split(","):
letter, digit_str = pair.strip().split("=")
letter = letter.strip()
predicted_mapping[letter] = int(digit_str.strip())
except (ValueError, AttributeError):
return 0.0 # Parse error
# case 1 : pairs are in a list format and the number of pairs matched up
if len(answer.split(",")) != len(correct_mapping):
return 0.1
# Extract puzzle constraints from metadata
words_letters = entry["metadata"]["words_letters"]
result_letters = entry["metadata"]["result_letters"]
allow_leading_zero = entry["metadata"].get("allow_leading_zero", False)
predict_mapping = {}
for pair in answer.split(","):
try:
alphabet, number = pair.strip().split("=")
# as the unique alphabet grows we may want this to scale linearly with the number alphabet
predict_mapping[alphabet] = int(number)
except ValueError:
return 0.15
# case 2 : all alphabet has correct format ALPHABET=NUMBER format
if len(predict_mapping) != len(correct_mapping):
return 0.3
# Validate the solution using the helper function
is_valid, failure_reason = verify_cryptarithm_solution(
predicted_mapping, words_letters, result_letters, allow_leading_zero
)
# case 3 : partial score for the number of correct mapping answer
total_correct, total = 0, 0
for alphabet, number in correct_mapping.items():
total += 1
if alphabet in predict_mapping:
if predict_mapping[alphabet] == number:
total_correct += 1
# note: linear relationship is probably not good?
return (total_correct / total) * 0.7 + 0.3
if is_valid:
return 1.0
else:
return 0.01 # Parseable but doesn't satisfy constraints
class CryptarithmCurriculum(BaseCurriculum):