diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 2cea91fd..9a97f35b 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -141,6 +141,19 @@ class ArithmeticDataset: expression = " ".join(expression_parts) return expression, result + def __iter__(self): + """Make the dataset iterable""" + self._current_idx = 0 + return self + + def __next__(self): + """Get next item in iteration""" + if self._current_idx >= self.config.size: + raise StopIteration + item = self[self._current_idx] + self._current_idx += 1 + return item + def _format_question(self, expression: str) -> str: """Format the expression according to config style""" if self.config.format_style == "simple": diff --git a/tests/test_arithmetic.py b/tests/test_arithmetic.py index 2ca1d6f1..d95e88db 100644 --- a/tests/test_arithmetic.py +++ b/tests/test_arithmetic.py @@ -70,3 +70,29 @@ def test_arithmetic_dataset_format_styles(): config.format_style = "natural" dataset = ArithmeticDataset(config) assert all("=" not in item["question"] for item in dataset) + + +def test_arithmetic_dataset_iteration(): + """Test that iteration respects dataset size""" + config = ArithmeticDatasetConfig( + min_terms=2, + max_terms=2, + size=5, # Small size for testing + seed=42 + ) + dataset = ArithmeticDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test list conversion + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test multiple iterations + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items"