feat(env): Binary Alternation Curriculum (#278)

* binary alternation

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
Zafir Stojanovski 2025-03-07 22:44:32 +01:00 committed by GitHub
parent 0a35e608ec
commit a8e920b552
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 58 additions and 7 deletions

View file

@ -8,7 +8,7 @@ Algorithmic tasks for training reasoning capabilities:
from .ab import ABConfig, ABDataset
from .base_conversion import BaseConversionConfig, BaseConversionDataset
from .binary_alternation import BinaryAlternationConfig, BinaryAlternationDataset
from .binary_alternation import BinaryAlternationConfig, BinaryAlternationCurriculum, BinaryAlternationDataset
from .binary_matrix import BinaryMatrixConfig, BinaryMatrixDataset
from .caesar_cipher import CaesarCipherConfig, CaesarCipherDataset
from .count_primes import CountPrimesConfig, CountPrimesDataset
@ -111,4 +111,5 @@ __all__ = [
"JugsDataset",
"BinaryAlternationConfig",
"BinaryAlternationDataset",
"BinaryAlternationCurriculum",
]

View file

@ -7,6 +7,7 @@ from dataclasses import dataclass
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a binary string, return the minimum number of character swaps to make it alternating, or -1 if it is impossible.
@ -43,8 +44,7 @@ class BinaryAlternationDataset(ProceduralDataset):
def __init__(self, config: BinaryAlternationConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _get_binary_string(self, rng: Random, solvable: bool) -> str:
n = rng.randint(self.config.min_n, self.config.max_n)
def _get_binary_string(self, rng: Random, n: int, solvable: bool) -> str:
ones, zeros = n // 2, n // 2
# Check if we need to add an extra bit
@ -96,15 +96,40 @@ class BinaryAlternationDataset(ProceduralDataset):
"""Generate a single Count Bits question"""
rng = Random(self.seed + idx)
n = rng.randint(self.config.min_n, self.config.max_n)
solvable = rng.random() < self.config.p_solvable
string = self._get_binary_string(rng, solvable)
string = self._get_binary_string(rng, n, solvable)
answer = self._get_answer(string)
return {
"question": QUESTION_TEMPLATE.format(string=string),
"answer": str(answer),
"metadata": {"string": string, "solution": answer, "solvable": solvable},
"metadata": {
"string": string,
"solution": answer,
"solvable": solvable,
"difficulty": {"n": n},
},
}
register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig)
class BinaryAlternationCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(BinaryAlternationCurriculum.__name__, BinaryAlternationConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="n",
levels=[10, 50, 500, 1000],
default_level=0,
description="Number of bits in the binary string",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_n",
upper_field_name="max_n",
)
)
register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig, BinaryAlternationCurriculum)

View file

@ -2,7 +2,11 @@
import pytest
from reasoning_gym.algorithmic.binary_alternation import BinaryAlternationConfig, BinaryAlternationDataset
from reasoning_gym.algorithmic.binary_alternation import (
BinaryAlternationConfig,
BinaryAlternationCurriculum,
BinaryAlternationDataset,
)
def test_binary_alternation_config_validation():
@ -102,3 +106,24 @@ def test_binary_alternation_answer():
# One shot example
string = "111000"
assert dataset._get_answer(string) == 1
def test_chain_sum_curriculum():
curriculum = BinaryAlternationCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: BinaryAlternationConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_n == 10 and base_cfg.max_n == 10
# test incrementing attribute levels
curriculum.increment_attr_level("n")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_n == 10 and increased_cfg.max_n == 50
# test decrementing attribute levels
curriculum.decrement_attr_level("n")
decreased_cfg = curriculum.generate_configuration(base_value)
assert decreased_cfg.min_n == 10 and decreased_cfg.max_n == 10