mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
Refactor CaesarCipher
This commit is contained in:
parent
23aa6ca7e7
commit
5279ccf7e1
8 changed files with 513 additions and 159 deletions
|
|
@ -1,100 +1,272 @@
|
|||
"""Tests for Caesar cipher task generation"""
|
||||
"""Unit tests for the Caesar cipher exercise."""
|
||||
|
||||
import pytest
|
||||
from reasoning_gym.curricula.algorithmic.caesar_cipher_curriculum import CaesarCipherCurriculum
|
||||
from reasoning_gym.exercises.algorithmic.caesar_cipher import CaesarCipherExercise
|
||||
import unittest
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
from reasoning_gym.algorithmic.caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
|
||||
class TestCaesarCipherParsing(unittest.TestCase):
|
||||
"""Test parsing of Caesar cipher metadata"""
|
||||
|
||||
def setUp(self):
|
||||
self.exercise = CaesarCipherExercise()
|
||||
|
||||
def test_caesar_cipher_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = CaesarCipherConfig(min_words=0)
|
||||
config.validate()
|
||||
def test_parse_expression_basic(self):
|
||||
"""Test parsing of basic Caesar cipher metadata"""
|
||||
test_metadata = {
|
||||
"cipher_text": {
|
||||
"encrypted_text": "KHOOR",
|
||||
"clear_text": "HELLO",
|
||||
"rotation": 3
|
||||
}
|
||||
}
|
||||
parsed = self.exercise._parse_expression(test_metadata)
|
||||
self.assertEqual(parsed["cipher_text"], "KHOOR")
|
||||
self.assertEqual(parsed["clear_text"], "HELLO")
|
||||
self.assertEqual(parsed["rotation"], 3)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CaesarCipherConfig(min_words=10, max_words=5)
|
||||
config.validate()
|
||||
def test_parse_with_spaces(self):
|
||||
"""Test parsing with spaces and punctuation"""
|
||||
test_metadata = {
|
||||
"cipher_text": {
|
||||
"encrypted_text": "KHOOR ZRUOG!",
|
||||
"clear_text": "HELLO WORLD!",
|
||||
"rotation": 3
|
||||
}
|
||||
}
|
||||
parsed = self.exercise._parse_expression(test_metadata)
|
||||
self.assertEqual(parsed["cipher_text"], "KHOOR ZRUOG!")
|
||||
self.assertEqual(parsed["clear_text"], "HELLO WORLD!")
|
||||
self.assertEqual(parsed["rotation"], 3)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CaesarCipherConfig(min_rotation=0)
|
||||
config.validate()
|
||||
def test_parse_mixed_case(self):
|
||||
"""Test parsing with mixed case text"""
|
||||
test_metadata = {
|
||||
"cipher_text": {
|
||||
"encrypted_text": "KhOoR",
|
||||
"clear_text": "HeLlO",
|
||||
"rotation": 3
|
||||
}
|
||||
}
|
||||
parsed = self.exercise._parse_expression(test_metadata)
|
||||
self.assertEqual(parsed["cipher_text"], "KhOoR")
|
||||
self.assertEqual(parsed["clear_text"], "HeLlO")
|
||||
self.assertEqual(parsed["rotation"], 3)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CaesarCipherConfig(max_rotation=26)
|
||||
config.validate()
|
||||
class TestCaesarCipherEvaluation(unittest.TestCase):
|
||||
"""Test evaluation of Caesar cipher problems"""
|
||||
|
||||
def setUp(self):
|
||||
self.exercise = CaesarCipherExercise()
|
||||
|
||||
def test_caesar_cipher_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = CaesarCipherConfig(seed=42, size=10)
|
||||
dataset1 = CaesarCipherDataset(config)
|
||||
dataset2 = CaesarCipherDataset(config)
|
||||
def test_basic_decryption(self):
|
||||
"""Test basic decryption cases"""
|
||||
test_cases = [
|
||||
("KHOOR", "HELLO", 3), # Basic uppercase
|
||||
("khoor", "hello", 3), # Basic lowercase
|
||||
("WORLD", "WORLD", 0), # No rotation
|
||||
("ABCDE", "ZABCD", 1), # Wrap around
|
||||
("hello", "hello", 26) # Full rotation
|
||||
]
|
||||
for cipher_text, clear_text, rotation in test_cases:
|
||||
parsed = {
|
||||
"cipher_text": cipher_text,
|
||||
"clear_text": clear_text,
|
||||
"rotation": rotation
|
||||
}
|
||||
result = self.exercise._evaluate_expression(parsed)
|
||||
self.assertEqual(result, clear_text)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
def test_mixed_case_decryption(self):
|
||||
"""Test decryption with mixed case"""
|
||||
test_cases = [
|
||||
("HeLlO", "HeLlO", 26), # Mixed case, full rotation
|
||||
("WoRlD", "WoRlD", 0), # Mixed case, no rotation
|
||||
("AbCdE", "ZaBcD", 1) # Mixed case, wrap around
|
||||
]
|
||||
for cipher_text, clear_text, rotation in test_cases:
|
||||
parsed = {
|
||||
"cipher_text": cipher_text,
|
||||
"clear_text": clear_text,
|
||||
"rotation": rotation
|
||||
}
|
||||
result = self.exercise._evaluate_expression(parsed)
|
||||
self.assertEqual(result, clear_text)
|
||||
|
||||
def test_with_spaces_and_punctuation(self):
|
||||
"""Test decryption with spaces and punctuation"""
|
||||
test_cases = [
|
||||
("KHOOR ZRUOG!", "HELLO WORLD!", 3),
|
||||
("Pb Pbvwhub!", "My Mystery!", 3),
|
||||
("ABCDE. FGHIJ?", "ZABCD. EFGHI?", 1)
|
||||
]
|
||||
for cipher_text, clear_text, rotation in test_cases:
|
||||
parsed = {
|
||||
"cipher_text": cipher_text,
|
||||
"clear_text": clear_text,
|
||||
"rotation": rotation
|
||||
}
|
||||
result = self.exercise._evaluate_expression(parsed)
|
||||
self.assertEqual(result, clear_text)
|
||||
|
||||
def test_caesar_cipher_encryption():
|
||||
"""Test the Caesar cipher encryption logic"""
|
||||
config = CaesarCipherConfig(size=1, seed=42)
|
||||
dataset = CaesarCipherDataset(config)
|
||||
class TestCaesarCipherGeneration(unittest.TestCase):
|
||||
"""Test problem generation"""
|
||||
|
||||
# Test with known rotation
|
||||
text = "HELLO"
|
||||
encrypted = dataset._caesar_encrypt(text, 1)
|
||||
assert encrypted == "IFMMP" # Each letter shifted by 1
|
||||
def setUp(self):
|
||||
self.curriculum = CaesarCipherCurriculum()
|
||||
self.exercise = CaesarCipherExercise()
|
||||
self.rng = random.Random(42)
|
||||
self.curriculum.rng = self.rng
|
||||
|
||||
# Test wrapping around Z
|
||||
encrypted = dataset._caesar_encrypt("XYZ", 2)
|
||||
assert encrypted == "ZAB"
|
||||
def test_problem_structure(self):
|
||||
"""Test that generated problems have the correct structure"""
|
||||
problem = self.exercise.generate(self.curriculum)
|
||||
|
||||
# Test preserving spaces
|
||||
encrypted = dataset._caesar_encrypt("HELLO WORLD", 1)
|
||||
assert encrypted == "IFMMP XPSME"
|
||||
# Check basic structure
|
||||
self.assertIn("question", problem)
|
||||
self.assertIn("answer", problem)
|
||||
self.assertIn("metadata", problem)
|
||||
|
||||
# Check metadata structure
|
||||
metadata = problem["metadata"]
|
||||
self.assertEqual(metadata["type"], "direct")
|
||||
self.assertIn("executed_parts", metadata)
|
||||
executed_parts = metadata["executed_parts"]
|
||||
self.assertIn("cipher_text", executed_parts)
|
||||
self.assertIn("clear_text", executed_parts)
|
||||
self.assertIn("rotation", executed_parts)
|
||||
|
||||
def test_caesar_cipher_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = CaesarCipherConfig(min_words=3, max_words=5, min_rotation=1, max_rotation=3, size=10, seed=42)
|
||||
dataset = CaesarCipherDataset(config)
|
||||
def test_rotation_ranges(self):
|
||||
"""Test that rotation values are within expected ranges"""
|
||||
# Test all rotation levels
|
||||
level_max_rotations = {0: 1, 1: 3, 2: 10, 3: 15, 4: 25}
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
for level, max_rotation in level_max_rotations.items():
|
||||
self.curriculum.set_attr_level("rotation", level)
|
||||
problem = self.exercise.generate(self.curriculum)
|
||||
rotation = problem["metadata"]["executed_parts"]["rotation"]
|
||||
self.assertLessEqual(rotation, max_rotation)
|
||||
self.assertGreaterEqual(rotation, 1) # Min rotation is 1
|
||||
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
def test_word_count_ranges(self):
|
||||
"""Test that word counts are within expected ranges"""
|
||||
# Test all word count levels
|
||||
level_word_counts = {0: 5, 1: 10, 2: 20}
|
||||
|
||||
# Check metadata
|
||||
assert "rotation" in item["metadata"]
|
||||
assert "cipher_text" in item["metadata"]
|
||||
assert "clear_text" in item["metadata"]
|
||||
for level, max_words in level_word_counts.items():
|
||||
self.curriculum.set_attr_level("num_words", level)
|
||||
problem = self.exercise.generate(self.curriculum)
|
||||
clear_text = problem["metadata"]["executed_parts"]["clear_text"]
|
||||
word_count = len(clear_text.split())
|
||||
self.assertLessEqual(word_count, max_words)
|
||||
self.assertGreaterEqual(word_count, 3) # Min words is 3
|
||||
|
||||
# Verify rotation constraints
|
||||
rotation = item["metadata"]["rotation"]
|
||||
assert config.min_rotation <= rotation <= config.max_rotation
|
||||
class TestCaesarCipherComprehensive(unittest.TestCase):
|
||||
"""Comprehensive tests for Caesar cipher"""
|
||||
|
||||
# Verify text properties
|
||||
clear_text = item["metadata"]["clear_text"]
|
||||
words = clear_text.split()
|
||||
assert config.min_words <= len(words) <= config.max_words
|
||||
assert all(word.isupper() and word.isalpha() for word in words)
|
||||
def setUp(self):
|
||||
self.curriculum = CaesarCipherCurriculum()
|
||||
self.exercise = CaesarCipherExercise()
|
||||
self.rng = random.Random(42)
|
||||
self.curriculum.rng = self.rng
|
||||
|
||||
# Verify encryption
|
||||
cipher_text = item["metadata"]["cipher_text"]
|
||||
decrypted = dataset._caesar_encrypt(cipher_text, -rotation) # Decrypt by negative rotation
|
||||
assert decrypted == clear_text
|
||||
def test_text_case_styles(self):
|
||||
"""Test different text case styles"""
|
||||
case_styles = ["UPPER", "lower", "Mixed"]
|
||||
num_samples = 100 # Test with multiple samples to ensure we see all styles
|
||||
|
||||
# Test each level
|
||||
for level, expected_styles in enumerate(case_styles):
|
||||
self.curriculum.set_attr_level("text_case", level)
|
||||
styles_seen = set()
|
||||
|
||||
def test_caesar_cipher_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = CaesarCipherConfig(size=5, seed=42)
|
||||
dataset = CaesarCipherDataset(config)
|
||||
# Generate multiple problems to catch all possible styles
|
||||
for _ in range(num_samples):
|
||||
problem = self.exercise.generate(self.curriculum)
|
||||
text = problem["metadata"]["executed_parts"]["clear_text"]
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
# Determine the style of this text
|
||||
if text.isupper():
|
||||
styles_seen.add("UPPER")
|
||||
elif text.islower():
|
||||
styles_seen.add("lower")
|
||||
else:
|
||||
styles_seen.add("Mixed")
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
# At each level, we should see all styles up to that level
|
||||
expected_styles_set = set(case_styles[:level + 1])
|
||||
self.assertEqual(styles_seen, expected_styles_set,
|
||||
f"At level {level}, expected to see styles {expected_styles_set} but saw {styles_seen}")
|
||||
|
||||
def test_template_variation(self):
|
||||
"""Test that different templates are used"""
|
||||
templates_seen = set()
|
||||
num_samples = 100
|
||||
|
||||
for _ in range(num_samples):
|
||||
problem = self.exercise.generate(self.curriculum)
|
||||
templates_seen.add(problem["question"].split(":")[0])
|
||||
|
||||
self.assertGreater(len(templates_seen), 1, "Not enough template variation")
|
||||
|
||||
def test_comprehensive_random_evaluation(self):
|
||||
"""Test random evaluation with various configurations and track statistics."""
|
||||
self.rng = random.Random(42) # Fixed seed for reproducibility
|
||||
self.curriculum.rng = self.rng
|
||||
|
||||
# Track statistics
|
||||
rotations_used = defaultdict(int)
|
||||
word_counts = defaultdict(int)
|
||||
case_styles = defaultdict(int)
|
||||
total_samples = 1000
|
||||
|
||||
# Generate test cases
|
||||
for _ in range(total_samples):
|
||||
# Set random attribute levels
|
||||
self.curriculum.set_attr_level("rotation", self.rng.randint(0, 4))
|
||||
self.curriculum.set_attr_level("num_words", self.rng.randint(0, 2))
|
||||
self.curriculum.set_attr_level("text_case", self.rng.randint(0, 2))
|
||||
|
||||
# Generate and evaluate a random problem
|
||||
problem = self.exercise.generate(self.curriculum)
|
||||
metadata = problem["metadata"]["executed_parts"]
|
||||
|
||||
# Track statistics
|
||||
rotations_used[metadata["rotation"]] += 1
|
||||
word_counts[len(metadata["clear_text"].split())] += 1
|
||||
|
||||
# Determine case style
|
||||
text = metadata["clear_text"]
|
||||
if text.isupper():
|
||||
case_styles["UPPER"] += 1
|
||||
elif text.islower():
|
||||
case_styles["lower"] += 1
|
||||
else:
|
||||
case_styles["Mixed"] += 1
|
||||
|
||||
# Verify encryption is correct
|
||||
cipher_text = metadata["cipher_text"]
|
||||
clear_text = metadata["clear_text"]
|
||||
rotation = metadata["rotation"]
|
||||
|
||||
# Verify each character is correctly encrypted
|
||||
for c1, c2 in zip(cipher_text, clear_text):
|
||||
if c1.isalpha():
|
||||
expected = chr((ord(c2.upper()) - ord('A') + rotation) % 26 + ord('A'))
|
||||
self.assertEqual(c1.upper(), expected)
|
||||
else:
|
||||
self.assertEqual(c1, c2)
|
||||
|
||||
# Print statistics
|
||||
print("\nRotations used:")
|
||||
for rotation, count in sorted(rotations_used.items()):
|
||||
print(f" Rotation {rotation}: {count}")
|
||||
|
||||
print("\nWord counts:")
|
||||
for words, count in sorted(word_counts.items()):
|
||||
print(f" {words} words: {count}")
|
||||
|
||||
print("\nCase styles:")
|
||||
for style, count in case_styles.items():
|
||||
print(f" {style}: {count}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue