diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 1b509970..8224a019 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -12,6 +12,7 @@ from .letter_counting import LetterCountingConfig, LetterCountingDataset from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset +from .palindrome_generation import PalindromeConfig, PalindromeDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .word_ladder import WordLadderConfig, WordLadderDataset @@ -42,4 +43,6 @@ __all__ = [ "TextTransformation", "WordLadderConfig", "WordLadderDataset", + "PalindromeConfig", + "PalindromeDataset", ] diff --git a/reasoning_gym/algorithmic/palindrome_generation.py b/reasoning_gym/algorithmic/palindrome_generation.py new file mode 100644 index 00000000..6d29e425 --- /dev/null +++ b/reasoning_gym/algorithmic/palindrome_generation.py @@ -0,0 +1,82 @@ +import random +import string +from dataclasses import dataclass +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +@dataclass +class PalindromeConfig: + """ + Configuration for the palindrome task. + + - min_length: Minimum length of the palindrome. + - max_length: Maximum length of the palindrome. + - seed: Optional seed for reproducibility. + - size: Number of palindrome samples in the virtual dataset. + """ + min_length: int = 3 + max_length: int = 10 + seed: Optional[int] = None + size: int = 50 + + def validate(self) -> None: + """Validate configuration parameters.""" + assert self.min_length >= 1, "min_length must be >= 1" + assert self.max_length >= self.min_length, "max_length must be >= min_length" + + +class PalindromeDataset(ProceduralDataset): + """ + Generates a set of letters that can be assembled into a palindrome. + """ + def __init__(self, config: PalindromeConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """ + Generate a single palindrome task. + + Returns: + dict with: + - "question": Set of letters to form a palindrome. + - "answer": A correct palindrome. + - "metadata": Includes letter set and generated palindrome. + """ + rng = random.Random(self.seed + idx) + length = rng.randint(self.config.min_length, self.config.max_length) + letters = self._generate_palindrome_letters(rng, length) + scrambled_letters = rng.sample(letters, len(letters)) # Scramble the order + palindrome = self._assemble_palindrome(letters) + + question_str = ( + f"Rearrange these letters to form a palindrome (a word, phrase, or sequence that remains the same in reverse): {', '.join(scrambled_letters)}\n\n" + "Example format:\n" + "racecar\n\n" + "What is your palindrome?" + ) + + return { + "question": question_str, + "answer": palindrome, + "metadata": { + "letters": scrambled_letters, + "generated_palindrome": palindrome, + }, + } + + def _generate_palindrome_letters(self, rng: random.Random, length: int) -> list[str]: + """Generate a set of letters that can form a palindrome.""" + half_length = length // 2 + letters = rng.choices(string.ascii_lowercase, k=half_length) + if length % 2 == 1: + middle_letter = rng.choice(string.ascii_lowercase) + return letters + [middle_letter] + letters[::-1] + return letters + letters[::-1] + + def _assemble_palindrome(self, letters: list[str]) -> str: + """Return the palindrome string from the letter set.""" + return "".join(letters) + + +register_dataset("palindrome", PalindromeDataset, PalindromeConfig) diff --git a/tests/test_palindrome.py b/tests/test_palindrome.py new file mode 100644 index 00000000..3af02379 --- /dev/null +++ b/tests/test_palindrome.py @@ -0,0 +1,53 @@ +import pytest + +from reasoning_gym.algorithmic.palindrome_generation import PalindromeConfig, PalindromeDataset + +def test_palindrome_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = PalindromeConfig(min_length=0) # Too short + config.validate() + + with pytest.raises(AssertionError): + config = PalindromeConfig(min_length=5, max_length=3) # Invalid range + config.validate() + +def test_palindrome_deterministic(): + """Test that dataset generates same items with same seed""" + config = PalindromeConfig(seed=42, size=10) + dataset1 = PalindromeDataset(config) + dataset2 = PalindromeDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + +def test_palindrome_items(): + """Test basic properties of generated items""" + config = PalindromeConfig(min_length=3, max_length=7, size=10, seed=42) + dataset = PalindromeDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata contains required fields + assert "letters" in item["metadata"] + assert "generated_palindrome" in item["metadata"] + + # Verify answer is a palindrome + palindrome = item["answer"] + assert palindrome == palindrome[::-1], f"{palindrome} is not a palindrome" + +def test_palindrome_randomization(): + """Test letter randomization in the question""" + config = PalindromeConfig(min_length=4, max_length=4, size=10, seed=42) + dataset = PalindromeDataset(config) + + for item in dataset: + letters = item["metadata"]["letters"] + palindrome = item["metadata"]["generated_palindrome"] + + # Ensure the same letters are present but in different order + assert sorted(letters) == sorted(palindrome)