mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
count bits (#101)
This commit is contained in:
parent
a8c39ddcfb
commit
ed10111834
3 changed files with 133 additions and 0 deletions
|
|
@ -5,6 +5,7 @@ Arithmetic tasks for training reasoning capabilities:
|
|||
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
|
||||
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
|
||||
from .chain_sum import ChainSum, ChainSumConfig
|
||||
from .count_bits import CountBitsConfig, CountBitsDataset
|
||||
from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset
|
||||
from .gcd import GCDConfig, GCDDataset
|
||||
from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig
|
||||
|
|
@ -35,4 +36,6 @@ __all__ = [
|
|||
"GSMSymbolicDataset",
|
||||
"TimeIntervalsConfig",
|
||||
"TimeIntervalsDataset",
|
||||
"CountBitsConfig",
|
||||
"CountBitsDataset",
|
||||
]
|
||||
|
|
|
|||
47
reasoning_gym/arithmetic/count_bits.py
Normal file
47
reasoning_gym/arithmetic/count_bits.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""Count number of 1 bits in a number."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of the number {number}?"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountBitsConfig:
|
||||
"""Configuration for Count Bits dataset generation"""
|
||||
|
||||
max_n: int = 2**31 - 1 # Maximum number to consider
|
||||
|
||||
size: int = 500 # Virtual dataset size
|
||||
seed: Optional[int] = None
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert 1 <= self.max_n, "max_n must be at least 1"
|
||||
|
||||
|
||||
class CountBitsDataset(ProceduralDataset):
|
||||
"""Generates Count Bits exercises with configurable difficulty"""
|
||||
|
||||
def __init__(self, config: CountBitsConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single Count Bits question"""
|
||||
rng = Random(self.seed + idx)
|
||||
|
||||
number = rng.randint(1, self.config.max_n)
|
||||
binary = bin(number)[2:]
|
||||
answer = binary.count("1")
|
||||
|
||||
return {
|
||||
"question": QUESTION_TEMPLATE.format(number=number),
|
||||
"answer": str(answer),
|
||||
"metadata": {"number": number, "solution": answer, "binary": binary},
|
||||
}
|
||||
|
||||
|
||||
register_dataset("count_bits", CountBitsDataset, CountBitsConfig)
|
||||
83
tests/test_count_bits.py
Normal file
83
tests/test_count_bits.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Tests for Count bits questions generation"""
|
||||
|
||||
import pytest
|
||||
|
||||
from reasoning_gym.arithmetic.count_bits import CountBitsConfig, CountBitsDataset
|
||||
|
||||
|
||||
def test_count_bits_config_validation():
|
||||
"""Test that invalid configs raise appropriate errors"""
|
||||
with pytest.raises(AssertionError):
|
||||
config = CountBitsConfig(max_n=-1) # Negative not allowed
|
||||
config.validate()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
config = CountBitsConfig(max_n=0) # Zero not allowed
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_count_bits_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = CountBitsConfig(seed=42, size=10)
|
||||
dataset1 = CountBitsDataset(config)
|
||||
dataset2 = CountBitsDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_count_bits_dataset_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = CountBitsConfig(max_n=10, size=10, seed=42)
|
||||
dataset = CountBitsDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
# Check item structure
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert "metadata" in item
|
||||
|
||||
# Check metadata
|
||||
assert "number" in item["metadata"]
|
||||
assert "solution" in item["metadata"]
|
||||
assert "binary" in item["metadata"]
|
||||
|
||||
number = item["metadata"]["number"]
|
||||
solution = item["metadata"]["solution"]
|
||||
binary = item["metadata"]["binary"]
|
||||
|
||||
# Verify values
|
||||
assert number <= config.max_n
|
||||
assert solution >= 0
|
||||
assert set(binary) <= {"0", "1"}
|
||||
|
||||
|
||||
def test_count_bits_dataset_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = CountBitsConfig(size=5, seed=42)
|
||||
dataset = CountBitsDataset(config)
|
||||
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size
|
||||
|
||||
# Test multiple iterations yield same items
|
||||
assert items == list(dataset)
|
||||
|
||||
|
||||
def test_count_bits_answer():
|
||||
"""Verify the number of 1 bits in the binary representation of a number"""
|
||||
config = CountBitsConfig(size=5, seed=42)
|
||||
dataset = CountBitsDataset(config)
|
||||
|
||||
for item in dataset:
|
||||
number = item["metadata"]["number"]
|
||||
solution = item["metadata"]["solution"]
|
||||
|
||||
# Count number of 1 bits in the number by shifting
|
||||
count = 0
|
||||
while number:
|
||||
count += number & 1
|
||||
number >>= 1
|
||||
assert solution == count
|
||||
Loading…
Add table
Add a link
Reference in a new issue