diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index f3aa4c09..bc930848 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -262,7 +262,7 @@ class CryptarithmDataset(ProceduralDataset): # case 3 : partial score for the number of correct mapping answer total_correct, total = 0, 0 - for alphabet, number in correct_mapping: + for alphabet, number in correct_mapping.items(): total += 1 if alphabet in predict_mapping: if predict_mapping[alphabet] == number: diff --git a/tests/test_cryptarithm.py b/tests/test_cryptarithm.py index 0ae3ea7f..e0ae9c85 100644 --- a/tests/test_cryptarithm.py +++ b/tests/test_cryptarithm.py @@ -103,3 +103,73 @@ def test_max_letters_constraint(): # Check total unique letters doesn't exceed 10 (digits 0-9) assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10" + +def test_cryptarithm_score_answer(): + """Test the CryptarithmDataset.score_answer method for various correctness levels.""" + dataset = create_dataset("cryptarithm", seed=42, size=1) + puzzle = dataset[0] + correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..." + + # 1) Missing 'ANSWER:' => score should be 0.0 + score = dataset.score_answer(answer=None, answer_str=correct_answer_str) + assert score == 0.0, f"Expected 0.0 when missing 'ANSWER:' prefix, got {score}" + + # 2) Correct mapping => expecting 1.0 + user_answer = f"ANSWER:\n{correct_answer_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}" + + # 3) Mismatch number of pairs => score should be 0.1 + # For instance, drop the last pair + splitted = correct_answer_str.split(',') + mismatch_str = ','.join(splitted[:-1]) + user_answer = f"ANSWER:\n{mismatch_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}" + + # 4) Parse error => 0.15 (e.g. remove '=' from the first pair) + splitted = correct_answer_str.split(',') + splitted[0] = splitted[0].replace('=', '') # remove '=' in the first pair + parse_error_str = ','.join(splitted) + user_answer = f"ANSWER:\n{parse_error_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}" + + # 5) Correct number of pairs, but duplicate alphabets => 0.3 + # This makes the dictionary have fewer unique keys than expected + splitted = correct_answer_str.split(',') + if len(splitted) > 1: + splitted[0] = splitted[1] # Duplicate the second pair in the first position + duplicates_str = ','.join(splitted) + user_answer = f"ANSWER:\n{duplicates_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}" + + # 6) Partial correctness => some correct, some incorrect + splitted = correct_answer_str.split(',') + correct_mapping = {} + for pair in splitted: + alpha, num_str = pair.split('=') + correct_mapping[alpha] = int(num_str) + + # Make exactly half of them correct, half incorrect + total = len(correct_mapping) + half = total // 2 + new_pairs = [] + i = 0 + for alpha, num in correct_mapping.items(): + if i < half: + new_pairs.append(f"{alpha}={num}") # keep correct + else: + new_pairs.append(f"{alpha}={(num+1) % 10}") # make incorrect + i += 1 + + partial_answer_str = ','.join(new_pairs) + user_answer = f"ANSWER:\n{partial_answer_str}" + score = dataset.score_answer(answer=user_answer, answer_str=correct_answer_str) + + # The formula is (num_correct / total) * 0.7 + 0.3 + expected_score = (half / total) * 0.7 + 0.3 + assert abs(score - expected_score) < 1e-9, ( + f"Partial correctness: expected {expected_score}, got {score}" + ) \ No newline at end of file