Bitwise arithmetic curriculum (#282)

* bitwise_arithmetic curriculum
* register BitwiseArithmeticCurriculum

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
vncntt 2025-03-07 16:32:00 -08:00 committed by GitHub
parent 444c793d3f
commit 775a42e9e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 58 additions and 4 deletions

View file

@ -3,7 +3,7 @@ Arithmetic tasks for training reasoning capabilities:
"""
from .basic_arithmetic import BasicArithmeticCurriculum, BasicArithmeticDataset, BasicArithmeticDatasetConfig
from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset
from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticCurriculum, BitwiseArithmeticDataset
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSumConfig, ChainSumDataset
from .count_bits import CountBitsConfig, CountBitsCurriculum, CountBitsDataset
@ -63,4 +63,5 @@ __all__ = [
"DecimalChainSumDataset",
"BitwiseArithmeticConfig",
"BitwiseArithmeticDataset",
"BitwiseArithmeticCurriculum",
]

View file

@ -2,6 +2,7 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
from ..coaching import AttributeType, BaseCurriculum, ScalarAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@ -151,7 +152,11 @@ class BitwiseArithmeticDataset(ProceduralDataset):
+ problem
)
return {"question": problem_str, "answer": answer, "metadata": {"problem": problem}}
return {
"question": problem_str,
"answer": answer,
"metadata": {"problem": problem, "difficulty": {"difficulty": self.config.difficulty}},
}
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""
@ -171,5 +176,24 @@ class BitwiseArithmeticDataset(ProceduralDataset):
return 0.0
class BitwiseArithmeticCurriculum(BaseCurriculum):
"""Curriculum for Bitwise Arithmetic dataset"""
def __init__(self):
super().__init__(BitwiseArithmeticCurriculum.__name__, BitwiseArithmeticConfig)
self._define_attributes(
ScalarAttributeDefinition(
name="difficulty",
levels=[1, 2, 3, 4],
default_level=0,
description="Range of difficulty levels",
attr_type=AttributeType.STATIC,
min_value=1,
field_name="difficulty",
),
)
# Register the dataset with the factory.
register_dataset("bitwise_arithmetic", BitwiseArithmeticDataset, BitwiseArithmeticConfig)
register_dataset("bitwise_arithmetic", BitwiseArithmeticDataset, BitwiseArithmeticConfig, BitwiseArithmeticCurriculum)