diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index eed3bb98..e8a1cb94 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -65,7 +65,7 @@ class BasicArithmeticDataset(ProceduralDataset): def __init__(self, config: BasicArithmeticDatasetConfig): super().__init__(config=config, seed=config.seed, size=config.size) self.added_instruction = ( - "Ensure to report the answer as an integer. Please do not add commas to the integer answers reported." + " Ensure to report the answer as an integer. Do not add commas to the integer answers reported." ) def __getitem__(self, idx: int) -> dict[str, Any]: @@ -226,15 +226,14 @@ class BasicArithmeticDataset(ProceduralDataset): return expression, result def _format_question(self, rng: Random, expression: str) -> str: - """Format the expression with clear answer positioning""" - # answer_instruction = "Put your final answer after '=' without additional text." + """Format the the question with the arithmetic expression""" if self.config.format_style == "simple": - return f"Calculate {expression}. " + return f"Calculate {expression}." else: - templates = ["What is {0}. ", "Solve {0}. ", "Compute {0}. ", "Evaluate: {0}. "] - template = rng.choice(templates).format(expression) - return f"{template}" + templates = ["What is {0}?", "Solve {0}.", "Compute {0}.", "Evaluate: {0}."] + template = rng.choice(templates) + return template.format(expression) # Register the dataset diff --git a/tests/test_basic_arithmetic.py b/tests/test_basic_arithmetic.py index 757d3c2f..c1035af9 100644 --- a/tests/test_basic_arithmetic.py +++ b/tests/test_basic_arithmetic.py @@ -1,5 +1,3 @@ -from random import Random - import pytest from reasoning_gym.arithmetic.basic_arithmetic import ( @@ -66,6 +64,18 @@ def test_arithmetic_dataset_format_styles(): dataset = BasicArithmeticDataset(config) assert all(item["question"].strip().endswith(".") for item in dataset) + config = BasicArithmeticDatasetConfig( + size=10, + seed=42, + format_style="natural", + min_terms=2, + max_terms=3, # Keep expressions simple for testing + min_digits=1, + max_digits=2, + ) + dataset = BasicArithmeticDataset(config) + assert all(item["question"].strip().endswith(".") for item in dataset) + def test_arithmetic_dataset_iteration(): """Test that iteration respects dataset size"""