type hints

blind roboting
This commit is contained in:
Rich Jones 2025-02-20 12:10:28 +01:00 committed by GitHub
parent da0b882b87
commit 921c9b1d7b

View file

@ -2,7 +2,7 @@ import ast
from dataclasses import dataclass
from decimal import ROUND_HALF_UP, Decimal, getcontext
from random import Random
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List
from ..factory import ProceduralDataset, register_dataset
@ -18,14 +18,16 @@ class DecimalArithmeticConfig:
seed: Optional[int] = None
size: int = 500
def validate(self):
def validate(self) -> None:
"""Validate configuration parameters"""
assert (
self.precision > self.max_num_decimal_places + 1
), "precision must be 2 or more higher than max_num_decimal_places"
def build_grouped_expression(operands, operators, rng):
def build_grouped_expression(
operands: List[str], operators: List[str], rng: Random
) -> str:
"""
Recursively build an arithmetic expression string from operands and operators,
inserting parentheses at random.
@ -37,18 +39,24 @@ def build_grouped_expression(operands, operators, rng):
if len(operands) == 1:
return operands[0]
# Randomly choose a split point (1 <= split < len(operands)).
split = rng.randint(1, len(operands) - 1)
left_expr = build_grouped_expression(operands[:split], operators[: split - 1], rng)
right_expr = build_grouped_expression(operands[split:], operators[split:], rng)
split: int = rng.randint(1, len(operands) - 1)
left_expr: str = build_grouped_expression(operands[:split], operators[: split - 1], rng)
right_expr: str = build_grouped_expression(operands[split:], operators[split:], rng)
# The operator at position (split - 1) is the one combining the two groups.
expr = left_expr + operators[split - 1] + right_expr
expr: str = left_expr + operators[split - 1] + right_expr
# Randomly decide to add parentheses around this subexpression.
if rng.choice([True, False]):
expr = "(" + expr + ")"
return expr
def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None):
def generate_arithmetic_problem(
rng: Random,
min_num_decimal_places: int,
max_num_decimal_places: int,
terms: int = 2,
operations: Optional[List[str]] = None
) -> str:
"""
Generates a simple arithmetic problem with decimal numbers (as a string) formatted
to a specific number of decimal places, with random parenthesis grouping.
@ -66,28 +74,28 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla
if operations is None:
operations = ["+", "-", "*", "/"]
operands = []
operators = []
operands: List[str] = []
operators: List[str] = []
for i in range(terms):
# Choose a random number of decimal places for this term.
ndp = rng.randint(min_num_decimal_places, max_num_decimal_places)
max_integer_part = 10 # Maximum whole number before the decimal
max_value = max_integer_part * (10**ndp)
raw_int = rng.randint(1, max_value)
ndp: int = rng.randint(min_num_decimal_places, max_num_decimal_places)
max_integer_part: int = 10 # Maximum whole number before the decimal
max_value: int = max_integer_part * (10 ** ndp)
raw_int: int = rng.randint(1, max_value)
# Create the Decimal number and quantize it to exactly ndp decimal places.
num = Decimal(raw_int) / (Decimal(10) ** ndp)
quantize_str = "1." + "0" * ndp
num: Decimal = Decimal(raw_int) / (Decimal(10) ** ndp)
quantize_str: str = "1." + "0" * ndp
num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP)
# Format the number as a string with exactly ndp decimals.
num_str = f"{num:.{ndp}f}"
num_str: str = f"{num:.{ndp}f}"
operands.append(num_str)
if i < terms - 1:
op = rng.choice(operations)
op: str = rng.choice(operations)
operators.append(op)
expr = build_grouped_expression(operands, operators, rng)
problem_str = expr + " = ?"
expr: str = build_grouped_expression(operands, operators, rng)
problem_str: str = expr + " = ?"
return problem_str
@ -102,15 +110,15 @@ def evaluate_expression(expr: str) -> Decimal:
Returns:
Decimal: The computed result.
"""
tree = ast.parse(expr, mode="eval")
tree: ast.Expression = ast.parse(expr, mode="eval")
return _eval_ast(tree.body)
def _eval_ast(node) -> Decimal:
def _eval_ast(node: ast.AST) -> Decimal:
"""Recursively evaluate an AST node using Decimal arithmetic."""
if isinstance(node, ast.BinOp):
left = _eval_ast(node.left)
right = _eval_ast(node.right)
left: Decimal = _eval_ast(node.left)
right: Decimal = _eval_ast(node.right)
if isinstance(node.op, ast.Add):
return left + right
elif isinstance(node.op, ast.Sub):
@ -122,7 +130,7 @@ def _eval_ast(node) -> Decimal:
else:
raise ValueError(f"Unsupported operator: {node.op}")
elif isinstance(node, ast.UnaryOp):
operand = _eval_ast(node.operand)
operand: Decimal = _eval_ast(node.operand)
if isinstance(node.op, ast.UAdd):
return operand
elif isinstance(node.op, ast.USub):
@ -140,10 +148,10 @@ def _eval_ast(node) -> Decimal:
class DecimalArithmeticDataset(ProceduralDataset):
"""Dataset that generates basic arithmetic tasks using Decimal arithmetic and proper operator precedence."""
def __init__(self, config: DecimalArithmeticConfig):
def __init__(self, config: DecimalArithmeticConfig) -> None:
super().__init__(config=config, seed=config.seed, size=config.size)
def __getitem__(self, idx: int) -> dict[str, Any]:
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Generate a single arithmetic task.
@ -154,18 +162,18 @@ class DecimalArithmeticDataset(ProceduralDataset):
- 'metadata': Additional metadata (currently empty).
"""
# Create a deterministic RNG from base seed and index.
rng = Random(self.seed + idx if self.seed is not None else None)
rng: Random = Random(self.seed + idx if self.seed is not None else None)
getcontext().prec = self.config.precision
problem_str = generate_arithmetic_problem(
problem_str: str = generate_arithmetic_problem(
rng,
self.config.min_num_decimal_places,
self.config.max_num_decimal_places,
terms=self.config.terms,
)
# Remove the trailing " = ?" to obtain the pure arithmetic expression.
expr = problem_str.replace(" = ?", "").strip()
answer = evaluate_expression(expr)
expr: str = problem_str.replace(" = ?", "").strip()
answer: Decimal = evaluate_expression(expr)
problem_str = (
f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n"
@ -187,12 +195,12 @@ class DecimalArithmeticDataset(ProceduralDataset):
return 0.0
try:
user_ans = Decimal(answer)
correct_ans = entry["answer"]
user_ans: Decimal = Decimal(answer)
correct_ans: Decimal = entry["answer"]
# Determine tolerance based on the desired precision.
precision = self.config.max_num_decimal_places
tol = Decimal(10) ** (-precision)
precision: int = self.config.max_num_decimal_places
tol: Decimal = Decimal(10) ** (-precision)
if abs(user_ans - correct_ans) <= tol:
return 1.0
except Exception: