add random paren grouping

This commit is contained in:
Rich Jones 2025-02-20 10:46:01 +01:00
parent a76e56fccc
commit eb64e3a2b8

View file

@ -25,10 +25,33 @@ class DecimalArithmeticDatasetConfig:
), "precision must be 2 or more higher than max_num_decimal_places" ), "precision must be 2 or more higher than max_num_decimal_places"
def build_grouped_expression(operands, operators, rng):
"""
Recursively build an arithmetic expression string from operands and operators,
inserting parentheses at random.
The expression is built by choosing a random split among the operands;
the operator at that split becomes the root of the subexpression.
With 50% chance, the resulting combination is wrapped in parentheses.
"""
if len(operands) == 1:
return operands[0]
# Randomly choose a split point (1 <= split < len(operands)).
split = rng.randint(1, len(operands) - 1)
left_expr = build_grouped_expression(operands[:split], operators[: split - 1], rng)
right_expr = build_grouped_expression(operands[split:], operators[split:], rng)
# The operator at position (split - 1) is the one combining the two groups.
expr = left_expr + operators[split - 1] + right_expr
# Randomly decide to add parentheses around this subexpression.
if rng.choice([True, False]):
expr = "(" + expr + ")"
return expr
def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None): def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None):
""" """
Generates a simple arithmetic problem with decimal numbers (as a string) formatted Generates a simple arithmetic problem with decimal numbers (as a string) formatted
to a specific number of decimal places. to a specific number of decimal places, with random parenthesis grouping.
Parameters: Parameters:
rng: Random number generator. rng: Random number generator.
@ -43,10 +66,11 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla
if operations is None: if operations is None:
operations = ["+", "-", "*", "/"] operations = ["+", "-", "*", "/"]
tokens = [] operands = []
# Build the expression by alternating numbers and operators. operators = []
for i in range(terms): for i in range(terms):
# Choose a number of decimal places for this term. # Choose a random number of decimal places for this term.
ndp = rng.randint(min_num_decimal_places, max_num_decimal_places) ndp = rng.randint(min_num_decimal_places, max_num_decimal_places)
max_integer_part = 10 # Maximum whole number before the decimal max_integer_part = 10 # Maximum whole number before the decimal
max_value = max_integer_part * (10**ndp) max_value = max_integer_part * (10**ndp)
@ -57,12 +81,13 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla
num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP) num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP)
# Format the number as a string with exactly ndp decimals. # Format the number as a string with exactly ndp decimals.
num_str = f"{num:.{ndp}f}" num_str = f"{num:.{ndp}f}"
tokens.append(num_str) operands.append(num_str)
if i < terms - 1: if i < terms - 1:
op = rng.choice(operations) op = rng.choice(operations)
tokens.append(op) operators.append(op)
problem_str = "".join(tokens) + " = ?" expr = build_grouped_expression(operands, operators, rng)
problem_str = expr + " = ?"
return problem_str return problem_str
@ -105,8 +130,6 @@ def _eval_ast(node) -> Decimal:
else: else:
raise ValueError(f"Unsupported unary operator: {node.op}") raise ValueError(f"Unsupported unary operator: {node.op}")
elif isinstance(node, ast.Constant): # For Python 3.8+ elif isinstance(node, ast.Constant): # For Python 3.8+
# Although ast converts numeric literals to floats,
# converting via str helps us get a Decimal with the intended value.
return Decimal(str(node.value)) return Decimal(str(node.value))
elif isinstance(node, ast.Num): # For older Python versions elif isinstance(node, ast.Num): # For older Python versions
return Decimal(str(node.n)) return Decimal(str(node.n))
@ -144,7 +167,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
expr = problem_str.replace(" = ?", "").strip() expr = problem_str.replace(" = ?", "").strip()
answer = evaluate_expression(expr) answer = evaluate_expression(expr)
problem_str = problem_str = ( problem_str = (
f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n" f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n"
+ problem_str + problem_str
) )
@ -157,8 +180,6 @@ class DecimalArithmeticDataset(ProceduralDataset):
Instead of requiring exact equality, we allow an error up to one unit in the Instead of requiring exact equality, we allow an error up to one unit in the
least significant digit as determined by the level of precision (max_num_decimal_places). least significant digit as determined by the level of precision (max_num_decimal_places).
For example, if max_num_decimal_places is 6, then an error of up to 1e-6 is accepted.
Returns: Returns:
float: 1.0 if the user's answer is within tolerance; otherwise, 0.01. float: 1.0 if the user's answer is within tolerance; otherwise, 0.01.
""" """
@ -170,7 +191,6 @@ class DecimalArithmeticDataset(ProceduralDataset):
correct_ans = entry["answer"] correct_ans = entry["answer"]
# Determine tolerance based on the desired precision. # Determine tolerance based on the desired precision.
# Here, we allow a difference of 1 in the last decimal place.
precision = self.config.max_num_decimal_places precision = self.config.max_num_decimal_places
tol = Decimal(10) ** (-precision) tol = Decimal(10) ** (-precision)
if abs(user_ans - correct_ans) <= tol: if abs(user_ans - correct_ans) <= tol: