feat(env): Binary Alternation Curriculum (#278)

* binary alternation

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
Zafir Stojanovski 2025-03-07 22:44:32 +01:00 committed by GitHub
parent ce55d528ad
commit dfc28c94d6
3 changed files with 58 additions and 7 deletions

View file

@ -7,6 +7,7 @@ from dataclasses import dataclass
from random import Random
from typing import Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """Given a binary string, return the minimum number of character swaps to make it alternating, or -1 if it is impossible.
@ -43,8 +44,7 @@ class BinaryAlternationDataset(ProceduralDataset):
def __init__(self, config: BinaryAlternationConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def _get_binary_string(self, rng: Random, solvable: bool) -> str:
n = rng.randint(self.config.min_n, self.config.max_n)
def _get_binary_string(self, rng: Random, n: int, solvable: bool) -> str:
ones, zeros = n // 2, n // 2
# Check if we need to add an extra bit
@ -96,15 +96,40 @@ class BinaryAlternationDataset(ProceduralDataset):
"""Generate a single Count Bits question"""
rng = Random(self.seed + idx)
n = rng.randint(self.config.min_n, self.config.max_n)
solvable = rng.random() < self.config.p_solvable
string = self._get_binary_string(rng, solvable)
string = self._get_binary_string(rng, n, solvable)
answer = self._get_answer(string)
return {
"question": QUESTION_TEMPLATE.format(string=string),
"answer": str(answer),
"metadata": {"string": string, "solution": answer, "solvable": solvable},
"metadata": {
"string": string,
"solution": answer,
"solvable": solvable,
"difficulty": {"n": n},
},
}
register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig)
class BinaryAlternationCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(BinaryAlternationCurriculum.__name__, BinaryAlternationConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="n",
levels=[10, 50, 500, 1000],
default_level=0,
description="Number of bits in the binary string",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_n",
upper_field_name="max_n",
)
)
register_dataset("binary_alternation", BinaryAlternationDataset, BinaryAlternationConfig, BinaryAlternationCurriculum)