diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 45134332..2cea91fd 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -14,7 +14,7 @@ class ArithmeticDatasetConfig: allow_parentheses: bool = True allow_negation: bool = True seed: Optional[int] = None - size: int = 10000 # Virtual dataset size + size: int = 500 # Virtual dataset size format_style: Literal["simple", "natural"] = "simple" def validate(self): @@ -102,14 +102,14 @@ class ArithmeticDataset: add_terms(num_terms) - # Add random spaces + # Add at most one random space between parts space_parts = [] for p in parts: - while rng.random() < 0.15: + if rng.random() < 0.15: space_parts.append(" ") space_parts.append(p) - expression = " ".join(space_parts) + expression = " ".join(space_parts).strip() result = eval(expression) # Note: eval is safe here as we control the input return expression, result diff --git a/tests/test_arithmetic.py b/tests/test_arithmetic.py index 57a794ab..2ca1d6f1 100644 --- a/tests/test_arithmetic.py +++ b/tests/test_arithmetic.py @@ -58,7 +58,11 @@ def test_arithmetic_dataset_format_styles(): config = ArithmeticDatasetConfig( size=10, seed=42, - format_style="simple" + format_style="simple", + min_terms=2, + max_terms=3, # Keep expressions simple for testing + min_digits=1, + max_digits=2 ) dataset = ArithmeticDataset(config) assert all(item["question"].endswith("=") for item in dataset)