formatting

This commit is contained in:
Andreas Koepf 2025-02-16 16:30:28 +01:00
parent 0e4b6a9026
commit c832e2a438
2 changed files with 33 additions and 21 deletions

View file

@ -1,6 +1,8 @@
import pytest
from reasoning_gym import create_dataset
from reasoning_gym.algorithmic.cryptarithm import CryptarithmDataset, CryptarithmConfig
from reasoning_gym.algorithmic.cryptarithm import CryptarithmConfig, CryptarithmDataset
def test_cryptarithm_generation():
dataset = create_dataset("cryptarithm", seed=42, size=10)
@ -16,7 +18,7 @@ def test_cryptarithm_generation():
question = item["question"]
assert "Solve this cryptarithm:" in question
assert "Each letter stands for a unique digit (0-9)" in question
# Validate metadata structure
metadata = item["metadata"]
assert "letters" in metadata
@ -40,6 +42,7 @@ def test_cryptarithm_generation():
assert len(unique_number) == len(dataset)
def test_cryptarithm_config():
# Test invalid configs raise assertions
with pytest.raises(AssertionError):
@ -51,22 +54,24 @@ def test_cryptarithm_config():
with pytest.raises(AssertionError):
dataset = create_dataset("cryptarithm", size=0) # size must be positive
def test_leading_zero_constraint():
# Test with leading zeros not allowed
dataset = create_dataset("cryptarithm", seed=42, size=5, allow_leading_zero=False, max_words=10, min_words=5)
for item in dataset:
# print(item['question'])
metadata = item["metadata"]
letter_to_digit = metadata["letter_to_digit"]
words_letters = metadata["words_letters"]
result_letters = metadata["result_letters"]
# Check leading letters of all words and result
leading_letters = [word[0] for word in words_letters] + [result_letters[0]]
for letter in leading_letters:
assert letter_to_digit[letter] != 0, "Leading letters cannot be zero when allow_leading_zero=False"
def test_deterministic_generation():
dataset1 = create_dataset("cryptarithm", seed=42, size=5)
dataset2 = create_dataset("cryptarithm", seed=42, size=5)
@ -76,23 +81,25 @@ def test_deterministic_generation():
assert dataset1[i]["answer"] == dataset2[i]["answer"]
assert dataset1[i]["metadata"] == dataset2[i]["metadata"]
def test_word_length_constraints():
dataset = create_dataset("cryptarithm", seed=42, size=10)
for item in dataset:
metadata = item["metadata"]
words_letters = metadata["words_letters"]
# Check each word is between 3-5 letters as specified in the code
for word in words_letters:
assert 3 <= len(word) <= 5, "Each word should be between 3 and 5 letters long"
def test_max_letters_constraint():
dataset = create_dataset("cryptarithm", seed=42, size=10)
for item in dataset:
metadata = item["metadata"]
letter_to_digit = metadata["letter_to_digit"]
# Check total unique letters doesn't exceed 10 (digits 0-9)
assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10"
assert len(letter_to_digit) <= 10, "Total unique letters should not exceed 10"