From ba58ecf8ea4eaf987796b6f4285f1d95f7d7868d Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Sun, 16 Feb 2025 12:01:54 +0000 Subject: [PATCH] corrected failing airthmetic test --- reasoning_gym/arithmetic/basic_arithmetic.py | 9 +++++---- tests/test_basic_arithmetic.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index a65ea295..efe9d465 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -227,15 +227,16 @@ class BasicArithmeticDataset(ProceduralDataset): answer_instruction = "Put your final answer after '=' without additional text." if self.config.format_style == "simple": - return f"Calculate {expression} =" + return f"{answer_instruction} Calculate {expression} =" else: templates = [ - "What is {0}? =", - "Solve {0} and write answer after =", + "What is {0} =", + "Solve {0}=", "Compute {0} =", "Evaluate: {0} =" ] - return rng.choice(templates).format(expression) + f" {answer_instruction}" + template = rng.choice(templates).format(expression) + return f"{answer_instruction} {template}" # Register the dataset diff --git a/tests/test_basic_arithmetic.py b/tests/test_basic_arithmetic.py index 3d3d08b5..406e4617 100644 --- a/tests/test_basic_arithmetic.py +++ b/tests/test_basic_arithmetic.py @@ -68,7 +68,7 @@ def test_arithmetic_dataset_format_styles(): config.format_style = "natural" dataset = BasicArithmeticDataset(config) - assert all("=" not in item["question"] for item in dataset) + assert all("=" in item["question"] for item in dataset) def test_arithmetic_dataset_iteration():