mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
add random paren grouping
This commit is contained in:
parent
a76e56fccc
commit
eb64e3a2b8
1 changed files with 33 additions and 13 deletions
|
|
@ -25,10 +25,33 @@ class DecimalArithmeticDatasetConfig:
|
|||
), "precision must be 2 or more higher than max_num_decimal_places"
|
||||
|
||||
|
||||
def build_grouped_expression(operands, operators, rng):
|
||||
"""
|
||||
Recursively build an arithmetic expression string from operands and operators,
|
||||
inserting parentheses at random.
|
||||
|
||||
The expression is built by choosing a random split among the operands;
|
||||
the operator at that split becomes the “root” of the subexpression.
|
||||
With 50% chance, the resulting combination is wrapped in parentheses.
|
||||
"""
|
||||
if len(operands) == 1:
|
||||
return operands[0]
|
||||
# Randomly choose a split point (1 <= split < len(operands)).
|
||||
split = rng.randint(1, len(operands) - 1)
|
||||
left_expr = build_grouped_expression(operands[:split], operators[: split - 1], rng)
|
||||
right_expr = build_grouped_expression(operands[split:], operators[split:], rng)
|
||||
# The operator at position (split - 1) is the one combining the two groups.
|
||||
expr = left_expr + operators[split - 1] + right_expr
|
||||
# Randomly decide to add parentheses around this subexpression.
|
||||
if rng.choice([True, False]):
|
||||
expr = "(" + expr + ")"
|
||||
return expr
|
||||
|
||||
|
||||
def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None):
|
||||
"""
|
||||
Generates a simple arithmetic problem with decimal numbers (as a string) formatted
|
||||
to a specific number of decimal places.
|
||||
to a specific number of decimal places, with random parenthesis grouping.
|
||||
|
||||
Parameters:
|
||||
rng: Random number generator.
|
||||
|
|
@ -43,10 +66,11 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla
|
|||
if operations is None:
|
||||
operations = ["+", "-", "*", "/"]
|
||||
|
||||
tokens = []
|
||||
# Build the expression by alternating numbers and operators.
|
||||
operands = []
|
||||
operators = []
|
||||
|
||||
for i in range(terms):
|
||||
# Choose a number of decimal places for this term.
|
||||
# Choose a random number of decimal places for this term.
|
||||
ndp = rng.randint(min_num_decimal_places, max_num_decimal_places)
|
||||
max_integer_part = 10 # Maximum whole number before the decimal
|
||||
max_value = max_integer_part * (10**ndp)
|
||||
|
|
@ -57,12 +81,13 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla
|
|||
num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP)
|
||||
# Format the number as a string with exactly ndp decimals.
|
||||
num_str = f"{num:.{ndp}f}"
|
||||
tokens.append(num_str)
|
||||
operands.append(num_str)
|
||||
if i < terms - 1:
|
||||
op = rng.choice(operations)
|
||||
tokens.append(op)
|
||||
operators.append(op)
|
||||
|
||||
problem_str = "".join(tokens) + " = ?"
|
||||
expr = build_grouped_expression(operands, operators, rng)
|
||||
problem_str = expr + " = ?"
|
||||
return problem_str
|
||||
|
||||
|
||||
|
|
@ -105,8 +130,6 @@ def _eval_ast(node) -> Decimal:
|
|||
else:
|
||||
raise ValueError(f"Unsupported unary operator: {node.op}")
|
||||
elif isinstance(node, ast.Constant): # For Python 3.8+
|
||||
# Although ast converts numeric literals to floats,
|
||||
# converting via str helps us get a Decimal with the intended value.
|
||||
return Decimal(str(node.value))
|
||||
elif isinstance(node, ast.Num): # For older Python versions
|
||||
return Decimal(str(node.n))
|
||||
|
|
@ -144,7 +167,7 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
expr = problem_str.replace(" = ?", "").strip()
|
||||
answer = evaluate_expression(expr)
|
||||
|
||||
problem_str = problem_str = (
|
||||
problem_str = (
|
||||
f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n"
|
||||
+ problem_str
|
||||
)
|
||||
|
|
@ -157,8 +180,6 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
Instead of requiring exact equality, we allow an error up to one unit in the
|
||||
least significant digit as determined by the level of precision (max_num_decimal_places).
|
||||
|
||||
For example, if max_num_decimal_places is 6, then an error of up to 1e-6 is accepted.
|
||||
|
||||
Returns:
|
||||
float: 1.0 if the user's answer is within tolerance; otherwise, 0.01.
|
||||
"""
|
||||
|
|
@ -170,7 +191,6 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
correct_ans = entry["answer"]
|
||||
|
||||
# Determine tolerance based on the desired precision.
|
||||
# Here, we allow a difference of 1 in the last decimal place.
|
||||
precision = self.config.max_num_decimal_places
|
||||
tol = Decimal(10) ** (-precision)
|
||||
if abs(user_ans - correct_ans) <= tol:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue