diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index 3aaedfe9..b2b7f95a 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -14,7 +14,7 @@ from .fraction_simplification import FractionSimplificationConfig, FractionSimpl from .gcd import GCDConfig, GCDDataset from .gsm_symbolic.gsm_symbolic import GSMSymbolicDataset, GSMSymbolicDatasetConfig from .lcm import LCMConfig, LCMDataset -from .leg_counting import LegCountingConfig, LegCountingDataset +from .leg_counting import LegCountingConfig, LegCountingCurriculum, LegCountingDataset from .number_format import NumberFormatConfig, NumberFormatDataset from .power_function import PowerFunctionConfig, PowerFunctionDataset from .prime_factorization import PrimeFactorizationConfig, PrimeFactorizationDataset @@ -36,6 +36,7 @@ __all__ = [ "LCMDataset", "LegCountingConfig", "LegCountingDataset", + "LegCountingCurriculum", "PowerFunctionConfig", "PowerFunctionDataset", "PrimeFactorizationConfig", diff --git a/reasoning_gym/arithmetic/leg_counting.py b/reasoning_gym/arithmetic/leg_counting.py index b68e133d..3733a80f 100644 --- a/reasoning_gym/arithmetic/leg_counting.py +++ b/reasoning_gym/arithmetic/leg_counting.py @@ -4,9 +4,7 @@ from dataclasses import dataclass from random import Random from typing import Optional -from reasoning_gym.coaching.attributes import AttributeType, RangeAttributeDefinition -from reasoning_gym.coaching.base_curriculum import BaseCurriculum - +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset ANIMALS = { @@ -69,6 +67,7 @@ class LegCountingConfig: min_animals: int = 3 # Minimum number of animals in problem max_animals: int = 10 # Maximum number of animals + min_instances: int = 1 # Minimum instances of each animal max_instances: int = 15 # Maximum instances of each animal seed: Optional[int] = None size: int = 500 # Virtual dataset size @@ -77,7 +76,8 @@ class LegCountingConfig: """Validate configuration parameters""" assert self.min_animals > 0, "min_animals must be positive" assert self.max_animals >= self.min_animals, "max_animals must be >= min_animals" - assert self.max_instances > 0, "max_instances must be positive" + assert self.min_instances > 0, "min_instances must be positive" + assert self.max_instances >= self.min_instances, "max_instances must be >= min_instances" class LegCountingDataset(ProceduralDataset): @@ -94,7 +94,7 @@ class LegCountingDataset(ProceduralDataset): # Select random animals selected_animals = rng.sample(list(ANIMALS.keys()), num_types) for animal in selected_animals: - count = rng.randint(1, self.config.max_instances) + count = rng.randint(self.config.min_instances, self.config.max_instances) animals[animal] = count return animals @@ -136,13 +136,23 @@ class LegCountingCurriculum(BaseCurriculum): RangeAttributeDefinition( name="num_animals", levels=list(range(1, 20)), - default_level=0, # Start with 2 terms + default_level=0, description="Number of animals in question", attr_type=AttributeType.APPEND, min_value=1, # Ensure at least 1 animal lower_field_name="min_animals", upper_field_name="max_animals", ), + RangeAttributeDefinition( + name="num_instances", + levels=[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024], + default_level=0, + description="Number of instances of each animal", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_instances", + upper_field_name="max_instances", + ), ) diff --git a/tests/test_leg_counting.py b/tests/test_leg_counting.py index 31191bda..266cb8f5 100644 --- a/tests/test_leg_counting.py +++ b/tests/test_leg_counting.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.arithmetic.leg_counting import ANIMALS, LegCountingConfig, LegCountingDataset +from reasoning_gym.arithmetic.leg_counting import ANIMALS, LegCountingConfig, LegCountingCurriculum, LegCountingDataset def test_leg_counting_config_validation(): @@ -84,3 +84,28 @@ def test_leg_counting_animal_validation(): assert ANIMALS["dog"] == 4 assert ANIMALS["chicken"] == 2 assert ANIMALS["snake"] == 0 + + +def test_leg_counting_curriculum(): + curriculum = LegCountingCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: LegCountingConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_animals == 1 and base_cfg.max_animals == 1 + assert base_cfg.min_instances == 2 and base_cfg.max_instances == 2 + + # test incrementing attribute levels + curriculum.increment_attr_level("num_animals") + curriculum.increment_attr_level("num_instances") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_animals == 1 and increased_cfg.max_animals == 2 + assert increased_cfg.min_instances == 2 and increased_cfg.max_instances == 4 + + # test decrementing attribute level for num_animals again + curriculum.decrement_attr_level("num_animals") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_animals == 1 and partially_decreased_cfg.max_animals == 1 + assert partially_decreased_cfg.min_instances == 2 and partially_decreased_cfg.max_instances == 4