diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 79b4aec2..ad9cccd3 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -8,6 +8,7 @@ Algorithmic tasks for training reasoning capabilities: from .ab import ABConfig, ABDataset from .base_conversion import BaseConversionConfig, BaseConversionDataset +from .binary_alternation import BinaryAlternationConfig, BinaryAlternationDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset from .count_primes import CountPrimesConfig, CountPrimesDataset @@ -38,7 +39,6 @@ from .string_synthesis import StringSynthesisConfig, StringSynthesisDataset from .word_ladder import WordLadderConfig, WordLadderDataset from .word_sequence_reversal import WordSequenceReversalConfig, WordSequenceReversalDataset from .word_sorting import TextTransformation, WordSortingConfig, WordSortingDataset -from .binary_alternation import BinaryAlternationConfig, BinaryAlternationDataset __all__ = [ "SpellBackwardConfig", diff --git a/reasoning_gym/algorithmic/binary_alternation.py b/reasoning_gym/algorithmic/binary_alternation.py index 29b74a13..ca204c6d 100644 --- a/reasoning_gym/algorithmic/binary_alternation.py +++ b/reasoning_gym/algorithmic/binary_alternation.py @@ -27,8 +27,8 @@ Now, determine the minimum number of swaps to make the following binary string a class BinaryAlternationConfig: """Configuration for Count Bits dataset generation""" - min_n: int = 10 # Minimum number of bits in the binary string - max_n: int = 30 # Maximum number of bits in the binary string + min_n: int = 10 # Minimum number of bits in the binary string + max_n: int = 30 # Maximum number of bits in the binary string p_solvable: float = 0.8 # Probability of generating a solvable sample size: int = 500 # Virtual dataset size @@ -50,7 +50,7 @@ class BinaryAlternationDataset(ProceduralDataset): def _get_binary_string(self, rng: Random, solvable: bool) -> str: n = rng.randint(self.config.min_n, self.config.max_n) ones, zeros = n // 2, n // 2 - + # Check if we need to add an extra bit if n % 2 == 1: if rng.random() < 0.5: @@ -69,13 +69,12 @@ class BinaryAlternationDataset(ProceduralDataset): ones += 2 else: zeros += 2 - + # Generate the string string = ["1"] * ones + ["0"] * zeros rng.shuffle(string) return "".join(string) - def _get_answer(self, string: str) -> int: """Calculate the minimum number of swaps to make the string alternating""" @@ -85,11 +84,11 @@ class BinaryAlternationDataset(ProceduralDataset): if c != expected: incorrect += 1 expected = "1" if expected == "0" else "0" - return incorrect // 2 # number of swaps is half of incorrect positions - + return incorrect // 2 # number of swaps is half of incorrect positions + ones, zeros = string.count("1"), string.count("0") - if abs(ones-zeros) > 1: - return -1 # impossible to make alternating + if abs(ones - zeros) > 1: + return -1 # impossible to make alternating if ones > zeros: return get_num_swaps("1") elif ones < zeros: @@ -111,4 +110,5 @@ class BinaryAlternationDataset(ProceduralDataset): "metadata": {"string": string, "solution": answer, "solvable": solvable}, } + register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig) diff --git a/tests/test_binary_alternation.py b/tests/test_binary_alternation.py index 2e02e7e4..81a581e6 100644 --- a/tests/test_binary_alternation.py +++ b/tests/test_binary_alternation.py @@ -26,7 +26,7 @@ def test_binary_alternation_config_validation(): with pytest.raises(AssertionError): config = BinaryAlternationConfig(p_solvable=-0.01) # < 0 not allowed config.validate() - + with pytest.raises(AssertionError): config = BinaryAlternationConfig(p_solvable=1.01) # > 0 not allowed config.validate() @@ -101,4 +101,4 @@ def test_binary_alternation_answer(): # One shot example string = "111000" - assert dataset._get_answer(string) == 1 \ No newline at end of file + assert dataset._get_answer(string) == 1