diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 20f3e90c..2400f49c 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -24,10 +24,10 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset +from .string_insertion import StringInsertionConfig, StringInsertionDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset -from .string_insertion import StringInsertionConfig, StringInsertionDataset __all__ = [ "SpellBackwardConfig", diff --git a/reasoning_gym/algorithmic/string_insertion.py b/reasoning_gym/algorithmic/string_insertion.py index fb1d35e0..b217ed76 100644 --- a/reasoning_gym/algorithmic/string_insertion.py +++ b/reasoning_gym/algorithmic/string_insertion.py @@ -9,7 +9,6 @@ from typing import Optional from ..factory import ProceduralDataset, register_dataset - QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern: 1. If there is a substring ABCD in the string, insert the character A after the substring. 2. If there is a substring BCDE in the string, insert the character B after the substring. @@ -22,7 +21,7 @@ Once you have inserted a character, you have to skip over the substring and the Example - Input: DDABCDEEDEAB - Output: DDABCDAEEDEABD -- Explanation: +- Explanation: - Theere are two inserted characters: DDABCD[A]EEDEAB[D] (shown in square brackets) - First, we insert A after ABCD. - Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character. @@ -37,7 +36,7 @@ class StringInsertionConfig: """Configuration for String Insertion dataset generation""" min_string_length: int = 5 # Minimum string length - max_string_length: int = 20 # Maximum string length + max_string_length: int = 20 # Maximum string length size: int = 500 # Virtual dataset size seed: Optional[int] = None @@ -47,12 +46,13 @@ class StringInsertionConfig: assert 5 <= self.min_string_length, "Minimum string length should be at least 5" assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum" + class StringInsertionDataset(ProceduralDataset): """Generates String Insertion exercises with configurable difficulty""" def __init__(self, config: StringInsertionConfig): super().__init__(config=config, seed=config.seed, size=config.size) - self.vocabulary = ['A', 'B', 'C', 'D', 'E'] + self.vocabulary = ["A", "B", "C", "D", "E"] self.insertion_rules = [ ("ABCD", "A"), ("BCDE", "B"), @@ -68,7 +68,7 @@ class StringInsertionDataset(ProceduralDataset): while i < len(string): inserted = False for pattern, char in self.insertion_rules: - substring = string[i:i+len(pattern)] + substring = string[i : i + len(pattern)] if substring == pattern: output.append(substring + char) i += len(pattern) @@ -82,7 +82,7 @@ class StringInsertionDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single String Insertion question""" rng = Random(self.seed + idx) - + string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) string = [rng.choice(self.vocabulary) for _ in range(string_length)] diff --git a/tests/test_string_insertion.py b/tests/test_string_insertion.py index 746ff5c5..12225954 100644 --- a/tests/test_string_insertion.py +++ b/tests/test_string_insertion.py @@ -7,13 +7,13 @@ from reasoning_gym.algorithmic.string_insertion import StringInsertionConfig, St def test_string_insertion_config_validation(): """Test that invalid configs raise appropriate errors""" - + for field in ["min_string_length", "max_string_length"]: for i in range(-1, 5): with pytest.raises(AssertionError): - config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid + config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid config.validate() - + def test_string_insertion_dataset_deterministic(): """Test that dataset generates same items with same seed""" @@ -67,7 +67,7 @@ def test_string_insertion_answer(): config = StringInsertionConfig(seed=42) dataset = StringInsertionDataset(config) - # No pattern match + # No pattern match assert dataset._get_answer("AAAAAAA") == "AAAAAAA" assert dataset._get_answer("ADBEEBEA") == "ADBEEBEA" assert dataset._get_answer("ADEACA") == "ADEACA" @@ -91,4 +91,4 @@ def test_string_insertion_answer(): assert dataset._get_answer("AABCDEEEEEEEBCDEAAAAA") == "AABCDAEEEEEEEBCDEBAAAAA" # No reuse of newly inserted characters - assert dataset._get_answer("ABCDBCD") == "ABCDABCD" \ No newline at end of file + assert dataset._get_answer("ABCDBCD") == "ABCDABCD"