reasoning-gym/reasoning_gym/arithmetic/count_bits.py
Zafir Stojanovski dced3bfc45
fix(curriculum): Make boundaries in curriculum more sensible (#407)
* init

* fix tests

* unify codeio

* filtered for libraries not present in reasoning-gym

* fix more bounds

* puzzle24

* knight swap curriculum

* fix number sorting

* fix attributes

* add validation of config in creation of dataset

* dry run for instantiating and validating the datasets

* remove unused imports

* fix curriculum tests to reference newly updated attribute names
2025-04-04 20:24:14 +02:00

78 lines
2.4 KiB
Python

"""Count number of 1 bits in a number."""
from dataclasses import dataclass
from random import Random
from typing import Optional
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of the number {number}?"""
DATASET_NAME = "count_bits"
@dataclass
class CountBitsConfig:
"""Configuration for Count Bits dataset generation"""
min_n: int = 1 # Minimum number to consider
max_n: int = 2**31 - 1 # Maximum number to consider
size: int = 500 # Virtual dataset size
seed: Optional[int] = None
def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.min_n <= self.max_n, "min_n must be between 1 and max_n"
class CountBitsDataset(ProceduralDataset):
"""Generates Count Bits exercises with configurable difficulty"""
def __init__(self, config: CountBitsConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict:
"""Generate a single Count Bits question"""
rng = Random(self.seed + idx)
number = rng.randint(self.config.min_n, self.config.max_n)
binary = bin(number)[2:]
answer = binary.count("1")
return {
"question": QUESTION_TEMPLATE.format(number=number),
"answer": str(answer),
"metadata": {
"source_dataset": DATASET_NAME,
"source_index": idx,
"number": number,
"solution": answer,
"binary": binary,
"n": number,
"difficulty": {
"n": (self.config.min_n, self.config.max_n),
},
},
}
class CountBitsCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(CountBitsCurriculum.__name__, CountBitsConfig)
# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="n",
levels=[10, 1_000, 1_000_000, 100_000_000, 2**31 - 1],
description="Number to count bits in",
lower_field_name="min_n",
upper_field_name="max_n",
ensure_interval=True,
),
)
register_dataset(DATASET_NAME, CountBitsDataset, CountBitsConfig, CountBitsCurriculum)