From a79d3d06f262921dabcfe38668c343265c98b21a Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 12 Feb 2025 15:18:51 +0100 Subject: [PATCH 1/3] string insertion --- reasoning_gym/algorithmic/__init__.py | 3 + reasoning_gym/algorithmic/string_insertion.py | 98 +++++++++++++++++++ tests/test_string_insertion.py | 94 ++++++++++++++++++ 3 files changed, 195 insertions(+) create mode 100644 reasoning_gym/algorithmic/string_insertion.py create mode 100644 tests/test_string_insertion.py diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index fe0a2dc2..20f3e90c 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -27,6 +27,7 @@ from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset +from .string_insertion import StringInsertionConfig, StringInsertionDataset __all__ = [ "SpellBackwardConfig", @@ -72,4 +73,6 @@ __all__ = [ "ABDataset", "CountPrimesConfig", "CountPrimesDataset", + "StringInsertionConfig", + "StringInsertionDataset", ] diff --git a/reasoning_gym/algorithmic/string_insertion.py b/reasoning_gym/algorithmic/string_insertion.py new file mode 100644 index 00000000..fb1d35e0 --- /dev/null +++ b/reasoning_gym/algorithmic/string_insertion.py @@ -0,0 +1,98 @@ +"""Insert into string according to a pattern + +https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dataset_string_insertion.py +""" + +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + + +QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern: +1. If there is a substring ABCD in the string, insert the character A after the substring. +2. If there is a substring BCDE in the string, insert the character B after the substring. +3. If there is a substring CDEA in the string, insert the character C after the substring. +4. If there is a substring DEAB in the string, insert the character D after the substring. +5. If there is a substring EABC in the string, insert the character E after the substring. + +Once you have inserted a character, you have to skip over the substring and the inserted character and continue the search from the next character. + +Example +- Input: DDABCDEEDEAB +- Output: DDABCDAEEDEABD +- Explanation: + - Theere are two inserted characters: DDABCD[A]EEDEAB[D] (shown in square brackets) + - First, we insert A after ABCD. + - Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character. + - Lastly, we insert D after DEAB. + +Given the following string, provide the answer after inserting the characters according to the pattern: {string} +""" + + +@dataclass +class StringInsertionConfig: + """Configuration for String Insertion dataset generation""" + + min_string_length: int = 5 # Minimum string length + max_string_length: int = 20 # Maximum string length + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 5 <= self.min_string_length, "Minimum string length should be at least 5" + assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum" + +class StringInsertionDataset(ProceduralDataset): + """Generates String Insertion exercises with configurable difficulty""" + + def __init__(self, config: StringInsertionConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.vocabulary = ['A', 'B', 'C', 'D', 'E'] + self.insertion_rules = [ + ("ABCD", "A"), + ("BCDE", "B"), + ("CDEA", "C"), + ("DEAB", "D"), + ("EABC", "E"), + ] + + def _get_answer(self, string: str) -> str: + """Apply insertion rules to a string""" + output = [] + i = 0 + while i < len(string): + inserted = False + for pattern, char in self.insertion_rules: + substring = string[i:i+len(pattern)] + if substring == pattern: + output.append(substring + char) + i += len(pattern) + inserted = True + break + if not inserted: + output.append(string[i]) + i += 1 + return "".join(output) + + def __getitem__(self, idx: int) -> dict: + """Generate a single String Insertion question""" + rng = Random(self.seed + idx) + + string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) + string = [rng.choice(self.vocabulary) for _ in range(string_length)] + + answer = self._get_answer(string) + + return { + "question": QUESTION_TEMPLATE.format(string=string), + "answer": str(answer), + "metadata": {"string": string, "solution": answer}, + } + + +register_dataset("string_insertion", StringInsertionDataset, StringInsertionConfig) diff --git a/tests/test_string_insertion.py b/tests/test_string_insertion.py new file mode 100644 index 00000000..746ff5c5 --- /dev/null +++ b/tests/test_string_insertion.py @@ -0,0 +1,94 @@ +"""Tests for String Insertion questions generation""" + +import pytest + +from reasoning_gym.algorithmic.string_insertion import StringInsertionConfig, StringInsertionDataset + + +def test_string_insertion_config_validation(): + """Test that invalid configs raise appropriate errors""" + + for field in ["min_string_length", "max_string_length"]: + for i in range(-1, 5): + with pytest.raises(AssertionError): + config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid + config.validate() + + +def test_string_insertion_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = StringInsertionConfig(seed=42, size=10) + dataset1 = StringInsertionDataset(config) + dataset2 = StringInsertionDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_string_insertion_dataset_items(): + """Test basic properties of generated items""" + config = StringInsertionConfig(min_string_length=5, max_string_length=30, size=10, seed=42) + dataset = StringInsertionDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "string" in item["metadata"] + assert "solution" in item["metadata"] + + string = item["metadata"]["string"] + solution = item["metadata"]["solution"] + + # Verify string dimensions + assert 5 <= len(string) <= 30 + assert len(string) <= len(solution) + + +def test_string_insertion_dataset_iteration(): + """Test that iteration respects dataset size""" + config = StringInsertionConfig(size=5, seed=42) + dataset = StringInsertionDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_string_insertion_answer(): + """Test the _get_rotated method""" + config = StringInsertionConfig(seed=42) + dataset = StringInsertionDataset(config) + + # No pattern match + assert dataset._get_answer("AAAAAAA") == "AAAAAAA" + assert dataset._get_answer("ADBEEBEA") == "ADBEEBEA" + assert dataset._get_answer("ADEACA") == "ADEACA" + + # Insert A after ABCD + assert dataset._get_answer("ABCDE") == "ABCDAE" + + # Insert B after BCDE + assert dataset._get_answer("AEBCDEC") == "AEBCDEBC" + + # Insert C after CDEA + assert dataset._get_answer("BBACDEAC") == "BBACDEACC" + + # Insert D after DEAB + assert dataset._get_answer("BAAABDEAB") == "BAAABDEABD" + + # Insert E after EABC + assert dataset._get_answer("EABCBCBC") == "EABCEBCBC" + + # Multiple insertions + assert dataset._get_answer("AABCDEEEEEEEBCDEAAAAA") == "AABCDAEEEEEEEBCDEBAAAAA" + + # No reuse of newly inserted characters + assert dataset._get_answer("ABCDBCD") == "ABCDABCD" \ No newline at end of file From 50f5b508459fd56ebdeae7999cd3b9d4ad6c5080 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 12 Feb 2025 17:26:23 +0100 Subject: [PATCH 2/3] lint --- reasoning_gym/algorithmic/__init__.py | 2 +- reasoning_gym/algorithmic/string_insertion.py | 12 ++++++------ tests/test_string_insertion.py | 10 +++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 20f3e90c..2400f49c 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -24,10 +24,10 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset +from .string_insertion import StringInsertionConfig, StringInsertionDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset -from .string_insertion import StringInsertionConfig, StringInsertionDataset __all__ = [ "SpellBackwardConfig", diff --git a/reasoning_gym/algorithmic/string_insertion.py b/reasoning_gym/algorithmic/string_insertion.py index fb1d35e0..b217ed76 100644 --- a/reasoning_gym/algorithmic/string_insertion.py +++ b/reasoning_gym/algorithmic/string_insertion.py @@ -9,7 +9,6 @@ from typing import Optional from ..factory import ProceduralDataset, register_dataset - QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern: 1. If there is a substring ABCD in the string, insert the character A after the substring. 2. If there is a substring BCDE in the string, insert the character B after the substring. @@ -22,7 +21,7 @@ Once you have inserted a character, you have to skip over the substring and the Example - Input: DDABCDEEDEAB - Output: DDABCDAEEDEABD -- Explanation: +- Explanation: - Theere are two inserted characters: DDABCD[A]EEDEAB[D] (shown in square brackets) - First, we insert A after ABCD. - Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character. @@ -37,7 +36,7 @@ class StringInsertionConfig: """Configuration for String Insertion dataset generation""" min_string_length: int = 5 # Minimum string length - max_string_length: int = 20 # Maximum string length + max_string_length: int = 20 # Maximum string length size: int = 500 # Virtual dataset size seed: Optional[int] = None @@ -47,12 +46,13 @@ class StringInsertionConfig: assert 5 <= self.min_string_length, "Minimum string length should be at least 5" assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum" + class StringInsertionDataset(ProceduralDataset): """Generates String Insertion exercises with configurable difficulty""" def __init__(self, config: StringInsertionConfig): super().__init__(config=config, seed=config.seed, size=config.size) - self.vocabulary = ['A', 'B', 'C', 'D', 'E'] + self.vocabulary = ["A", "B", "C", "D", "E"] self.insertion_rules = [ ("ABCD", "A"), ("BCDE", "B"), @@ -68,7 +68,7 @@ class StringInsertionDataset(ProceduralDataset): while i < len(string): inserted = False for pattern, char in self.insertion_rules: - substring = string[i:i+len(pattern)] + substring = string[i : i + len(pattern)] if substring == pattern: output.append(substring + char) i += len(pattern) @@ -82,7 +82,7 @@ class StringInsertionDataset(ProceduralDataset): def __getitem__(self, idx: int) -> dict: """Generate a single String Insertion question""" rng = Random(self.seed + idx) - + string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) string = [rng.choice(self.vocabulary) for _ in range(string_length)] diff --git a/tests/test_string_insertion.py b/tests/test_string_insertion.py index 746ff5c5..12225954 100644 --- a/tests/test_string_insertion.py +++ b/tests/test_string_insertion.py @@ -7,13 +7,13 @@ from reasoning_gym.algorithmic.string_insertion import StringInsertionConfig, St def test_string_insertion_config_validation(): """Test that invalid configs raise appropriate errors""" - + for field in ["min_string_length", "max_string_length"]: for i in range(-1, 5): with pytest.raises(AssertionError): - config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid + config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid config.validate() - + def test_string_insertion_dataset_deterministic(): """Test that dataset generates same items with same seed""" @@ -67,7 +67,7 @@ def test_string_insertion_answer(): config = StringInsertionConfig(seed=42) dataset = StringInsertionDataset(config) - # No pattern match + # No pattern match assert dataset._get_answer("AAAAAAA") == "AAAAAAA" assert dataset._get_answer("ADBEEBEA") == "ADBEEBEA" assert dataset._get_answer("ADEACA") == "ADEACA" @@ -91,4 +91,4 @@ def test_string_insertion_answer(): assert dataset._get_answer("AABCDEEEEEEEBCDEAAAAA") == "AABCDAEEEEEEEBCDEBAAAAA" # No reuse of newly inserted characters - assert dataset._get_answer("ABCDBCD") == "ABCDABCD" \ No newline at end of file + assert dataset._get_answer("ABCDBCD") == "ABCDABCD" From 7a12f45d53ceebbce2bbedf25f495834bb933a97 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 12 Feb 2025 22:28:23 +0100 Subject: [PATCH 3/3] string manipulation --- reasoning_gym/algorithmic/__init__.py | 3 + .../algorithmic/string_manipulation.py | 199 ++++++++++++++ tests/test_string_manipulation.py | 257 ++++++++++++++++++ 3 files changed, 459 insertions(+) create mode 100644 reasoning_gym/algorithmic/string_manipulation.py create mode 100644 tests/test_string_manipulation.py diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 875ab539..3dbbe0d2 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -25,6 +25,7 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset +from .string_manipulation import StringManipulationConfig, StringManipulationDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset @@ -75,4 +76,6 @@ __all__ = [ "ABDataset", "CountPrimesConfig", "CountPrimesDataset", + "StringManipulationConfig", + "StringManipulationDataset", ] diff --git a/reasoning_gym/algorithmic/string_manipulation.py b/reasoning_gym/algorithmic/string_manipulation.py new file mode 100644 index 00000000..b382921f --- /dev/null +++ b/reasoning_gym/algorithmic/string_manipulation.py @@ -0,0 +1,199 @@ +"""Manipulate a string according to a set of rules + +https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dataset_string_deletion_and_modification.py +""" + +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Your job is to repeatedly transform a string according to a set of rules until no further transformations can be performed, or a state is repeated. + +Evaluate the following rules in order, and apply the first applicable rule to the string: +{rules} + +Once you have applied a rule, repeat the process with the new string until no further transformations can be performed (i.e. the string doesn't change), or a state is repeated. +If a state is repeated, the process is terminated, and the repeated state is discarded (i.e. is not considered as the final answer) and the state before the repeated state is considered as the final answer. + +Example: +- Input: + - String: abbac + - Rules: + 1. If the string prefix is 'ab', replace it with 'ca'. + 2. If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end. + 3. If the string ends with 'aa', replace it with 'cc'. +- Output: bbbacc +- Explanation: + - In the first iteration, rule 1 is applied to the string abbac, resulting in cabac + - In the second interation, rule 1 doesn't apply, but rule 2 is applied to the string cabac, resulting in bbbacc + - In the third iteration, none of the rules (1, 2, 3) apply, so the process is terminated, and the final answer is bbbacc + +Transform the following string according to the above list of rules: +{string} +""" + + +@dataclass +class StringManipulationConfig: + """Configuration for String Insertion dataset generation""" + + min_string_length: int = 5 # Minimum string length + max_string_length: int = 20 # Maximum string length + min_num_rules: int = 3 # Minimum number of rules/transforms + max_num_rules: int = 8 # Maximum number of rules/transforms + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 5 <= self.min_string_length, "Minimum string length should be at least 5" + assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum" + assert 3 <= self.min_num_rules, "Minimum number of rules should be at least 3" + assert self.min_num_rules <= self.max_num_rules, "Minimum number of rules should be less than maximum" + + +class StringManipulationDataset(ProceduralDataset): + """Generates String Insertion exercises with configurable difficulty""" + + def __init__(self, config: StringManipulationConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.vocabulary = ["a", "b", "c"] + self.rules = [ + ( + "If the string prefix is 'ab', replace it with 'ca'.", + lambda s: ("ca" + s[2:], 1) if s.startswith("ab") else (s, 0), + ), + ( + "If the string suffix is 'ac', replace it with 'cb'.", + lambda s: (s[:-2] + "cb", 2) if s.endswith("ac") else (s, 0), + ), + ( + "If the string prefix is 'bc', delete the first two characters and append 'aa' to the end.", + lambda s: (s[2:] + "aa", 3) if s.startswith("bc") else (s, 0), + ), + ( + "If the string suffix is 'bb', delete the last two characters.", + lambda s: (s[:-2], 4) if s.endswith("bb") else (s, 0), + ), + ( + "If the string prefix is 'cb', replace it with 'aa' and delete the last character.", + lambda s: ("aa" + s[2:-1], 5) if s.startswith("cb") and len(s) > 1 else (s, 0), + ), + ( + "If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end.", + lambda s: ("bb" + s[2:] + "c", 6) if s.startswith("ca") else (s, 0), + ), + ( + "If the string suffix is 'cc', replace it with 'b' and prepend 'a' to the start.", + lambda s: ("a" + s[:-2] + "b", 7) if s.endswith("cc") else (s, 0), + ), + ( + "If the string prefix is 'aa', remove the first character.", + lambda s: (s[1:], 8) if s.startswith("aa") else (s, 0), + ), + ( + "If the string contains 'abc', replace the first occurrence with 'cab'.", + lambda s: (s.replace("abc", "cab", 1), 9) if "abc" in s else (s, 0), + ), + ( + "If the string contains 'bca', delete the first occurrence entirely.", + lambda s: (s.replace("bca", "", 1), 10) if "bca" in s else (s, 0), + ), + ( + "If the string ends with 'ba', replace it with 'ab'.", + lambda s: (s[:-2] + "ab", 11) if s.endswith("ba") else (s, 0), + ), + ( + "If the string starts with 'cc', remove the first two characters.", + lambda s: (s[2:], 12) if s.startswith("cc") else (s, 0), + ), + ( + "If the string contains 'acb', replace the first occurrence with its reverse ('bca').", + lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0), + ), + ( + "If the string ends with 'ca', remove the last character.", + lambda s: (s[:-1], 14) if s.endswith("ca") and len(s) > 0 else (s, 0), + ), + ( + "If the string starts with 'bb', remove the second character.", + lambda s: (s[0] + s[2:], 15) if s.startswith("bb") and len(s) >= 2 else (s, 0), + ), + ( + "If the string ends with 'aa', replace it with 'cc'.", + lambda s: (s[:-2] + "cc", 16) if s.endswith("aa") else (s, 0), + ), + ( + "If the string contains 'ca' (not at the start), remove the first occurrence found after the first character.", + lambda s: (s[:idx] + s[idx + 2 :], 17) if (idx := s.find("ca", 1)) != -1 else (s, 0), + ), + ( + "If the string contains an even number of 'b's (and at least one 'b'), append 'ab' at the end.", + lambda s: (s + "ab", 18) if (s.count("b") > 0 and s.count("b") % 2 == 0) else (s, 0), + ), + ( + "If the string length is greater than 15, remove the middle character.", + lambda s: (s[: len(s) // 2] + s[len(s) // 2 + 1 :], 19) if len(s) > 15 else (s, 0), + ), + ( + "If the string starts with 'ac', replace the first two characters with 'zz'.", + lambda s: ("zz" + s[2:], 20) if s.startswith("ac") else (s, 0), + ), + ] + + def _apply_rule(self, string: str, selected_rules: list[tuple[str, callable]]) -> tuple[str, int]: + """ + Apply the first applicable rule from the list of selected rules. + Returns a tuple containing the modified string and the rule index (1-based) that was applied. + If no rule is applicable, returns (s, 0). + """ + for _, rule_fn in selected_rules: + new_string, op_idx = rule_fn(string) + if op_idx != 0: + return new_string, op_idx + return string, 0 + + def _get_all_transforms(self, string: str, selected_rules: list[tuple[str, callable]]) -> list[str]: + """ + Repeatedly apply transformation rules to a string until no further transformations can be performed, + or a state is repeated. If a state is repeated, the process is terminated, and the state is not added to the list. + Returns a list of string states from the initial string to the final state (i.e. the desired answer). + """ + states = [string] + while True: + new_string, op_idx = self._apply_rule(states[-1], selected_rules) + if op_idx == 0 or new_string in states: + break + states.append(new_string) + return states + + def __getitem__(self, idx: int) -> dict: + """Generate a single String Insertion question""" + rng = Random(self.seed + idx) + + string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) + string = "".join(rng.choice(self.vocabulary) for _ in range(string_length)) + + num_rules = rng.randint(self.config.min_num_rules, self.config.max_num_rules) + selected_rules = rng.sample(self.rules, num_rules) + rules_str = "\n".join(f"{i+1}. {rule}" for i, (rule, _) in enumerate(selected_rules)) + + states = self._get_all_transforms(string, selected_rules) + answer = states[-1] + + return { + "question": QUESTION_TEMPLATE.format(string=string, rules=rules_str), + "answer": str(answer), + "metadata": { + "string": string, + "solution": answer, + "states": states, + "selected_rules": [rule for rule, _ in selected_rules], + }, + } + + +register_dataset("string_manipulation", StringManipulationDataset, StringManipulationConfig) diff --git a/tests/test_string_manipulation.py b/tests/test_string_manipulation.py new file mode 100644 index 00000000..f62a7acd --- /dev/null +++ b/tests/test_string_manipulation.py @@ -0,0 +1,257 @@ +"""Tests for String Manipulation questions generation""" + +import pytest + +from reasoning_gym.algorithmic.string_manipulation import StringManipulationConfig, StringManipulationDataset + + +def test_string_manipulation_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = StringManipulationConfig(min_string_length=4) # Minimum string length should be at least 5 + config.validate() + + with pytest.raises(AssertionError): + config = StringManipulationConfig(min_string_length=10, max_string_length=7) # Max must be greater than min + config.validate() + + with pytest.raises(AssertionError): + config = StringManipulationConfig(min_num_rules=2) # Min number of rules should be at least 3 + config.validate() + + with pytest.raises(AssertionError): + config = StringManipulationConfig(min_num_rules=5, max_num_rules=3) # Max must be greater than min + config.validate() + + +def test_string_manipulation_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = StringManipulationConfig(seed=42, size=10) + dataset1 = StringManipulationDataset(config) + dataset2 = StringManipulationDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_string_manipulation_dataset_items(): + """Test basic properties of generated items""" + config = StringManipulationConfig( + min_string_length=7, max_string_length=25, min_num_rules=5, max_num_rules=12, size=10, seed=42 + ) + dataset = StringManipulationDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "string" in item["metadata"] + assert "states" in item["metadata"] + # assert "selected_rules" in item["metadata"] + assert "solution" in item["metadata"] + + string = item["metadata"]["string"] + solution = item["metadata"]["solution"] + states = item["metadata"]["states"] + selected_rules = item["metadata"]["selected_rules"] + + # Verify dimensions + assert config.min_string_length <= len(string) <= config.max_string_length + assert config.min_num_rules <= len(selected_rules) <= config.max_num_rules + assert len(states) >= 1 + assert solution == states[-1] + + +def test_string_manipulation_dataset_iteration(): + """Test that iteration respects dataset size""" + config = StringManipulationConfig(size=5, seed=42) + dataset = StringManipulationDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_string_manipulation_answer(): + """Test the method for getting the answer""" + config = StringManipulationConfig(seed=42) + dataset = StringManipulationDataset(config) + + rules = [ + ( + "If the string prefix is 'ab', replace it with 'ca'.", + lambda s: ("ca" + s[2:], 1) if s.startswith("ab") else (s, 0), + ) + ] + assert dataset._get_all_transforms("abbbab", rules)[-1] == "cabbab" + + rules = [ + ( + "If the string suffix is 'ac', replace it with 'cb'.", + lambda s: (s[:-2] + "cb", 2) if s.endswith("ac") else (s, 0), + ), + ] + assert dataset._get_all_transforms("abbbac", rules)[-1] == "abbbcb" + + rules = [ + ( + "If the string prefix is 'bc', delete the first two characters and append 'aa' to the end.", + lambda s: (s[2:] + "aa", 3) if s.startswith("bc") else (s, 0), + ), + ] + assert dataset._get_all_transforms("bcabbb", rules)[-1] == "abbbaa" + + rules = [ + ( + "If the string suffix is 'bb', delete the last two characters.", + lambda s: (s[:-2], 4) if s.endswith("bb") else (s, 0), + ), + ] + assert dataset._get_all_transforms("abbbabb", rules)[-1] == "abbba" + + rules = [ + ( + "If the string prefix is 'cb', replace it with 'aa' and delete the last character.", + lambda s: ("aa" + s[2:-1], 5) if s.startswith("cb") and len(s) > 1 else (s, 0), + ) + ] + assert dataset._get_all_transforms("cbabbb", rules)[-1] == "aaabb" + + rules = [ + ( + "If the string prefix is 'ca', replace it with 'bb' and append 'c' to the end.", + lambda s: ("bb" + s[2:] + "c", 6) if s.startswith("ca") else (s, 0), + ) + ] + assert dataset._get_all_transforms("caabbb", rules)[-1] == "bbabbbc" + + rules = [ + ( + "If the string suffix is 'cc', replace it with 'b' and prepend 'a' to the start.", + lambda s: ("a" + s[:-2] + "b", 7) if s.endswith("cc") else (s, 0), + ) + ] + assert dataset._get_all_transforms("abbbcc", rules)[-1] == "aabbbb" + + rules = [ + ( + "If the string prefix is 'aa', remove the first character.", + lambda s: (s[1:], 8) if s.startswith("aa") else (s, 0), + ) + ] + assert dataset._get_all_transforms("aabbb", rules)[-1] == "abbb" + + rules = [ + ( + "If the string contains 'abc', replace the first occurrence with 'cab'.", + lambda s: (s.replace("abc", "cab", 1), 9) if "abc" in s else (s, 0), + ) + ] + assert dataset._get_all_transforms("ababcb", rules)[-1] == "cababb" # 'ababcb' -> 'abcabb' -> 'cababb' + + rules = [ + ( + "If the string contains 'bca', delete the first occurrence entirely.", + lambda s: (s.replace("bca", "", 1), 10) if "bca" in s else (s, 0), + ) + ] + assert dataset._get_all_transforms("abbcab", rules)[-1] == "abb" + + rules = [ + ( + "If the string ends with 'ba', replace it with 'ab'.", + lambda s: (s[:-2] + "ab", 11) if s.endswith("ba") else (s, 0), + ) + ] + assert dataset._get_all_transforms("abbbba", rules)[-1] == "abbbab" + + rules = [ + ( + "If the string starts with 'cc', remove the first two characters.", + lambda s: (s[2:], 12) if s.startswith("cc") else (s, 0), + ) + ] + assert dataset._get_all_transforms("ccabbb", rules)[-1] == "abbb" + + rules = [ + ( + "If the string contains 'acb', replace the first occurrence with its reverse ('bca').", + lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0), + ) + ] + assert dataset._get_all_transforms("abacbb", rules)[-1] == "abbcab" + + rules = [ + ( + "If the string contains 'acb', replace the first occurrence with its reverse ('bca').", + lambda s: (s.replace("acb", "bca", 1), 13) if "acb" in s else (s, 0), + ) + ] + assert dataset._get_all_transforms("abacbb", rules)[-1] == "abbcab" + + rules = [ + ( + "If the string ends with 'ca', remove the last character.", + lambda s: (s[:-1], 14) if s.endswith("ca") and len(s) > 0 else (s, 0), + ) + ] + assert dataset._get_all_transforms("abbbca", rules)[-1] == "abbbc" + + rules = [ + ( + "If the string starts with 'bb', remove the second character.", + lambda s: (s[0] + s[2:], 15) if s.startswith("bb") and len(s) >= 2 else (s, 0), + ) + ] + assert dataset._get_all_transforms("bbabcbb", rules)[-1] == "babcbb" + + rules = [ + ( + "If the string ends with 'aa', replace it with 'cc'.", + lambda s: (s[:-2] + "cc", 16) if s.endswith("aa") else (s, 0), + ) + ] + assert dataset._get_all_transforms("abccbaa", rules)[-1] == "abccbcc" + + rules = [ + ( + "If the string contains 'ca' (not at the start), remove the first occurrence found after the first character.", + lambda s: (s[:idx] + s[idx + 2 :], 17) if (idx := s.find("ca", 1)) != -1 else (s, 0), + ) + ] + assert dataset._get_all_transforms("abacab", rules)[-1] == "abab" + assert dataset._get_all_transforms("caabab", rules)[-1] == "caabab" + + rules = [ + ( + "If the string contains an even number of 'b's (and at least one 'b'), append 'ab' at the end.", + lambda s: (s + "ab", 18) if (s.count("b") > 0 and s.count("b") % 2 == 0) else (s, 0), + ) + ] + assert dataset._get_all_transforms("abab", rules)[-1] == "ababab" + assert dataset._get_all_transforms("abbab", rules)[-1] == "abbab" + + rules = [ + ( + "If the string length is greater than 15, remove the middle character.", + lambda s: (s[: len(s) // 2] + s[len(s) // 2 + 1 :], 19) if len(s) > 15 else (s, 0), + ) + ] + assert ( + dataset._get_all_transforms("bccbcbbbcbbbbcccc", rules)[-1] == "bccbcbbbbbbcccc" + ) # bccbcbbbcbbbbcccc -> "bccbcbbbbbbbcccc" -> "bccbcbbbbbbcccc" + + rules = [ + ( + "If the string starts with 'ac', replace the first two characters with 'zz'.", + lambda s: ("zz" + s[2:], 20) if s.startswith("ac") else (s, 0), + ) + ] + assert dataset._get_all_transforms("acab", rules)[-1] == "zzab"