mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
type hints
blind roboting
This commit is contained in:
parent
da0b882b87
commit
921c9b1d7b
1 changed files with 43 additions and 35 deletions
|
|
@ -2,7 +2,7 @@ import ast
|
|||
from dataclasses import dataclass
|
||||
from decimal import ROUND_HALF_UP, Decimal, getcontext
|
||||
from random import Random
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, List
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -18,14 +18,16 @@ class DecimalArithmeticConfig:
|
|||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
|
||||
def validate(self):
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert (
|
||||
self.precision > self.max_num_decimal_places + 1
|
||||
), "precision must be 2 or more higher than max_num_decimal_places"
|
||||
|
||||
|
||||
def build_grouped_expression(operands, operators, rng):
|
||||
def build_grouped_expression(
|
||||
operands: List[str], operators: List[str], rng: Random
|
||||
) -> str:
|
||||
"""
|
||||
Recursively build an arithmetic expression string from operands and operators,
|
||||
inserting parentheses at random.
|
||||
|
|
@ -37,18 +39,24 @@ def build_grouped_expression(operands, operators, rng):
|
|||
if len(operands) == 1:
|
||||
return operands[0]
|
||||
# Randomly choose a split point (1 <= split < len(operands)).
|
||||
split = rng.randint(1, len(operands) - 1)
|
||||
left_expr = build_grouped_expression(operands[:split], operators[: split - 1], rng)
|
||||
right_expr = build_grouped_expression(operands[split:], operators[split:], rng)
|
||||
split: int = rng.randint(1, len(operands) - 1)
|
||||
left_expr: str = build_grouped_expression(operands[:split], operators[: split - 1], rng)
|
||||
right_expr: str = build_grouped_expression(operands[split:], operators[split:], rng)
|
||||
# The operator at position (split - 1) is the one combining the two groups.
|
||||
expr = left_expr + operators[split - 1] + right_expr
|
||||
expr: str = left_expr + operators[split - 1] + right_expr
|
||||
# Randomly decide to add parentheses around this subexpression.
|
||||
if rng.choice([True, False]):
|
||||
expr = "(" + expr + ")"
|
||||
return expr
|
||||
|
||||
|
||||
def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None):
|
||||
def generate_arithmetic_problem(
|
||||
rng: Random,
|
||||
min_num_decimal_places: int,
|
||||
max_num_decimal_places: int,
|
||||
terms: int = 2,
|
||||
operations: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generates a simple arithmetic problem with decimal numbers (as a string) formatted
|
||||
to a specific number of decimal places, with random parenthesis grouping.
|
||||
|
|
@ -66,28 +74,28 @@ def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_pla
|
|||
if operations is None:
|
||||
operations = ["+", "-", "*", "/"]
|
||||
|
||||
operands = []
|
||||
operators = []
|
||||
operands: List[str] = []
|
||||
operators: List[str] = []
|
||||
|
||||
for i in range(terms):
|
||||
# Choose a random number of decimal places for this term.
|
||||
ndp = rng.randint(min_num_decimal_places, max_num_decimal_places)
|
||||
max_integer_part = 10 # Maximum whole number before the decimal
|
||||
max_value = max_integer_part * (10**ndp)
|
||||
raw_int = rng.randint(1, max_value)
|
||||
ndp: int = rng.randint(min_num_decimal_places, max_num_decimal_places)
|
||||
max_integer_part: int = 10 # Maximum whole number before the decimal
|
||||
max_value: int = max_integer_part * (10 ** ndp)
|
||||
raw_int: int = rng.randint(1, max_value)
|
||||
# Create the Decimal number and quantize it to exactly ndp decimal places.
|
||||
num = Decimal(raw_int) / (Decimal(10) ** ndp)
|
||||
quantize_str = "1." + "0" * ndp
|
||||
num: Decimal = Decimal(raw_int) / (Decimal(10) ** ndp)
|
||||
quantize_str: str = "1." + "0" * ndp
|
||||
num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP)
|
||||
# Format the number as a string with exactly ndp decimals.
|
||||
num_str = f"{num:.{ndp}f}"
|
||||
num_str: str = f"{num:.{ndp}f}"
|
||||
operands.append(num_str)
|
||||
if i < terms - 1:
|
||||
op = rng.choice(operations)
|
||||
op: str = rng.choice(operations)
|
||||
operators.append(op)
|
||||
|
||||
expr = build_grouped_expression(operands, operators, rng)
|
||||
problem_str = expr + " = ?"
|
||||
expr: str = build_grouped_expression(operands, operators, rng)
|
||||
problem_str: str = expr + " = ?"
|
||||
return problem_str
|
||||
|
||||
|
||||
|
|
@ -102,15 +110,15 @@ def evaluate_expression(expr: str) -> Decimal:
|
|||
Returns:
|
||||
Decimal: The computed result.
|
||||
"""
|
||||
tree = ast.parse(expr, mode="eval")
|
||||
tree: ast.Expression = ast.parse(expr, mode="eval")
|
||||
return _eval_ast(tree.body)
|
||||
|
||||
|
||||
def _eval_ast(node) -> Decimal:
|
||||
def _eval_ast(node: ast.AST) -> Decimal:
|
||||
"""Recursively evaluate an AST node using Decimal arithmetic."""
|
||||
if isinstance(node, ast.BinOp):
|
||||
left = _eval_ast(node.left)
|
||||
right = _eval_ast(node.right)
|
||||
left: Decimal = _eval_ast(node.left)
|
||||
right: Decimal = _eval_ast(node.right)
|
||||
if isinstance(node.op, ast.Add):
|
||||
return left + right
|
||||
elif isinstance(node.op, ast.Sub):
|
||||
|
|
@ -122,7 +130,7 @@ def _eval_ast(node) -> Decimal:
|
|||
else:
|
||||
raise ValueError(f"Unsupported operator: {node.op}")
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
operand = _eval_ast(node.operand)
|
||||
operand: Decimal = _eval_ast(node.operand)
|
||||
if isinstance(node.op, ast.UAdd):
|
||||
return operand
|
||||
elif isinstance(node.op, ast.USub):
|
||||
|
|
@ -140,10 +148,10 @@ def _eval_ast(node) -> Decimal:
|
|||
class DecimalArithmeticDataset(ProceduralDataset):
|
||||
"""Dataset that generates basic arithmetic tasks using Decimal arithmetic and proper operator precedence."""
|
||||
|
||||
def __init__(self, config: DecimalArithmeticConfig):
|
||||
def __init__(self, config: DecimalArithmeticConfig) -> None:
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a single arithmetic task.
|
||||
|
||||
|
|
@ -154,18 +162,18 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
- 'metadata': Additional metadata (currently empty).
|
||||
"""
|
||||
# Create a deterministic RNG from base seed and index.
|
||||
rng = Random(self.seed + idx if self.seed is not None else None)
|
||||
rng: Random = Random(self.seed + idx if self.seed is not None else None)
|
||||
getcontext().prec = self.config.precision
|
||||
|
||||
problem_str = generate_arithmetic_problem(
|
||||
problem_str: str = generate_arithmetic_problem(
|
||||
rng,
|
||||
self.config.min_num_decimal_places,
|
||||
self.config.max_num_decimal_places,
|
||||
terms=self.config.terms,
|
||||
)
|
||||
# Remove the trailing " = ?" to obtain the pure arithmetic expression.
|
||||
expr = problem_str.replace(" = ?", "").strip()
|
||||
answer = evaluate_expression(expr)
|
||||
expr: str = problem_str.replace(" = ?", "").strip()
|
||||
answer: Decimal = evaluate_expression(expr)
|
||||
|
||||
problem_str = (
|
||||
f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n"
|
||||
|
|
@ -187,12 +195,12 @@ class DecimalArithmeticDataset(ProceduralDataset):
|
|||
return 0.0
|
||||
|
||||
try:
|
||||
user_ans = Decimal(answer)
|
||||
correct_ans = entry["answer"]
|
||||
user_ans: Decimal = Decimal(answer)
|
||||
correct_ans: Decimal = entry["answer"]
|
||||
|
||||
# Determine tolerance based on the desired precision.
|
||||
precision = self.config.max_num_decimal_places
|
||||
tol = Decimal(10) ** (-precision)
|
||||
precision: int = self.config.max_num_decimal_places
|
||||
tol: Decimal = Decimal(10) ** (-precision)
|
||||
if abs(user_ans - correct_ans) <= tol:
|
||||
return 1.0
|
||||
except Exception:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue