diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index fe0a2dc2..20f3e90c 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -27,6 +27,7 @@ from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset +from .string_insertion import StringInsertionConfig, StringInsertionDataset __all__ = [ "SpellBackwardConfig", @@ -72,4 +73,6 @@ __all__ = [ "ABDataset", "CountPrimesConfig", "CountPrimesDataset", + "StringInsertionConfig", + "StringInsertionDataset", ] diff --git a/reasoning_gym/algorithmic/string_insertion.py b/reasoning_gym/algorithmic/string_insertion.py new file mode 100644 index 00000000..fb1d35e0 --- /dev/null +++ b/reasoning_gym/algorithmic/string_insertion.py @@ -0,0 +1,98 @@ +"""Insert into string according to a pattern + +https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dataset_string_insertion.py +""" + +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + + +QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern: +1. If there is a substring ABCD in the string, insert the character A after the substring. +2. If there is a substring BCDE in the string, insert the character B after the substring. +3. If there is a substring CDEA in the string, insert the character C after the substring. +4. If there is a substring DEAB in the string, insert the character D after the substring. +5. If there is a substring EABC in the string, insert the character E after the substring. + +Once you have inserted a character, you have to skip over the substring and the inserted character and continue the search from the next character. + +Example +- Input: DDABCDEEDEAB +- Output: DDABCDAEEDEABD +- Explanation: + - Theere are two inserted characters: DDABCD[A]EEDEAB[D] (shown in square brackets) + - First, we insert A after ABCD. + - Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character. + - Lastly, we insert D after DEAB. + +Given the following string, provide the answer after inserting the characters according to the pattern: {string} +""" + + +@dataclass +class StringInsertionConfig: + """Configuration for String Insertion dataset generation""" + + min_string_length: int = 5 # Minimum string length + max_string_length: int = 20 # Maximum string length + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 5 <= self.min_string_length, "Minimum string length should be at least 5" + assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum" + +class StringInsertionDataset(ProceduralDataset): + """Generates String Insertion exercises with configurable difficulty""" + + def __init__(self, config: StringInsertionConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.vocabulary = ['A', 'B', 'C', 'D', 'E'] + self.insertion_rules = [ + ("ABCD", "A"), + ("BCDE", "B"), + ("CDEA", "C"), + ("DEAB", "D"), + ("EABC", "E"), + ] + + def _get_answer(self, string: str) -> str: + """Apply insertion rules to a string""" + output = [] + i = 0 + while i < len(string): + inserted = False + for pattern, char in self.insertion_rules: + substring = string[i:i+len(pattern)] + if substring == pattern: + output.append(substring + char) + i += len(pattern) + inserted = True + break + if not inserted: + output.append(string[i]) + i += 1 + return "".join(output) + + def __getitem__(self, idx: int) -> dict: + """Generate a single String Insertion question""" + rng = Random(self.seed + idx) + + string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) + string = [rng.choice(self.vocabulary) for _ in range(string_length)] + + answer = self._get_answer(string) + + return { + "question": QUESTION_TEMPLATE.format(string=string), + "answer": str(answer), + "metadata": {"string": string, "solution": answer}, + } + + +register_dataset("string_insertion", StringInsertionDataset, StringInsertionConfig) diff --git a/tests/test_string_insertion.py b/tests/test_string_insertion.py new file mode 100644 index 00000000..746ff5c5 --- /dev/null +++ b/tests/test_string_insertion.py @@ -0,0 +1,94 @@ +"""Tests for String Insertion questions generation""" + +import pytest + +from reasoning_gym.algorithmic.string_insertion import StringInsertionConfig, StringInsertionDataset + + +def test_string_insertion_config_validation(): + """Test that invalid configs raise appropriate errors""" + + for field in ["min_string_length", "max_string_length"]: + for i in range(-1, 5): + with pytest.raises(AssertionError): + config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid + config.validate() + + +def test_string_insertion_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = StringInsertionConfig(seed=42, size=10) + dataset1 = StringInsertionDataset(config) + dataset2 = StringInsertionDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_string_insertion_dataset_items(): + """Test basic properties of generated items""" + config = StringInsertionConfig(min_string_length=5, max_string_length=30, size=10, seed=42) + dataset = StringInsertionDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "string" in item["metadata"] + assert "solution" in item["metadata"] + + string = item["metadata"]["string"] + solution = item["metadata"]["solution"] + + # Verify string dimensions + assert 5 <= len(string) <= 30 + assert len(string) <= len(solution) + + +def test_string_insertion_dataset_iteration(): + """Test that iteration respects dataset size""" + config = StringInsertionConfig(size=5, seed=42) + dataset = StringInsertionDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_string_insertion_answer(): + """Test the _get_rotated method""" + config = StringInsertionConfig(seed=42) + dataset = StringInsertionDataset(config) + + # No pattern match + assert dataset._get_answer("AAAAAAA") == "AAAAAAA" + assert dataset._get_answer("ADBEEBEA") == "ADBEEBEA" + assert dataset._get_answer("ADEACA") == "ADEACA" + + # Insert A after ABCD + assert dataset._get_answer("ABCDE") == "ABCDAE" + + # Insert B after BCDE + assert dataset._get_answer("AEBCDEC") == "AEBCDEBC" + + # Insert C after CDEA + assert dataset._get_answer("BBACDEAC") == "BBACDEACC" + + # Insert D after DEAB + assert dataset._get_answer("BAAABDEAB") == "BAAABDEABD" + + # Insert E after EABC + assert dataset._get_answer("EABCBCBC") == "EABCEBCBC" + + # Multiple insertions + assert dataset._get_answer("AABCDEEEEEEEBCDEAAAAA") == "AABCDAEEEEEEEBCDEBAAAAA" + + # No reuse of newly inserted characters + assert dataset._get_answer("ABCDBCD") == "ABCDABCD" \ No newline at end of file