mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
313 lines
12 KiB
Python
313 lines
12 KiB
Python
import pytest
|
|
|
|
from reasoning_gym import create_dataset
|
|
from reasoning_gym.algorithmic.cryptarithm import (
|
|
CryptarithmConfig,
|
|
CryptarithmCurriculum,
|
|
CryptarithmDataset,
|
|
verify_cryptarithm_solution,
|
|
)
|
|
|
|
|
|
def test_cryptarithm_generation():
|
|
dataset = create_dataset("cryptarithm", seed=42, size=10)
|
|
assert isinstance(dataset, CryptarithmDataset)
|
|
unique_number = set()
|
|
for item in dataset:
|
|
# Check required keys exist
|
|
assert "question" in item
|
|
assert "answer" in item
|
|
assert "metadata" in item
|
|
|
|
# Validate question format
|
|
question = item["question"]
|
|
assert "Solve this cryptarithm:" in question
|
|
assert "Each letter stands for a unique digit (0-9)" in question
|
|
|
|
# Validate metadata structure
|
|
metadata = item["metadata"]
|
|
assert "letters" in metadata
|
|
assert "letter_to_digit" in metadata
|
|
assert "words_letters" in metadata
|
|
assert "result_letters" in metadata
|
|
assert "word_values" in metadata
|
|
assert "sum_number" in metadata
|
|
|
|
# Validate letter to digit mapping
|
|
letter_to_digit = metadata["letter_to_digit"]
|
|
used_digits = set(letter_to_digit.values())
|
|
assert len(used_digits) == len(letter_to_digit), "Each letter should map to a unique digit"
|
|
assert all(0 <= digit <= 9 for digit in used_digits), "All digits should be between 0 and 9"
|
|
|
|
# Validate the arithmetic
|
|
word_values = metadata["word_values"]
|
|
result_value = metadata["sum_number"]
|
|
assert sum(word_values) == result_value, "Sum of word values should equal result value"
|
|
unique_number.add(result_value)
|
|
|
|
assert len(unique_number) == len(dataset)
|
|
|
|
|
|
def test_cryptarithm_config():
|
|
# Test invalid configs raise assertions
|
|
with pytest.raises(AssertionError):
|
|
dataset = create_dataset("cryptarithm", min_words=1) # min_words must be >= 2
|
|
|
|
with pytest.raises(AssertionError):
|
|
dataset = create_dataset("cryptarithm", min_words=4, max_words=3) # min must be <= max
|
|
|
|
with pytest.raises(AssertionError):
|
|
dataset = create_dataset("cryptarithm", size=0) # size must be positive
|
|
|
|
|
|
def test_leading_zero_constraint():
|
|
# Test with leading zeros not allowed
|
|
dataset = create_dataset("cryptarithm", seed=42, size=5, allow_leading_zero=False, max_words=10, min_words=5)
|
|
|
|
for item in dataset:
|
|
# print(item['question'])
|
|
metadata = item["metadata"]
|
|
letter_to_digit = metadata["letter_to_digit"]
|
|
words_letters = metadata["words_letters"]
|
|
result_letters = metadata["result_letters"]
|
|
|
|
# Check leading letters of all words and result
|
|
leading_letters = [word[0] for word in words_letters] + [result_letters[0]]
|
|
for letter in leading_letters:
|
|
assert letter_to_digit[letter] != 0, "Leading letters cannot be zero when allow_leading_zero=False"
|
|
|
|
|
|
def test_deterministic_generation():
|
|
dataset1 = create_dataset("cryptarithm", seed=42, size=5)
|
|
dataset2 = create_dataset("cryptarithm", seed=42, size=5)
|
|
|
|
for i in range(5):
|
|
assert dataset1[i]["question"] == dataset2[i]["question"]
|
|
assert dataset1[i]["answer"] == dataset2[i]["answer"]
|
|
assert dataset1[i]["metadata"] == dataset2[i]["metadata"]
|
|
|
|
|
|
def test_word_length_constraints():
|
|
dataset = create_dataset("cryptarithm", seed=42, size=10)
|
|
|
|
for item in dataset:
|
|
metadata = item["metadata"]
|
|
words_letters = metadata["words_letters"]
|
|
|
|
# Check each word is between 3-5 letters as specified in the code
|
|
for word in words_letters:
|
|
assert 3 <= len(word) <= 5, "Each word should be between 3 and 5 letters long"
|
|
|
|
|
|
def test_max_letters_constraint():
|
|
dataset = create_dataset("cryptarithm", seed=42, size=10)
|
|
|
|
for item in dataset:
|
|
metadata = item["metadata"]
|
|
letter_to_digit = metadata["letter_to_digit"]
|
|
|
|
# Check total unique letters doesn't exceed 10 (digits 0-9)
|
|
assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10"
|
|
|
|
|
|
def test_cryptarithm_score_answer():
|
|
"""Test the CryptarithmDataset.score_answer method for various correctness levels."""
|
|
dataset = create_dataset("cryptarithm", seed=42, size=1)
|
|
puzzle = dataset[0]
|
|
correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..."
|
|
|
|
# 1) Correct mapping => expecting 1.0
|
|
score = dataset.score_answer(answer=correct_answer_str, entry=puzzle)
|
|
assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}"
|
|
|
|
# 2) Correct mapping in different order => should still be 1.0
|
|
correct_mapping = {}
|
|
for pair in correct_answer_str.split(","):
|
|
alpha, num_str = pair.split("=")
|
|
correct_mapping[alpha] = int(num_str)
|
|
reversed_answer = ",".join(f"{letter}={correct_mapping[letter]}" for letter in reversed(sorted(correct_mapping.keys())))
|
|
score = dataset.score_answer(answer=reversed_answer, entry=puzzle)
|
|
assert score == 1.0, f"Expected 1.0 for correct answer in different order, got {score}"
|
|
|
|
# 3) Mismatch number of pairs => score should be 0.0 (parse succeeds but validation fails)
|
|
# For instance, drop the last pair
|
|
splitted = correct_answer_str.split(",")
|
|
mismatch_str = ",".join(splitted[:-1])
|
|
score = dataset.score_answer(answer=mismatch_str, entry=puzzle)
|
|
assert score == 0.01, f"Expected 0.01 when #pairs does not match (missing letter), got {score}"
|
|
|
|
# 4) Parse error => 0.0 (e.g. remove '=' from the first pair)
|
|
splitted = correct_answer_str.split(",")
|
|
splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair
|
|
parse_error_str = ",".join(splitted)
|
|
score = dataset.score_answer(answer=parse_error_str, entry=puzzle)
|
|
assert score == 0.0, f"Expected 0.0 when parsing fails on at least one pair, got {score}"
|
|
|
|
# 5) Correct number of pairs, but duplicate alphabets => 0.01 (parseable but invalid)
|
|
# This makes the dictionary have fewer unique keys than expected
|
|
splitted = correct_answer_str.split(",")
|
|
if len(splitted) > 1:
|
|
splitted[0] = splitted[1] # Duplicate the second pair in the first position
|
|
duplicates_str = ",".join(splitted)
|
|
score = dataset.score_answer(answer=duplicates_str, entry=puzzle)
|
|
assert score == 0.01, f"Expected 0.01 if the final dict has fewer unique alphabets, got {score}"
|
|
|
|
# 6) Wrong arithmetic - swap two digits to break the equation
|
|
correct_mapping = {}
|
|
for pair in correct_answer_str.split(","):
|
|
alpha, num_str = pair.split("=")
|
|
correct_mapping[alpha] = int(num_str)
|
|
|
|
# Swap two digit assignments to break arithmetic
|
|
letters = list(correct_mapping.keys())
|
|
if len(letters) >= 2:
|
|
wrong_mapping = correct_mapping.copy()
|
|
wrong_mapping[letters[0]], wrong_mapping[letters[1]] = (
|
|
wrong_mapping[letters[1]],
|
|
wrong_mapping[letters[0]],
|
|
)
|
|
|
|
wrong_answer_str = ",".join(f"{l}={wrong_mapping[l]}" for l in sorted(letters))
|
|
score = dataset.score_answer(answer=wrong_answer_str, entry=puzzle)
|
|
assert score == 0.01, f"Expected 0.01 for invalid arithmetic, got {score}"
|
|
|
|
# 7) None or non-string answer => 0.0
|
|
score = dataset.score_answer(answer=None, entry=puzzle)
|
|
assert score == 0.0, f"Expected 0.0 for None answer, got {score}"
|
|
|
|
|
|
def test_cryptarithm_verify_solution():
|
|
"""Test the verify_cryptarithm_solution helper function."""
|
|
|
|
# Test case 1: Valid solution with simple arithmetic
|
|
mapping = {"A": 1, "B": 2}
|
|
words = ["A", "B"] # 1 + 2
|
|
result = "B" # 2 (wait, that's wrong - 1+2=3, not 2)
|
|
# Let me fix: 1 + 1 = 2
|
|
mapping = {"A": 1, "B": 2}
|
|
words = ["A", "A"] # 1 + 1
|
|
result = "B" # 2
|
|
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
|
assert is_valid, f"Valid solution marked invalid: {reason}"
|
|
|
|
# Test case 2: Valid solution with multi-digit numbers
|
|
mapping = {"A": 1, "B": 2, "C": 3, "D": 5}
|
|
words = ["AB", "CD"] # 12 + 35
|
|
result = "DC" # 53 (wait, 12+35=47, not 53)
|
|
# Fix: need 12 + 35 = 47
|
|
mapping = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 7}
|
|
words = ["AB", "CD"] # 12 + 34
|
|
result = "DE" # 47 (wait, 12+34=46, not 47)
|
|
# Let me be more careful: 12 + 35 = 47
|
|
mapping = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 5, "F": 7}
|
|
words = ["AB", "CE"] # 12 + 35
|
|
result = "DF" # 47
|
|
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
|
assert is_valid, f"Valid solution marked invalid: {reason}"
|
|
|
|
# Test case 3: Wrong arithmetic
|
|
mapping = {"A": 1, "B": 2, "C": 3}
|
|
words = ["AB"] # 12
|
|
result = "AC" # 13 (wrong!)
|
|
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
|
assert not is_valid, "Invalid arithmetic not detected"
|
|
assert "Arithmetic equation not satisfied" in reason
|
|
|
|
# Test case 4: Leading zero violation
|
|
mapping = {"A": 0, "B": 1}
|
|
words = ["AB"] # 01
|
|
result = "AB" # 01
|
|
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, False)
|
|
assert not is_valid, "Leading zero violation not detected"
|
|
assert "cannot map to 0" in reason
|
|
|
|
# Test case 5: Leading zero allowed
|
|
mapping = {"A": 0, "B": 1}
|
|
words = ["AB"] # 01
|
|
result = "AB" # 01
|
|
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
|
assert is_valid, f"Leading zero incorrectly rejected when allowed: {reason}"
|
|
|
|
# Test case 6: Duplicate digit assignments
|
|
mapping = {"A": 1, "B": 1, "C": 2} # A and B both map to 1
|
|
words = ["AB"] # Both A and B are in puzzle
|
|
result = "C" # C is also in puzzle
|
|
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
|
assert not is_valid, "Duplicate digits not detected"
|
|
assert "Duplicate digit" in reason
|
|
|
|
# Test case 7: Missing letter mapping
|
|
mapping = {"A": 1} # Missing B
|
|
words = ["AB"]
|
|
result = "AB"
|
|
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
|
assert not is_valid, "Missing letter not detected"
|
|
assert "Missing mapping" in reason
|
|
|
|
# Test case 8: Extra letter in mapping
|
|
mapping = {"A": 1, "B": 2, "C": 3} # C is not in puzzle
|
|
words = ["AB"] # 12
|
|
result = "AB" # 12
|
|
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
|
assert not is_valid, "Extra letter not detected"
|
|
assert "Extra letter" in reason
|
|
|
|
# Test case 9: Invalid digit (out of range)
|
|
mapping = {"A": 10, "B": 2} # 10 is invalid
|
|
words = ["AB"]
|
|
result = "AB"
|
|
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, True)
|
|
assert not is_valid, "Invalid digit not detected"
|
|
assert "Invalid digit" in reason
|
|
|
|
# Test case 10: Real cryptarithm example
|
|
# SEND + MORE = MONEY
|
|
# S=9, E=5, N=6, D=7, M=1, O=0, R=8, Y=2
|
|
# 9567 + 1085 = 10652
|
|
mapping = {"S": 9, "E": 5, "N": 6, "D": 7, "M": 1, "O": 0, "R": 8, "Y": 2}
|
|
words = ["SEND", "MORE"]
|
|
result = "MONEY"
|
|
is_valid, reason = verify_cryptarithm_solution(mapping, words, result, False)
|
|
assert is_valid, f"Classic SEND+MORE=MONEY not validated: {reason}"
|
|
|
|
|
|
def test_cryptarithm_curriculum():
|
|
"""Test curriculum for cryptarithm dataset"""
|
|
|
|
curriculum = CryptarithmCurriculum()
|
|
base_value = {"size": 150, "seed": 1}
|
|
|
|
base_cfg: CryptarithmCurriculum = curriculum.generate_configuration(base_value)
|
|
|
|
assert base_cfg.seed == 1
|
|
assert base_cfg.size == 150
|
|
assert base_cfg.min_words == 2
|
|
assert base_cfg.max_words == 5
|
|
|
|
# Test and validate increase in level
|
|
curriculum.increment_attr_level("words")
|
|
increased_cfg: CryptarithmCurriculum = curriculum.generate_configuration(base_value)
|
|
|
|
assert increased_cfg.min_words == 2
|
|
assert increased_cfg.max_words == 10
|
|
|
|
# Test and validate decrease in level
|
|
curriculum.decrement_attr_level("words")
|
|
decreased_cfg: CryptarithmCurriculum = curriculum.generate_configuration(base_value)
|
|
|
|
assert decreased_cfg.min_words == 2
|
|
assert decreased_cfg.max_words == 5
|
|
|
|
# Test upper bound boundary conditions
|
|
for _ in range(10):
|
|
curriculum.increment_attr_level("words")
|
|
upper_bound_cfg: CryptarithmCurriculum = curriculum.generate_configuration(base_value)
|
|
assert upper_bound_cfg.min_words == 2
|
|
assert upper_bound_cfg.max_words == 50
|
|
|
|
# Test lower bound boundary conditions
|
|
for _ in range(10):
|
|
curriculum.decrement_attr_level("words")
|
|
lower_bound_cfg: CryptarithmCurriculum = curriculum.generate_configuration(base_value)
|
|
assert lower_bound_cfg.min_words == 2
|
|
assert lower_bound_cfg.max_words == 5
|