InternBootcamp/internbootcamp/libs/bbeh_multistep_arithmetic/multistep_arithmetic.py
2025-05-23 15:27:15 +08:00

510 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import json
import math
import re
import logging
import time
# 定义极限值 - 使用Python支持的最大范围
MAX_VALUE = 1.7976931348623157e+308 # 最大浮点数
MIN_VALUE = -1.7976931348623157e+308 # 最小浮点数
EPSILON = 1e-10 # 用于浮点数比较
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 定义更多操作符
OPERATORS = ['+', '-', '*', '/', '><', ';', '@', '<>', '[]', '#', '!', '~', '&', ':', '][']
# 定义优先级
PRECEDENCE = {
'+': 1, '-': 1,
'*': 2, '/': 2,
'><': 3, ';': 3,
'@': 4, '<>': 4, '[]': 4,
'#': 5, '!': 5,
'~': 6, '&': 6,
':': 7, '][': 7
}
# 添加数字单词映射
NUMBER_WORDS = {
'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5,
'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10
}
# 定义极限值
MAX_VALUE = 1e100
MIN_VALUE = -1e100
def safe_pop(stack, error_msg="Empty stack"):
"""安全的pop操作包含错误处理"""
if not stack:
raise ValueError(error_msg)
return stack.pop()
def perform_operation(op, a, b):
"""执行操作,改进的错误处理和大数处理"""
try:
# 处理无穷大和非数值的情况
if math.isinf(a) or math.isinf(b):
if op in ['+', '*', '@', '<>', '[]', '#', '&']:
return float('inf') if (a > 0 or b > 0) else float('-inf')
elif op in ['-', '/', '><', ';', '!', '~', ':', '][']:
return float('inf')
# 处理零和接近零的情况
if op == '/' and abs(b) < EPSILON:
return float('inf')
# 防止溢出的预检查
def safe_operation(func):
try:
result = func()
# 检查结果是否在有效范围内
if math.isinf(result) or math.isnan(result):
return float('inf') if result > 0 else float('-inf')
if abs(result) > MAX_VALUE:
return float('inf') if result > 0 else float('-inf')
if abs(result) != 0 and abs(result) < MIN_VALUE:
return 0.0
return result
except (OverflowError, ValueError, ZeroDivisionError):
return float('inf')
# 使用安全操作来执行各种运算
if op == '+':
result = safe_operation(lambda: a + b)
elif op == '-':
result = safe_operation(lambda: a - b)
elif op == '*':
result = safe_operation(lambda: a * b)
elif op == '/':
result = safe_operation(lambda: a / b if abs(b) > EPSILON else float('inf'))
elif op == '><':
result = safe_operation(lambda: a - b if a * b > 0 else a + b)
elif op == ';':
result = safe_operation(lambda: (a - b) if a + b > 0 else a - b)
elif op == '@':
result = safe_operation(lambda: a * b if a > b else a + b)
elif op == '<>':
result = safe_operation(lambda: abs(a - b))
elif op == '[]':
result = safe_operation(lambda: max(a, b))
elif op == '#':
result = safe_operation(lambda: (a * b) if a > b else min(a, b))
elif op == '!':
result = safe_operation(lambda: (a - b) if a * b > 0 else (a + b))
elif op == '~':
result = safe_operation(lambda: (2 * b - a) if a > b else b)
elif op == '&':
result = safe_operation(lambda: (a + b) if a * b > 0 else (a - b))
elif op == ':':
result = safe_operation(lambda: (a * b) if a + b > 0 else (-a * -b))
elif op == '][':
try:
if abs(a) < 1e10 and abs(b) < 1e10:
gcd_value = gcd(abs(int(a)), abs(int(b)))
return safe_operation(lambda: (a - b) if gcd_value == 1 else float(gcd_value))
else:
return float('inf')
except (OverflowError, ValueError):
return float('inf')
else:
raise ValueError(f"Unknown operator: {op}")
return result
except Exception as e:
logger.error(f"Error in operation {op} with operands {a}, {b}: {str(e)}")
return float('inf')
def validate_expression(tokens):
"""验证表达式的有效性"""
stack = []
for token in tokens:
if token == '(':
stack.append(token)
elif token == ')':
if not stack:
return False
stack.pop()
return len(stack) == 0
def tokenize(expr):
"""改进的tokenize函数支持科学记数法"""
try:
tokens = []
i = 0
while i < len(expr):
char = expr[i]
# 处理空格
if char.isspace():
i += 1
continue
# 处理数字单词
if char.isalpha() and (i == 0 or not expr[i - 1].isdigit()): # 确保不是科学记数法中的'e'
word = ''
while i < len(expr) and expr[i].isalpha():
word += expr[i]
i += 1
if word in NUMBER_WORDS:
tokens.append(str(NUMBER_WORDS[word]))
else:
logger.warning(f"Unknown word encountered: {word}")
raise ValueError(f"Unknown word: {word}")
continue
# 处理数字(包括科学记数法)
if char.isdigit() or (char == '-' and (not tokens or tokens[-1] in OPERATORS + ['('])):
num = char
i += 1
# 处理整数和小数部分
while i < len(expr) and (expr[i].isdigit() or expr[i] == '.'):
num += expr[i]
i += 1
# 处理科学记数法
if i < len(expr) and (expr[i] == 'e' or expr[i] == 'E'):
num += expr[i]
i += 1
if i < len(expr) and (expr[i] == '+' or expr[i] == '-'):
num += expr[i]
i += 1
while i < len(expr) and expr[i].isdigit():
num += expr[i]
i += 1
try:
float(num) # 验证数字格式是否正确
tokens.append(num)
except ValueError:
raise ValueError(f"Invalid number format: {num}")
continue
# 处理括号
if char in '()':
tokens.append(char)
i += 1
continue
# 处理运算符
if i < len(expr):
# 尝试匹配最长的运算符
max_op_len = 3
matched = False
for length in range(max_op_len, 0, -1):
if i + length <= len(expr):
potential_op = expr[i:i + length]
if potential_op in OPERATORS:
tokens.append(potential_op)
i += length
matched = True
break
if not matched:
raise ValueError(f"Invalid character: {char}")
if not validate_expression(tokens):
raise ValueError("Invalid expression: Mismatched parentheses")
return tokens
except Exception as e:
logger.error(f"Error in tokenization: {str(e)}")
raise
def validate_expression_structure(tokens):
"""验证表达式的结构是否合法"""
stack = []
operand_count = 0
operator_count = 0
for token in tokens:
if token == '(':
stack.append(token)
elif token == ')':
if not stack:
return False
stack.pop()
elif token in OPERATORS:
operator_count += 1
else:
operand_count += 1
# 检查括号是否匹配
if stack:
return False
# 检查操作数和操作符的数量关系
# 对于二元运算符操作数应该比操作符多1
return operand_count == operator_count + 1
def evaluate_expression(expression):
"""改进的表达式求值函数"""
try:
tokens = tokenize(expression)
if not tokens:
logger.error(f"Empty expression: {expression}")
return float('inf')
# 验证表达式结构
if not validate_expression_structure(tokens):
logger.error(f"Invalid expression structure: {expression}")
return float('inf')
values = []
operators = []
for token in tokens:
try:
if token == '(':
operators.append(token)
elif token == ')':
while operators and operators[-1] != '(':
if len(values) < 2:
raise ValueError(f"Not enough operands for operator {operators[-1]}")
op = operators.pop()
b = safe_float_conversion(values.pop())
a = safe_float_conversion(values.pop())
result = perform_operation(op, a, b)
values.append(result)
if operators:
operators.pop() # 移除 '('
elif token in OPERATORS:
while (operators and operators[-1] != '(' and
PRECEDENCE.get(operators[-1], 0) >= PRECEDENCE.get(token, 0)):
if len(values) < 2:
raise ValueError(f"Not enough operands for operator {operators[-1]}")
op = operators.pop()
b = safe_float_conversion(values.pop())
a = safe_float_conversion(values.pop())
result = perform_operation(op, a, b)
values.append(result)
operators.append(token)
else:
# 改进的数值转换部分
try:
if isinstance(token, str):
token_lower = token.lower()
if token_lower in NUMBER_WORDS:
# 处理数字单词
value = float(NUMBER_WORDS[token_lower])
else:
# 处理数字字符串(包括科学记数法)
value = safe_float_conversion(token)
values.append(value)
else:
raise ValueError(f"Invalid token type: {type(token)}")
except ValueError as ve:
logger.error(f"Error converting number: {token}")
raise ValueError(f"Invalid number format: {token}")
except Exception as e:
logger.error(f"Error processing token {token}: {str(e)}")
return float('inf')
# 处理剩余的操作符
while operators:
if len(values) < 2:
raise ValueError("Not enough operands for remaining operators")
op = operators.pop()
if op == '(':
raise ValueError("Mismatched parentheses")
b = safe_float_conversion(values.pop())
a = safe_float_conversion(values.pop())
result = perform_operation(op, a, b)
values.append(result)
if len(values) != 1:
raise ValueError("Invalid expression: too many values")
return values[0]
except Exception as e:
logger.error(f"Error evaluating expression '{expression}': {str(e)}")
return float('inf')
def safe_float_conversion(value):
"""安全地将值转换为浮点数,改进的版本"""
try:
if isinstance(value, (int, float)):
if math.isinf(value) or math.isnan(value):
return float('inf') if value > 0 else float('-inf')
return float(value)
if isinstance(value, str):
try:
result = float(value)
# 处理特殊情况
if math.isinf(result) or math.isnan(result):
return float('inf') if result > 0 else float('-inf')
# 处理超大数和超小数
if abs(result) > MAX_VALUE:
return float('inf') if result > 0 else float('-inf')
if abs(result) != 0 and abs(result) < MIN_VALUE:
return 0.0
return result
except ValueError:
# 处理数字单词
if value.lower() in NUMBER_WORDS:
return float(NUMBER_WORDS[value.lower()])
raise
raise ValueError(f"Cannot convert {type(value)} to float")
except (ValueError, TypeError) as e:
logger.error(f"Error in safe_float_conversion: {str(e)}")
return float('inf')
def gcd(a, b):
while b:
a, b = b, a % b
return a
def generate_expression(min_depth=3, max_depth=6, max_length=50):
"""改进的表达式生成函数"""
def generate_subexpression(depth, current_length):
if depth == 0 or current_length >= max_length:
if random.random() < 0.3:
return random.choice(list(NUMBER_WORDS.keys())), current_length + 1
return str(random.randint(-1000000, 1000000)), current_length + 1
choice = random.random()
if choice < 0.4:
# 生成简单的二元运算表达式
op = random.choice(OPERATORS)
left, left_length = generate_subexpression(depth - 1, current_length)
right, right_length = generate_subexpression(depth - 1, left_length + 1)
return f"({left} {op} {right})", right_length + 3
else:
# 生成带括号的表达式
op = random.choice(OPERATORS)
left, left_length = generate_subexpression(depth - 1, current_length)
right, right_length = generate_subexpression(depth - 1, left_length + 1)
return f"({left} {op} {right})", right_length + 3
while True:
try:
expression, _ = generate_subexpression(random.randint(min_depth, max_depth), 0)
# 验证生成的表达式
tokens = tokenize(expression)
if validate_expression_structure(tokens):
answer = evaluate_expression(expression)
if not math.isinf(answer):
return expression, answer
except:
continue
def generate_dataset(num_samples=100):
dataset = []
for _ in range(num_samples):
expression, answer = generate_expression()
dataset.append({"expression": expression, "answer": answer})
return dataset
def solve_expression(expression):
return evaluate_expression(expression)
def validate_solution(expression, expected_answer, calculated_answer):
"""改进的解决方案验证函数"""
# 处理无穷大的情况
if math.isinf(expected_answer) and math.isinf(calculated_answer):
return expected_answer * calculated_answer > 0 # 确保符号相同
# 处理零附近的值
if abs(expected_answer) < EPSILON and abs(calculated_answer) < EPSILON:
return True
# 处理普通情况
if abs(expected_answer) > EPSILON:
relative_error = abs((expected_answer - calculated_answer) / expected_answer)
return relative_error < EPSILON
return abs(expected_answer - calculated_answer) < EPSILON
def performance_test(num_expressions=1000):
start_time = time.time()
for _ in range(num_expressions):
expression, _ = generate_expression()
evaluate_expression(expression)
end_time = time.time()
avg_time = (end_time - start_time) / num_expressions
print(f"Average time per expression: {avg_time:.6f} seconds")
def test_system():
# 从文件加载数据集
with open("bbeh_arithmetic_dataset.json", "r") as f:
dataset = json.load(f)
total_tests = len(dataset)
passed_tests = 0
for item in dataset:
expression = item["expression"]
expected_answer = item["answer"]
calculated_answer = solve_expression(expression)
is_correct = validate_solution(expression, expected_answer, calculated_answer)
if is_correct:
passed_tests += 1
else:
print(f"Failed test: {expression}")
print(f"Expected: {expected_answer}, Calculated: {calculated_answer}")
success_rate = passed_tests / total_tests * 100
print(f"Passed {passed_tests} out of {total_tests} tests.")
print(f"Success rate: {success_rate:.2f}%")
# 添加边界测试
edge_cases = [
# 基本运算的边界情况
"1e308 + 1", # 接近最大值
"-1e308 - 1", # 接近最小值
"1e-307 * 1e-307", # 接近最小正数
"1e307 * 2", # 溢出测试
"1 / 1e-307", # 除法边界
# 自定义运算符的边界情况
"1e100 >< -1e100",
"1e100 ; 1e100",
"1e100 + 1e100",
"-1e100 - 1e100",
"1e-100 * 1e-100",
"1 / 1e-100",
"1e100 >< -1e100",
"1e100 ; 1e100",
"1e100 @ 1e100",
"1e100 <> -1e100",
"1e100 [] -1e100",
"1e100 # -1e100",
"1e100 ! -1e100",
"1e100 ~ -1e100",
"1e100 & -1e100",
"1e100 : -1e100",
"1e100 ][ -1e100"
]
print("\nTesting edge cases:")
for case in edge_cases:
result = evaluate_expression(case)
print(f"{case} = {result}")
# 运行性能测试
print("\nRunning performance test:")
performance_test()
# 生成数据集并保存到文件
dataset = generate_dataset()
with open("bbeh_arithmetic_dataset.json", "w") as f:
json.dump(dataset, f, indent=2)
# 运行测试
test_system()