add random paren grouping

This commit is contained in:
Rich Jones 2025-02-20 10:46:01 +01:00
parent a76e56fccc
commit eb64e3a2b8

View file

@ -25,10 +25,33 @@ class DecimalArithmeticDatasetConfig:
), "precision must be 2 or more higher than max_num_decimal_places"
def build_grouped_expression(operands, operators, rng):
"""
Recursively build an arithmetic expression string from operands and operators,
inserting parentheses at random.
The expression is built by choosing a random split among the operands;
the operator at that split becomes the root of the subexpression.
With 50% chance, the resulting combination is wrapped in parentheses.
"""
if len(operands) == 1:
return operands[0]
# Randomly choose a split point (1 <= split < len(operands)).
split = rng.randint(1, len(operands) - 1)
left_expr = build_grouped_expression(operands[:split], operators[: split - 1], rng)
right_expr = build_grouped_expression(operands[split:], operators[split:], rng)
# The operator at position (split - 1) is the one combining the two groups.
expr = left_expr + operators[split - 1] + right_expr
# Randomly decide to add parentheses around this subexpression.
if rng.choice([True, False]):
expr = "(" + expr + ")"
return expr
def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None):
"""
Generates a simple arithmetic problem with decimal numbers (as a string) formatted
to a specific number of decimal places.
to a specific number of decimal places, with random parenthesis grouping.
Parameters:
rng: Random number generator.
@ -43,10 +66,11 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla
if operations is None:
operations = ["+", "-", "*", "/"]
tokens = []
# Build the expression by alternating numbers and operators.
operands = []
operators = []
for i in range(terms):
# Choose a number of decimal places for this term.
# Choose a random number of decimal places for this term.
ndp = rng.randint(min_num_decimal_places, max_num_decimal_places)
max_integer_part = 10 # Maximum whole number before the decimal
max_value = max_integer_part * (10**ndp)
@ -57,12 +81,13 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla
num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP)
# Format the number as a string with exactly ndp decimals.
num_str = f"{num:.{ndp}f}"
tokens.append(num_str)
operands.append(num_str)
if i < terms - 1:
op = rng.choice(operations)
tokens.append(op)
operators.append(op)
problem_str = "".join(tokens) + " = ?"
expr = build_grouped_expression(operands, operators, rng)
problem_str = expr + " = ?"
return problem_str
@ -105,8 +130,6 @@ def _eval_ast(node) -> Decimal:
else:
raise ValueError(f"Unsupported unary operator: {node.op}")
elif isinstance(node, ast.Constant): # For Python 3.8+
# Although ast converts numeric literals to floats,
# converting via str helps us get a Decimal with the intended value.
return Decimal(str(node.value))
elif isinstance(node, ast.Num): # For older Python versions
return Decimal(str(node.n))
@ -144,7 +167,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
expr = problem_str.replace(" = ?", "").strip()
answer = evaluate_expression(expr)
problem_str = problem_str = (
problem_str = (
f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n"
+ problem_str
)
@ -157,8 +180,6 @@ class DecimalArithmeticDataset(ProceduralDataset):
Instead of requiring exact equality, we allow an error up to one unit in the
least significant digit as determined by the level of precision (max_num_decimal_places).
For example, if max_num_decimal_places is 6, then an error of up to 1e-6 is accepted.
Returns:
float: 1.0 if the user's answer is within tolerance; otherwise, 0.01.
"""
@ -170,7 +191,6 @@ class DecimalArithmeticDataset(ProceduralDataset):
correct_ans = entry["answer"]
# Determine tolerance based on the desired precision.
# Here, we allow a difference of 1 in the last decimal place.
precision = self.config.max_num_decimal_places
tol = Decimal(10) ** (-precision)
if abs(user_ans - correct_ans) <= tol: