mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
369 lines
14 KiB
Python
369 lines
14 KiB
Python
import json
|
||
import pandas as pd
|
||
import numpy as np
|
||
import re
|
||
pd.set_option('future.no_silent_downcasting', True)
|
||
from typing import Dict, List, Any, Union, Tuple
|
||
|
||
|
||
class BBEHBuggyTablesSolver:
|
||
def __init__(self):
|
||
"""初始化求解器"""
|
||
pass
|
||
|
||
def _fix_missing_null_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||
"""修复缺失的null值问题,使用统计方法推断和恢复缺失值"""
|
||
df = pd.DataFrame(table)
|
||
|
||
# 对每一列进行处理
|
||
for col in df.columns:
|
||
# 检查是否为数值列
|
||
numeric_col = pd.to_numeric(df[col], errors='coerce')
|
||
if numeric_col.notna().sum() > 0: # 如果列包含数值
|
||
# 计算列的统计特征
|
||
mean_val = numeric_col.mean()
|
||
std_val = numeric_col.std()
|
||
|
||
# 检测异常值的阈值
|
||
lower_bound = mean_val - 2 * std_val
|
||
upper_bound = mean_val + 2 * std_val
|
||
|
||
# 找出可能缺失null值的位置(数据异常断点)
|
||
values = numeric_col.dropna().values
|
||
gaps = np.diff(sorted(values))
|
||
median_gap = np.median(gaps)
|
||
|
||
# 如果发现异常大的间隔,在这些位置插入null
|
||
large_gaps = np.where(gaps > 3 * median_gap)[0]
|
||
if len(large_gaps) > 0:
|
||
df[col] = df[col].astype(object)
|
||
for gap_idx in large_gaps:
|
||
# 在异常间隔处插入null值
|
||
df.loc[len(df)] = {col: np.nan for col in df.columns}
|
||
|
||
return df.sort_index()
|
||
|
||
def _fix_appended_random_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||
"""修复添加的随机值"""
|
||
df = pd.DataFrame(table)
|
||
|
||
if "每行末尾添加了一个随机值列" in bug_description:
|
||
# 随机值被添加为新列,识别并删除它
|
||
match = re.search(r"'(.+?)'", bug_description)
|
||
if match:
|
||
random_col = match.group(1)
|
||
if random_col in df.columns:
|
||
df = df.drop(columns=[random_col])
|
||
else:
|
||
# 如果没有找到列名,尝试删除最后一列
|
||
df = df.iloc[:, :-1]
|
||
|
||
elif "表格末尾添加了一行随机值" in bug_description:
|
||
# 随机值被添加为新行,删除最后一行
|
||
df = df.iloc[:-1]
|
||
|
||
return df
|
||
|
||
def _fix_merged_rows(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||
"""修复合并的行"""
|
||
df = pd.DataFrame(table)
|
||
unmerged_data = []
|
||
|
||
for _, row in df.iterrows():
|
||
for col in df.columns:
|
||
value = row[col]
|
||
# 检查是否是字符串并包含逗号
|
||
if isinstance(value, str) and "," in value:
|
||
# 这一行是合并的,需要拆分
|
||
split_values = {}
|
||
for c in df.columns:
|
||
vals = str(row[c]).split(",")
|
||
split_values[c] = [v if v != "null" else None for v in vals]
|
||
|
||
# 确定每行有多少个值
|
||
max_vals = max(len(split_values[c]) for c in df.columns)
|
||
|
||
# 创建拆分后的行
|
||
for i in range(max_vals):
|
||
new_row = {}
|
||
for c in df.columns:
|
||
vals = split_values[c]
|
||
if i < len(vals):
|
||
new_row[c] = vals[i]
|
||
else:
|
||
new_row[c] = None
|
||
unmerged_data.append(new_row)
|
||
|
||
break
|
||
else:
|
||
# 这行没有合并值,直接添加
|
||
unmerged_data.append(row.to_dict())
|
||
|
||
return pd.DataFrame(unmerged_data)
|
||
|
||
def _fix_rotated_data(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||
"""修复旋转的数据"""
|
||
df = pd.DataFrame(table)
|
||
|
||
if "行被旋转" in bug_description and "最后一行被移到了第一行" in bug_description:
|
||
# 将第一行移到最后一行
|
||
rows = df.values.tolist()
|
||
if len(rows) > 1:
|
||
rows = rows[1:] + [rows[0]]
|
||
df = pd.DataFrame(rows, columns=df.columns)
|
||
|
||
elif "列被旋转" in bug_description and "最后一列被移到了第一列" in bug_description:
|
||
# 将第一列移到最后一列
|
||
cols = list(df.columns)
|
||
if len(cols) > 1:
|
||
new_order = cols[1:] + [cols[0]]
|
||
df = df[new_order]
|
||
|
||
return df
|
||
|
||
def _fix_replaced_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||
df = pd.DataFrame(table)
|
||
|
||
for col in df.columns:
|
||
df[col] = df[col].astype(object)
|
||
|
||
error_mask = df[col] == "ERROR"
|
||
if error_mask.any():
|
||
numeric_values = pd.to_numeric(df[~error_mask][col], errors='coerce')
|
||
|
||
if numeric_values.notna().sum() > 0:
|
||
mean_val = numeric_values.mean()
|
||
median_val = numeric_values.median()
|
||
|
||
for idx in df[error_mask].index:
|
||
prev_vals = df.loc[:idx - 1, col]
|
||
next_vals = df.loc[idx + 1:, col]
|
||
|
||
# 使用 concat 替代 append
|
||
nearby_values = pd.to_numeric(
|
||
pd.concat([prev_vals.tail(2), next_vals.head(2)]),
|
||
errors='coerce'
|
||
).dropna()
|
||
|
||
if len(nearby_values) > 0:
|
||
df.loc[idx, col] = nearby_values.mean()
|
||
else:
|
||
df.loc[idx, col] = median_val
|
||
else:
|
||
df[col] = df[col].replace("ERROR", np.nan)
|
||
|
||
return df
|
||
|
||
def fix_table(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||
"""根据bug描述修复表格"""
|
||
if "空值(null)被错误地移除" in bug_description:
|
||
return self._fix_missing_null_values(table, bug_description)
|
||
elif "添加了一个随机值列" in bug_description or "添加了一行随机值" in bug_description:
|
||
return self._fix_appended_random_values(table, bug_description)
|
||
elif "每两行被错误地合并" in bug_description:
|
||
return self._fix_merged_rows(table, bug_description)
|
||
elif "行被旋转" in bug_description or "列被旋转" in bug_description:
|
||
return self._fix_rotated_data(table, bug_description)
|
||
elif "值被错误地替换为'ERROR'" in bug_description:
|
||
return self._fix_replaced_values(table, bug_description)
|
||
else:
|
||
# 未识别的bug类型,返回原始表格
|
||
return pd.DataFrame(table)
|
||
|
||
def _convert_values(self, df: pd.DataFrame) -> pd.DataFrame:
|
||
"""改进的值转换方法,更好地处理混合类型数据"""
|
||
for col in df.columns:
|
||
# 保存原始数据类型
|
||
original_dtype = df[col].dtype
|
||
|
||
# 尝试转换为数值类型
|
||
try:
|
||
numeric_values = pd.to_numeric(df[col], errors='coerce')
|
||
non_numeric = df[col][numeric_values.isna() & df[col].notna()]
|
||
|
||
# 检查是否主要是数值类型
|
||
if numeric_values.notna().sum() / len(df) > 0.5:
|
||
df[col] = numeric_values
|
||
|
||
# 特殊处理非数值项
|
||
if len(non_numeric) > 0:
|
||
for idx in non_numeric.index:
|
||
if df.loc[idx, col] == 'null':
|
||
df.loc[idx, col] = np.nan
|
||
elif isinstance(df.loc[idx, col], str):
|
||
try:
|
||
# 尝试清理并转换字符串
|
||
cleaned_val = df.loc[idx, col].strip().replace(',', '')
|
||
df.loc[idx, col] = float(cleaned_val)
|
||
except:
|
||
df.loc[idx, col] = np.nan
|
||
except:
|
||
# 如果转换失败,保持原始类型
|
||
df[col] = df[col].astype(original_dtype)
|
||
|
||
# 统一处理null值
|
||
df[col] = df[col].replace(['null', 'NULL', 'None', ''], np.nan)
|
||
|
||
return df
|
||
|
||
def apply_condition(self, df: pd.DataFrame, condition: Dict) -> pd.DataFrame:
|
||
"""应用查询条件"""
|
||
if not condition:
|
||
return df
|
||
|
||
column = condition.get('column')
|
||
operator = condition.get('operator')
|
||
value = condition.get('value')
|
||
|
||
if not column or not operator or column not in df.columns:
|
||
return df
|
||
|
||
if operator == ">":
|
||
return df[df[column] > value]
|
||
elif operator == "<":
|
||
return df[df[column] < value]
|
||
elif operator == "==":
|
||
return df[df[column] == value]
|
||
elif operator == "between":
|
||
low = condition.get('low')
|
||
high = condition.get('high')
|
||
if low is not None and high is not None:
|
||
return df[(df[column] >= low) & (df[column] <= high)]
|
||
|
||
return df
|
||
|
||
def execute_query(self, df: pd.DataFrame, query_info: Dict) -> float:
|
||
query_type = query_info.get('type')
|
||
column = query_info.get('column')
|
||
condition = query_info.get('condition', {})
|
||
|
||
if not query_type or not column or column not in df.columns:
|
||
return np.nan
|
||
|
||
try:
|
||
# 创建副本避免 SettingWithCopyWarning
|
||
filtered_df = self.apply_condition(df, condition).copy()
|
||
filtered_df.loc[:, column] = pd.to_numeric(filtered_df[column], errors='coerce')
|
||
|
||
values = filtered_df[column].dropna()
|
||
|
||
if len(values) == 0:
|
||
return np.nan
|
||
|
||
# 执行查询
|
||
if query_type == "count":
|
||
return len(values)
|
||
elif query_type == "sum":
|
||
return float(values.sum())
|
||
elif query_type == "mean":
|
||
return float(values.mean())
|
||
elif query_type == "stdev":
|
||
return float(values.std()) if len(values) > 1 else np.nan
|
||
elif query_type == "median":
|
||
return float(values.median())
|
||
except Exception as e:
|
||
print(f"Query execution error: {str(e)}")
|
||
return np.nan
|
||
|
||
return np.nan
|
||
|
||
def parse_query_string(self, query_string: str) -> Dict:
|
||
"""从查询字符串解析查询信息"""
|
||
query_info = {}
|
||
|
||
# 匹配查询类型和列名
|
||
query_match = re.match(r"(\w+)\((.+?)\)(\s+where\s+(.+))?", query_string)
|
||
if not query_match:
|
||
return query_info
|
||
|
||
query_type, column = query_match.group(1), query_match.group(2)
|
||
query_info["type"] = query_type
|
||
query_info["column"] = column
|
||
|
||
# 提取条件(如果有)
|
||
if query_match.group(4):
|
||
condition_str = query_match.group(4)
|
||
|
||
# 处理不同类型的条件
|
||
between_match = re.match(r"(\d+)\s*<=\s*(.+?)\s*<=\s*(\d+)", condition_str)
|
||
gt_match = re.match(r"(.+?)\s*>\s*(\d+)", condition_str)
|
||
lt_match = re.match(r"(.+?)\s*<\s*(\d+)", condition_str)
|
||
eq_match = re.match(r"(.+?)\s*==\s*(\d+)", condition_str)
|
||
|
||
if between_match:
|
||
low, col, high = between_match.groups()
|
||
query_info["condition"] = {
|
||
"column": col,
|
||
"operator": "between",
|
||
"low": int(low),
|
||
"high": int(high)
|
||
}
|
||
elif gt_match:
|
||
col, val = gt_match.groups()
|
||
query_info["condition"] = {
|
||
"column": col,
|
||
"operator": ">",
|
||
"value": int(val)
|
||
}
|
||
elif lt_match:
|
||
col, val = lt_match.groups()
|
||
query_info["condition"] = {
|
||
"column": col,
|
||
"operator": "<",
|
||
"value": int(val)
|
||
}
|
||
elif eq_match:
|
||
col, val = eq_match.groups()
|
||
query_info["condition"] = {
|
||
"column": col,
|
||
"operator": "==",
|
||
"value": int(val)
|
||
}
|
||
|
||
return query_info
|
||
|
||
def solve(self, example: Dict) -> float:
|
||
"""解决一个示例问题"""
|
||
# 提取输入数据
|
||
table = example['input']['table']
|
||
bug_description = example['input']['bug_description']
|
||
query = example['input']['query']
|
||
|
||
# 修复表格
|
||
fixed_df = self.fix_table(table, bug_description)
|
||
|
||
# 转换数据类型
|
||
fixed_df = self._convert_values(fixed_df)
|
||
|
||
# 解析查询信息
|
||
if 'query_info' in example:
|
||
query_info = example['query_info']
|
||
else:
|
||
query_info = self.parse_query_string(query)
|
||
|
||
# 执行查询
|
||
result = self.execute_query(fixed_df, query_info)
|
||
|
||
# 根据查询类型做适当的四舍五入
|
||
if result is not None:
|
||
if query_info.get('type') in ['mean', 'stdev']:
|
||
result = round(result, 2)
|
||
elif query_info.get('type') in ['sum', 'median']:
|
||
result = round(result, 1)
|
||
|
||
return result
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 从文件加载测试数据
|
||
with open("generated_dataset.json", 'r', encoding='utf-8') as f:
|
||
test_data = json.load(f)
|
||
|
||
solver = BBEHBuggyTablesSolver()
|
||
|
||
# 解决每个示例
|
||
for i, example in enumerate(test_data['examples']):
|
||
result = solver.solve(example)
|
||
print(f"示例 {i + 1} 的计算结果: {result}")
|
||
|