init-commit

This commit is contained in:
lilinyang 2025-05-23 15:27:15 +08:00
commit 18a552597a
3461 changed files with 1150579 additions and 0 deletions

View file

@ -0,0 +1,369 @@
import json
import pandas as pd
import numpy as np
import re
pd.set_option('future.no_silent_downcasting', True)
from typing import Dict, List, Any, Union, Tuple
class BBEHBuggyTablesSolver:
def __init__(self):
"""初始化求解器"""
pass
def _fix_missing_null_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
"""修复缺失的null值问题使用统计方法推断和恢复缺失值"""
df = pd.DataFrame(table)
# 对每一列进行处理
for col in df.columns:
# 检查是否为数值列
numeric_col = pd.to_numeric(df[col], errors='coerce')
if numeric_col.notna().sum() > 0: # 如果列包含数值
# 计算列的统计特征
mean_val = numeric_col.mean()
std_val = numeric_col.std()
# 检测异常值的阈值
lower_bound = mean_val - 2 * std_val
upper_bound = mean_val + 2 * std_val
# 找出可能缺失null值的位置数据异常断点
values = numeric_col.dropna().values
gaps = np.diff(sorted(values))
median_gap = np.median(gaps)
# 如果发现异常大的间隔在这些位置插入null
large_gaps = np.where(gaps > 3 * median_gap)[0]
if len(large_gaps) > 0:
df[col] = df[col].astype(object)
for gap_idx in large_gaps:
# 在异常间隔处插入null值
df.loc[len(df)] = {col: np.nan for col in df.columns}
return df.sort_index()
def _fix_appended_random_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
"""修复添加的随机值"""
df = pd.DataFrame(table)
if "每行末尾添加了一个随机值列" in bug_description:
# 随机值被添加为新列,识别并删除它
match = re.search(r"'(.+?)'", bug_description)
if match:
random_col = match.group(1)
if random_col in df.columns:
df = df.drop(columns=[random_col])
else:
# 如果没有找到列名,尝试删除最后一列
df = df.iloc[:, :-1]
elif "表格末尾添加了一行随机值" in bug_description:
# 随机值被添加为新行,删除最后一行
df = df.iloc[:-1]
return df
def _fix_merged_rows(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
"""修复合并的行"""
df = pd.DataFrame(table)
unmerged_data = []
for _, row in df.iterrows():
for col in df.columns:
value = row[col]
# 检查是否是字符串并包含逗号
if isinstance(value, str) and "," in value:
# 这一行是合并的,需要拆分
split_values = {}
for c in df.columns:
vals = str(row[c]).split(",")
split_values[c] = [v if v != "null" else None for v in vals]
# 确定每行有多少个值
max_vals = max(len(split_values[c]) for c in df.columns)
# 创建拆分后的行
for i in range(max_vals):
new_row = {}
for c in df.columns:
vals = split_values[c]
if i < len(vals):
new_row[c] = vals[i]
else:
new_row[c] = None
unmerged_data.append(new_row)
break
else:
# 这行没有合并值,直接添加
unmerged_data.append(row.to_dict())
return pd.DataFrame(unmerged_data)
def _fix_rotated_data(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
"""修复旋转的数据"""
df = pd.DataFrame(table)
if "行被旋转" in bug_description and "最后一行被移到了第一行" in bug_description:
# 将第一行移到最后一行
rows = df.values.tolist()
if len(rows) > 1:
rows = rows[1:] + [rows[0]]
df = pd.DataFrame(rows, columns=df.columns)
elif "列被旋转" in bug_description and "最后一列被移到了第一列" in bug_description:
# 将第一列移到最后一列
cols = list(df.columns)
if len(cols) > 1:
new_order = cols[1:] + [cols[0]]
df = df[new_order]
return df
def _fix_replaced_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
df = pd.DataFrame(table)
for col in df.columns:
df[col] = df[col].astype(object)
error_mask = df[col] == "ERROR"
if error_mask.any():
numeric_values = pd.to_numeric(df[~error_mask][col], errors='coerce')
if numeric_values.notna().sum() > 0:
mean_val = numeric_values.mean()
median_val = numeric_values.median()
for idx in df[error_mask].index:
prev_vals = df.loc[:idx - 1, col]
next_vals = df.loc[idx + 1:, col]
# 使用 concat 替代 append
nearby_values = pd.to_numeric(
pd.concat([prev_vals.tail(2), next_vals.head(2)]),
errors='coerce'
).dropna()
if len(nearby_values) > 0:
df.loc[idx, col] = nearby_values.mean()
else:
df.loc[idx, col] = median_val
else:
df[col] = df[col].replace("ERROR", np.nan)
return df
def fix_table(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
"""根据bug描述修复表格"""
if "空值(null)被错误地移除" in bug_description:
return self._fix_missing_null_values(table, bug_description)
elif "添加了一个随机值列" in bug_description or "添加了一行随机值" in bug_description:
return self._fix_appended_random_values(table, bug_description)
elif "每两行被错误地合并" in bug_description:
return self._fix_merged_rows(table, bug_description)
elif "行被旋转" in bug_description or "列被旋转" in bug_description:
return self._fix_rotated_data(table, bug_description)
elif "值被错误地替换为'ERROR'" in bug_description:
return self._fix_replaced_values(table, bug_description)
else:
# 未识别的bug类型返回原始表格
return pd.DataFrame(table)
def _convert_values(self, df: pd.DataFrame) -> pd.DataFrame:
"""改进的值转换方法,更好地处理混合类型数据"""
for col in df.columns:
# 保存原始数据类型
original_dtype = df[col].dtype
# 尝试转换为数值类型
try:
numeric_values = pd.to_numeric(df[col], errors='coerce')
non_numeric = df[col][numeric_values.isna() & df[col].notna()]
# 检查是否主要是数值类型
if numeric_values.notna().sum() / len(df) > 0.5:
df[col] = numeric_values
# 特殊处理非数值项
if len(non_numeric) > 0:
for idx in non_numeric.index:
if df.loc[idx, col] == 'null':
df.loc[idx, col] = np.nan
elif isinstance(df.loc[idx, col], str):
try:
# 尝试清理并转换字符串
cleaned_val = df.loc[idx, col].strip().replace(',', '')
df.loc[idx, col] = float(cleaned_val)
except:
df.loc[idx, col] = np.nan
except:
# 如果转换失败,保持原始类型
df[col] = df[col].astype(original_dtype)
# 统一处理null值
df[col] = df[col].replace(['null', 'NULL', 'None', ''], np.nan)
return df
def apply_condition(self, df: pd.DataFrame, condition: Dict) -> pd.DataFrame:
"""应用查询条件"""
if not condition:
return df
column = condition.get('column')
operator = condition.get('operator')
value = condition.get('value')
if not column or not operator or column not in df.columns:
return df
if operator == ">":
return df[df[column] > value]
elif operator == "<":
return df[df[column] < value]
elif operator == "==":
return df[df[column] == value]
elif operator == "between":
low = condition.get('low')
high = condition.get('high')
if low is not None and high is not None:
return df[(df[column] >= low) & (df[column] <= high)]
return df
def execute_query(self, df: pd.DataFrame, query_info: Dict) -> float:
query_type = query_info.get('type')
column = query_info.get('column')
condition = query_info.get('condition', {})
if not query_type or not column or column not in df.columns:
return np.nan
try:
# 创建副本避免 SettingWithCopyWarning
filtered_df = self.apply_condition(df, condition).copy()
filtered_df.loc[:, column] = pd.to_numeric(filtered_df[column], errors='coerce')
values = filtered_df[column].dropna()
if len(values) == 0:
return np.nan
# 执行查询
if query_type == "count":
return len(values)
elif query_type == "sum":
return float(values.sum())
elif query_type == "mean":
return float(values.mean())
elif query_type == "stdev":
return float(values.std()) if len(values) > 1 else np.nan
elif query_type == "median":
return float(values.median())
except Exception as e:
print(f"Query execution error: {str(e)}")
return np.nan
return np.nan
def parse_query_string(self, query_string: str) -> Dict:
"""从查询字符串解析查询信息"""
query_info = {}
# 匹配查询类型和列名
query_match = re.match(r"(\w+)\((.+?)\)(\s+where\s+(.+))?", query_string)
if not query_match:
return query_info
query_type, column = query_match.group(1), query_match.group(2)
query_info["type"] = query_type
query_info["column"] = column
# 提取条件(如果有)
if query_match.group(4):
condition_str = query_match.group(4)
# 处理不同类型的条件
between_match = re.match(r"(\d+)\s*<=\s*(.+?)\s*<=\s*(\d+)", condition_str)
gt_match = re.match(r"(.+?)\s*>\s*(\d+)", condition_str)
lt_match = re.match(r"(.+?)\s*<\s*(\d+)", condition_str)
eq_match = re.match(r"(.+?)\s*==\s*(\d+)", condition_str)
if between_match:
low, col, high = between_match.groups()
query_info["condition"] = {
"column": col,
"operator": "between",
"low": int(low),
"high": int(high)
}
elif gt_match:
col, val = gt_match.groups()
query_info["condition"] = {
"column": col,
"operator": ">",
"value": int(val)
}
elif lt_match:
col, val = lt_match.groups()
query_info["condition"] = {
"column": col,
"operator": "<",
"value": int(val)
}
elif eq_match:
col, val = eq_match.groups()
query_info["condition"] = {
"column": col,
"operator": "==",
"value": int(val)
}
return query_info
def solve(self, example: Dict) -> float:
"""解决一个示例问题"""
# 提取输入数据
table = example['input']['table']
bug_description = example['input']['bug_description']
query = example['input']['query']
# 修复表格
fixed_df = self.fix_table(table, bug_description)
# 转换数据类型
fixed_df = self._convert_values(fixed_df)
# 解析查询信息
if 'query_info' in example:
query_info = example['query_info']
else:
query_info = self.parse_query_string(query)
# 执行查询
result = self.execute_query(fixed_df, query_info)
# 根据查询类型做适当的四舍五入
if result is not None:
if query_info.get('type') in ['mean', 'stdev']:
result = round(result, 2)
elif query_info.get('type') in ['sum', 'median']:
result = round(result, 1)
return result
# 使用示例
if __name__ == "__main__":
# 从文件加载测试数据
with open("generated_dataset.json", 'r', encoding='utf-8') as f:
test_data = json.load(f)
solver = BBEHBuggyTablesSolver()
# 解决每个示例
for i, example in enumerate(test_data['examples']):
result = solver.solve(example)
print(f"示例 {i + 1} 的计算结果: {result}")