[fix] issue #516 of cryptarithm validation issue

This commit is contained in:
theblackcat102 2026-03-06 00:43:56 +08:00
parent 5dcca08309
commit 467eb4da82
2 changed files with 249 additions and 64 deletions

View file

@ -1,7 +1,12 @@
import pytest
from reasoning_gym import create_dataset
from reasoning_gym.algorithmic.cryptarithm import CryptarithmConfig, CryptarithmCurriculum, CryptarithmDataset
from reasoning_gym.algorithmic.cryptarithm import (
CryptarithmConfig,
CryptarithmCurriculum,
CryptarithmDataset,
verify_cryptarithm_solution,
)
def test_cryptarithm_generation():
@ -111,62 +116,159 @@ def test_cryptarithm_score_answer():
puzzle = dataset[0]
correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..."
# 1) Missing '<answer>' => score should be 0.0
# score = dataset.score_answer(answer=None, answer_str=correct_answer_str)
# assert score == 0.0, f"Expected 0.0 when missing '<answer>' prefix, got {score}"
# 2) Correct mapping => expecting 1.0
# 1) Correct mapping => expecting 1.0
score = dataset.score_answer(answer=correct_answer_str, entry=puzzle)
assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}"
# 3) Mismatch number of pairs => score should be 0.1
# 2) Correct mapping in different order => should still be 1.0
correct_mapping = {}
for pair in correct_answer_str.split(","):
alpha, num_str = pair.split("=")
correct_mapping[alpha] = int(num_str)
reversed_answer = ",".join(f"{letter}={correct_mapping[letter]}" for letter in reversed(sorted(correct_mapping.keys())))
score = dataset.score_answer(answer=reversed_answer, entry=puzzle)
assert score == 1.0, f"Expected 1.0 for correct answer in different order, got {score}"
# 3) Mismatch number of pairs => score should be 0.0 (parse succeeds but validation fails)
# For instance, drop the last pair
splitted = correct_answer_str.split(",")
mismatch_str = ",".join(splitted[:-1])
score = dataset.score_answer(answer=mismatch_str, entry=puzzle)
assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}"
assert score == 0.01, f"Expected 0.01 when #pairs does not match (missing letter), got {score}"
# 4) Parse error => 0.15 (e.g. remove '=' from the first pair)
# 4) Parse error => 0.0 (e.g. remove '=' from the first pair)
splitted = correct_answer_str.split(",")
splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair
parse_error_str = ",".join(splitted)
score = dataset.score_answer(answer=parse_error_str, entry=puzzle)
assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}"
assert score == 0.0, f"Expected 0.0 when parsing fails on at least one pair, got {score}"
# 5) Correct number of pairs, but duplicate alphabets => 0.3
# 5) Correct number of pairs, but duplicate alphabets => 0.01 (parseable but invalid)
# This makes the dictionary have fewer unique keys than expected
splitted = correct_answer_str.split(",")
if len(splitted) > 1:
splitted[0] = splitted[1] # Duplicate the second pair in the first position
duplicates_str = ",".join(splitted)
score = dataset.score_answer(answer=duplicates_str, entry=puzzle)
assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}"
assert score == 0.01, f"Expected 0.01 if the final dict has fewer unique alphabets, got {score}"
# 6) Partial correctness => some correct, some incorrect
splitted = correct_answer_str.split(",")
# 6) Wrong arithmetic - swap two digits to break the equation
correct_mapping = {}
for pair in splitted:
for pair in correct_answer_str.split(","):
alpha, num_str = pair.split("=")
correct_mapping[alpha] = int(num_str)
# Make exactly half of them correct, half incorrect
total = len(correct_mapping)
half = total // 2
new_pairs = []
i = 0
for alpha, num in correct_mapping.items():
if i < half:
new_pairs.append(f"{alpha}={num}") # keep correct
else:
new_pairs.append(f"{alpha}={(num+1) % 10}") # make incorrect
i += 1
# Swap two digit assignments to break arithmetic
letters = list(correct_mapping.keys())
if len(letters) >= 2:
wrong_mapping = correct_mapping.copy()
wrong_mapping[letters[0]], wrong_mapping[letters[1]] = (
wrong_mapping[letters[1]],
wrong_mapping[letters[0]],
)
partial_answer_str = ",".join(new_pairs)
score = dataset.score_answer(answer=partial_answer_str, entry=puzzle)
wrong_answer_str = ",".join(f"{l}={wrong_mapping[l]}" for l in sorted(letters))
score = dataset.score_answer(answer=wrong_answer_str, entry=puzzle)
assert score == 0.01, f"Expected 0.01 for invalid arithmetic, got {score}"
# The formula is (num_correct / total) * 0.7 + 0.3
expected_score = (half / total) * 0.7 + 0.3
assert abs(score - expected_score) < 1e-9, f"Partial correctness: expected {expected_score}, got {score}"
# 7) None or non-string answer => 0.0
score = dataset.score_answer(answer=None, entry=puzzle)
assert score == 0.0, f"Expected 0.0 for None answer, got {score}"
def test_cryptarithm_verify_solution():
"""Test the verify_cryptarithm_solution helper function."""
# Test case 1: Valid solution with simple arithmetic
mapping = {"A": 1, "B": 2}
words = ["A", "B"] # 1 + 2
result = "B" # 2 (wait, that's wrong - 1+2=3, not 2)
# Let me fix: 1 + 1 = 2
mapping = {"A": 1, "B": 2}
words = ["A", "A"] # 1 + 1
result = "B" # 2
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
assert is_valid, f"Valid solution marked invalid: {reason}"
# Test case 2: Valid solution with multi-digit numbers
mapping = {"A": 1, "B": 2, "C": 3, "D": 5}
words = ["AB", "CD"] # 12 + 35
result = "DC" # 53 (wait, 12+35=47, not 53)
# Fix: need 12 + 35 = 47
mapping = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 7}
words = ["AB", "CD"] # 12 + 34
result = "DE" # 47 (wait, 12+34=46, not 47)
# Let me be more careful: 12 + 35 = 47
mapping = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 5, "F": 7}
words = ["AB", "CE"] # 12 + 35
result = "DF" # 47
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
assert is_valid, f"Valid solution marked invalid: {reason}"
# Test case 3: Wrong arithmetic
mapping = {"A": 1, "B": 2, "C": 3}
words = ["AB"] # 12
result = "AC" # 13 (wrong!)
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
assert not is_valid, "Invalid arithmetic not detected"
assert "Arithmetic equation not satisfied" in reason
# Test case 4: Leading zero violation
mapping = {"A": 0, "B": 1}
words = ["AB"] # 01
result = "AB" # 01
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, False)
assert not is_valid, "Leading zero violation not detected"
assert "cannot map to 0" in reason
# Test case 5: Leading zero allowed
mapping = {"A": 0, "B": 1}
words = ["AB"] # 01
result = "AB" # 01
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
assert is_valid, f"Leading zero incorrectly rejected when allowed: {reason}"
# Test case 6: Duplicate digit assignments
mapping = {"A": 1, "B": 1, "C": 2} # A and B both map to 1
words = ["AB"] # Both A and B are in puzzle
result = "C" # C is also in puzzle
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
assert not is_valid, "Duplicate digits not detected"
assert "Duplicate digit" in reason
# Test case 7: Missing letter mapping
mapping = {"A": 1} # Missing B
words = ["AB"]
result = "AB"
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
assert not is_valid, "Missing letter not detected"
assert "Missing mapping" in reason
# Test case 8: Extra letter in mapping
mapping = {"A": 1, "B": 2, "C": 3} # C is not in puzzle
words = ["AB"] # 12
result = "AB" # 12
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
assert not is_valid, "Extra letter not detected"
assert "Extra letter" in reason
# Test case 9: Invalid digit (out of range)
mapping = {"A": 10, "B": 2} # 10 is invalid
words = ["AB"]
result = "AB"
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
assert not is_valid, "Invalid digit not detected"
assert "Invalid digit" in reason
# Test case 10: Real cryptarithm example
# SEND + MORE = MONEY
# S=9, E=5, N=6, D=7, M=1, O=0, R=8, Y=2
# 9567 + 1085 = 10652
mapping = {"S": 9, "E": 5, "N": 6, "D": 7, "M": 1, "O": 0, "R": 8, "Y": 2}
words = ["SEND", "MORE"]
result = "MONEY"
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, False)
assert is_valid, f"Classic SEND+MORE=MONEY not validated: {reason}"
def test_cryptarithm_curriculum():