mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
[fix] issue #516 of cryptarithm validation issue
This commit is contained in:
parent
5dcca08309
commit
467eb4da82
2 changed files with 249 additions and 64 deletions
|
|
@ -21,6 +21,89 @@ from ..factory import ProceduralDataset, register_dataset
|
|||
DATASET_NAME = "cryptarithm"
|
||||
|
||||
|
||||
def verify_cryptarithm_solution(
|
||||
mapping: dict[str, int],
|
||||
words_letters: list[str],
|
||||
result_letters: str,
|
||||
allow_leading_zero: bool,
|
||||
) -> tuple[bool, str]:
|
||||
"""Validate if a letter-to-digit mapping satisfies the cryptarithm puzzle constraints.
|
||||
|
||||
Args:
|
||||
mapping: Dictionary mapping letters to digits (e.g., {'A': 1, 'B': 2})
|
||||
words_letters: List of word strings using letters (e.g., ['ABC', 'DE'])
|
||||
result_letters: Result string using letters (e.g., 'FGH')
|
||||
allow_leading_zero: Whether leading zeros are allowed
|
||||
|
||||
Returns:
|
||||
(is_valid, failure_reason) tuple:
|
||||
- is_valid: True if mapping satisfies all constraints
|
||||
- failure_reason: String describing why validation failed (empty if valid)
|
||||
"""
|
||||
# Collect all letters used in the puzzle
|
||||
all_puzzle_letters = set()
|
||||
for word in words_letters:
|
||||
all_puzzle_letters.update(word)
|
||||
all_puzzle_letters.update(result_letters)
|
||||
|
||||
# Check 1: All letters must be mapped
|
||||
mapped_letters = set(mapping.keys())
|
||||
if mapped_letters != all_puzzle_letters:
|
||||
missing = all_puzzle_letters - mapped_letters
|
||||
extra = mapped_letters - all_puzzle_letters
|
||||
if missing:
|
||||
return False, f"Missing mapping for letter(s): {sorted(missing)}"
|
||||
if extra:
|
||||
return False, f"Extra letter(s) in mapping: {sorted(extra)}"
|
||||
|
||||
# Check 2: All digits must be valid (0-9)
|
||||
for letter, digit in mapping.items():
|
||||
if not isinstance(digit, int) or digit < 0 or digit > 9:
|
||||
return False, f"Invalid digit for letter {letter}: {digit}"
|
||||
|
||||
# Check 3: Uniqueness constraint - each digit can only be assigned to one letter
|
||||
digit_values = list(mapping.values())
|
||||
if len(set(digit_values)) != len(digit_values):
|
||||
return False, "Duplicate digit assignments detected"
|
||||
|
||||
# Check 4: Leading zero constraint (if not allowed)
|
||||
if not allow_leading_zero:
|
||||
# Check leading letters of all words
|
||||
for word in words_letters:
|
||||
if word: # non-empty word
|
||||
leading_letter = word[0]
|
||||
if mapping.get(leading_letter) == 0:
|
||||
return False, f"Leading letter '{leading_letter}' cannot map to 0"
|
||||
# Check leading letter of result
|
||||
if result_letters:
|
||||
leading_letter = result_letters[0]
|
||||
if mapping.get(leading_letter) == 0:
|
||||
return False, f"Leading letter '{leading_letter}' in result cannot map to 0"
|
||||
|
||||
# Check 5: Arithmetic constraint - the sum must be correct
|
||||
try:
|
||||
# Convert each word from letters to numbers
|
||||
word_numbers = []
|
||||
for word in words_letters:
|
||||
number_str = "".join(str(mapping[letter]) for letter in word)
|
||||
word_numbers.append(int(number_str))
|
||||
|
||||
# Convert result from letters to number
|
||||
result_number_str = "".join(str(mapping[letter]) for letter in result_letters)
|
||||
result_number = int(result_number_str)
|
||||
|
||||
# Check if sum is correct
|
||||
computed_sum = sum(word_numbers)
|
||||
if computed_sum != result_number:
|
||||
return False, f"Arithmetic equation not satisfied: {word_numbers} sums to {computed_sum}, expected {result_number}"
|
||||
|
||||
except (KeyError, ValueError) as e:
|
||||
return False, f"Error applying mapping: {e}"
|
||||
|
||||
# All checks passed
|
||||
return True, ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CryptarithmConfig:
|
||||
"""Configuration for Cryptarithm dataset generation."""
|
||||
|
|
@ -173,6 +256,7 @@ class CryptarithmDataset(ProceduralDataset):
|
|||
)
|
||||
|
||||
# 8) Create a human-readable answer, e.g. "A=1,B=0,C=9,..."
|
||||
# Note: This is ONE valid solution. Other solutions may exist and are equally valid.
|
||||
sorted_letter_keys = sorted(letter_to_digit.keys())
|
||||
answer_str = ",".join(f"{letter}={letter_to_digit[letter]}" for letter in sorted_letter_keys)
|
||||
|
||||
|
|
@ -183,6 +267,7 @@ class CryptarithmDataset(ProceduralDataset):
|
|||
"metadata": {
|
||||
"source_dataset": DATASET_NAME,
|
||||
"source_index": idx,
|
||||
"allow_leading_zero": self.config.allow_leading_zero,
|
||||
"letters": list(letter_to_digit.keys()),
|
||||
"word_values": words_numbers,
|
||||
"sum_number": total_sum,
|
||||
|
|
@ -199,50 +284,48 @@ class CryptarithmDataset(ProceduralDataset):
|
|||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
"""Determine if the solution provided solves the Cryptarithm task.
|
||||
|
||||
The function awards 1.0 for a correct format and answers for all alphabet pairs.
|
||||
Validates that the provided letter-to-digit mapping satisfies all constraints:
|
||||
1. All letters are mapped to unique digits (0-9)
|
||||
2. Leading letters don't map to 0 (if allow_leading_zero=False)
|
||||
3. The arithmetic equation is satisfied
|
||||
|
||||
The function awards 1.0 for any valid solution (not just the stored solution).
|
||||
Multiple valid solutions may exist and are equally acceptable.
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer already parsed by `extract_answer`
|
||||
answer_str (dict[str, Any]): The original dataset answer_str containing the correct answer. ie "A=1,B=3..."
|
||||
answer (Optional[str]): The user's answer in format "A=1,B=2,C=3"
|
||||
entry (dict[str, Any]): The dataset entry containing puzzle metadata
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
float: 1.0 for valid solution, 0.01 for parseable but invalid, 0.0 for parse error
|
||||
"""
|
||||
if not isinstance(answer, str):
|
||||
return 0.0
|
||||
|
||||
correct_mapping = {}
|
||||
correct_answer_str = entry["answer"] # e.g. "A=1,B=7,..."
|
||||
for pair in correct_answer_str.split(","):
|
||||
alphabet, number = pair.split("=")
|
||||
correct_mapping[alphabet] = int(number)
|
||||
# Parse the answer into a letter-to-digit mapping
|
||||
try:
|
||||
predicted_mapping = {}
|
||||
for pair in answer.split(","):
|
||||
letter, digit_str = pair.strip().split("=")
|
||||
letter = letter.strip()
|
||||
predicted_mapping[letter] = int(digit_str.strip())
|
||||
except (ValueError, AttributeError):
|
||||
return 0.0 # Parse error
|
||||
|
||||
# case 1 : pairs are in a list format and the number of pairs matched up
|
||||
if len(answer.split(",")) != len(correct_mapping):
|
||||
return 0.1
|
||||
# Extract puzzle constraints from metadata
|
||||
words_letters = entry["metadata"]["words_letters"]
|
||||
result_letters = entry["metadata"]["result_letters"]
|
||||
allow_leading_zero = entry["metadata"].get("allow_leading_zero", False)
|
||||
|
||||
predict_mapping = {}
|
||||
for pair in answer.split(","):
|
||||
try:
|
||||
alphabet, number = pair.strip().split("=")
|
||||
# as the unique alphabet grows we may want this to scale linearly with the number alphabet
|
||||
predict_mapping[alphabet] = int(number)
|
||||
except ValueError:
|
||||
return 0.15
|
||||
# case 2 : all alphabet has correct format ALPHABET=NUMBER format
|
||||
if len(predict_mapping) != len(correct_mapping):
|
||||
return 0.3
|
||||
# Validate the solution using the helper function
|
||||
is_valid, failure_reason = verify_cryptarithm_solution(
|
||||
predicted_mapping, words_letters, result_letters, allow_leading_zero
|
||||
)
|
||||
|
||||
# case 3 : partial score for the number of correct mapping answer
|
||||
total_correct, total = 0, 0
|
||||
for alphabet, number in correct_mapping.items():
|
||||
total += 1
|
||||
if alphabet in predict_mapping:
|
||||
if predict_mapping[alphabet] == number:
|
||||
total_correct += 1
|
||||
|
||||
# note: linear relationship is probably not good?
|
||||
return (total_correct / total) * 0.7 + 0.3
|
||||
if is_valid:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.01 # Parseable but doesn't satisfy constraints
|
||||
|
||||
|
||||
class CryptarithmCurriculum(BaseCurriculum):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym import create_dataset
|
||||
from reasoning_gym.algorithmic.cryptarithm import CryptarithmConfig, CryptarithmCurriculum, CryptarithmDataset
|
||||
from reasoning_gym.algorithmic.cryptarithm import (
|
||||
CryptarithmConfig,
|
||||
CryptarithmCurriculum,
|
||||
CryptarithmDataset,
|
||||
verify_cryptarithm_solution,
|
||||
)
|
||||
|
||||
|
||||
def test_cryptarithm_generation():
|
||||
|
|
@ -111,62 +116,159 @@ def test_cryptarithm_score_answer():
|
|||
puzzle = dataset[0]
|
||||
correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..."
|
||||
|
||||
# 1) Missing '<answer>' => score should be 0.0
|
||||
# score = dataset.score_answer(answer=None, answer_str=correct_answer_str)
|
||||
# assert score == 0.0, f"Expected 0.0 when missing '<answer>' prefix, got {score}"
|
||||
|
||||
# 2) Correct mapping => expecting 1.0
|
||||
# 1) Correct mapping => expecting 1.0
|
||||
score = dataset.score_answer(answer=correct_answer_str, entry=puzzle)
|
||||
assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}"
|
||||
|
||||
# 3) Mismatch number of pairs => score should be 0.1
|
||||
# 2) Correct mapping in different order => should still be 1.0
|
||||
correct_mapping = {}
|
||||
for pair in correct_answer_str.split(","):
|
||||
alpha, num_str = pair.split("=")
|
||||
correct_mapping[alpha] = int(num_str)
|
||||
reversed_answer = ",".join(f"{letter}={correct_mapping[letter]}" for letter in reversed(sorted(correct_mapping.keys())))
|
||||
score = dataset.score_answer(answer=reversed_answer, entry=puzzle)
|
||||
assert score == 1.0, f"Expected 1.0 for correct answer in different order, got {score}"
|
||||
|
||||
# 3) Mismatch number of pairs => score should be 0.0 (parse succeeds but validation fails)
|
||||
# For instance, drop the last pair
|
||||
splitted = correct_answer_str.split(",")
|
||||
mismatch_str = ",".join(splitted[:-1])
|
||||
score = dataset.score_answer(answer=mismatch_str, entry=puzzle)
|
||||
assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}"
|
||||
assert score == 0.01, f"Expected 0.01 when #pairs does not match (missing letter), got {score}"
|
||||
|
||||
# 4) Parse error => 0.15 (e.g. remove '=' from the first pair)
|
||||
# 4) Parse error => 0.0 (e.g. remove '=' from the first pair)
|
||||
splitted = correct_answer_str.split(",")
|
||||
splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair
|
||||
parse_error_str = ",".join(splitted)
|
||||
score = dataset.score_answer(answer=parse_error_str, entry=puzzle)
|
||||
assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}"
|
||||
assert score == 0.0, f"Expected 0.0 when parsing fails on at least one pair, got {score}"
|
||||
|
||||
# 5) Correct number of pairs, but duplicate alphabets => 0.3
|
||||
# 5) Correct number of pairs, but duplicate alphabets => 0.01 (parseable but invalid)
|
||||
# This makes the dictionary have fewer unique keys than expected
|
||||
splitted = correct_answer_str.split(",")
|
||||
if len(splitted) > 1:
|
||||
splitted[0] = splitted[1] # Duplicate the second pair in the first position
|
||||
duplicates_str = ",".join(splitted)
|
||||
score = dataset.score_answer(answer=duplicates_str, entry=puzzle)
|
||||
assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}"
|
||||
assert score == 0.01, f"Expected 0.01 if the final dict has fewer unique alphabets, got {score}"
|
||||
|
||||
# 6) Partial correctness => some correct, some incorrect
|
||||
splitted = correct_answer_str.split(",")
|
||||
# 6) Wrong arithmetic - swap two digits to break the equation
|
||||
correct_mapping = {}
|
||||
for pair in splitted:
|
||||
for pair in correct_answer_str.split(","):
|
||||
alpha, num_str = pair.split("=")
|
||||
correct_mapping[alpha] = int(num_str)
|
||||
|
||||
# Make exactly half of them correct, half incorrect
|
||||
total = len(correct_mapping)
|
||||
half = total // 2
|
||||
new_pairs = []
|
||||
i = 0
|
||||
for alpha, num in correct_mapping.items():
|
||||
if i < half:
|
||||
new_pairs.append(f"{alpha}={num}") # keep correct
|
||||
else:
|
||||
new_pairs.append(f"{alpha}={(num+1) % 10}") # make incorrect
|
||||
i += 1
|
||||
# Swap two digit assignments to break arithmetic
|
||||
letters = list(correct_mapping.keys())
|
||||
if len(letters) >= 2:
|
||||
wrong_mapping = correct_mapping.copy()
|
||||
wrong_mapping[letters[0]], wrong_mapping[letters[1]] = (
|
||||
wrong_mapping[letters[1]],
|
||||
wrong_mapping[letters[0]],
|
||||
)
|
||||
|
||||
partial_answer_str = ",".join(new_pairs)
|
||||
score = dataset.score_answer(answer=partial_answer_str, entry=puzzle)
|
||||
wrong_answer_str = ",".join(f"{l}={wrong_mapping[l]}" for l in sorted(letters))
|
||||
score = dataset.score_answer(answer=wrong_answer_str, entry=puzzle)
|
||||
assert score == 0.01, f"Expected 0.01 for invalid arithmetic, got {score}"
|
||||
|
||||
# The formula is (num_correct / total) * 0.7 + 0.3
|
||||
expected_score = (half / total) * 0.7 + 0.3
|
||||
assert abs(score - expected_score) < 1e-9, f"Partial correctness: expected {expected_score}, got {score}"
|
||||
# 7) None or non-string answer => 0.0
|
||||
score = dataset.score_answer(answer=None, entry=puzzle)
|
||||
assert score == 0.0, f"Expected 0.0 for None answer, got {score}"
|
||||
|
||||
|
||||
def test_cryptarithm_verify_solution():
|
||||
"""Test the verify_cryptarithm_solution helper function."""
|
||||
|
||||
# Test case 1: Valid solution with simple arithmetic
|
||||
mapping = {"A": 1, "B": 2}
|
||||
words = ["A", "B"] # 1 + 2
|
||||
result = "B" # 2 (wait, that's wrong - 1+2=3, not 2)
|
||||
# Let me fix: 1 + 1 = 2
|
||||
mapping = {"A": 1, "B": 2}
|
||||
words = ["A", "A"] # 1 + 1
|
||||
result = "B" # 2
|
||||
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
||||
assert is_valid, f"Valid solution marked invalid: {reason}"
|
||||
|
||||
# Test case 2: Valid solution with multi-digit numbers
|
||||
mapping = {"A": 1, "B": 2, "C": 3, "D": 5}
|
||||
words = ["AB", "CD"] # 12 + 35
|
||||
result = "DC" # 53 (wait, 12+35=47, not 53)
|
||||
# Fix: need 12 + 35 = 47
|
||||
mapping = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 7}
|
||||
words = ["AB", "CD"] # 12 + 34
|
||||
result = "DE" # 47 (wait, 12+34=46, not 47)
|
||||
# Let me be more careful: 12 + 35 = 47
|
||||
mapping = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 5, "F": 7}
|
||||
words = ["AB", "CE"] # 12 + 35
|
||||
result = "DF" # 47
|
||||
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
||||
assert is_valid, f"Valid solution marked invalid: {reason}"
|
||||
|
||||
# Test case 3: Wrong arithmetic
|
||||
mapping = {"A": 1, "B": 2, "C": 3}
|
||||
words = ["AB"] # 12
|
||||
result = "AC" # 13 (wrong!)
|
||||
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
||||
assert not is_valid, "Invalid arithmetic not detected"
|
||||
assert "Arithmetic equation not satisfied" in reason
|
||||
|
||||
# Test case 4: Leading zero violation
|
||||
mapping = {"A": 0, "B": 1}
|
||||
words = ["AB"] # 01
|
||||
result = "AB" # 01
|
||||
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, False)
|
||||
assert not is_valid, "Leading zero violation not detected"
|
||||
assert "cannot map to 0" in reason
|
||||
|
||||
# Test case 5: Leading zero allowed
|
||||
mapping = {"A": 0, "B": 1}
|
||||
words = ["AB"] # 01
|
||||
result = "AB" # 01
|
||||
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
||||
assert is_valid, f"Leading zero incorrectly rejected when allowed: {reason}"
|
||||
|
||||
# Test case 6: Duplicate digit assignments
|
||||
mapping = {"A": 1, "B": 1, "C": 2} # A and B both map to 1
|
||||
words = ["AB"] # Both A and B are in puzzle
|
||||
result = "C" # C is also in puzzle
|
||||
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
||||
assert not is_valid, "Duplicate digits not detected"
|
||||
assert "Duplicate digit" in reason
|
||||
|
||||
# Test case 7: Missing letter mapping
|
||||
mapping = {"A": 1} # Missing B
|
||||
words = ["AB"]
|
||||
result = "AB"
|
||||
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
||||
assert not is_valid, "Missing letter not detected"
|
||||
assert "Missing mapping" in reason
|
||||
|
||||
# Test case 8: Extra letter in mapping
|
||||
mapping = {"A": 1, "B": 2, "C": 3} # C is not in puzzle
|
||||
words = ["AB"] # 12
|
||||
result = "AB" # 12
|
||||
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
||||
assert not is_valid, "Extra letter not detected"
|
||||
assert "Extra letter" in reason
|
||||
|
||||
# Test case 9: Invalid digit (out of range)
|
||||
mapping = {"A": 10, "B": 2} # 10 is invalid
|
||||
words = ["AB"]
|
||||
result = "AB"
|
||||
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
||||
assert not is_valid, "Invalid digit not detected"
|
||||
assert "Invalid digit" in reason
|
||||
|
||||
# Test case 10: Real cryptarithm example
|
||||
# SEND + MORE = MONEY
|
||||
# S=9, E=5, N=6, D=7, M=1, O=0, R=8, Y=2
|
||||
# 9567 + 1085 = 10652
|
||||
mapping = {"S": 9, "E": 5, "N": 6, "D": 7, "M": 1, "O": 0, "R": 8, "Y": 2}
|
||||
words = ["SEND", "MORE"]
|
||||
result = "MONEY"
|
||||
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, False)
|
||||
assert is_valid, f"Classic SEND+MORE=MONEY not validated: {reason}"
|
||||
|
||||
|
||||
def test_cryptarithm_curriculum():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue