mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-23 16:55:05 +00:00
implement decimal precision
This commit is contained in:
parent
59229bd2d2
commit
17ba950c1a
2 changed files with 144 additions and 81 deletions
|
|
@ -1,6 +1,8 @@
|
|||
import ast
|
||||
from dataclasses import dataclass
|
||||
from decimal import ROUND_HALF_UP, Decimal, getcontext
|
||||
from random import Random
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..factory import ProceduralDataset, register_dataset
|
||||
|
||||
|
|
@ -11,114 +13,173 @@ class DecimalArithmeticDatasetConfig:
|
|||
|
||||
min_num_decimal_places: int = 6
|
||||
max_num_decimal_places: int = 6
|
||||
precision: int = 28
|
||||
terms: int = 6
|
||||
seed: Optional[int] = None
|
||||
size: int = 500 # Virtual dataset size
|
||||
size: int = 500
|
||||
|
||||
# def validate(self) -> None:
|
||||
# """Validate configuration parameters"""
|
||||
# assert self.num_decimal_places > 0, "num_decimal_places must be positive"
|
||||
def validate(self):
|
||||
"""Validate configuration parameters"""
|
||||
assert (
|
||||
self.precision > self.max_num_decimal_places + 1
|
||||
), "precision must be 2 or more higher than max_num_decimal_places"
|
||||
|
||||
|
||||
def generate_arithmetic_problem(rng, min_num_decimal_places, max_num_decimal_places, terms=2, operations=None):
|
||||
"""
|
||||
Generates simple arithmetic problems with decimal numbers formatted to a specific number of decimal places.
|
||||
Generates a simple arithmetic problem with decimal numbers (as a string) formatted
|
||||
to a specific number of decimal places.
|
||||
|
||||
Parameters:
|
||||
rng
|
||||
num_problems (int): Number of problems to generate
|
||||
num_decimal_places (int): Number of decimal places for the numbers
|
||||
operations (list): List of operations to use (default: ['+', '-', '*', '/'])
|
||||
rng: Random number generator.
|
||||
min_num_decimal_places (int): Minimum number of decimal places.
|
||||
max_num_decimal_places (int): Maximum number of decimal places.
|
||||
terms (int): Number of numbers in the arithmetic expression.
|
||||
operations (list): List of operations to use (default: ['+', '-', '*', '/']).
|
||||
|
||||
Returns:
|
||||
list: List of formatted arithmetic problem strings
|
||||
str: A formatted arithmetic expression ending with " = ?"
|
||||
"""
|
||||
if operations is None:
|
||||
operations = ["+", "-", "*", "/"]
|
||||
|
||||
problem = ""
|
||||
tokens = []
|
||||
# Build the expression by alternating numbers and operators.
|
||||
for i in range(terms):
|
||||
# Choose a number of decimal places for this term.
|
||||
ndp = rng.randint(min_num_decimal_places, max_num_decimal_places)
|
||||
max_integer_part = 10 # Maximum whole number before the decimal
|
||||
max_value = max_integer_part * (10**ndp)
|
||||
raw_int = rng.randint(1, max_value)
|
||||
# Create the Decimal number and quantize it to exactly ndp decimal places.
|
||||
num = Decimal(raw_int) / (Decimal(10) ** ndp)
|
||||
quantize_str = "1." + "0" * ndp
|
||||
num = num.quantize(Decimal(quantize_str), rounding=ROUND_HALF_UP)
|
||||
# Format the number as a string with exactly ndp decimals.
|
||||
num_str = f"{num:.{ndp}f}"
|
||||
tokens.append(num_str)
|
||||
if i < terms - 1:
|
||||
op = rng.choice(operations)
|
||||
tokens.append(op)
|
||||
|
||||
for term in range(0, terms):
|
||||
|
||||
# Generate random numbers with exact decimal places
|
||||
ndp1 = rng.randint(min_num_decimal_places, max_num_decimal_places)
|
||||
max_integer_part = 10 # Maximum whole number portion before decimal
|
||||
max_value = max_integer_part * (10**ndp1)
|
||||
num1 = rng.randint(1, max_value) / (10**ndp1)
|
||||
|
||||
# Select random operation
|
||||
op = rng.choice(operations)
|
||||
op = op if (term <= terms - 2) else ""
|
||||
|
||||
# Format numbers to ensure exact decimal places
|
||||
formatted_num1 = f"{num1:.{ndp1}f}"
|
||||
|
||||
problem = problem + f"{formatted_num1} { op }" + " "
|
||||
|
||||
problem = problem + "= ?"
|
||||
print(problem)
|
||||
return problem
|
||||
problem_str = "".join(tokens) + " = ?"
|
||||
return problem_str
|
||||
|
||||
|
||||
def eval_floordiv(exp: str) -> int:
|
||||
return eval(exp.replace("/", "//").replace(" = ?", ""))
|
||||
def evaluate_expression(expr: str) -> Decimal:
|
||||
"""
|
||||
Safely evaluates a simple arithmetic expression using AST parsing, performing
|
||||
all arithmetic in the Decimal context.
|
||||
|
||||
Args:
|
||||
expr: A string containing the arithmetic expression.
|
||||
|
||||
Returns:
|
||||
Decimal: The computed result.
|
||||
"""
|
||||
tree = ast.parse(expr, mode="eval")
|
||||
return _eval_ast(tree.body)
|
||||
|
||||
|
||||
def _eval_ast(node) -> Decimal:
|
||||
"""Recursively evaluate an AST node using Decimal arithmetic."""
|
||||
if isinstance(node, ast.BinOp):
|
||||
left = _eval_ast(node.left)
|
||||
right = _eval_ast(node.right)
|
||||
if isinstance(node.op, ast.Add):
|
||||
return left + right
|
||||
elif isinstance(node.op, ast.Sub):
|
||||
return left - right
|
||||
elif isinstance(node.op, ast.Mult):
|
||||
return left * right
|
||||
elif isinstance(node.op, ast.Div):
|
||||
return left / right
|
||||
else:
|
||||
raise ValueError(f"Unsupported operator: {node.op}")
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
operand = _eval_ast(node.operand)
|
||||
if isinstance(node.op, ast.UAdd):
|
||||
return operand
|
||||
elif isinstance(node.op, ast.USub):
|
||||
return -operand
|
||||
else:
|
||||
raise ValueError(f"Unsupported unary operator: {node.op}")
|
||||
elif isinstance(node, ast.Constant): # For Python 3.8+
|
||||
# Although ast converts numeric literals to floats,
|
||||
# converting via str helps us get a Decimal with the intended value.
|
||||
return Decimal(str(node.value))
|
||||
elif isinstance(node, ast.Num): # For older Python versions
|
||||
return Decimal(str(node.n))
|
||||
else:
|
||||
raise ValueError(f"Unsupported expression component: {node}")
|
||||
|
||||
|
||||
class DecimalArithmeticDataset(ProceduralDataset):
|
||||
"""Dataset that generates basic arithmetic tasks with configurable complexity"""
|
||||
"""Dataset that generates basic arithmetic tasks using Decimal arithmetic and proper operator precedence."""
|
||||
|
||||
def __init__(self, config: DecimalArithmeticDatasetConfig):
|
||||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
"""Generate a single arithmetic task
|
||||
|
||||
Args:
|
||||
idx: Index of the item to generate
|
||||
"""
|
||||
Generate a single arithmetic task.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- question: str, the formatted arithmetic expression
|
||||
- answer: str, the ground truth result
|
||||
- metadata: dict with generation parameters
|
||||
dict: Contains:
|
||||
- 'question': The formatted arithmetic expression as a string.
|
||||
- 'answer': The computed Decimal result.
|
||||
- 'metadata': Additional metadata (currently empty).
|
||||
"""
|
||||
# Create deterministic RNG from base seed and idx
|
||||
rng = Random(self.seed + idx)
|
||||
# Create a deterministic RNG from base seed and index.
|
||||
rng = Random(self.seed + idx if self.seed is not None else None)
|
||||
getcontext().prec = self.config.precision
|
||||
|
||||
decimal_problem = generate_arithmetic_problem(
|
||||
problem_str = generate_arithmetic_problem(
|
||||
rng,
|
||||
self.config.min_num_decimal_places,
|
||||
self.config.max_num_decimal_places,
|
||||
terms=self.config.terms,
|
||||
)
|
||||
answer = eval_floordiv(decimal_problem)
|
||||
# Remove the trailing " = ?" to obtain the pure arithmetic expression.
|
||||
expr = problem_str.replace(" = ?", "").strip()
|
||||
answer = evaluate_expression(expr)
|
||||
|
||||
return {"question": decimal_problem, "answer": answer, "metadata": {}}
|
||||
problem_str = problem_str = (
|
||||
f"Please solve this problem to a maximum of {str(self.config.precision)} significant digits, rounding up from the half. Only reply with the final value.\n"
|
||||
+ problem_str
|
||||
)
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
"""Determine if the solution provided solves the Sokoban task.
|
||||
return {"question": problem_str, "answer": answer, "metadata": {}}
|
||||
|
||||
The function awards 1.0 for a correct answer.
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Compares the user's answer (converted to Decimal) with the correct answer.
|
||||
Instead of requiring exact equality, we allow an error up to one unit in the
|
||||
least significant digit as determined by the level of precision (max_num_decimal_places).
|
||||
|
||||
Args:
|
||||
answer (Optional[str]): The user's answer.
|
||||
entry (Dict[str, any]): The original dataset entry containing the correct answer.
|
||||
For example, if max_num_decimal_places is 6, then an error of up to 1e-6 is accepted.
|
||||
|
||||
Returns:
|
||||
float: The computed score between 0.0 and 1.0.
|
||||
float: 1.0 if the user's answer is within tolerance; otherwise, 0.01.
|
||||
"""
|
||||
|
||||
if answer == None:
|
||||
if answer is None:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
if float(answer) == entry["answer"]:
|
||||
user_ans = Decimal(answer)
|
||||
correct_ans = entry["answer"]
|
||||
|
||||
# Determine tolerance based on the desired precision.
|
||||
# Here, we allow a difference of 1 in the last decimal place.
|
||||
precision = self.config.max_num_decimal_places
|
||||
tol = Decimal(10) ** (-precision)
|
||||
if abs(user_ans - correct_ans) <= tol:
|
||||
return 1.0
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return 0.01
|
||||
|
||||
return 0.01
|
||||
|
||||
|
||||
# Register the dataset
|
||||
# Register the dataset with the factory.
|
||||
register_dataset("decimal_arithmetic", DecimalArithmeticDataset, DecimalArithmeticDatasetConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue