diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 999b4521..05d321da 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -5,6 +5,7 @@ Arithmetic tasks for training reasoning capabilities: from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSum, ChainSumConfig +from .count_bits import CountBitsConfig, CountBitsDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset from .gcd import GCDConfig, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig @@ -35,4 +36,6 @@ __all__ = [ "GSMSymbolicDataset", "TimeIntervalsConfig", "TimeIntervalsDataset", + "CountBitsConfig", + "CountBitsDataset", ] diff --git a/reasoning_gym/arithmetic/count_bits.py b/reasoning_gym/arithmetic/count_bits.py new file mode 100644 index 00000000..5dc2c099 --- /dev/null +++ b/reasoning_gym/arithmetic/count_bits.py @@ -0,0 +1,47 @@ +"""Count number of 1 bits in a number.""" + +from dataclasses import dataclass +from random import Random +from typing import Optional + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of the number {number}?""" + + +@dataclass +class CountBitsConfig: + """Configuration for Count Bits dataset generation""" + + max_n: int = 2**31 - 1 # Maximum number to consider + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 1 <= self.max_n, "max_n must be at least 1" + + +class CountBitsDataset(ProceduralDataset): + """Generates Count Bits exercises with configurable difficulty""" + + def __init__(self, config: CountBitsConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """Generate a single Count Bits question""" + rng = Random(self.seed + idx) + + number = rng.randint(1, self.config.max_n) + binary = bin(number)[2:] + answer = binary.count("1") + + return { + "question": QUESTION_TEMPLATE.format(number=number), + "answer": str(answer), + "metadata": {"number": number, "solution": answer, "binary": binary}, + } + + +register_dataset("count_bits", CountBitsDataset, CountBitsConfig) diff --git a/tests/test_count_bits.py b/tests/test_count_bits.py new file mode 100644 index 00000000..6a36c886 --- /dev/null +++ b/tests/test_count_bits.py @@ -0,0 +1,83 @@ +"""Tests for Count bits questions generation""" + +import pytest + +from reasoning_gym.arithmetic.count_bits import CountBitsConfig, CountBitsDataset + + +def test_count_bits_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = CountBitsConfig(max_n=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CountBitsConfig(max_n=0) # Zero not allowed + config.validate() + + +def test_count_bits_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = CountBitsConfig(seed=42, size=10) + dataset1 = CountBitsDataset(config) + dataset2 = CountBitsDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_count_bits_dataset_items(): + """Test basic properties of generated items""" + config = CountBitsConfig(max_n=10, size=10, seed=42) + dataset = CountBitsDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "number" in item["metadata"] + assert "solution" in item["metadata"] + assert "binary" in item["metadata"] + + number = item["metadata"]["number"] + solution = item["metadata"]["solution"] + binary = item["metadata"]["binary"] + + # Verify values + assert number <= config.max_n + assert solution >= 0 + assert set(binary) <= {"0", "1"} + + +def test_count_bits_dataset_iteration(): + """Test that iteration respects dataset size""" + config = CountBitsConfig(size=5, seed=42) + dataset = CountBitsDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_count_bits_answer(): + """Verify the number of 1 bits in the binary representation of a number""" + config = CountBitsConfig(size=5, seed=42) + dataset = CountBitsDataset(config) + + for item in dataset: + number = item["metadata"]["number"] + solution = item["metadata"]["solution"] + + # Count number of 1 bits in the number by shifting + count = 0 + while number: + count += number & 1 + number >>= 1 + assert solution == count