diff --git a/tests/test_basic_arithmetic.py b/tests/test_basic_arithmetic.py index 406e4617..757d3c2f 100644 --- a/tests/test_basic_arithmetic.py +++ b/tests/test_basic_arithmetic.py @@ -64,11 +64,7 @@ def test_arithmetic_dataset_format_styles(): max_digits=2, ) dataset = BasicArithmeticDataset(config) - assert all(item["question"].endswith("=") for item in dataset) - - config.format_style = "natural" - dataset = BasicArithmeticDataset(config) - assert all("=" in item["question"] for item in dataset) + assert all(item["question"].strip().endswith(".") for item in dataset) def test_arithmetic_dataset_iteration():