diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 7ea7027a..26758a74 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -17,6 +17,7 @@ class ArithmeticDatasetConfig: seed: Optional[int] = None size: int = 500 # Virtual dataset size format_style: Literal["simple", "natural"] = "simple" + whitespace: Literal["no_space", "single", "random"] = "single" # Whitespace style between terms def validate(self): """Validate configuration parameters""" @@ -104,14 +105,18 @@ class ArithmeticDataset: add_terms(num_terms) - # Add at most one random space between parts - space_parts = [] - for p in parts: - if rng.random() < 0.15: - space_parts.append(" ") - space_parts.append(p) - - expression = " ".join(space_parts).strip() + # Add whitespace according to config + if self.config.whitespace == "no_space": + expression = "".join(parts) + elif self.config.whitespace == "single": + expression = " ".join(parts) + else: # random + space_parts = [] + for p in parts: + if rng.random() < 0.15: + space_parts.append(" ") + space_parts.append(p) + expression = "".join(space_parts).strip() result = eval(expression) # Note: eval is safe here as we control the input return expression, result @@ -176,6 +181,7 @@ def arithmetic_dataset( seed: Optional[int] = None, size: int = 500, format_style: Literal["simple", "natural"] = "simple", + whitespace: Literal["no_space", "single", "random"] = "single", ) -> ArithmeticDataset: """Create an ArithmeticDataset with the given configuration. @@ -205,5 +211,6 @@ def arithmetic_dataset( seed=seed, size=size, format_style=format_style, + whitespace=whitespace, ) return ArithmeticDataset(config)