diff --git a/reasoning_gym/arithmetic/decimal_arithmetic.py b/reasoning_gym/arithmetic/decimal_arithmetic.py index d9465f23..d72b4d03 100644 --- a/reasoning_gym/arithmetic/decimal_arithmetic.py +++ b/reasoning_gym/arithmetic/decimal_arithmetic.py @@ -2,7 +2,7 @@ import ast from dataclasses import dataclass from decimal import ROUND_HALF_UP, Decimal, getcontext from random import Random -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from ..factory import ProceduralDataset, register_dataset @@ -18,14 +18,16 @@ class DecimalArithmeticConfig: seed: Optional[int] = None size: int = 500 - def validate(self): + def validate(self) -> None: """Validate configuration parameters""" assert ( self.precision > self.max_num_decimal_places + 1 ), "precision must be 2 or more higher than max_num_decimal_places" -def build_grouped_expression(operands, operators, rng): +def build_grouped_expression( + operands: List[str], operators: List[str], rng: Random +) -> str: """ Recursively build an arithmetic expression string from operands and operators, inserting parentheses at random. @@ -37,18 +39,24 @@ def build_grouped_expression(operands, operators, rng): if len(operands) == 1: return operands[0] # Randomly choose a split point (1 <= split < len(operands)). - split = rng.randint(1, len(operands) - 1) - left_expr = build_grouped_expression(operands[:split], operators[: split - 1], rng) - right_expr = build_grouped_expression(operands[split:], operators[split:], rng) + split: int = rng.randint(1, len(operands) - 1) + left_expr: str = build_grouped_expression(operands[:split], operators[: split - 1], rng) + right_expr: str = build_grouped_expression(operands[split:], operators[split:], rng) # The operator at position (split - 1) is the one combining the two groups. - expr = left_expr + operators[split - 1] + right_expr + expr: str = left_expr + operators[split - 1] + right_expr # Randomly decide to add parentheses around this subexpression. if rng.choice([True, False]): expr = "(" + expr + ")" return expr -def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None): +def generate_arithmetic_problem( + rng: Random, + min_num_decimal_places: int, + max_num_decimal_places: int, + terms: int = 2, + operations: Optional[List[str]] = None +) -> str: """ Generates a simple arithmetic problem with decimal numbers (as a string) formatted to a specific number of decimal places, with random parenthesis grouping. @@ -66,28 +74,28 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla if operations is None: operations = ["+", "-", "*", "/"] - operands = [] - operators = [] + operands: List[str] = [] + operators: List[str] = [] for i in range(terms): # Choose a random number of decimal places for this term. - ndp = rng.randint(min_num_decimal_places, max_num_decimal_places) - max_integer_part = 10 # Maximum whole number before the decimal - max_value = max_integer_part * (10**ndp) - raw_int = rng.randint(1, max_value) + ndp: int = rng.randint(min_num_decimal_places, max_num_decimal_places) + max_integer_part: int = 10 # Maximum whole number before the decimal + max_value: int = max_integer_part * (10 ** ndp) + raw_int: int = rng.randint(1, max_value) # Create the Decimal number and quantize it to exactly ndp decimal places. - num = Decimal(raw_int) / (Decimal(10) ** ndp) - quantize_str = "1." + "0" * ndp + num: Decimal = Decimal(raw_int) / (Decimal(10) ** ndp) + quantize_str: str = "1." + "0" * ndp num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP) # Format the number as a string with exactly ndp decimals. - num_str = f"{num:.{ndp}f}" + num_str: str = f"{num:.{ndp}f}" operands.append(num_str) if i < terms - 1: - op = rng.choice(operations) + op: str = rng.choice(operations) operators.append(op) - expr = build_grouped_expression(operands, operators, rng) - problem_str = expr + " = ?" + expr: str = build_grouped_expression(operands, operators, rng) + problem_str: str = expr + " = ?" return problem_str @@ -102,15 +110,15 @@ def evaluate_expression(expr: str) -> Decimal: Returns: Decimal: The computed result. """ - tree = ast.parse(expr, mode="eval") + tree: ast.Expression = ast.parse(expr, mode="eval") return _eval_ast(tree.body) -def _eval_ast(node) -> Decimal: +def _eval_ast(node: ast.AST) -> Decimal: """Recursively evaluate an AST node using Decimal arithmetic.""" if isinstance(node, ast.BinOp): - left = _eval_ast(node.left) - right = _eval_ast(node.right) + left: Decimal = _eval_ast(node.left) + right: Decimal = _eval_ast(node.right) if isinstance(node.op, ast.Add): return left + right elif isinstance(node.op, ast.Sub): @@ -122,7 +130,7 @@ def _eval_ast(node) -> Decimal: else: raise ValueError(f"Unsupported operator: {node.op}") elif isinstance(node, ast.UnaryOp): - operand = _eval_ast(node.operand) + operand: Decimal = _eval_ast(node.operand) if isinstance(node.op, ast.UAdd): return operand elif isinstance(node.op, ast.USub): @@ -140,10 +148,10 @@ def _eval_ast(node) -> Decimal: class DecimalArithmeticDataset(ProceduralDataset): """Dataset that generates basic arithmetic tasks using Decimal arithmetic and proper operator precedence.""" - def __init__(self, config: DecimalArithmeticConfig): + def __init__(self, config: DecimalArithmeticConfig) -> None: super().__init__(config=config, seed=config.seed, size=config.size) - def __getitem__(self, idx: int) -> dict[str, Any]: + def __getitem__(self, idx: int) -> Dict[str, Any]: """ Generate a single arithmetic task. @@ -154,18 +162,18 @@ class DecimalArithmeticDataset(ProceduralDataset): - 'metadata': Additional metadata (currently empty). """ # Create a deterministic RNG from base seed and index. - rng = Random(self.seed + idx if self.seed is not None else None) + rng: Random = Random(self.seed + idx if self.seed is not None else None) getcontext().prec = self.config.precision - problem_str = generate_arithmetic_problem( + problem_str: str = generate_arithmetic_problem( rng, self.config.min_num_decimal_places, self.config.max_num_decimal_places, terms=self.config.terms, ) # Remove the trailing " = ?" to obtain the pure arithmetic expression. - expr = problem_str.replace(" = ?", "").strip() - answer = evaluate_expression(expr) + expr: str = problem_str.replace(" = ?", "").strip() + answer: Decimal = evaluate_expression(expr) problem_str = ( f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n" @@ -187,12 +195,12 @@ class DecimalArithmeticDataset(ProceduralDataset): return 0.0 try: - user_ans = Decimal(answer) - correct_ans = entry["answer"] + user_ans: Decimal = Decimal(answer) + correct_ans: Decimal = entry["answer"] # Determine tolerance based on the desired precision. - precision = self.config.max_num_decimal_places - tol = Decimal(10) ** (-precision) + precision: int = self.config.max_num_decimal_places + tol: Decimal = Decimal(10) ** (-precision) if abs(user_ans - correct_ans) <= tol: return 1.0 except Exception: