added basic arith curricula (#276)

* added basic arith curricula
* register BasicArithmeticCurriculum

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
joesharratt1229 2025-03-07 22:54:49 +01:00 committed by GitHub
parent f490b9f760
commit 1888fe2bb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 82 additions and 4 deletions

View file

@ -2,7 +2,7 @@
Arithmetic tasks for training reasoning capabilities: Arithmetic tasks for training reasoning capabilities:
""" """
from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from .basic_arithmetic import BasicArithmeticCurriculum, BasicArithmeticDataset, BasicArithmeticDatasetConfig
from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset
from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset
from .chain_sum import ChainSumConfig, ChainSumDataset from .chain_sum import ChainSumConfig, ChainSumDataset
@ -24,6 +24,7 @@ from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset
__all__ = [ __all__ = [
"BasicArithmeticDataset", "BasicArithmeticDataset",
"BasicArithmeticDatasetConfig", "BasicArithmeticDatasetConfig",
"BasicArithmeticCurriculum",
"ChainSumDataset", "ChainSumDataset",
"ChainSumConfig", "ChainSumConfig",
"CalendarArithmeticConfig", "CalendarArithmeticConfig",

View file

@ -2,6 +2,7 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -94,9 +95,8 @@ class BasicArithmeticDataset(ProceduralDataset):
"question": question, "question": question,
"answer": str(result), "answer": str(result),
"metadata": { "metadata": {
"num_terms": num_terms,
"num_digits": num_digits,
"expression": expression, "expression": expression,
"difficulty": {"num_terms": num_terms, "num_digits": num_digits},
}, },
} }
@ -233,5 +233,32 @@ class BasicArithmeticDataset(ProceduralDataset):
return template.format(expression) return template.format(expression)
class BasicArithmeticCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(name=BasicArithmeticCurriculum.__name__, config_cls=BasicArithmeticDatasetConfig)
self._define_attributes(
RangeAttributeDefinition(
name="num_terms",
levels=[2, 5, 10, 20],
default_level=0,
description="Number of terms in the expression",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_terms",
upper_field_name="max_terms",
),
RangeAttributeDefinition(
name="num_digits",
levels=[1, 2, 5, 10],
default_level=0,
description="Number of digits in the numbers",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_digits",
upper_field_name="max_digits",
),
)
# Register the dataset # Register the dataset
register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig) register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig, BasicArithmeticCurriculum)

View file

@ -1,6 +1,7 @@
import pytest import pytest
from reasoning_gym.arithmetic.basic_arithmetic import ( from reasoning_gym.arithmetic.basic_arithmetic import (
BasicArithmeticCurriculum,
BasicArithmeticDataset, BasicArithmeticDataset,
BasicArithmeticDatasetConfig, BasicArithmeticDatasetConfig,
eval_floordiv, eval_floordiv,
@ -96,3 +97,52 @@ def test_arithmetic_dataset_iteration():
first_items = list(dataset) first_items = list(dataset)
second_items = list(dataset) second_items = list(dataset)
assert first_items == second_items, "Multiple iterations should yield same items" assert first_items == second_items, "Multiple iterations should yield same items"
def test_basic_arithmetic_curriculum():
"""Test the BasicArithmeticCurriculum functionality"""
curriculum = BasicArithmeticCurriculum()
base_value = {"size": 150, "seed": 1}
base_cfg: BasicArithmeticDatasetConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2
assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1
# Test incrementing attribute levels
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("num_digits")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 5
assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2
# Test decrementing attribute level for num_terms
curriculum.decrement_attr_level("num_terms")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 2
assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 2
# Test additional increments to ensure levels work as expected
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("num_terms")
higher_level_cfg = curriculum.generate_configuration(base_value)
assert higher_level_cfg.min_terms == 2 and higher_level_cfg.max_terms == 10
assert higher_level_cfg.min_digits == 1 and higher_level_cfg.max_digits == 2
# Test boundary conditions - trying to decrement below level 0
curriculum.decrement_attr_level("num_terms")
curriculum.decrement_attr_level("num_terms")
curriculum.decrement_attr_level("num_digits")
lower_bound_cfg = curriculum.generate_configuration(base_value)
assert lower_bound_cfg.min_terms == 2 and lower_bound_cfg.max_terms == 2
assert lower_bound_cfg.min_digits == 1 and lower_bound_cfg.max_digits == 1
# Test boundary conditions - trying to increment above max level
for _ in range(5):
curriculum.increment_attr_level("num_terms")
curriculum.increment_attr_level("num_digits")
upper_bound_cfg = curriculum.generate_configuration(base_value)
assert upper_bound_cfg.min_terms == 2 and upper_bound_cfg.max_terms == 20
assert upper_bound_cfg.min_digits == 1 and upper_bound_cfg.max_digits == 10