diff --git a/reasoning_gym/algorithmic/caesar_cipher.py b/reasoning_gym/algorithmic/caesar_cipher.py index 2b457909..01b8f4ed 100644 --- a/reasoning_gym/algorithmic/caesar_cipher.py +++ b/reasoning_gym/algorithmic/caesar_cipher.py @@ -7,6 +7,7 @@ from string import ascii_uppercase from typing import List, Optional from reasoning_gym.data import read_data_file + from ..factory import ProceduralDataset, register_dataset @@ -37,10 +38,10 @@ class CaesarCipherDataset(ProceduralDataset): # Load and preprocess text text = read_data_file("in_the_year_2889.txt") - + # Split into sentences and filter sentences = [s.strip() for s in text.split(config.delimiter) if s.strip()] - + # Process each sentence self.valid_sentences = [] for sentence in sentences: @@ -55,7 +56,7 @@ class CaesarCipherDataset(ProceduralDataset): for char in text: if char.isalpha(): # Convert to 0-25 range, rotate, convert back to ASCII - base = ord('A') + base = ord("A") rotated = (ord(char) - base + rotation) % 26 result.append(chr(base + rotated)) else: @@ -65,22 +66,18 @@ class CaesarCipherDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single Caesar cipher task""" rng = Random(self.seed + idx) - + # Select random sentence and rotation sentence = rng.choice(self.valid_sentences) rotation = rng.randint(self.config.min_rotation, self.config.max_rotation) - + # Generate cipher text cipher_text = self._caesar_encrypt(sentence, rotation) - + return { "question": f"Decrypt this Caesar cipher text: {cipher_text}", "answer": sentence, - "metadata": { - "rotation": rotation, - "cipher_text": cipher_text, - "clear_text": sentence - } + "metadata": {"rotation": rotation, "cipher_text": cipher_text, "clear_text": sentence}, } diff --git a/reasoning_gym/algorithmic/letter_jumble.py b/reasoning_gym/algorithmic/letter_jumble.py index 59f1836b..9919ccf0 100644 --- a/reasoning_gym/algorithmic/letter_jumble.py +++ b/reasoning_gym/algorithmic/letter_jumble.py @@ -6,6 +6,7 @@ from random import Random from typing import List, Optional from reasoning_gym.data import read_data_file + from ..factory import ProceduralDataset, register_dataset @@ -31,7 +32,9 @@ class LetterJumbleConfig: assert self.max_words >= self.min_words, "max_words must be >= min_words" assert 0 <= self.min_corruption_level <= 1, "min_corruption_level must be in [0,1]" assert 0 <= self.max_corruption_level <= 1, "max_corruption_level must be in [0,1]" - assert self.max_corruption_level >= self.min_corruption_level, "max_corruption_level must be >= min_corruption_level" + assert ( + self.max_corruption_level >= self.min_corruption_level + ), "max_corruption_level must be >= min_corruption_level" class LetterJumbleDataset(ProceduralDataset): @@ -44,50 +47,47 @@ class LetterJumbleDataset(ProceduralDataset): text = read_data_file("in_the_year_2889.txt") # Extract words and filter by length self.words = [ - word for word in re.findall(r"\b\w+\b", text) - if self.config.min_word_len <= len(word) <= self.config.max_word_len - and word.isalpha() + word + for word in re.findall(r"\b\w+\b", text) + if self.config.min_word_len <= len(word) <= self.config.max_word_len and word.isalpha() ] def _scramble_word(self, word: str, corruption_level: float, rng: Random) -> str: """Scramble a word by swapping random pairs of characters""" if len(word) < 2: # Can't scramble 1-character words return word - + word = list(word) num_swaps = max(1, int(len(word) * corruption_level)) # Ensure at least one swap - + for _ in range(num_swaps): # Pick two different random positions pos1, pos2 = rng.sample(range(len(word)), 2) # Swap characters word[pos1], word[pos2] = word[pos2], word[pos1] - + return "".join(word) def __getitem__(self, idx: int) -> dict: """Generate a single word jumbling task""" rng = Random(self.seed + idx) - + # Select number of words and corruption level num_words = rng.randint(self.config.min_words, self.config.max_words) corruption_level = rng.uniform(self.config.min_corruption_level, self.config.max_corruption_level) - + # Select words based on configuration if self.config.consecutive_words: # Select consecutive words from a random starting position start_idx = rng.randint(0, len(self.words) - num_words) - selected_words = self.words[start_idx:start_idx + num_words] + selected_words = self.words[start_idx : start_idx + num_words] else: # Select random words selected_words = rng.sample(self.words, num_words) - + # Scramble each word - scrambled_words = [ - self._scramble_word(word, corruption_level, rng) - for word in selected_words - ] - + scrambled_words = [self._scramble_word(word, corruption_level, rng) for word in selected_words] + return { "question": f"Unscramble these words: {' '.join(scrambled_words)}", "answer": " ".join(selected_words), @@ -95,8 +95,8 @@ class LetterJumbleDataset(ProceduralDataset): "num_words": num_words, "corruption_level": corruption_level, "scrambled_words": scrambled_words, - "original_words": selected_words - } + "original_words": selected_words, + }, } diff --git a/reasoning_gym/games/countdown.py b/reasoning_gym/games/countdown.py index b1715a50..4721844d 100644 --- a/reasoning_gym/games/countdown.py +++ b/reasoning_gym/games/countdown.py @@ -78,11 +78,11 @@ class CountdownDataset(ProceduralDataset): def _generate_candidate_expression(self, rng: Random, num_terms: int) -> Tuple[sympy.Expr, List[int], List[Symbol]]: """Generate a candidate expression with random numbers and operators - + Args: rng: Random number generator num_terms: Number of terms to include - + Returns: Tuple of (sympy expression, list of numbers, list of symbols) """ @@ -139,23 +139,23 @@ class CountdownDataset(ProceduralDataset): for attempt in range(max_attempts): try: expr, numbers, syms = self._generate_candidate_expression(rng, num_terms) - + # Substitute actual numbers to get target subs = {sym: num for sym, num in zip(syms, numbers)} target = int(expr.subs(subs)) - + # Convert to string expression expr_str = str(expr) for i, sym in enumerate(syms): expr_str = expr_str.replace(str(sym), str(numbers[i])) - + # Ensure target is within bounds if self.config.min_target <= target <= self.config.max_target: return expr_str, numbers, target - + except (ValueError, ZeroDivisionError): continue - + raise ValueError(f"Failed to generate valid expression after {max_attempts} attempts") diff --git a/tests/test_caesar_cipher.py b/tests/test_caesar_cipher.py index 2ad86bbb..fa572d8d 100644 --- a/tests/test_caesar_cipher.py +++ b/tests/test_caesar_cipher.py @@ -38,7 +38,7 @@ def test_caesar_cipher_encryption(): """Test the Caesar cipher encryption logic""" config = CaesarCipherConfig(size=1, seed=42) dataset = CaesarCipherDataset(config) - + # Test with known rotation text = "HELLO" encrypted = dataset._caesar_encrypt(text, 1) @@ -55,34 +55,27 @@ def test_caesar_cipher_encryption(): def test_caesar_cipher_dataset_items(): """Test basic properties of generated items""" - config = CaesarCipherConfig( - min_words=3, - max_words=5, - min_rotation=1, - max_rotation=3, - size=10, - seed=42 - ) + config = CaesarCipherConfig(min_words=3, max_words=5, min_rotation=1, max_rotation=3, size=10, seed=42) dataset = CaesarCipherDataset(config) for i in range(len(dataset)): item = dataset[i] - + # Check item structure assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata assert "rotation" in item["metadata"] assert "cipher_text" in item["metadata"] assert "clear_text" in item["metadata"] - + # Verify rotation constraints rotation = item["metadata"]["rotation"] assert config.min_rotation <= rotation <= config.max_rotation - + # Verify text properties clear_text = item["metadata"]["clear_text"] words = clear_text.split() diff --git a/tests/test_countdown.py b/tests/test_countdown.py index 365273a6..e426caf2 100644 --- a/tests/test_countdown.py +++ b/tests/test_countdown.py @@ -61,7 +61,7 @@ def test_countdown_game_items(): # Verify all numbers are within config range assert all(config.min_value <= n <= config.max_value for n in item["metadata"]["numbers"]) - + # Verify expression evaluates correctly expr = item["metadata"]["expression"] try: diff --git a/tests/test_letter_jumble.py b/tests/test_letter_jumble.py index b659d2ad..8203f2f0 100644 --- a/tests/test_letter_jumble.py +++ b/tests/test_letter_jumble.py @@ -1,8 +1,9 @@ """Tests for letter jumbling task generation""" -import pytest from random import Random +import pytest + from reasoning_gym.algorithmic.letter_jumble import LetterJumbleConfig, LetterJumbleDataset @@ -45,15 +46,15 @@ def test_letter_jumble_scrambling(): min_corruption_level=0.5, max_corruption_level=0.5, size=1, - seed=42 + seed=42, ) dataset = LetterJumbleDataset(config) - + # Test with known word word = "testing" rng = Random(42) scrambled = dataset._scramble_word(word, 0.5, rng) - + # Verify scrambled word: # - Has same length as original assert len(scrambled) == len(word) @@ -73,35 +74,35 @@ def test_letter_jumble_dataset_items(): min_corruption_level=0.1, max_corruption_level=0.3, size=50, - seed=42 + seed=42, ) dataset = LetterJumbleDataset(config) for i in range(len(dataset)): item = dataset[i] - + # Check item structure assert isinstance(item, dict) assert "question" in item assert "answer" in item assert "metadata" in item - + # Check metadata metadata = item["metadata"] assert "num_words" in metadata assert "corruption_level" in metadata assert "scrambled_words" in metadata assert "original_words" in metadata - + # Verify word counts num_words = metadata["num_words"] assert config.min_words <= num_words <= config.max_words assert len(metadata["scrambled_words"]) == num_words assert len(metadata["original_words"]) == num_words - + # Verify corruption level assert config.min_corruption_level <= metadata["corruption_level"] <= config.max_corruption_level - + # Verify word properties for word in metadata["original_words"]: assert config.min_word_len <= len(word) <= config.max_word_len