From 1888fe2bb467374f6703f0537f2cb4c7b03cd8ae Mon Sep 17 00:00:00 2001 From: joesharratt1229 <118444587+joesharratt1229@users.noreply.github.com> Date: Fri, 7 Mar 2025 22:54:49 +0100 Subject: [PATCH] added basic arith curricula (#276) * added basic arith curricula * register BasicArithmeticCurriculum --------- Co-authored-by: Andreas Koepf --- reasoning_gym/arithmetic/__init__.py | 3 +- reasoning_gym/arithmetic/basic_arithmetic.py | 33 +++++++++++-- tests/test_basic_arithmetic.py | 50 ++++++++++++++++++++ 3 files changed, 82 insertions(+), 4 deletions(-) diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index b2b7f95a..e895caeb 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -2,7 +2,7 @@ Arithmetic tasks for training reasoning capabilities: """ -from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig +from .basic_arithmetic import BasicArithmeticCurriculum, BasicArithmeticDataset, BasicArithmeticDatasetConfig from .bitwise_arithmetic import BitwiseArithmeticConfig, BitwiseArithmeticDataset from .calendar_arithmetic import CalendarArithmeticConfig, CalendarArithmeticDataset from .chain_sum import ChainSumConfig, ChainSumDataset @@ -24,6 +24,7 @@ from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset __all__ = [ "BasicArithmeticDataset", "BasicArithmeticDatasetConfig", + "BasicArithmeticCurriculum", "ChainSumDataset", "ChainSumConfig", "CalendarArithmeticConfig", diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 0c3ee345..b966e00f 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Literal, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -94,9 +95,8 @@ class BasicArithmeticDataset(ProceduralDataset): "question": question, "answer": str(result), "metadata": { - "num_terms": num_terms, - "num_digits": num_digits, "expression": expression, + "difficulty": {"num_terms": num_terms, "num_digits": num_digits}, }, } @@ -233,5 +233,32 @@ class BasicArithmeticDataset(ProceduralDataset): return template.format(expression) +class BasicArithmeticCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(name=BasicArithmeticCurriculum.__name__, config_cls=BasicArithmeticDatasetConfig) + self._define_attributes( + RangeAttributeDefinition( + name="num_terms", + levels=[2, 5, 10, 20], + default_level=0, + description="Number of terms in the expression", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_terms", + upper_field_name="max_terms", + ), + RangeAttributeDefinition( + name="num_digits", + levels=[1, 2, 5, 10], + default_level=0, + description="Number of digits in the numbers", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_digits", + upper_field_name="max_digits", + ), + ) + + # Register the dataset -register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig) +register_dataset("basic_arithmetic", BasicArithmeticDataset, BasicArithmeticDatasetConfig, BasicArithmeticCurriculum) diff --git a/tests/test_basic_arithmetic.py b/tests/test_basic_arithmetic.py index 6eda876a..b60f28bf 100644 --- a/tests/test_basic_arithmetic.py +++ b/tests/test_basic_arithmetic.py @@ -1,6 +1,7 @@ import pytest from reasoning_gym.arithmetic.basic_arithmetic import ( + BasicArithmeticCurriculum, BasicArithmeticDataset, BasicArithmeticDatasetConfig, eval_floordiv, @@ -96,3 +97,52 @@ def test_arithmetic_dataset_iteration(): first_items = list(dataset) second_items = list(dataset) assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_basic_arithmetic_curriculum(): + """Test the BasicArithmeticCurriculum functionality""" + curriculum = BasicArithmeticCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: BasicArithmeticDatasetConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_terms == 2 and base_cfg.max_terms == 2 + assert base_cfg.min_digits == 1 and base_cfg.max_digits == 1 + + # Test incrementing attribute levels + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("num_digits") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 5 + assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2 + + # Test decrementing attribute level for num_terms + curriculum.decrement_attr_level("num_terms") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_terms == 2 and partially_decreased_cfg.max_terms == 2 + assert partially_decreased_cfg.min_digits == 1 and partially_decreased_cfg.max_digits == 2 + + # Test additional increments to ensure levels work as expected + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("num_terms") + higher_level_cfg = curriculum.generate_configuration(base_value) + assert higher_level_cfg.min_terms == 2 and higher_level_cfg.max_terms == 10 + assert higher_level_cfg.min_digits == 1 and higher_level_cfg.max_digits == 2 + + # Test boundary conditions - trying to decrement below level 0 + curriculum.decrement_attr_level("num_terms") + curriculum.decrement_attr_level("num_terms") + curriculum.decrement_attr_level("num_digits") + lower_bound_cfg = curriculum.generate_configuration(base_value) + assert lower_bound_cfg.min_terms == 2 and lower_bound_cfg.max_terms == 2 + assert lower_bound_cfg.min_digits == 1 and lower_bound_cfg.max_digits == 1 + + # Test boundary conditions - trying to increment above max level + for _ in range(5): + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("num_digits") + upper_bound_cfg = curriculum.generate_configuration(base_value) + assert upper_bound_cfg.min_terms == 2 and upper_bound_cfg.max_terms == 20 + assert upper_bound_cfg.min_digits == 1 and upper_bound_cfg.max_digits == 10