diff --git a/tests/test_basic_arithmetic.py b/tests/test_basic_arithmetic.py index c1035af9..6eda876a 100644 --- a/tests/test_basic_arithmetic.py +++ b/tests/test_basic_arithmetic.py @@ -74,7 +74,7 @@ def test_arithmetic_dataset_format_styles(): max_digits=2, ) dataset = BasicArithmeticDataset(config) - assert all(item["question"].strip().endswith(".") for item in dataset) + assert all(item["question"].strip().endswith(".") or item["question"].strip().endswith("?") for item in dataset) def test_arithmetic_dataset_iteration():