reasoning-gym/tests/test_cryptarithm.py
Andreas Köpf d2c895f1d3
Refactor Curriculum Attributes (#335)
* remove min_value from AttributeDefinition
* remove type from AttributeDefinition
* Add CurriculumContext
* add ensure_interval option for RangeAttributes
* docs: Add legend explaining curriculum indicators in dataset gallery
* update GALLERY.md
2025-03-16 15:40:28 +01:00

211 lines
8.2 KiB
Python

import pytest
from reasoning_gym import create_dataset
from reasoning_gym.algorithmic.cryptarithm import CryptarithmConfig, CryptarithmCurriculum, CryptarithmDataset
def test_cryptarithm_generation():
dataset = create_dataset("cryptarithm", seed=42, size=10)
assert isinstance(dataset, CryptarithmDataset)
unique_number = set()
for item in dataset:
# Check required keys exist
assert "question" in item
assert "answer" in item
assert "metadata" in item
# Validate question format
question = item["question"]
assert "Solve this cryptarithm:" in question
assert "Each letter stands for a unique digit (0-9)" in question
# Validate metadata structure
metadata = item["metadata"]
assert "letters" in metadata
assert "letter_to_digit" in metadata
assert "words_letters" in metadata
assert "result_letters" in metadata
assert "word_values" in metadata
assert "sum_number" in metadata
# Validate letter to digit mapping
letter_to_digit = metadata["letter_to_digit"]
used_digits = set(letter_to_digit.values())
assert len(used_digits) == len(letter_to_digit), "Each letter should map to a unique digit"
assert all(0 <= digit <= 9 for digit in used_digits), "All digits should be between 0 and 9"
# Validate the arithmetic
word_values = metadata["word_values"]
result_value = metadata["sum_number"]
assert sum(word_values) == result_value, "Sum of word values should equal result value"
unique_number.add(result_value)
assert len(unique_number) == len(dataset)
def test_cryptarithm_config():
# Test invalid configs raise assertions
with pytest.raises(AssertionError):
dataset = create_dataset("cryptarithm", min_words=1) # min_words must be >= 2
with pytest.raises(AssertionError):
dataset = create_dataset("cryptarithm", min_words=4, max_words=3) # min must be <= max
with pytest.raises(AssertionError):
dataset = create_dataset("cryptarithm", size=0) # size must be positive
def test_leading_zero_constraint():
# Test with leading zeros not allowed
dataset = create_dataset("cryptarithm", seed=42, size=5, allow_leading_zero=False, max_words=10, min_words=5)
for item in dataset:
# print(item['question'])
metadata = item["metadata"]
letter_to_digit = metadata["letter_to_digit"]
words_letters = metadata["words_letters"]
result_letters = metadata["result_letters"]
# Check leading letters of all words and result
leading_letters = [word[0] for word in words_letters] + [result_letters[0]]
for letter in leading_letters:
assert letter_to_digit[letter] != 0, "Leading letters cannot be zero when allow_leading_zero=False"
def test_deterministic_generation():
dataset1 = create_dataset("cryptarithm", seed=42, size=5)
dataset2 = create_dataset("cryptarithm", seed=42, size=5)
for i in range(5):
assert dataset1[i]["question"] == dataset2[i]["question"]
assert dataset1[i]["answer"] == dataset2[i]["answer"]
assert dataset1[i]["metadata"] == dataset2[i]["metadata"]
def test_word_length_constraints():
dataset = create_dataset("cryptarithm", seed=42, size=10)
for item in dataset:
metadata = item["metadata"]
words_letters = metadata["words_letters"]
# Check each word is between 3-5 letters as specified in the code
for word in words_letters:
assert 3 <= len(word) <= 5, "Each word should be between 3 and 5 letters long"
def test_max_letters_constraint():
dataset = create_dataset("cryptarithm", seed=42, size=10)
for item in dataset:
metadata = item["metadata"]
letter_to_digit = metadata["letter_to_digit"]
# Check total unique letters doesn't exceed 10 (digits 0-9)
assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10"
def test_cryptarithm_score_answer():
"""Test the CryptarithmDataset.score_answer method for various correctness levels."""
dataset = create_dataset("cryptarithm", seed=42, size=1)
puzzle = dataset[0]
correct_answer_str = puzzle["answer"] # e.g. "A=1,B=7,..."
# 1) Missing '<answer>' => score should be 0.0
# score = dataset.score_answer(answer=None, answer_str=correct_answer_str)
# assert score == 0.0, f"Expected 0.0 when missing '<answer>' prefix, got {score}"
# 2) Correct mapping => expecting 1.0
score = dataset.score_answer(answer=correct_answer_str, entry=puzzle)
assert score == 1.0, f"Expected 1.0 for perfectly correct answer, got {score}"
# 3) Mismatch number of pairs => score should be 0.1
# For instance, drop the last pair
splitted = correct_answer_str.split(",")
mismatch_str = ",".join(splitted[:-1])
score = dataset.score_answer(answer=mismatch_str, entry=puzzle)
assert score == 0.1, f"Expected 0.1 when #pairs does not match, got {score}"
# 4) Parse error => 0.15 (e.g. remove '=' from the first pair)
splitted = correct_answer_str.split(",")
splitted[0] = splitted[0].replace("=", "") # remove '=' in the first pair
parse_error_str = ",".join(splitted)
score = dataset.score_answer(answer=parse_error_str, entry=puzzle)
assert score == 0.15, f"Expected 0.15 when parsing fails on at least one pair, got {score}"
# 5) Correct number of pairs, but duplicate alphabets => 0.3
# This makes the dictionary have fewer unique keys than expected
splitted = correct_answer_str.split(",")
if len(splitted) > 1:
splitted[0] = splitted[1] # Duplicate the second pair in the first position
duplicates_str = ",".join(splitted)
score = dataset.score_answer(answer=duplicates_str, entry=puzzle)
assert score == 0.3, f"Expected 0.3 if the final dict has fewer unique alphabets, got {score}"
# 6) Partial correctness => some correct, some incorrect
splitted = correct_answer_str.split(",")
correct_mapping = {}
for pair in splitted:
alpha, num_str = pair.split("=")
correct_mapping[alpha] = int(num_str)
# Make exactly half of them correct, half incorrect
total = len(correct_mapping)
half = total // 2
new_pairs = []
i = 0
for alpha, num in correct_mapping.items():
if i < half:
new_pairs.append(f"{alpha}={num}") # keep correct
else:
new_pairs.append(f"{alpha}={(num+1) % 10}") # make incorrect
i += 1
partial_answer_str = ",".join(new_pairs)
score = dataset.score_answer(answer=partial_answer_str, entry=puzzle)
# The formula is (num_correct / total) * 0.7 + 0.3
expected_score = (half / total) * 0.7 + 0.3
assert abs(score - expected_score) < 1e-9, f"Partial correctness: expected {expected_score}, got {score}"
def test_cryptarithm_curriculum():
"""Test curriculum for cryptarithm dataset"""
curriculum = CryptarithmCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: CryptarithmCurriculum = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_words == 2
assert base_cfg.max_words == 5
# Test and validate increase in level
curriculum.increment_attr_level("words")
increased_cfg: CryptarithmCurriculum = curriculum.generate_configuration(base_value)
assert increased_cfg.min_words == 2
assert increased_cfg.max_words == 10
# Test and validate decrease in level
curriculum.decrement_attr_level("words")
decreased_cfg: CryptarithmCurriculum = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_words == 2
assert decreased_cfg.max_words == 5
# Test upper bound boundary conditions
for _ in range(10):
curriculum.increment_attr_level("words")
upper_bound_cfg: CryptarithmCurriculum = curriculum.generate_configuration(base_value)
assert upper_bound_cfg.min_words == 2
assert upper_bound_cfg.max_words == 50
# Test lower bound boundary conditions
for _ in range(10):
curriculum.decrement_attr_level("words")
lower_bound_cfg: CryptarithmCurriculum = curriculum.generate_configuration(base_value)
assert lower_bound_cfg.min_words == 2
assert lower_bound_cfg.max_words == 5