diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 355bb57b..0ddfc49f 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -72,6 +72,19 @@ class ChainSum: } } + def __iter__(self): + """Make the dataset iterable""" + self._current_idx = 0 + return self + + def __next__(self): + """Get next item in iteration""" + if self._current_idx >= self.config.size: + raise StopIteration + item = self[self._current_idx] + self._current_idx += 1 + return item + def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max_value: int) -> tuple[str, int]: """Generate a chain sum task diff --git a/tests/test_chain_sum.py b/tests/test_chain_sum.py index ac01825f..aff20a6c 100644 --- a/tests/test_chain_sum.py +++ b/tests/test_chain_sum.py @@ -125,3 +125,29 @@ def test_chain_sum_negation(): # With enough samples and allow_negation=True, we should see both positive and negative numbers assert has_positive and has_negative, "Expected both positive and negative numbers with allow_negation=True" + + +def test_chain_sum_iteration(): + """Test that iteration respects dataset size""" + config = ChainSumConfig( + min_terms=2, + max_terms=2, + size=5, # Small size for testing + seed=42 + ) + dataset = ChainSum(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test list conversion + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test multiple iterations + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items"