Refactor CaesarCipher

This commit is contained in:
EduardDurech 2025-02-09 05:15:12 +00:00
parent 23aa6ca7e7
commit 5279ccf7e1
8 changed files with 513 additions and 159 deletions

View file

@ -1,100 +1,272 @@
"""Tests for Caesar cipher task generation"""
"""Unit tests for the Caesar cipher exercise."""
import pytest
from reasoning_gym.curricula.algorithmic.caesar_cipher_curriculum import CaesarCipherCurriculum
from reasoning_gym.exercises.algorithmic.caesar_cipher import CaesarCipherExercise
import unittest
import random
from collections import defaultdict
from reasoning_gym.algorithmic.caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
class TestCaesarCipherParsing(unittest.TestCase):
"""Test parsing of Caesar cipher metadata"""
def setUp(self):
self.exercise = CaesarCipherExercise()
def test_caesar_cipher_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = CaesarCipherConfig(min_words=0)
config.validate()
def test_parse_expression_basic(self):
"""Test parsing of basic Caesar cipher metadata"""
test_metadata = {
"cipher_text": {
"encrypted_text": "KHOOR",
"clear_text": "HELLO",
"rotation": 3
}
}
parsed = self.exercise._parse_expression(test_metadata)
self.assertEqual(parsed["cipher_text"], "KHOOR")
self.assertEqual(parsed["clear_text"], "HELLO")
self.assertEqual(parsed["rotation"], 3)
with pytest.raises(AssertionError):
config = CaesarCipherConfig(min_words=10, max_words=5)
config.validate()
def test_parse_with_spaces(self):
"""Test parsing with spaces and punctuation"""
test_metadata = {
"cipher_text": {
"encrypted_text": "KHOOR ZRUOG!",
"clear_text": "HELLO WORLD!",
"rotation": 3
}
}
parsed = self.exercise._parse_expression(test_metadata)
self.assertEqual(parsed["cipher_text"], "KHOOR ZRUOG!")
self.assertEqual(parsed["clear_text"], "HELLO WORLD!")
self.assertEqual(parsed["rotation"], 3)
with pytest.raises(AssertionError):
config = CaesarCipherConfig(min_rotation=0)
config.validate()
def test_parse_mixed_case(self):
"""Test parsing with mixed case text"""
test_metadata = {
"cipher_text": {
"encrypted_text": "KhOoR",
"clear_text": "HeLlO",
"rotation": 3
}
}
parsed = self.exercise._parse_expression(test_metadata)
self.assertEqual(parsed["cipher_text"], "KhOoR")
self.assertEqual(parsed["clear_text"], "HeLlO")
self.assertEqual(parsed["rotation"], 3)
with pytest.raises(AssertionError):
config = CaesarCipherConfig(max_rotation=26)
config.validate()
class TestCaesarCipherEvaluation(unittest.TestCase):
"""Test evaluation of Caesar cipher problems"""
def setUp(self):
self.exercise = CaesarCipherExercise()
def test_caesar_cipher_deterministic():
"""Test that dataset generates same items with same seed"""
config = CaesarCipherConfig(seed=42, size=10)
dataset1 = CaesarCipherDataset(config)
dataset2 = CaesarCipherDataset(config)
def test_basic_decryption(self):
"""Test basic decryption cases"""
test_cases = [
("KHOOR", "HELLO", 3), # Basic uppercase
("khoor", "hello", 3), # Basic lowercase
("WORLD", "WORLD", 0), # No rotation
("ABCDE", "ZABCD", 1), # Wrap around
("hello", "hello", 26) # Full rotation
]
for cipher_text, clear_text, rotation in test_cases:
parsed = {
"cipher_text": cipher_text,
"clear_text": clear_text,
"rotation": rotation
}
result = self.exercise._evaluate_expression(parsed)
self.assertEqual(result, clear_text)
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]
def test_mixed_case_decryption(self):
"""Test decryption with mixed case"""
test_cases = [
("HeLlO", "HeLlO", 26), # Mixed case, full rotation
("WoRlD", "WoRlD", 0), # Mixed case, no rotation
("AbCdE", "ZaBcD", 1) # Mixed case, wrap around
]
for cipher_text, clear_text, rotation in test_cases:
parsed = {
"cipher_text": cipher_text,
"clear_text": clear_text,
"rotation": rotation
}
result = self.exercise._evaluate_expression(parsed)
self.assertEqual(result, clear_text)
def test_with_spaces_and_punctuation(self):
"""Test decryption with spaces and punctuation"""
test_cases = [
("KHOOR ZRUOG!", "HELLO WORLD!", 3),
("Pb Pbvwhub!", "My Mystery!", 3),
("ABCDE. FGHIJ?", "ZABCD. EFGHI?", 1)
]
for cipher_text, clear_text, rotation in test_cases:
parsed = {
"cipher_text": cipher_text,
"clear_text": clear_text,
"rotation": rotation
}
result = self.exercise._evaluate_expression(parsed)
self.assertEqual(result, clear_text)
def test_caesar_cipher_encryption():
"""Test the Caesar cipher encryption logic"""
config = CaesarCipherConfig(size=1, seed=42)
dataset = CaesarCipherDataset(config)
class TestCaesarCipherGeneration(unittest.TestCase):
"""Test problem generation"""
# Test with known rotation
text = "HELLO"
encrypted = dataset._caesar_encrypt(text, 1)
assert encrypted == "IFMMP" # Each letter shifted by 1
def setUp(self):
self.curriculum = CaesarCipherCurriculum()
self.exercise = CaesarCipherExercise()
self.rng = random.Random(42)
self.curriculum.rng = self.rng
# Test wrapping around Z
encrypted = dataset._caesar_encrypt("XYZ", 2)
assert encrypted == "ZAB"
def test_problem_structure(self):
"""Test that generated problems have the correct structure"""
problem = self.exercise.generate(self.curriculum)
# Test preserving spaces
encrypted = dataset._caesar_encrypt("HELLO WORLD", 1)
assert encrypted == "IFMMP XPSME"
# Check basic structure
self.assertIn("question", problem)
self.assertIn("answer", problem)
self.assertIn("metadata", problem)
# Check metadata structure
metadata = problem["metadata"]
self.assertEqual(metadata["type"], "direct")
self.assertIn("executed_parts", metadata)
executed_parts = metadata["executed_parts"]
self.assertIn("cipher_text", executed_parts)
self.assertIn("clear_text", executed_parts)
self.assertIn("rotation", executed_parts)
def test_caesar_cipher_dataset_items():
"""Test basic properties of generated items"""
config = CaesarCipherConfig(min_words=3, max_words=5, min_rotation=1, max_rotation=3, size=10, seed=42)
dataset = CaesarCipherDataset(config)
def test_rotation_ranges(self):
"""Test that rotation values are within expected ranges"""
# Test all rotation levels
level_max_rotations = {0: 1, 1: 3, 2: 10, 3: 15, 4: 25}
for i in range(len(dataset)):
item = dataset[i]
for level, max_rotation in level_max_rotations.items():
self.curriculum.set_attr_level("rotation", level)
problem = self.exercise.generate(self.curriculum)
rotation = problem["metadata"]["executed_parts"]["rotation"]
self.assertLessEqual(rotation, max_rotation)
self.assertGreaterEqual(rotation, 1) # Min rotation is 1
# Check item structure
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item
def test_word_count_ranges(self):
"""Test that word counts are within expected ranges"""
# Test all word count levels
level_word_counts = {0: 5, 1: 10, 2: 20}
# Check metadata
assert "rotation" in item["metadata"]
assert "cipher_text" in item["metadata"]
assert "clear_text" in item["metadata"]
for level, max_words in level_word_counts.items():
self.curriculum.set_attr_level("num_words", level)
problem = self.exercise.generate(self.curriculum)
clear_text = problem["metadata"]["executed_parts"]["clear_text"]
word_count = len(clear_text.split())
self.assertLessEqual(word_count, max_words)
self.assertGreaterEqual(word_count, 3) # Min words is 3
# Verify rotation constraints
rotation = item["metadata"]["rotation"]
assert config.min_rotation <= rotation <= config.max_rotation
class TestCaesarCipherComprehensive(unittest.TestCase):
"""Comprehensive tests for Caesar cipher"""
# Verify text properties
clear_text = item["metadata"]["clear_text"]
words = clear_text.split()
assert config.min_words <= len(words) <= config.max_words
assert all(word.isupper() and word.isalpha() for word in words)
def setUp(self):
self.curriculum = CaesarCipherCurriculum()
self.exercise = CaesarCipherExercise()
self.rng = random.Random(42)
self.curriculum.rng = self.rng
# Verify encryption
cipher_text = item["metadata"]["cipher_text"]
decrypted = dataset._caesar_encrypt(cipher_text, -rotation) # Decrypt by negative rotation
assert decrypted == clear_text
def test_text_case_styles(self):
"""Test different text case styles"""
case_styles = ["UPPER", "lower", "Mixed"]
num_samples = 100 # Test with multiple samples to ensure we see all styles
# Test each level
for level, expected_styles in enumerate(case_styles):
self.curriculum.set_attr_level("text_case", level)
styles_seen = set()
def test_caesar_cipher_iteration():
"""Test that iteration respects dataset size"""
config = CaesarCipherConfig(size=5, seed=42)
dataset = CaesarCipherDataset(config)
# Generate multiple problems to catch all possible styles
for _ in range(num_samples):
problem = self.exercise.generate(self.curriculum)
text = problem["metadata"]["executed_parts"]["clear_text"]
items = list(dataset)
assert len(items) == config.size
# Determine the style of this text
if text.isupper():
styles_seen.add("UPPER")
elif text.islower():
styles_seen.add("lower")
else:
styles_seen.add("Mixed")
# Test multiple iterations yield same items
assert items == list(dataset)
# At each level, we should see all styles up to that level
expected_styles_set = set(case_styles[:level + 1])
self.assertEqual(styles_seen, expected_styles_set,
f"At level {level}, expected to see styles {expected_styles_set} but saw {styles_seen}")
def test_template_variation(self):
"""Test that different templates are used"""
templates_seen = set()
num_samples = 100
for _ in range(num_samples):
problem = self.exercise.generate(self.curriculum)
templates_seen.add(problem["question"].split(":")[0])
self.assertGreater(len(templates_seen), 1, "Not enough template variation")
def test_comprehensive_random_evaluation(self):
"""Test random evaluation with various configurations and track statistics."""
self.rng = random.Random(42) # Fixed seed for reproducibility
self.curriculum.rng = self.rng
# Track statistics
rotations_used = defaultdict(int)
word_counts = defaultdict(int)
case_styles = defaultdict(int)
total_samples = 1000
# Generate test cases
for _ in range(total_samples):
# Set random attribute levels
self.curriculum.set_attr_level("rotation", self.rng.randint(0, 4))
self.curriculum.set_attr_level("num_words", self.rng.randint(0, 2))
self.curriculum.set_attr_level("text_case", self.rng.randint(0, 2))
# Generate and evaluate a random problem
problem = self.exercise.generate(self.curriculum)
metadata = problem["metadata"]["executed_parts"]
# Track statistics
rotations_used[metadata["rotation"]] += 1
word_counts[len(metadata["clear_text"].split())] += 1
# Determine case style
text = metadata["clear_text"]
if text.isupper():
case_styles["UPPER"] += 1
elif text.islower():
case_styles["lower"] += 1
else:
case_styles["Mixed"] += 1
# Verify encryption is correct
cipher_text = metadata["cipher_text"]
clear_text = metadata["clear_text"]
rotation = metadata["rotation"]
# Verify each character is correctly encrypted
for c1, c2 in zip(cipher_text, clear_text):
if c1.isalpha():
expected = chr((ord(c2.upper()) - ord('A') + rotation) % 26 + ord('A'))
self.assertEqual(c1.upper(), expected)
else:
self.assertEqual(c1, c2)
# Print statistics
print("\nRotations used:")
for rotation, count in sorted(rotations_used.items()):
print(f" Rotation {rotation}: {count}")
print("\nWord counts:")
for words, count in sorted(word_counts.items()):
print(f" {words} words: {count}")
print("\nCase styles:")
for style, count in case_styles.items():
print(f" {style}: {count}")