diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index a7e6ab73..61dab185 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -3,7 +3,7 @@ Arithmetic tasks for training reasoning capabilities: """ from .basic_arithmetic import BasicArithmeticCurriculum, BasicArithmeticDataset, BasicArithmeticDatasetConfig -from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset +from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticCurriculum, BitwiseArithmeticDataset from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSumConfig, ChainSumDataset from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset @@ -63,4 +63,5 @@ __all__ = [ "DecimalChainSumDataset", "BitwiseArithmeticConfig", "BitwiseArithmeticDataset", + "BitwiseArithmeticCurriculum", ] diff --git a/reasoning_gym/arithmetic/bitwise_arithmetic.py b/reasoning_gym/arithmetic/bitwise_arithmetic.py index 97de426b..8b97221c 100644 --- a/reasoning_gym/arithmetic/bitwise_arithmetic.py +++ b/reasoning_gym/arithmetic/bitwise_arithmetic.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -151,7 +152,11 @@ class BitwiseArithmeticDataset(ProceduralDataset): + problem ) - return {"question": problem_str, "answer": answer, "metadata": {"problem": problem}} + return { + "question": problem_str, + "answer": answer, + "metadata": {"problem": problem, "difficulty": {"difficulty": self.config.difficulty}}, + } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """ @@ -171,5 +176,24 @@ class BitwiseArithmeticDataset(ProceduralDataset): return 0.0 +class BitwiseArithmeticCurriculum(BaseCurriculum): + """Curriculum for Bitwise Arithmetic dataset""" + + def __init__(self): + super().__init__(BitwiseArithmeticCurriculum.__name__, BitwiseArithmeticConfig) + + self._define_attributes( + ScalarAttributeDefinition( + name="difficulty", + levels=[1, 2, 3, 4], + default_level=0, + description="Range of difficulty levels", + attr_type=AttributeType.STATIC, + min_value=1, + field_name="difficulty", + ), + ) + + # Register the dataset with the factory. -register_dataset("bitwise_arithmetic", BitwiseArithmeticDataset, BitwiseArithmeticConfig) +register_dataset("bitwise_arithmetic", BitwiseArithmeticDataset, BitwiseArithmeticConfig, BitwiseArithmeticCurriculum) diff --git a/tests/test_bitwise_arithmetic.py b/tests/test_bitwise_arithmetic.py index 854c764a..2756f5c0 100644 --- a/tests/test_bitwise_arithmetic.py +++ b/tests/test_bitwise_arithmetic.py @@ -1,6 +1,10 @@ import pytest -from reasoning_gym.arithmetic.bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset +from reasoning_gym.arithmetic.bitwise_arithmetic import ( + BitwiseArithmeticConfig, + BitwiseArithmeticCurriculum, + BitwiseArithmeticDataset, +) def test_bitwise_arithmetic_config_validation(): @@ -116,3 +120,28 @@ def test_bitwise_arithmetic_answer_formats(): elif not correct.startswith("0x"): # For positive numbers without prefix assert dataset.score_answer(answer="0x" + correct, entry=item) == 1.0 + + +def test_bitwise_arithmetic_curriculum(): + """Test that curriculum generates appropriate configurations""" + + curriculum = BitwiseArithmeticCurriculum() + + base_value = {"size": 500, "seed": 42} + + base_cfg: BitwiseArithmeticConfig = curriculum.generate_configuration(base_value) + assert base_cfg.difficulty == 1 + assert base_cfg.size == 500 + assert base_cfg.seed == 42 + + curriculum.set_attr_level("difficulty", 1) # 0-indexed + cfg: BitwiseArithmeticConfig = curriculum.generate_configuration() + assert cfg.difficulty == 2 + + curriculum.increment_attr_level("difficulty") + cfg: BitwiseArithmeticConfig = curriculum.generate_configuration() + assert cfg.difficulty == 3 + + curriculum.decrement_attr_level("difficulty") + cfg: BitwiseArithmeticConfig = curriculum.generate_configuration() + assert cfg.difficulty == 2