This commit is contained in:
Zafir Stojanovski 2025-02-12 17:26:23 +01:00 committed by abdulhakeem
parent aaf1df285e
commit c64a32155a
3 changed files with 12 additions and 12 deletions

View file

@ -25,10 +25,10 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset
from .string_insertion import StringInsertionConfig, StringInsertionDataset
from .word_ladder import WordLadderConfig, WordLadderDataset from .word_ladder import WordLadderConfig, WordLadderDataset
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
from .string_insertion import StringInsertionConfig, StringInsertionDataset
__all__ = [ __all__ = [
"SpellBackwardConfig", "SpellBackwardConfig",

View file

@ -9,7 +9,6 @@ from typing import Optional
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern: QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern:
1. If there is a substring ABCD in the string, insert the character A after the substring. 1. If there is a substring ABCD in the string, insert the character A after the substring.
2. If there is a substring BCDE in the string, insert the character B after the substring. 2. If there is a substring BCDE in the string, insert the character B after the substring.
@ -22,7 +21,7 @@ Once you have inserted a character, you have to skip over the substring and the
Example Example
- Input: DDABCDEEDEAB - Input: DDABCDEEDEAB
- Output: DDABCDAEEDEABD - Output: DDABCDAEEDEABD
- Explanation: - Explanation:
- Theere are two inserted characters: DDABCD[A]EEDEAB[D] (shown in square brackets) - Theere are two inserted characters: DDABCD[A]EEDEAB[D] (shown in square brackets)
- First, we insert A after ABCD. - First, we insert A after ABCD.
- Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character. - Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character.
@ -37,7 +36,7 @@ class StringInsertionConfig:
"""Configuration for String Insertion dataset generation""" """Configuration for String Insertion dataset generation"""
min_string_length: int = 5 # Minimum string length min_string_length: int = 5 # Minimum string length
max_string_length: int = 20 # Maximum string length max_string_length: int = 20 # Maximum string length
size: int = 500 # Virtual dataset size size: int = 500 # Virtual dataset size
seed: Optional[int] = None seed: Optional[int] = None
@ -47,12 +46,13 @@ class StringInsertionConfig:
assert 5 <= self.min_string_length, "Minimum string length should be at least 5" assert 5 <= self.min_string_length, "Minimum string length should be at least 5"
assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum" assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum"
class StringInsertionDataset(ProceduralDataset): class StringInsertionDataset(ProceduralDataset):
"""Generates String Insertion exercises with configurable difficulty""" """Generates String Insertion exercises with configurable difficulty"""
def __init__(self, config: StringInsertionConfig): def __init__(self, config: StringInsertionConfig):
super().__init__(config=config, seed=config.seed, size=config.size) super().__init__(config=config, seed=config.seed, size=config.size)
self.vocabulary = ['A', 'B', 'C', 'D', 'E'] self.vocabulary = ["A", "B", "C", "D", "E"]
self.insertion_rules = [ self.insertion_rules = [
("ABCD", "A"), ("ABCD", "A"),
("BCDE", "B"), ("BCDE", "B"),
@ -68,7 +68,7 @@ class StringInsertionDataset(ProceduralDataset):
while i < len(string): while i < len(string):
inserted = False inserted = False
for pattern, char in self.insertion_rules: for pattern, char in self.insertion_rules:
substring = string[i:i+len(pattern)] substring = string[i : i + len(pattern)]
if substring == pattern: if substring == pattern:
output.append(substring + char) output.append(substring + char)
i += len(pattern) i += len(pattern)
@ -82,7 +82,7 @@ class StringInsertionDataset(ProceduralDataset):
def __getitem__(self, idx: int) -> dict: def __getitem__(self, idx: int) -> dict:
"""Generate a single String Insertion question""" """Generate a single String Insertion question"""
rng = Random(self.seed + idx) rng = Random(self.seed + idx)
string_length = rng.randint(self.config.min_string_length, self.config.max_string_length) string_length = rng.randint(self.config.min_string_length, self.config.max_string_length)
string = [rng.choice(self.vocabulary) for _ in range(string_length)] string = [rng.choice(self.vocabulary) for _ in range(string_length)]

View file

@ -7,13 +7,13 @@ from reasoning_gym.algorithmic.string_insertion import StringInsertionConfig, St
def test_string_insertion_config_validation(): def test_string_insertion_config_validation():
"""Test that invalid configs raise appropriate errors""" """Test that invalid configs raise appropriate errors"""
for field in ["min_string_length", "max_string_length"]: for field in ["min_string_length", "max_string_length"]:
for i in range(-1, 5): for i in range(-1, 5):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid
config.validate() config.validate()
def test_string_insertion_dataset_deterministic(): def test_string_insertion_dataset_deterministic():
"""Test that dataset generates same items with same seed""" """Test that dataset generates same items with same seed"""
@ -67,7 +67,7 @@ def test_string_insertion_answer():
config = StringInsertionConfig(seed=42) config = StringInsertionConfig(seed=42)
dataset = StringInsertionDataset(config) dataset = StringInsertionDataset(config)
# No pattern match # No pattern match
assert dataset._get_answer("AAAAAAA") == "AAAAAAA" assert dataset._get_answer("AAAAAAA") == "AAAAAAA"
assert dataset._get_answer("ADBEEBEA") == "ADBEEBEA" assert dataset._get_answer("ADBEEBEA") == "ADBEEBEA"
assert dataset._get_answer("ADEACA") == "ADEACA" assert dataset._get_answer("ADEACA") == "ADEACA"
@ -91,4 +91,4 @@ def test_string_insertion_answer():
assert dataset._get_answer("AABCDEEEEEEEBCDEAAAAA") == "AABCDAEEEEEEEBCDEBAAAAA" assert dataset._get_answer("AABCDEEEEEEEBCDEAAAAA") == "AABCDAEEEEEEEBCDEBAAAAA"
# No reuse of newly inserted characters # No reuse of newly inserted characters
assert dataset._get_answer("ABCDBCD") == "ABCDABCD" assert dataset._get_answer("ABCDBCD") == "ABCDABCD"