Bitwise arithmetic curriculum (#282)

* bitwise_arithmetic curriculum
* register BitwiseArithmeticCurriculum

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
vncntt 2025-03-07 16:32:00 -08:00 committed by GitHub
parent 444c793d3f
commit 775a42e9e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 58 additions and 4 deletions

View file

@ -3,7 +3,7 @@ Arithmetic tasks for training reasoning capabilities:
""" """
from .basic_arithmetic import BasicArithmeticCurriculum, BasicArithmeticDataset, BasicArithmeticDatasetConfig from .basic_arithmetic import BasicArithmeticCurriculum, BasicArithmeticDataset, BasicArithmeticDatasetConfig
from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticCurriculum, BitwiseArithmeticDataset
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSumConfig, ChainSumDataset from .chain_sum import ChainSumConfig, ChainSumDataset
from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset
@ -63,4 +63,5 @@ __all__ = [
"DecimalChainSumDataset", "DecimalChainSumDataset",
"BitwiseArithmeticConfig", "BitwiseArithmeticConfig",
"BitwiseArithmeticDataset", "BitwiseArithmeticDataset",
"BitwiseArithmeticCurriculum",
] ]

View file

@ -2,6 +2,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -151,7 +152,11 @@ class BitwiseArithmeticDataset(ProceduralDataset):
+ problem + problem
) )
return {"question": problem_str, "answer": answer, "metadata": {"problem": problem}} return {
"question": problem_str,
"answer": answer,
"metadata": {"problem": problem, "difficulty": {"difficulty": self.config.difficulty}},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
""" """
@ -171,5 +176,24 @@ class BitwiseArithmeticDataset(ProceduralDataset):
return 0.0 return 0.0
class BitwiseArithmeticCurriculum(BaseCurriculum):
"""Curriculum for Bitwise Arithmetic dataset"""
def __init__(self):
super().__init__(BitwiseArithmeticCurriculum.__name__, BitwiseArithmeticConfig)
self._define_attributes(
ScalarAttributeDefinition(
name="difficulty",
levels=[1, 2, 3, 4],
default_level=0,
description="Range of difficulty levels",
attr_type=AttributeType.STATIC,
min_value=1,
field_name="difficulty",
),
)
# Register the dataset with the factory. # Register the dataset with the factory.
register_dataset("bitwise_arithmetic", BitwiseArithmeticDataset, BitwiseArithmeticConfig) register_dataset("bitwise_arithmetic", BitwiseArithmeticDataset, BitwiseArithmeticConfig, BitwiseArithmeticCurriculum)

View file

@ -1,6 +1,10 @@
import pytest import pytest
from reasoning_gym.arithmetic.bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset from reasoning_gym.arithmetic.bitwise_arithmetic import (
BitwiseArithmeticConfig,
BitwiseArithmeticCurriculum,
BitwiseArithmeticDataset,
)
def test_bitwise_arithmetic_config_validation(): def test_bitwise_arithmetic_config_validation():
@ -116,3 +120,28 @@ def test_bitwise_arithmetic_answer_formats():
elif not correct.startswith("0x"): elif not correct.startswith("0x"):
# For positive numbers without prefix # For positive numbers without prefix
assert dataset.score_answer(answer="0x" + correct, entry=item) == 1.0 assert dataset.score_answer(answer="0x" + correct, entry=item) == 1.0
def test_bitwise_arithmetic_curriculum():
"""Test that curriculum generates appropriate configurations"""
curriculum = BitwiseArithmeticCurriculum()
base_value = {"size": 500, "seed": 42}
base_cfg: BitwiseArithmeticConfig = curriculum.generate_configuration(base_value)
assert base_cfg.difficulty == 1
assert base_cfg.size == 500
assert base_cfg.seed == 42
curriculum.set_attr_level("difficulty", 1) # 0-indexed
cfg: BitwiseArithmeticConfig = curriculum.generate_configuration()
assert cfg.difficulty == 2
curriculum.increment_attr_level("difficulty")
cfg: BitwiseArithmeticConfig = curriculum.generate_configuration()
assert cfg.difficulty == 3
curriculum.decrement_attr_level("difficulty")
cfg: BitwiseArithmeticConfig = curriculum.generate_configuration()
assert cfg.difficulty == 2