diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 41e12f67..4be180d0 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -7,7 +7,7 @@ Algorithmic tasks for training reasoning capabilities: """ from reasoning_gym.arithmetic.basic_arithmetic import arithmetic_dataset -from reasoning_gym.arithmetic.chain_sum import chain_sum +from reasoning_gym.arithmetic.chain_sum import chain_sum_dataset from .base_conversion import BaseConversionConfig, BaseConversionDataset, base_conversion_dataset from .letter_counting import LetterCountingConfig, LetterCountingDataset, letter_counting_dataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset, number_filtering_dataset @@ -18,7 +18,7 @@ __all__ = [ "BaseConversionConfig", "BaseConversionDataset", "base_conversion_dataset", - "chain_sum", + "chain_sum_dataset", "LetterCountingConfig", "LetterCountingDataset", "letter_counting_dataset", diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index a9cbe9bf..46e5cfdf 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -7,7 +7,7 @@ Arithmetic tasks for training reasoning capabilities: """ from .basic_arithmetic import ArithmeticDataset, ArithmeticDatasetConfig, arithmetic_dataset -from .chain_sum import ChainSum, ChainSumConfig, chain_sum +from .chain_sum import ChainSum, ChainSumConfig, chain_sum_dataset from .leg_counting import LegCountingConfig, LegCountingDataset, leg_counting_dataset __all__ = [ @@ -16,7 +16,7 @@ __all__ = [ "arithmetic_dataset", "ChainSum", "ChainSumConfig", - "chain_sum", + "chain_sum_dataset", "LegCountingConfig", "LegCountingDataset", "leg_counting_dataset" diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 5fc22387..0b8fdf7d 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -125,7 +125,7 @@ class ChainSum: return expression, result -def chain_sum( +def chain_sum_dataset( min_terms: int = 2, max_terms: int = 6, min_digits: int = 1,