diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 6bfad5f9..9ea7ffc2 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Optional from ..dataset import ProceduralDataset +from ..factory import register_dataset @dataclass @@ -143,3 +144,6 @@ def chain_sum_dataset( size=size, ) return ChainSum(config) + +# Register the dataset +register_dataset("chain_sum", ChainSum, ChainSumConfig)