diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index cd168229..72ae5e27 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -17,7 +17,7 @@ from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingCurriculum, LegCountingDataset from .number_format import NumberFormatConfig, NumberFormatCurriculum, NumberFormatDataset from .power_function import PowerFunctionConfig, PowerFunctionCurriculum, PowerFunctionDataset -from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset +from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationCurriculum, PrimeFactorizationDataset from .products import ProductsConfig, ProductsDataset from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset @@ -45,6 +45,7 @@ __all__ = [ "PowerFunctionCurriculum", "PrimeFactorizationConfig", "PrimeFactorizationDataset", + "PrimeFactorizationCurriculum", "ProductsDataset", "ProductsConfig", "GSMSymbolicDatasetConfig", diff --git a/reasoning_gym/arithmetic/prime_factorization.py b/reasoning_gym/arithmetic/prime_factorization.py index f6b03845..5795fbef 100644 --- a/reasoning_gym/arithmetic/prime_factorization.py +++ b/reasoning_gym/arithmetic/prime_factorization.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -82,8 +83,29 @@ class PrimeFactorizationDataset(ProceduralDataset): f"(Example: for 12 the answer would be: 2 × 2 × 3)" ), "answer": answer, - "metadata": {"number": number, "factors": factors}, + "metadata": {"number": number, "factors": factors, "difficulty": {"value": number}}, } -register_dataset("prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig) +class PrimeFactorizationCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(PrimeFactorizationCurriculum.__name__, PrimeFactorizationConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="value", + levels=[10, 1_000, 10_000, 50_000], + default_level=1, + description="Number to factorize", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_value", + upper_field_name="max_value", + ) + ) + + +register_dataset( + "prime_factorization", PrimeFactorizationDataset, PrimeFactorizationConfig, PrimeFactorizationCurriculum +) diff --git a/tests/test_prime_factorization.py b/tests/test_prime_factorization.py index 70c463b7..9ca0a0b1 100644 --- a/tests/test_prime_factorization.py +++ b/tests/test_prime_factorization.py @@ -2,7 +2,11 @@ import pytest -from reasoning_gym.arithmetic.prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset +from reasoning_gym.arithmetic.prime_factorization import ( + PrimeFactorizationConfig, + PrimeFactorizationCurriculum, + PrimeFactorizationDataset, +) def test_prime_factorization_config_validation(): @@ -124,3 +128,24 @@ def is_prime(n: int) -> bool: if n % i == 0: return False return True + + +def test_prime_factorization_curriculum(): + curriculum = PrimeFactorizationCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: PrimeFactorizationConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_value == 10 and base_cfg.max_value == 1_000 + + # test incrementing attribute levels + curriculum.increment_attr_level("value") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_value == 10 and increased_cfg.max_value == 10_000 + + # test decrementing attribute level for value again + curriculum.decrement_attr_level("value") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_value == 10 and partially_decreased_cfg.max_value == 1_000