From 9bb6d028a30463c363eac1e1db2b6e608d15d0b5 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 5 Mar 2025 22:44:04 +0100 Subject: [PATCH] feat(env): Count Bits Curriculum (#267) * add min n * count bits --- reasoning_gym/arithmetic/__init__.py | 3 ++- reasoning_gym/arithmetic/count_bits.py | 25 +++++++++++++++++++-- tests/test_count_bits.py | 30 +++++++++++++++++++++----- 3 files changed, 50 insertions(+), 8 deletions(-) diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 5a5ae48b..3aaedfe9 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -6,7 +6,7 @@ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConf from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSumConfig, ChainSumDataset -from .count_bits import CountBitsConfig, CountBitsDataset +from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticDataset from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumDataset from .dice import DiceConfig, DiceDataset @@ -48,6 +48,7 @@ __all__ = [ "TimeIntervalsDataset", "CountBitsConfig", "CountBitsDataset", + "CountBitsCurriculum", "DiceConfig", "DiceDataset", "NumberFormatConfig", diff --git a/reasoning_gym/arithmetic/count_bits.py b/reasoning_gym/arithmetic/count_bits.py index 5dc2c099..059df677 100644 --- a/reasoning_gym/arithmetic/count_bits.py +++ b/reasoning_gym/arithmetic/count_bits.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of the number {number}?""" @@ -13,6 +14,7 @@ QUESTION_TEMPLATE = """How many 1 bits are there in the binary representation of class CountBitsConfig: """Configuration for Count Bits dataset generation""" + min_n: int = 1 # Minimum number to consider max_n: int = 2**31 - 1 # Maximum number to consider size: int = 500 # Virtual dataset size @@ -20,7 +22,7 @@ class CountBitsConfig: def validate(self): """Validate configuration parameters""" - assert 1 <= self.max_n, "max_n must be at least 1" + assert 1 <= self.min_n <= self.max_n, "min_n must be between 1 and max_n" class CountBitsDataset(ProceduralDataset): @@ -33,7 +35,7 @@ class CountBitsDataset(ProceduralDataset): """Generate a single Count Bits question""" rng = Random(self.seed + idx) - number = rng.randint(1, self.config.max_n) + number = rng.randint(self.config.min_n, self.config.max_n) binary = bin(number)[2:] answer = binary.count("1") @@ -44,4 +46,23 @@ class CountBitsDataset(ProceduralDataset): } +class CountBitsCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(CountBitsCurriculum.__name__, CountBitsConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="n", + levels=[1_000, 1_000_000, 100_000_000, 2**31 - 1], + default_level=0, + description="Number to count bits in", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_n", + upper_field_name="max_n", + ), + ) + + register_dataset("count_bits", CountBitsDataset, CountBitsConfig) diff --git a/tests/test_count_bits.py b/tests/test_count_bits.py index 6a36c886..86134368 100644 --- a/tests/test_count_bits.py +++ b/tests/test_count_bits.py @@ -2,17 +2,21 @@ import pytest -from reasoning_gym.arithmetic.count_bits import CountBitsConfig, CountBitsDataset +from reasoning_gym.arithmetic.count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset def test_count_bits_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): - config = CountBitsConfig(max_n=-1) # Negative not allowed + config = CountBitsConfig(min_n=-1) # Negative not allowed config.validate() with pytest.raises(AssertionError): - config = CountBitsConfig(max_n=0) # Zero not allowed + config = CountBitsConfig(min_n=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CountBitsConfig(min_n=10, max_n=5) # min_n > max_n config.validate() @@ -28,7 +32,7 @@ def test_count_bits_dataset_deterministic(): def test_count_bits_dataset_items(): """Test basic properties of generated items""" - config = CountBitsConfig(max_n=10, size=10, seed=42) + config = CountBitsConfig(min_n=1, max_n=1_000_000, size=10, seed=42) dataset = CountBitsDataset(config) for i in range(len(dataset)): @@ -49,7 +53,7 @@ def test_count_bits_dataset_items(): binary = item["metadata"]["binary"] # Verify values - assert number <= config.max_n + assert config.min_n <= number <= config.max_n assert solution >= 0 assert set(binary) <= {"0", "1"} @@ -81,3 +85,19 @@ def test_count_bits_answer(): count += number & 1 number >>= 1 assert solution == count + + +def test_count_bits_curriculum(): + curriculum = CountBitsCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: CountBitsConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_n == 1_000 and base_cfg.max_n == 1_000 + + # test incrementing attribute levels + curriculum.increment_attr_level("n") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_n == 1_000 and increased_cfg.max_n == 1_000_000