InternBootcamp/internbootcamp/libs/bbeh_buggy_tables/bbeh_buggy_tables_solver.py
2025-05-23 15:27:15 +08:00

369 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import pandas as pd
import numpy as np
import re
pd.set_option('future.no_silent_downcasting', True)
from typing import Dict, List, Any, Union, Tuple
class BBEHBuggyTablesSolver:
def __init__(self):
"""初始化求解器"""
pass
def _fix_missing_null_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
"""修复缺失的null值问题使用统计方法推断和恢复缺失值"""
df = pd.DataFrame(table)
# 对每一列进行处理
for col in df.columns:
# 检查是否为数值列
numeric_col = pd.to_numeric(df[col], errors='coerce')
if numeric_col.notna().sum() > 0: # 如果列包含数值
# 计算列的统计特征
mean_val = numeric_col.mean()
std_val = numeric_col.std()
# 检测异常值的阈值
lower_bound = mean_val - 2 * std_val
upper_bound = mean_val + 2 * std_val
# 找出可能缺失null值的位置数据异常断点
values = numeric_col.dropna().values
gaps = np.diff(sorted(values))
median_gap = np.median(gaps)
# 如果发现异常大的间隔在这些位置插入null
large_gaps = np.where(gaps > 3 * median_gap)[0]
if len(large_gaps) > 0:
df[col] = df[col].astype(object)
for gap_idx in large_gaps:
# 在异常间隔处插入null值
df.loc[len(df)] = {col: np.nan for col in df.columns}
return df.sort_index()
def _fix_appended_random_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
"""修复添加的随机值"""
df = pd.DataFrame(table)
if "每行末尾添加了一个随机值列" in bug_description:
# 随机值被添加为新列,识别并删除它
match = re.search(r"'(.+?)'", bug_description)
if match:
random_col = match.group(1)
if random_col in df.columns:
df = df.drop(columns=[random_col])
else:
# 如果没有找到列名,尝试删除最后一列
df = df.iloc[:, :-1]
elif "表格末尾添加了一行随机值" in bug_description:
# 随机值被添加为新行,删除最后一行
df = df.iloc[:-1]
return df
def _fix_merged_rows(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
"""修复合并的行"""
df = pd.DataFrame(table)
unmerged_data = []
for _, row in df.iterrows():
for col in df.columns:
value = row[col]
# 检查是否是字符串并包含逗号
if isinstance(value, str) and "," in value:
# 这一行是合并的,需要拆分
split_values = {}
for c in df.columns:
vals = str(row[c]).split(",")
split_values[c] = [v if v != "null" else None for v in vals]
# 确定每行有多少个值
max_vals = max(len(split_values[c]) for c in df.columns)
# 创建拆分后的行
for i in range(max_vals):
new_row = {}
for c in df.columns:
vals = split_values[c]
if i < len(vals):
new_row[c] = vals[i]
else:
new_row[c] = None
unmerged_data.append(new_row)
break
else:
# 这行没有合并值,直接添加
unmerged_data.append(row.to_dict())
return pd.DataFrame(unmerged_data)
def _fix_rotated_data(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
"""修复旋转的数据"""
df = pd.DataFrame(table)
if "行被旋转" in bug_description and "最后一行被移到了第一行" in bug_description:
# 将第一行移到最后一行
rows = df.values.tolist()
if len(rows) > 1:
rows = rows[1:] + [rows[0]]
df = pd.DataFrame(rows, columns=df.columns)
elif "列被旋转" in bug_description and "最后一列被移到了第一列" in bug_description:
# 将第一列移到最后一列
cols = list(df.columns)
if len(cols) > 1:
new_order = cols[1:] + [cols[0]]
df = df[new_order]
return df
def _fix_replaced_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
df = pd.DataFrame(table)
for col in df.columns:
df[col] = df[col].astype(object)
error_mask = df[col] == "ERROR"
if error_mask.any():
numeric_values = pd.to_numeric(df[~error_mask][col], errors='coerce')
if numeric_values.notna().sum() > 0:
mean_val = numeric_values.mean()
median_val = numeric_values.median()
for idx in df[error_mask].index:
prev_vals = df.loc[:idx - 1, col]
next_vals = df.loc[idx + 1:, col]
# 使用 concat 替代 append
nearby_values = pd.to_numeric(
pd.concat([prev_vals.tail(2), next_vals.head(2)]),
errors='coerce'
).dropna()
if len(nearby_values) > 0:
df.loc[idx, col] = nearby_values.mean()
else:
df.loc[idx, col] = median_val
else:
df[col] = df[col].replace("ERROR", np.nan)
return df
def fix_table(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
"""根据bug描述修复表格"""
if "空值(null)被错误地移除" in bug_description:
return self._fix_missing_null_values(table, bug_description)
elif "添加了一个随机值列" in bug_description or "添加了一行随机值" in bug_description:
return self._fix_appended_random_values(table, bug_description)
elif "每两行被错误地合并" in bug_description:
return self._fix_merged_rows(table, bug_description)
elif "行被旋转" in bug_description or "列被旋转" in bug_description:
return self._fix_rotated_data(table, bug_description)
elif "值被错误地替换为'ERROR'" in bug_description:
return self._fix_replaced_values(table, bug_description)
else:
# 未识别的bug类型返回原始表格
return pd.DataFrame(table)
def _convert_values(self, df: pd.DataFrame) -> pd.DataFrame:
"""改进的值转换方法,更好地处理混合类型数据"""
for col in df.columns:
# 保存原始数据类型
original_dtype = df[col].dtype
# 尝试转换为数值类型
try:
numeric_values = pd.to_numeric(df[col], errors='coerce')
non_numeric = df[col][numeric_values.isna() & df[col].notna()]
# 检查是否主要是数值类型
if numeric_values.notna().sum() / len(df) > 0.5:
df[col] = numeric_values
# 特殊处理非数值项
if len(non_numeric) > 0:
for idx in non_numeric.index:
if df.loc[idx, col] == 'null':
df.loc[idx, col] = np.nan
elif isinstance(df.loc[idx, col], str):
try:
# 尝试清理并转换字符串
cleaned_val = df.loc[idx, col].strip().replace(',', '')
df.loc[idx, col] = float(cleaned_val)
except:
df.loc[idx, col] = np.nan
except:
# 如果转换失败,保持原始类型
df[col] = df[col].astype(original_dtype)
# 统一处理null值
df[col] = df[col].replace(['null', 'NULL', 'None', ''], np.nan)
return df
def apply_condition(self, df: pd.DataFrame, condition: Dict) -> pd.DataFrame:
"""应用查询条件"""
if not condition:
return df
column = condition.get('column')
operator = condition.get('operator')
value = condition.get('value')
if not column or not operator or column not in df.columns:
return df
if operator == ">":
return df[df[column] > value]
elif operator == "<":
return df[df[column] < value]
elif operator == "==":
return df[df[column] == value]
elif operator == "between":
low = condition.get('low')
high = condition.get('high')
if low is not None and high is not None:
return df[(df[column] >= low) & (df[column] <= high)]
return df
def execute_query(self, df: pd.DataFrame, query_info: Dict) -> float:
query_type = query_info.get('type')
column = query_info.get('column')
condition = query_info.get('condition', {})
if not query_type or not column or column not in df.columns:
return np.nan
try:
# 创建副本避免 SettingWithCopyWarning
filtered_df = self.apply_condition(df, condition).copy()
filtered_df.loc[:, column] = pd.to_numeric(filtered_df[column], errors='coerce')
values = filtered_df[column].dropna()
if len(values) == 0:
return np.nan
# 执行查询
if query_type == "count":
return len(values)
elif query_type == "sum":
return float(values.sum())
elif query_type == "mean":
return float(values.mean())
elif query_type == "stdev":
return float(values.std()) if len(values) > 1 else np.nan
elif query_type == "median":
return float(values.median())
except Exception as e:
print(f"Query execution error: {str(e)}")
return np.nan
return np.nan
def parse_query_string(self, query_string: str) -> Dict:
"""从查询字符串解析查询信息"""
query_info = {}
# 匹配查询类型和列名
query_match = re.match(r"(\w+)\((.+?)\)(\s+where\s+(.+))?", query_string)
if not query_match:
return query_info
query_type, column = query_match.group(1), query_match.group(2)
query_info["type"] = query_type
query_info["column"] = column
# 提取条件(如果有)
if query_match.group(4):
condition_str = query_match.group(4)
# 处理不同类型的条件
between_match = re.match(r"(\d+)\s*<=\s*(.+?)\s*<=\s*(\d+)", condition_str)
gt_match = re.match(r"(.+?)\s*>\s*(\d+)", condition_str)
lt_match = re.match(r"(.+?)\s*<\s*(\d+)", condition_str)
eq_match = re.match(r"(.+?)\s*==\s*(\d+)", condition_str)
if between_match:
low, col, high = between_match.groups()
query_info["condition"] = {
"column": col,
"operator": "between",
"low": int(low),
"high": int(high)
}
elif gt_match:
col, val = gt_match.groups()
query_info["condition"] = {
"column": col,
"operator": ">",
"value": int(val)
}
elif lt_match:
col, val = lt_match.groups()
query_info["condition"] = {
"column": col,
"operator": "<",
"value": int(val)
}
elif eq_match:
col, val = eq_match.groups()
query_info["condition"] = {
"column": col,
"operator": "==",
"value": int(val)
}
return query_info
def solve(self, example: Dict) -> float:
"""解决一个示例问题"""
# 提取输入数据
table = example['input']['table']
bug_description = example['input']['bug_description']
query = example['input']['query']
# 修复表格
fixed_df = self.fix_table(table, bug_description)
# 转换数据类型
fixed_df = self._convert_values(fixed_df)
# 解析查询信息
if 'query_info' in example:
query_info = example['query_info']
else:
query_info = self.parse_query_string(query)
# 执行查询
result = self.execute_query(fixed_df, query_info)
# 根据查询类型做适当的四舍五入
if result is not None:
if query_info.get('type') in ['mean', 'stdev']:
result = round(result, 2)
elif query_info.get('type') in ['sum', 'median']:
result = round(result, 1)
return result
# 使用示例
if __name__ == "__main__":
# 从文件加载测试数据
with open("generated_dataset.json", 'r', encoding='utf-8') as f:
test_data = json.load(f)
solver = BBEHBuggyTablesSolver()
# 解决每个示例
for i, example in enumerate(test_data['examples']):
result = solver.solve(example)
print(f"示例 {i + 1} 的计算结果: {result}")