diff --git a/reasoning_gym/algorithmic/unscramble_words.py b/reasoning_gym/algorithmic/unscramble_words.py index 87febd7f..0dd4521d 100644 --- a/reasoning_gym/algorithmic/unscramble_words.py +++ b/reasoning_gym/algorithmic/unscramble_words.py @@ -19,6 +19,7 @@ class UnscrambleWordsConfig: max_words: int = 20 # Maximum words per task min_corruption_level: float = 0.1 # Minimum fraction of characters to swap max_corruption_level: float = 0.9 # Maximum fraction of characters to swap + consecutive_words: bool = True # Whether to select consecutive words from text seed: Optional[int] = None size: int = 500 # Virtual dataset size @@ -54,7 +55,7 @@ class UnscrambleWordsDataset(ProceduralDataset): return word word = list(word) - num_swaps = int(len(word) * corruption_level) + num_swaps = max(1, int(len(word) * corruption_level)) # Ensure at least one swap for _ in range(num_swaps): # Pick two different random positions @@ -72,8 +73,14 @@ class UnscrambleWordsDataset(ProceduralDataset): num_words = rng.randint(self.config.min_words, self.config.max_words) corruption_level = rng.uniform(self.config.min_corruption_level, self.config.max_corruption_level) - # Select random words - selected_words = rng.sample(self.words, num_words) + # Select words based on configuration + if self.config.consecutive_words: + # Select consecutive words from a random starting position + start_idx = rng.randint(0, len(self.words) - num_words) + selected_words = self.words[start_idx:start_idx + num_words] + else: + # Select random words + selected_words = rng.sample(self.words, num_words) # Scramble each word scrambled_words = [