diff --git a/reasoning_gym/arithmetic/decimal_arithmetic.py b/reasoning_gym/arithmetic/decimal_arithmetic.py index 3d90a7d2..c97c7230 100644 --- a/reasoning_gym/arithmetic/decimal_arithmetic.py +++ b/reasoning_gym/arithmetic/decimal_arithmetic.py @@ -1,6 +1,8 @@ +import ast from dataclasses import dataclass +from decimal import ROUND_HALF_UP, Decimal, getcontext from random import Random -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, Optional from ..factory import ProceduralDataset, register_dataset @@ -11,114 +13,173 @@ class DecimalArithmeticDatasetConfig: min_num_decimal_places: int = 6 max_num_decimal_places: int = 6 + precision: int = 28 terms: int = 6 seed: Optional[int] = None - size: int = 500 # Virtual dataset size + size: int = 500 - # def validate(self) -> None: - # """Validate configuration parameters""" - # assert self.num_decimal_places > 0, "num_decimal_places must be positive" + def validate(self): + """Validate configuration parameters""" + assert ( + self.precision > self.max_num_decimal_places + 1 + ), "precision must be 2 or more higher than max_num_decimal_places" def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None): """ - Generates simple arithmetic problems with decimal numbers formatted to a specific number of decimal places. + Generates a simple arithmetic problem with decimal numbers (as a string) formatted + to a specific number of decimal places. Parameters: - rng - num_problems (int): Number of problems to generate - num_decimal_places (int): Number of decimal places for the numbers - operations (list): List of operations to use (default: ['+', '-', '*', '/']) + rng: Random number generator. + min_num_decimal_places (int): Minimum number of decimal places. + max_num_decimal_places (int): Maximum number of decimal places. + terms (int): Number of numbers in the arithmetic expression. + operations (list): List of operations to use (default: ['+', '-', '*', '/']). Returns: - list: List of formatted arithmetic problem strings + str: A formatted arithmetic expression ending with " = ?" """ if operations is None: operations = ["+", "-", "*", "/"] - problem = "" + tokens = [] + # Build the expression by alternating numbers and operators. + for i in range(terms): + # Choose a number of decimal places for this term. + ndp = rng.randint(min_num_decimal_places, max_num_decimal_places) + max_integer_part = 10 # Maximum whole number before the decimal + max_value = max_integer_part * (10**ndp) + raw_int = rng.randint(1, max_value) + # Create the Decimal number and quantize it to exactly ndp decimal places. + num = Decimal(raw_int) / (Decimal(10) ** ndp) + quantize_str = "1." + "0" * ndp + num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP) + # Format the number as a string with exactly ndp decimals. + num_str = f"{num:.{ndp}f}" + tokens.append(num_str) + if i < terms - 1: + op = rng.choice(operations) + tokens.append(op) - for term in range(0, terms): - - # Generate random numbers with exact decimal places - ndp1 = rng.randint(min_num_decimal_places, max_num_decimal_places) - max_integer_part = 10 # Maximum whole number portion before decimal - max_value = max_integer_part * (10**ndp1) - num1 = rng.randint(1, max_value) / (10**ndp1) - - # Select random operation - op = rng.choice(operations) - op = op if (term <= terms - 2) else "" - - # Format numbers to ensure exact decimal places - formatted_num1 = f"{num1:.{ndp1}f}" - - problem = problem + f"{formatted_num1} { op }" + " " - - problem = problem + "= ?" - print(problem) - return problem + problem_str = "".join(tokens) + " = ?" + return problem_str -def eval_floordiv(exp: str) -> int: - return eval(exp.replace("/", "//").replace(" = ?", "")) +def evaluate_expression(expr: str) -> Decimal: + """ + Safely evaluates a simple arithmetic expression using AST parsing, performing + all arithmetic in the Decimal context. + + Args: + expr: A string containing the arithmetic expression. + + Returns: + Decimal: The computed result. + """ + tree = ast.parse(expr, mode="eval") + return _eval_ast(tree.body) + + +def _eval_ast(node) -> Decimal: + """Recursively evaluate an AST node using Decimal arithmetic.""" + if isinstance(node, ast.BinOp): + left = _eval_ast(node.left) + right = _eval_ast(node.right) + if isinstance(node.op, ast.Add): + return left + right + elif isinstance(node.op, ast.Sub): + return left - right + elif isinstance(node.op, ast.Mult): + return left * right + elif isinstance(node.op, ast.Div): + return left / right + else: + raise ValueError(f"Unsupported operator: {node.op}") + elif isinstance(node, ast.UnaryOp): + operand = _eval_ast(node.operand) + if isinstance(node.op, ast.UAdd): + return operand + elif isinstance(node.op, ast.USub): + return -operand + else: + raise ValueError(f"Unsupported unary operator: {node.op}") + elif isinstance(node, ast.Constant): # For Python 3.8+ + # Although ast converts numeric literals to floats, + # converting via str helps us get a Decimal with the intended value. + return Decimal(str(node.value)) + elif isinstance(node, ast.Num): # For older Python versions + return Decimal(str(node.n)) + else: + raise ValueError(f"Unsupported expression component: {node}") class DecimalArithmeticDataset(ProceduralDataset): - """Dataset that generates basic arithmetic tasks with configurable complexity""" + """Dataset that generates basic arithmetic tasks using Decimal arithmetic and proper operator precedence.""" def __init__(self, config: DecimalArithmeticDatasetConfig): super().__init__(config=config, seed=config.seed, size=config.size) def __getitem__(self, idx: int) -> dict[str, Any]: - """Generate a single arithmetic task - - Args: - idx: Index of the item to generate + """ + Generate a single arithmetic task. Returns: - dict with keys: - - question: str, the formatted arithmetic expression - - answer: str, the ground truth result - - metadata: dict with generation parameters + dict: Contains: + - 'question': The formatted arithmetic expression as a string. + - 'answer': The computed Decimal result. + - 'metadata': Additional metadata (currently empty). """ - # Create deterministic RNG from base seed and idx - rng = Random(self.seed + idx) + # Create a deterministic RNG from base seed and index. + rng = Random(self.seed + idx if self.seed is not None else None) + getcontext().prec = self.config.precision - decimal_problem = generate_arithmetic_problem( + problem_str = generate_arithmetic_problem( rng, self.config.min_num_decimal_places, self.config.max_num_decimal_places, terms=self.config.terms, ) - answer = eval_floordiv(decimal_problem) + # Remove the trailing " = ?" to obtain the pure arithmetic expression. + expr = problem_str.replace(" = ?", "").strip() + answer = evaluate_expression(expr) - return {"question": decimal_problem, "answer": answer, "metadata": {}} + problem_str = problem_str = ( + f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n" + + problem_str + ) - def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: - """Determine if the solution provided solves the Sokoban task. + return {"question": problem_str, "answer": answer, "metadata": {}} - The function awards 1.0 for a correct answer. + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + """ + Compares the user's answer (converted to Decimal) with the correct answer. + Instead of requiring exact equality, we allow an error up to one unit in the + least significant digit as determined by the level of precision (max_num_decimal_places). - Args: - answer (Optional[str]): The user's answer. - entry (Dict[str, any]): The original dataset entry containing the correct answer. + For example, if max_num_decimal_places is 6, then an error of up to 1e-6 is accepted. Returns: - float: The computed score between 0.0 and 1.0. + float: 1.0 if the user's answer is within tolerance; otherwise, 0.01. """ - - if answer == None: + if answer is None: return 0.0 try: - if float(answer) == entry["answer"]: + user_ans = Decimal(answer) + correct_ans = entry["answer"] + + # Determine tolerance based on the desired precision. + # Here, we allow a difference of 1 in the last decimal place. + precision = self.config.max_num_decimal_places + tol = Decimal(10) ** (-precision) + if abs(user_ans - correct_ans) <= tol: return 1.0 - except Exception as e: + except Exception: return 0.01 return 0.01 -# Register the dataset +# Register the dataset with the factory. register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticDatasetConfig) diff --git a/tests/test_decimal_arithmetic.py b/tests/test_decimal_arithmetic.py index 1b2eb464..60a9d91f 100644 --- a/tests/test_decimal_arithmetic.py +++ b/tests/test_decimal_arithmetic.py @@ -8,7 +8,7 @@ def test_decimal_arithmetic(): # Easy config = DecimalArithmeticDatasetConfig( - seed=42, size=999000, min_num_decimal_places=3, max_num_decimal_places=13, terms=13 + seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=3, precision=5, terms=3 ) dataset = DecimalArithmeticDataset(config) @@ -18,31 +18,33 @@ def test_decimal_arithmetic(): assert "answer" in item assert "metadata" in item - print(item["answer"]) - # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - # # M - # config = DecimalArithmeticDatasetConfig(seed=42, size=2000, num_decimal_places=8) - # dataset = DecimalArithmeticDataset(config) + # M + config = DecimalArithmeticDatasetConfig( + seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=6, precision=8, terms=6 + ) + dataset = DecimalArithmeticDataset(config) - # for item in dataset: - # assert isinstance(item, dict) - # assert "question" in item - # assert "answer" in item - # assert "metadata" in item + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item - # assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - # # H - # config = DecimalArithmeticDatasetConfig(seed=42, size=2000, num_decimal_places=15) - # dataset = DecimalArithmeticDataset(config) + # H + config = DecimalArithmeticDatasetConfig( + seed=42, size=2000, min_num_decimal_places=3, max_num_decimal_places=13, precision=15, terms=10 + ) + dataset = DecimalArithmeticDataset(config) - # for item in dataset: - # assert isinstance(item, dict) - # assert "question" in item - # assert "answer" in item - # assert "metadata" in item + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item - # assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0