diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index de78b69d..cd168229 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -15,7 +15,7 @@ from .gcd import GCDConfig, GCDCurriculum, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig from .lcm import LCMConfig, LCMDataset from .leg_counting import LegCountingConfig, LegCountingCurriculum, LegCountingDataset -from .number_format import NumberFormatConfig, NumberFormatDataset +from .number_format import NumberFormatConfig, NumberFormatCurriculum, NumberFormatDataset from .power_function import PowerFunctionConfig, PowerFunctionCurriculum, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset from .products import ProductsConfig, ProductsDataset @@ -59,6 +59,7 @@ __all__ = [ "DiceCurriculum", "NumberFormatConfig", "NumberFormatDataset", + "NumberFormatCurriculum", "DecimalArithmeticConfig", "DecimalArithmeticDataset", "DecimalArithmeticCurriculum", diff --git a/reasoning_gym/arithmetic/number_format.py b/reasoning_gym/arithmetic/number_format.py index a85bf7cb..00a9d732 100644 --- a/reasoning_gym/arithmetic/number_format.py +++ b/reasoning_gym/arithmetic/number_format.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Your task is to pick the largest/smallest number out of several options. @@ -18,20 +19,24 @@ Now, pick the {size} number of the following candidates: {numbers} class NumberFormatConfig: """Configuration for Count Bits dataset generation""" + min_num_candidates: int = 2 # Minimum number of candidates max_num_candidates: int = 5 # Maximum number of candidates min_n: float = 1_000 # Lower bound for the numbers max_n: float = 1_000_000_000 # Upper bound for the numbers - max_delta: int = 1_000 + max_delta: float = 10.0 size: int = 500 # Virtual dataset size seed: Optional[int] = None def validate(self): """Validate configuration parameters""" - assert 2 <= self.max_num_candidates, "max_num_candidates must be at least 2" + assert 2 <= self.min_num_candidates, "min_num_candidates must be at least 2" + assert ( + self.min_num_candidates <= self.max_num_candidates + ), "min_num_candidates must be less than max_num_candidates" assert 1 <= self.min_n, "min_n must be at least 1" assert self.min_n < self.max_n, "min_n must be less than max_n" - assert 1 <= self.max_delta, "max_delta must be at least 1" + assert 0 < self.max_delta, "max_delta must be greater than 0" class NumberFormatDataset(ProceduralDataset): @@ -78,7 +83,7 @@ class NumberFormatDataset(ProceduralDataset): """Generate a single Count Bits question""" rng = Random(self.seed + idx) - num_candidates = rng.randint(2, self.config.max_num_candidates) + num_candidates = rng.randint(self.config.min_num_candidates, self.config.max_num_candidates) candidates = self._get_candidates(rng, num_candidates) formatted_candidates = self._transform_candidates(rng, candidates) @@ -93,8 +98,50 @@ class NumberFormatDataset(ProceduralDataset): "solution": answer, "formatted_candidates": formatted_candidates, "size": size, + "difficulty": { + "num_candidates": num_candidates, + "n": (self.config.min_n, self.config.max_n), + "min_delta": self.config.max_delta, + }, }, } -register_dataset("number_format", NumberFormatDataset, NumberFormatConfig) +class NumberFormatCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(NumberFormatCurriculum.__name__, NumberFormatConfig) + + self._define_attributes( + RangeAttributeDefinition( + name="num_candidates", + levels=[5, 25, 100, 500], + default_level=1, + description="Number of candidates", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_num_candidates", + upper_field_name="max_num_candidates", + ), + RangeAttributeDefinition( + name="n", + levels=[10, 1_000, 1_000_000, 1_000_000_000], + default_level=1, + description="Magnitude of the values", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_n", + upper_field_name="max_n", + ), + ScalarAttributeDefinition( + name="max_delta", + field_name="max_delta", + levels=[1e1, 1e0, 1e-3, 1e-6], + default_level=0, + description="Max delta", + attr_type=AttributeType.STATIC, + min_value=1e-6, + ), + ) + + +register_dataset("number_format", NumberFormatDataset, NumberFormatConfig, NumberFormatCurriculum) diff --git a/tests/test_number_format.py b/tests/test_number_format.py index 76ae834b..f0819bd8 100644 --- a/tests/test_number_format.py +++ b/tests/test_number_format.py @@ -2,17 +2,21 @@ import pytest -from reasoning_gym.arithmetic.number_format import NumberFormatConfig, NumberFormatDataset +from reasoning_gym.arithmetic.number_format import NumberFormatConfig, NumberFormatCurriculum, NumberFormatDataset def test_number_format_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): - config = NumberFormatConfig(max_num_candidates=0) # Zero not allowed + config = NumberFormatConfig(min_num_candidates=0) # Zero not allowed config.validate() with pytest.raises(AssertionError): - config = NumberFormatConfig(max_num_candidates=1) # One not allowed + config = NumberFormatConfig(min_num_candidates=1) # One not allowed + config.validate() + + with pytest.raises(AssertionError): + config = NumberFormatConfig(max_num_candidates=5, min_num_candidates=6) # min > max config.validate() with pytest.raises(AssertionError): @@ -119,3 +123,32 @@ def test_number_format_answer(): # Answer is unparsable model_answer = "test" assert dataset.score_answer(model_answer, entry) == 0.0 + + +def test_number_format_curriculum(): + curriculum = NumberFormatCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: NumberFormatConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_num_candidates == 5 and base_cfg.max_num_candidates == 25 + assert base_cfg.min_n == 10 and base_cfg.max_n == 1_000 + assert base_cfg.max_delta == 1e1 + + # test incrementing attribute levels + curriculum.increment_attr_level("num_candidates") + curriculum.increment_attr_level("n") + curriculum.increment_attr_level("max_delta") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_num_candidates == 5 and increased_cfg.max_num_candidates == 100 + assert increased_cfg.min_n == 10 and increased_cfg.max_n == 1_000_000 + assert increased_cfg.max_delta == 1e0 + + # test decrementing attribute level + curriculum.decrement_attr_level("num_candidates") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_num_candidates == 5 and partially_decreased_cfg.max_num_candidates == 25 + assert partially_decreased_cfg.min_n == 10 and partially_decreased_cfg.max_n == 1_000_000 + assert partially_decreased_cfg.max_delta == 1e0