diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 384f9cf4..a62aed26 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -8,7 +8,7 @@ Algorithmic tasks for training reasoning capabilities: from .ab import ABConfig, ABDataset from .base_conversion import BaseConversionConfig, BaseConversionDataset -from .binary_alternation import BinaryAlternationConfig, BinaryAlternationDataset +from .binary_alternation import BinaryAlternationConfig, BinaryAlternationCurriculum, BinaryAlternationDataset from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset from .count_primes import CountPrimesConfig, CountPrimesDataset @@ -111,4 +111,5 @@ __all__ = [ "JugsDataset", "BinaryAlternationConfig", "BinaryAlternationDataset", + "BinaryAlternationCurriculum", ] diff --git a/reasoning_gym/algorithmic/binary_alternation.py b/reasoning_gym/algorithmic/binary_alternation.py index ea50b0c8..7180291d 100644 --- a/reasoning_gym/algorithmic/binary_alternation.py +++ b/reasoning_gym/algorithmic/binary_alternation.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Given a binary string, return the minimum number of character swaps to make it alternating, or -1 if it is impossible. @@ -43,8 +44,7 @@ class BinaryAlternationDataset(ProceduralDataset): def __init__(self, config: BinaryAlternationConfig): super().__init__(config=config, seed=config.seed, size=config.size) - def _get_binary_string(self, rng: Random, solvable: bool) -> str: - n = rng.randint(self.config.min_n, self.config.max_n) + def _get_binary_string(self, rng: Random, n: int, solvable: bool) -> str: ones, zeros = n // 2, n // 2 # Check if we need to add an extra bit @@ -96,15 +96,40 @@ class BinaryAlternationDataset(ProceduralDataset): """Generate a single Count Bits question""" rng = Random(self.seed + idx) + n = rng.randint(self.config.min_n, self.config.max_n) solvable = rng.random() < self.config.p_solvable - string = self._get_binary_string(rng, solvable) + string = self._get_binary_string(rng, n, solvable) answer = self._get_answer(string) return { "question": QUESTION_TEMPLATE.format(string=string), "answer": str(answer), - "metadata": {"string": string, "solution": answer, "solvable": solvable}, + "metadata": { + "string": string, + "solution": answer, + "solvable": solvable, + "difficulty": {"n": n}, + }, } -register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig) +class BinaryAlternationCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(BinaryAlternationCurriculum.__name__, BinaryAlternationConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="n", + levels=[10, 50, 500, 1000], + default_level=0, + description="Number of bits in the binary string", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_n", + upper_field_name="max_n", + ) + ) + + +register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig, BinaryAlternationCurriculum) diff --git a/tests/test_binary_alternation.py b/tests/test_binary_alternation.py index 81a581e6..bed89afd 100644 --- a/tests/test_binary_alternation.py +++ b/tests/test_binary_alternation.py @@ -2,7 +2,11 @@ import pytest -from reasoning_gym.algorithmic.binary_alternation import BinaryAlternationConfig, BinaryAlternationDataset +from reasoning_gym.algorithmic.binary_alternation import ( + BinaryAlternationConfig, + BinaryAlternationCurriculum, + BinaryAlternationDataset, +) def test_binary_alternation_config_validation(): @@ -102,3 +106,24 @@ def test_binary_alternation_answer(): # One shot example string = "111000" assert dataset._get_answer(string) == 1 + + +def test_chain_sum_curriculum(): + curriculum = BinaryAlternationCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: BinaryAlternationConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_n == 10 and base_cfg.max_n == 10 + + # test incrementing attribute levels + curriculum.increment_attr_level("n") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_n == 10 and increased_cfg.max_n == 50 + + # test decrementing attribute levels + curriculum.decrement_attr_level("n") + decreased_cfg = curriculum.generate_configuration(base_value) + assert decreased_cfg.min_n == 10 and decreased_cfg.max_n == 10