mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-01 17:45:24 +00:00
lint
This commit is contained in:
parent
aaf1df285e
commit
c64a32155a
3 changed files with 12 additions and 12 deletions
|
|
@ -25,10 +25,10 @@ from .rotate_matrix import RotateMatrixConfig, RotateMatrixDataset
|
||||||
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
|
||||||
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
|
||||||
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset
|
from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset
|
||||||
|
from .string_insertion import StringInsertionConfig, StringInsertionDataset
|
||||||
from .word_ladder import WordLadderConfig, WordLadderDataset
|
from .word_ladder import WordLadderConfig, WordLadderDataset
|
||||||
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
|
from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset
|
||||||
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
|
from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset
|
||||||
from .string_insertion import StringInsertionConfig, StringInsertionDataset
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SpellBackwardConfig",
|
"SpellBackwardConfig",
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ from typing import Optional
|
||||||
|
|
||||||
from ..factory import ProceduralDataset, register_dataset
|
from ..factory import ProceduralDataset, register_dataset
|
||||||
|
|
||||||
|
|
||||||
QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern:
|
QUESTION_TEMPLATE = """Given a string consisting of characters A, B, C, D, and E, your job is to insert a character according to the following pattern:
|
||||||
1. If there is a substring ABCD in the string, insert the character A after the substring.
|
1. If there is a substring ABCD in the string, insert the character A after the substring.
|
||||||
2. If there is a substring BCDE in the string, insert the character B after the substring.
|
2. If there is a substring BCDE in the string, insert the character B after the substring.
|
||||||
|
|
@ -22,7 +21,7 @@ Once you have inserted a character, you have to skip over the substring and the
|
||||||
Example
|
Example
|
||||||
- Input: DDABCDEEDEAB
|
- Input: DDABCDEEDEAB
|
||||||
- Output: DDABCDAEEDEABD
|
- Output: DDABCDAEEDEABD
|
||||||
- Explanation:
|
- Explanation:
|
||||||
- Theere are two inserted characters: DDABCD[A]EEDEAB[D] (shown in square brackets)
|
- Theere are two inserted characters: DDABCD[A]EEDEAB[D] (shown in square brackets)
|
||||||
- First, we insert A after ABCD.
|
- First, we insert A after ABCD.
|
||||||
- Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character.
|
- Even though with the newly inserted 'A' we can obtain the substring BCD[A], we can't use it to insert another character.
|
||||||
|
|
@ -37,7 +36,7 @@ class StringInsertionConfig:
|
||||||
"""Configuration for String Insertion dataset generation"""
|
"""Configuration for String Insertion dataset generation"""
|
||||||
|
|
||||||
min_string_length: int = 5 # Minimum string length
|
min_string_length: int = 5 # Minimum string length
|
||||||
max_string_length: int = 20 # Maximum string length
|
max_string_length: int = 20 # Maximum string length
|
||||||
|
|
||||||
size: int = 500 # Virtual dataset size
|
size: int = 500 # Virtual dataset size
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
|
|
@ -47,12 +46,13 @@ class StringInsertionConfig:
|
||||||
assert 5 <= self.min_string_length, "Minimum string length should be at least 5"
|
assert 5 <= self.min_string_length, "Minimum string length should be at least 5"
|
||||||
assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum"
|
assert self.min_string_length <= self.max_string_length, "Minimum string length should be less than maximum"
|
||||||
|
|
||||||
|
|
||||||
class StringInsertionDataset(ProceduralDataset):
|
class StringInsertionDataset(ProceduralDataset):
|
||||||
"""Generates String Insertion exercises with configurable difficulty"""
|
"""Generates String Insertion exercises with configurable difficulty"""
|
||||||
|
|
||||||
def __init__(self, config: StringInsertionConfig):
|
def __init__(self, config: StringInsertionConfig):
|
||||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||||
self.vocabulary = ['A', 'B', 'C', 'D', 'E']
|
self.vocabulary = ["A", "B", "C", "D", "E"]
|
||||||
self.insertion_rules = [
|
self.insertion_rules = [
|
||||||
("ABCD", "A"),
|
("ABCD", "A"),
|
||||||
("BCDE", "B"),
|
("BCDE", "B"),
|
||||||
|
|
@ -68,7 +68,7 @@ class StringInsertionDataset(ProceduralDataset):
|
||||||
while i < len(string):
|
while i < len(string):
|
||||||
inserted = False
|
inserted = False
|
||||||
for pattern, char in self.insertion_rules:
|
for pattern, char in self.insertion_rules:
|
||||||
substring = string[i:i+len(pattern)]
|
substring = string[i : i + len(pattern)]
|
||||||
if substring == pattern:
|
if substring == pattern:
|
||||||
output.append(substring + char)
|
output.append(substring + char)
|
||||||
i += len(pattern)
|
i += len(pattern)
|
||||||
|
|
@ -82,7 +82,7 @@ class StringInsertionDataset(ProceduralDataset):
|
||||||
def __getitem__(self, idx: int) -> dict:
|
def __getitem__(self, idx: int) -> dict:
|
||||||
"""Generate a single String Insertion question"""
|
"""Generate a single String Insertion question"""
|
||||||
rng = Random(self.seed + idx)
|
rng = Random(self.seed + idx)
|
||||||
|
|
||||||
string_length = rng.randint(self.config.min_string_length, self.config.max_string_length)
|
string_length = rng.randint(self.config.min_string_length, self.config.max_string_length)
|
||||||
string = [rng.choice(self.vocabulary) for _ in range(string_length)]
|
string = [rng.choice(self.vocabulary) for _ in range(string_length)]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,13 @@ from reasoning_gym.algorithmic.string_insertion import StringInsertionConfig, St
|
||||||
|
|
||||||
def test_string_insertion_config_validation():
|
def test_string_insertion_config_validation():
|
||||||
"""Test that invalid configs raise appropriate errors"""
|
"""Test that invalid configs raise appropriate errors"""
|
||||||
|
|
||||||
for field in ["min_string_length", "max_string_length"]:
|
for field in ["min_string_length", "max_string_length"]:
|
||||||
for i in range(-1, 5):
|
for i in range(-1, 5):
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid
|
config = StringInsertionConfig(**{field: i}) # [-1, 4] is invalid
|
||||||
config.validate()
|
config.validate()
|
||||||
|
|
||||||
|
|
||||||
def test_string_insertion_dataset_deterministic():
|
def test_string_insertion_dataset_deterministic():
|
||||||
"""Test that dataset generates same items with same seed"""
|
"""Test that dataset generates same items with same seed"""
|
||||||
|
|
@ -67,7 +67,7 @@ def test_string_insertion_answer():
|
||||||
config = StringInsertionConfig(seed=42)
|
config = StringInsertionConfig(seed=42)
|
||||||
dataset = StringInsertionDataset(config)
|
dataset = StringInsertionDataset(config)
|
||||||
|
|
||||||
# No pattern match
|
# No pattern match
|
||||||
assert dataset._get_answer("AAAAAAA") == "AAAAAAA"
|
assert dataset._get_answer("AAAAAAA") == "AAAAAAA"
|
||||||
assert dataset._get_answer("ADBEEBEA") == "ADBEEBEA"
|
assert dataset._get_answer("ADBEEBEA") == "ADBEEBEA"
|
||||||
assert dataset._get_answer("ADEACA") == "ADEACA"
|
assert dataset._get_answer("ADEACA") == "ADEACA"
|
||||||
|
|
@ -91,4 +91,4 @@ def test_string_insertion_answer():
|
||||||
assert dataset._get_answer("AABCDEEEEEEEBCDEAAAAA") == "AABCDAEEEEEEEBCDEBAAAAA"
|
assert dataset._get_answer("AABCDEEEEEEEBCDEAAAAA") == "AABCDAEEEEEEEBCDEBAAAAA"
|
||||||
|
|
||||||
# No reuse of newly inserted characters
|
# No reuse of newly inserted characters
|
||||||
assert dataset._get_answer("ABCDBCD") == "ABCDABCD"
|
assert dataset._get_answer("ABCDBCD") == "ABCDABCD"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue