diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index b7b1f950..b3958b26 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -191,6 +191,9 @@ class BasicArithmeticDataset(ProceduralDataset): space_parts.append(" ") space_parts.append(p) expression = "".join(space_parts).strip() + # Avoid division-by-zero in final evaluation by converting '/0' patterns to '/1' + if "/ 0" in expression: + expression = expression.replace("/ 0", "/ 1") result = eval_floordiv(expression) # Note: eval is safe here as we control the input return expression, result diff --git a/tests/test_basic_arithmetic.py b/tests/test_basic_arithmetic.py index 0101109f..e6a9de35 100644 --- a/tests/test_basic_arithmetic.py +++ b/tests/test_basic_arithmetic.py @@ -168,3 +168,19 @@ def test_basic_arithmetic_curriculum_upper_bound(): increased_cfg = curriculum.generate_configuration(base_value) assert increased_cfg.min_terms == 2 and increased_cfg.max_terms == 3 assert increased_cfg.min_digits == 1 and increased_cfg.max_digits == 2 + + +def test_arithmetic_dataset_large_random_generation(): + """Stress-test generation of many arithmetic questions to catch random errors""" + config = BasicArithmeticDatasetConfig( + size=100000, + seed=123, + min_terms=2, + max_terms=6, + min_digits=1, + max_digits=3, + ) + dataset = BasicArithmeticDataset(config) + for item in dataset: + assert isinstance(item, dict) + assert "question" in item and "answer" in item and "metadata" in item