diff --git a/reasoning_gym/arithmetic/decimal_arithmetic.py b/reasoning_gym/arithmetic/decimal_arithmetic.py index c97c7230..b5877dba 100644 --- a/reasoning_gym/arithmetic/decimal_arithmetic.py +++ b/reasoning_gym/arithmetic/decimal_arithmetic.py @@ -25,10 +25,33 @@ class DecimalArithmeticDatasetConfig: ), "precision must be 2 or more higher than max_num_decimal_places" +def build_grouped_expression(operands, operators, rng): + """ + Recursively build an arithmetic expression string from operands and operators, + inserting parentheses at random. + + The expression is built by choosing a random split among the operands; + the operator at that split becomes the “root” of the subexpression. + With 50% chance, the resulting combination is wrapped in parentheses. + """ + if len(operands) == 1: + return operands[0] + # Randomly choose a split point (1 <= split < len(operands)). + split = rng.randint(1, len(operands) - 1) + left_expr = build_grouped_expression(operands[:split], operators[: split - 1], rng) + right_expr = build_grouped_expression(operands[split:], operators[split:], rng) + # The operator at position (split - 1) is the one combining the two groups. + expr = left_expr + operators[split - 1] + right_expr + # Randomly decide to add parentheses around this subexpression. + if rng.choice([True, False]): + expr = "(" + expr + ")" + return expr + + def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None): """ Generates a simple arithmetic problem with decimal numbers (as a string) formatted - to a specific number of decimal places. + to a specific number of decimal places, with random parenthesis grouping. Parameters: rng: Random number generator. @@ -43,10 +66,11 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla if operations is None: operations = ["+", "-", "*", "/"] - tokens = [] - # Build the expression by alternating numbers and operators. + operands = [] + operators = [] + for i in range(terms): - # Choose a number of decimal places for this term. + # Choose a random number of decimal places for this term. ndp = rng.randint(min_num_decimal_places, max_num_decimal_places) max_integer_part = 10 # Maximum whole number before the decimal max_value = max_integer_part * (10**ndp) @@ -57,12 +81,13 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP) # Format the number as a string with exactly ndp decimals. num_str = f"{num:.{ndp}f}" - tokens.append(num_str) + operands.append(num_str) if i < terms - 1: op = rng.choice(operations) - tokens.append(op) + operators.append(op) - problem_str = "".join(tokens) + " = ?" + expr = build_grouped_expression(operands, operators, rng) + problem_str = expr + " = ?" return problem_str @@ -105,8 +130,6 @@ def _eval_ast(node) -> Decimal: else: raise ValueError(f"Unsupported unary operator: {node.op}") elif isinstance(node, ast.Constant): # For Python 3.8+ - # Although ast converts numeric literals to floats, - # converting via str helps us get a Decimal with the intended value. return Decimal(str(node.value)) elif isinstance(node, ast.Num): # For older Python versions return Decimal(str(node.n)) @@ -144,7 +167,7 @@ class DecimalArithmeticDataset(ProceduralDataset): expr = problem_str.replace(" = ?", "").strip() answer = evaluate_expression(expr) - problem_str = problem_str = ( + problem_str = ( f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n" + problem_str ) @@ -157,8 +180,6 @@ class DecimalArithmeticDataset(ProceduralDataset): Instead of requiring exact equality, we allow an error up to one unit in the least significant digit as determined by the level of precision (max_num_decimal_places). - For example, if max_num_decimal_places is 6, then an error of up to 1e-6 is accepted. - Returns: float: 1.0 if the user's answer is within tolerance; otherwise, 0.01. """ @@ -170,7 +191,6 @@ class DecimalArithmeticDataset(ProceduralDataset): correct_ans = entry["answer"] # Determine tolerance based on the desired precision. - # Here, we allow a difference of 1 in the last decimal place. precision = self.config.max_num_decimal_places tol = Decimal(10) ** (-precision) if abs(user_ans - correct_ans) <= tol: