diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index f8da34ce..de78b69d 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -11,7 +11,7 @@ from .decimal_arithmetic import DecimalArithmeticConfig, DecimalArithmeticCurric from .decimal_chain_sum import DecimalChainSumConfig, DecimalChainSumCurriculum, DecimalChainSumDataset from .dice import DiceConfig, DiceCurriculum, DiceDataset from .fraction_simplification import FractionSimplificationConfig, FractionSimplificationDataset -from .gcd import GCDConfig, GCDDataset +from .gcd import GCDConfig, GCDCurriculum, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingCurriculum, LegCountingDataset @@ -34,6 +34,7 @@ __all__ = [ "FractionSimplificationDataset", "GCDConfig", "GCDDataset", + "GCDCurriculum", "LCMConfig", "LCMDataset", "LegCountingConfig", diff --git a/reasoning_gym/arithmetic/gcd.py b/reasoning_gym/arithmetic/gcd.py index a764b2bd..f3a0a61a 100644 --- a/reasoning_gym/arithmetic/gcd.py +++ b/reasoning_gym/arithmetic/gcd.py @@ -6,6 +6,7 @@ from math import gcd from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -54,13 +55,50 @@ class GCDDataset(ProceduralDataset): rng = Random(self.seed + idx) numbers, result = self._generate_numbers(rng) + num_terms = len(numbers) numbers_str = ", ".join(str(n) for n in numbers) return { "question": f"Find the Greatest Common Divisor (GCD) of these numbers: {numbers_str}. Give only the GCD as your final answer.", "answer": str(result), - "metadata": {"numbers": numbers, "result": result}, + "metadata": { + "numbers": numbers, + "result": result, + "difficulty": { + "num_terms": num_terms, + "max_value": self.config.max_value, + }, + }, } +class GCDCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(GCDCurriculum.__name__, GCDConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="num_terms", + levels=[2, 3, 4, 5], + default_level=0, + description="number of terms", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_numbers", + upper_field_name="max_numbers", + ), + RangeAttributeDefinition( + name="max_value", + levels=[100, 1000, 10000, 100000], + default_level=0, + description="maximum value", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_value", + upper_field_name="max_value", + ), + ) + + register_dataset("gcd", GCDDataset, GCDConfig) diff --git a/tests/test_gcd.py b/tests/test_gcd.py index 1ed90df6..7178b46c 100644 --- a/tests/test_gcd.py +++ b/tests/test_gcd.py @@ -3,7 +3,7 @@ from math import gcd import pytest -from reasoning_gym.arithmetic import GCDConfig, GCDDataset +from reasoning_gym.arithmetic import GCDConfig, GCDCurriculum, GCDDataset def test_gcd_config_validation(): @@ -115,3 +115,39 @@ def test_gcd_special_cases(): # With enough samples, we should see both coprime and non-coprime numbers assert seen_gcd_1, "Expected to see some coprime numbers (GCD=1)" assert seen_large_gcd, "Expected to see some non-coprime numbers (GCD>1)" + + +def test_gcd_curriculum(): + """Test that curriculum generates correct items""" + curriculum = GCDCurriculum() + + base_value = {"size": 150, "seed": 1} + base_cfg: GCDConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_numbers == 2 and base_cfg.max_numbers == 2 + assert base_cfg.min_value == 100 and base_cfg.max_value == 100 + + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("max_value") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 3 + assert increased_cfg.min_value == 100 and increased_cfg.max_value == 1000 + + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("max_value") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 4 + assert increased_cfg.min_value == 100 and increased_cfg.max_value == 10000 + + curriculum.increment_attr_level("num_terms") + curriculum.increment_attr_level("max_value") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_numbers == 2 and increased_cfg.max_numbers == 5 + assert increased_cfg.min_value == 100 and increased_cfg.max_value == 100000 + + curriculum.decrement_attr_level("num_terms") + curriculum.decrement_attr_level("max_value") + decreased_cfg = curriculum.generate_configuration(base_value) + assert decreased_cfg.min_numbers == 2 and decreased_cfg.max_numbers == 4 + assert decreased_cfg.min_value == 100 and decreased_cfg.max_value == 10000