diff --git a/README.md b/README.md index 10956a51..44b59fd5 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ Available dataset names (which can be used with `create_dataset()`): 'polynomial_equations', 'simple_equations', 'base_conversion', +'caesar_cipher', 'letter_counting', 'number_filtering', 'number_sorting', diff --git a/reasoning_gym/algorithmic/caesar_cipher.py b/reasoning_gym/algorithmic/caesar_cipher.py new file mode 100644 index 00000000..2b457909 --- /dev/null +++ b/reasoning_gym/algorithmic/caesar_cipher.py @@ -0,0 +1,87 @@ +"""Caesar cipher task generator""" + +import re +from dataclasses import dataclass +from random import Random +from string import ascii_uppercase +from typing import List, Optional + +from reasoning_gym.data import read_data_file +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class CaesarCipherConfig: + """Configuration for Caesar cipher task generation""" + + delimiter: str = "." # Delimiter for splitting text into sentences + min_words: int = 3 # Minimum words per sentence + max_words: int = 20 # Maximum words per sentence + min_rotation: int = 1 # Minimum Caesar rotation + max_rotation: int = 25 # Maximum Caesar rotation + seed: Optional[int] = None + size: int = 500 # Virtual dataset size + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.min_words > 0, "min_words must be positive" + assert self.max_words >= self.min_words, "max_words must be >= min_words" + assert 0 < self.min_rotation <= self.max_rotation < 26, "rotation must be in range [1,25]" + + +class CaesarCipherDataset(ProceduralDataset): + """Generates Caesar cipher encryption/decryption tasks""" + + def __init__(self, config: CaesarCipherConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + # Load and preprocess text + text = read_data_file("in_the_year_2889.txt") + + # Split into sentences and filter + sentences = [s.strip() for s in text.split(config.delimiter) if s.strip()] + + # Process each sentence + self.valid_sentences = [] + for sentence in sentences: + # Split into words and filter for alpha-only + words = [w.upper() for w in sentence.split() if w.isalpha()] + if self.config.min_words <= len(words) <= self.config.max_words: + self.valid_sentences.append(" ".join(words)) + + def _caesar_encrypt(self, text: str, rotation: int) -> str: + """Apply Caesar cipher encryption with given rotation""" + result = [] + for char in text: + if char.isalpha(): + # Convert to 0-25 range, rotate, convert back to ASCII + base = ord('A') + rotated = (ord(char) - base + rotation) % 26 + result.append(chr(base + rotated)) + else: + result.append(char) + return "".join(result) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Caesar cipher task""" + rng = Random(self.seed + idx) + + # Select random sentence and rotation + sentence = rng.choice(self.valid_sentences) + rotation = rng.randint(self.config.min_rotation, self.config.max_rotation) + + # Generate cipher text + cipher_text = self._caesar_encrypt(sentence, rotation) + + return { + "question": f"Decrypt this Caesar cipher text: {cipher_text}", + "answer": sentence, + "metadata": { + "rotation": rotation, + "cipher_text": cipher_text, + "clear_text": sentence + } + } + + +register_dataset("caesar_cipher", CaesarCipherDataset, CaesarCipherConfig) diff --git a/tests/test_caesar_cipher.py b/tests/test_caesar_cipher.py new file mode 100644 index 00000000..2ad86bbb --- /dev/null +++ b/tests/test_caesar_cipher.py @@ -0,0 +1,107 @@ +"""Tests for Caesar cipher task generation""" + +import pytest + +from reasoning_gym.algorithmic.caesar_cipher import CaesarCipherConfig, CaesarCipherDataset + + +def test_caesar_cipher_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = CaesarCipherConfig(min_words=0) + config.validate() + + with pytest.raises(AssertionError): + config = CaesarCipherConfig(min_words=10, max_words=5) + config.validate() + + with pytest.raises(AssertionError): + config = CaesarCipherConfig(min_rotation=0) + config.validate() + + with pytest.raises(AssertionError): + config = CaesarCipherConfig(max_rotation=26) + config.validate() + + +def test_caesar_cipher_deterministic(): + """Test that dataset generates same items with same seed""" + config = CaesarCipherConfig(seed=42, size=10) + dataset1 = CaesarCipherDataset(config) + dataset2 = CaesarCipherDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_caesar_cipher_encryption(): + """Test the Caesar cipher encryption logic""" + config = CaesarCipherConfig(size=1, seed=42) + dataset = CaesarCipherDataset(config) + + # Test with known rotation + text = "HELLO" + encrypted = dataset._caesar_encrypt(text, 1) + assert encrypted == "IFMMP" # Each letter shifted by 1 + + # Test wrapping around Z + encrypted = dataset._caesar_encrypt("XYZ", 2) + assert encrypted == "ZAB" + + # Test preserving spaces + encrypted = dataset._caesar_encrypt("HELLO WORLD", 1) + assert encrypted == "IFMMP XPSME" + + +def test_caesar_cipher_dataset_items(): + """Test basic properties of generated items""" + config = CaesarCipherConfig( + min_words=3, + max_words=5, + min_rotation=1, + max_rotation=3, + size=10, + seed=42 + ) + dataset = CaesarCipherDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "rotation" in item["metadata"] + assert "cipher_text" in item["metadata"] + assert "clear_text" in item["metadata"] + + # Verify rotation constraints + rotation = item["metadata"]["rotation"] + assert config.min_rotation <= rotation <= config.max_rotation + + # Verify text properties + clear_text = item["metadata"]["clear_text"] + words = clear_text.split() + assert config.min_words <= len(words) <= config.max_words + assert all(word.isupper() and word.isalpha() for word in words) + + # Verify encryption + cipher_text = item["metadata"]["cipher_text"] + decrypted = dataset._caesar_encrypt(cipher_text, -rotation) # Decrypt by negative rotation + assert decrypted == clear_text + + +def test_caesar_cipher_iteration(): + """Test that iteration respects dataset size""" + config = CaesarCipherConfig(size=5, seed=42) + dataset = CaesarCipherDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset)