diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 016d252f..f8da34ce 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -16,7 +16,7 @@ from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetCon from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingCurriculum, LegCountingDataset from .number_format import NumberFormatConfig, NumberFormatDataset -from .power_function import PowerFunctionConfig, PowerFunctionDataset +from .power_function import PowerFunctionConfig, PowerFunctionCurriculum, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .products import ProductsConfig, ProductsDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset @@ -41,6 +41,7 @@ __all__ = [ "LegCountingCurriculum", "PowerFunctionConfig", "PowerFunctionDataset", + "PowerFunctionCurriculum", "PrimeFactorizationConfig", "PrimeFactorizationDataset", "ProductsDataset", diff --git a/reasoning_gym/arithmetic/power_function.py b/reasoning_gym/arithmetic/power_function.py index 321da833..aae9fb52 100644 --- a/reasoning_gym/arithmetic/power_function.py +++ b/reasoning_gym/arithmetic/power_function.py @@ -6,6 +6,7 @@ from math import pow from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Your task is to compute an exponentiation of a number. @@ -21,7 +22,7 @@ class PowerFunctionConfig: min_base: float = -1e3 # Minimum base value max_base: float = 1e3 # Maximum base value - min_exponent: int = -8 # Minimum exponent value + min_exponent: int = 0 # Minimum exponent value max_exponent: int = 8 # Maximum exponent value size: int = 500 # Virtual dataset size @@ -63,13 +64,33 @@ class PowerFunctionDataset(ProceduralDataset): base = round(rng.uniform(self.config.min_base, self.config.max_base), 4) exponent = rng.randint(self.config.min_exponent, self.config.max_exponent) + + if rng.random() < 0.5: + exponent = -exponent + answer = pow(base, exponent) return { "question": QUESTION_TEMPLATE.format(base=base, exponent=exponent), "answer": str(answer), - "metadata": {"base": base, "exponent": exponent, "solution": answer}, + "metadata": {"base": base, "exponent": exponent, "solution": answer, "difficulty": {"exponent": exponent}}, } -register_dataset("power_function", PowerFunctionDataset, PowerFunctionConfig) +class PowerFunctionCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(PowerFunctionCurriculum.__name__, PowerFunctionConfig) + self._define_attributes( + RangeAttributeDefinition( + name="exponent", + levels=[2, 4, 6, 10], + default_level=0, + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_exponent", + upper_field_name="max_exponent", + ), + ) + + +register_dataset("power_function", PowerFunctionDataset, PowerFunctionConfig, PowerFunctionCurriculum) diff --git a/tests/test_power_function.py b/tests/test_power_function.py index 4a7b8382..22c5327e 100644 --- a/tests/test_power_function.py +++ b/tests/test_power_function.py @@ -62,3 +62,23 @@ def test_power_function_score_function(): for item in dataset: answer = item["answer"] assert dataset.score_answer(answer, item) == 1.0 + + +def test_power_function_curriculum(): + """Test PowerFunctionCurriculum configuration generation and attribute manipulation""" + from reasoning_gym.arithmetic import PowerFunctionCurriculum + + curriculum = PowerFunctionCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_exponent == 2 and base_cfg.max_exponent == 2 + + # Test incrementing attribute levels for exponent & base attributes + curriculum.increment_attr_level("exponent") + + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_exponent == 2 and increased_cfg.max_exponent == 4