diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 082a65b1..0a7e9742 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -27,6 +27,7 @@ from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .spiral_matrix import SpiralMatrixConfig, SpiralMatrixDataset from .string_insertion import StringInsertionConfig, StringInsertionDataset from .string_manipulation import StringManipulationConfig, StringManipulationDataset +from .string_synthesis import StringSynthesisConfig, StringSynthesisDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset @@ -81,4 +82,6 @@ __all__ = [ "StringInsertionDataset", "StringManipulationConfig", "StringManipulationDataset", + "StringSynthesisConfig", + "StringSynthesisDataset", ] diff --git a/reasoning_gym/algorithmic/string_synthesis.py b/reasoning_gym/algorithmic/string_synthesis.py new file mode 100644 index 00000000..c78ed35b --- /dev/null +++ b/reasoning_gym/algorithmic/string_synthesis.py @@ -0,0 +1,139 @@ +"""Iteratively synthesizes a string by inserting characters according to a pattern. + +https://github.com/yongchao98/CodeSteer-v1.0/blob/main/create_dataset/create_dataset_string_synthesis.py +""" + +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """There are nine different blocks [A] [B] [C] {{A}} {{B}} {{C}} (A) (B) (C) +1. One [A], one [B], and one [C] can be combined to form one {{A}}. +2. One [A] and one [B] can be combined to form one {{C}}. +3. One [B] and one [C] can be combined to form one {{B}}. +4. Two [C] can be combined to form one {{C}}. +5. One {{A}} and one {{C}} can be combined to form one (A) and one (B). +6. Two {{B}} can be combined to form one (C). + +Given a certain number of initial blocks, your job is to cycle through the rules 1-6 above, synthesizing new blocks until no more rules can be applied, or until a state (counts of each block type) is repeated. +In the case a state is repeated the answer is the state before the repetition! + +The output should be the count of each block type after the rules have been applied in the order they are listed above. +For example 1 0 3 0 2 0 0 0 1 means that you have 1 [A] 0 [B] 3 [C] 0 {{A}} 2 {{B}} 0 {{C}} 0 (A) 0 (B) 1 (C). + +Example: +- Input: You have 2 [A], 3 [B], and 3 [C]. +- Output: 0 0 0 2 1 0 0 0 0 +- Explanation: + 0. Initial state: 2 3 3 0 0 0 0 0 0 + 1. We can apply Rule 1 and obtain 1 {{A}}. New state: 1 2 2 1 0 0 0 0 0 + 2. We can apply Rule 1 again and obtain 1 {{A}}. New state 0 1 1 2 0 0 0 0 0 + 3. We can apply Rule 3 and obtain 1 {{B}}. New state 0 0 0 2 1 0 0 0 0 + 4. No more rules can be applied. The answer is 0 0 0 2 1 0 0 0 0 + +Now, you have {A_square} [A], {B_square} [B], and {C_square} [C] blocks. Provide the count of each block type after applying the above rules. +""" + + +@dataclass +class StringSynthesisConfig: + """Configuration for String Synthesis dataset generation""" + + min_initial_blocks: int = 0 # Minimum number of initial blocks + max_initial_blocks: int = 5 # Maximum number of initial blocks + max_iterations: int = 1_000 # Maximum number of iterations to apply the rules (Safety check for infinite loops) + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 0 <= self.min_initial_blocks, "min_initial_blocks must be non-negative" + assert ( + self.min_initial_blocks <= self.max_initial_blocks + ), "min_initial_blocks must be less than or equal to max_initial_blocks" + assert 0 < self.max_iterations, "max_iterations must be positive" + + +class StringSynthesisDataset(ProceduralDataset): + """Generates String Synthesis exercises with configurable difficulty""" + + def __init__(self, config: StringSynthesisConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def _apply_rule(self, counts: list[int]) -> list[int]: + """ + Apply the first applicable rule to the given counts. + In case no rule is applicable, the counts are returned unchanged. + """ + # label the indices for the counts + A_square, B_square, C_square, A_curly, B_curly, C_curly, A_round, B_round, C_round = range(9) + # Rule 1: One [A], one [B], and one [C] can be combined to form one {A} + if counts[A_square] >= 1 and counts[B_square] >= 1 and counts[C_square] >= 1: + counts[A_square] -= 1 + counts[B_square] -= 1 + counts[C_square] -= 1 + counts[A_curly] += 1 + # Rule 2: One [A] and one [B] can be combined to form one {C} + elif counts[A_square] >= 1 and counts[B_square] >= 1: + counts[A_square] -= 1 + counts[B_square] -= 1 + counts[C_curly] += 1 + # Rule 3: One [B] and one [C] can be combined to form one {B} + elif counts[B_square] >= 1 and counts[C_square] >= 1: + counts[B_square] -= 1 + counts[C_square] -= 1 + counts[B_curly] += 1 + # Rule 4: Two [C] can be combined to form one {C} + elif counts[C_square] >= 2: + counts[C_square] -= 2 + counts[C_curly] += 1 + # Rule 5: One {A} and one {C} can be combined to form one (A) and one (B) + elif counts[A_curly] >= 1 and counts[C_curly] >= 1: + counts[A_curly] -= 1 + counts[C_curly] -= 1 + counts[A_round] += 1 + counts[B_round] += 1 + # Rule 6: Two {B} can be combined to form one (C) + elif counts[B_curly] >= 2: + counts[B_curly] -= 2 + counts[C_round] += 1 + return counts + + def _get_answer(self, A_square: int, B_square: int, C_square: int) -> list[list[int]]: + """Calculate the answer for a given input""" + # [A] [B] [C] {A} {B} {C} (A) (B) (C) + counts = [A_square, B_square, C_square] + [0 for _ in range(6)] + states = [counts] + + for _ in range(self.config.max_iterations): + new_counts = self._apply_rule(counts[:]) + if new_counts in states: + break + states.append(new_counts) + counts = new_counts + + return states + + def __getitem__(self, idx: int) -> dict: + """Generate a single String Synthesis question""" + rng = Random(self.seed + idx) + + A_square = rng.randint(self.config.min_initial_blocks, self.config.max_initial_blocks) + B_square = rng.randint(self.config.min_initial_blocks, self.config.max_initial_blocks) + C_square = rng.randint(self.config.min_initial_blocks, self.config.max_initial_blocks) + + states = self._get_answer(A_square, B_square, C_square) + answer = states[-1] + answer_str = " ".join(str(x) for x in answer) + + return { + "question": QUESTION_TEMPLATE.format(A_square=A_square, B_square=B_square, C_square=C_square), + "answer": answer_str, + "metadata": {"states": states, "solution": answer}, + } + + +register_dataset("string_synthesis", StringSynthesisDataset, StringSynthesisConfig) diff --git a/tests/test_string_synthesis.py b/tests/test_string_synthesis.py new file mode 100644 index 00000000..39fa4133 --- /dev/null +++ b/tests/test_string_synthesis.py @@ -0,0 +1,119 @@ +"""Tests for String Synthesis questions generation""" + +import pytest + +from reasoning_gym.algorithmic.string_synthesis import StringSynthesisConfig, StringSynthesisDataset + + +def test_string_synthesis_config_validation(): + """Test that invalid configs raise appropriate errors""" + + with pytest.raises(AssertionError): + config = StringSynthesisConfig(min_initial_blocks=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = StringSynthesisConfig(min_initial_blocks=3, max_initial_blocks=2) # Min > Max + config.validate() + + with pytest.raises(AssertionError): + config = StringSynthesisConfig(max_iterations=0) # Zero not allowed + config.validate() + + +def test_string_synthesis_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = StringSynthesisConfig(seed=42, size=10) + dataset1 = StringSynthesisDataset(config) + dataset2 = StringSynthesisDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_string_synthesis_dataset_items(): + """Test basic properties of generated items""" + config = StringSynthesisConfig(min_initial_blocks=1, max_initial_blocks=3, size=10, seed=42) + dataset = StringSynthesisDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "states" in item["metadata"] + assert "solution" in item["metadata"] + + states = item["metadata"]["states"] + solution = item["metadata"]["solution"] + + # Verify dimensions + assert len(states) >= 1 + first_state = states[0] + assert len(first_state) == 9 + for i in range(3): + assert 0 <= first_state[i] <= 3 + for i in range(3, 9): + assert first_state[i] == 0 + assert solution == states[-1] + for i in range(9): + assert 0 <= solution[i] + + +def test_string_synthesis_dataset_iteration(): + """Test that iteration respects dataset size""" + config = StringSynthesisConfig(size=5, seed=42) + dataset = StringSynthesisDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_string_synthesis_answer(): + """Test the _get_answer method""" + config = StringSynthesisConfig(seed=42) + dataset = StringSynthesisDataset(config) + + # Empty input + counts = [0, 0, 0, 0, 0, 0, 0, 0, 0] + assert dataset._apply_rule(counts) == [0, 0, 0, 0, 0, 0, 0, 0, 0] + + # Rule 1 + counts = [1, 1, 1, 0, 0, 0, 0, 0, 0] + assert dataset._apply_rule(counts) == [0, 0, 0, 1, 0, 0, 0, 0, 0] + + # Rule 2 + counts = [1, 1, 0, 0, 0, 0, 0, 0, 0] + assert dataset._apply_rule(counts) == [0, 0, 0, 0, 0, 1, 0, 0, 0] + + # Rule 3 + counts = [0, 1, 1, 0, 0, 0, 0, 0, 0] + assert dataset._apply_rule(counts) == [0, 0, 0, 0, 1, 0, 0, 0, 0] + + # Rule 4 + counts = [0, 0, 2, 0, 0, 0, 0, 0, 0] + assert dataset._apply_rule(counts) == [0, 0, 0, 0, 0, 1, 0, 0, 0] + + # Rule 5 + counts = [0, 0, 0, 1, 0, 1, 0, 0, 0] + assert dataset._apply_rule(counts) == [0, 0, 0, 0, 0, 0, 1, 1, 0] + + # Rule 6 + counts = [0, 0, 0, 0, 2, 0, 0, 0, 0] + assert dataset._apply_rule(counts) == [0, 0, 0, 0, 0, 0, 0, 0, 1] + + # 1-shot example provided in the prompt + A_square, B_square, C_square = 2, 3, 3 + assert dataset._get_answer(A_square, B_square, C_square) == [ + [2, 3, 3, 0, 0, 0, 0, 0, 0], # Initial state + [1, 2, 2, 1, 0, 0, 0, 0, 0], # Rule 1 + [0, 1, 1, 2, 0, 0, 0, 0, 0], # Rule 1 again + [0, 0, 0, 2, 1, 0, 0, 0, 0], # Rule 3 (final state) + ]