mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-24 17:05:00 +00:00
init-commit
This commit is contained in:
commit
18a552597a
3461 changed files with 1150579 additions and 0 deletions
|
|
@ -0,0 +1,267 @@
|
|||
import json
|
||||
import random
|
||||
import copy
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Any, Tuple, Union
|
||||
|
||||
|
||||
class BBEHBuggyTablesGenerator:
|
||||
def __init__(self, task_file_path: str):
|
||||
"""初始化生成器,加载task.json文件"""
|
||||
self.task_data = self._load_task_data(task_file_path)
|
||||
self.bug_types = [
|
||||
"missing_null_values", # 空值被错误移除
|
||||
"appended_random_values", # 添加随机值
|
||||
"merged_rows", # 行合并错误
|
||||
"rotated_data", # 数据旋转
|
||||
"replaced_values" # 数据替换
|
||||
]
|
||||
self.query_types = ["count", "sum", "mean", "stdev", "median"]
|
||||
|
||||
def _load_task_data(self, file_path: str) -> Dict:
|
||||
"""加载task.json文件"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"加载任务数据文件失败: {e}")
|
||||
return {}
|
||||
|
||||
def _extract_examples(self) -> List[Dict]:
|
||||
"""从task.json提取示例"""
|
||||
if 'examples' in self.task_data:
|
||||
return self.task_data['examples']
|
||||
return []
|
||||
|
||||
def _generate_clean_table(self, n_rows: int, n_cols: int) -> pd.DataFrame:
|
||||
"""生成一个干净的表格数据"""
|
||||
data = {}
|
||||
for col in range(n_cols):
|
||||
col_name = f"col_{col}"
|
||||
# 随机选择列的数据类型:数值或分类
|
||||
if random.random() > 0.3: # 70%概率是数值列
|
||||
# 生成数值列,包含一些空值
|
||||
values = np.random.randint(1, 100, size=n_rows).astype(float)
|
||||
# 随机添加一些空值
|
||||
null_indices = random.sample(range(n_rows), k=random.randint(0, n_rows // 5))
|
||||
values[null_indices] = np.nan
|
||||
else: # 30%概率是分类列
|
||||
# 生成分类列
|
||||
categories = [f"cat_{i}" for i in range(random.randint(3, 7))]
|
||||
values = [random.choice(categories) for _ in range(n_rows)]
|
||||
# 随机添加一些空值
|
||||
null_indices = random.sample(range(n_rows), k=random.randint(0, n_rows // 5))
|
||||
for idx in null_indices:
|
||||
values[idx] = None
|
||||
data[col_name] = values
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def _apply_bug(self, df: pd.DataFrame, bug_type: str) -> Tuple[pd.DataFrame, str]:
|
||||
"""对表格应用指定类型的bug"""
|
||||
buggy_df = copy.deepcopy(df)
|
||||
bug_description = ""
|
||||
|
||||
if bug_type == "missing_null_values":
|
||||
# 删除空值(不保留空值标记)
|
||||
buggy_df = buggy_df.dropna()
|
||||
bug_description = "表格中的所有空值(null)被错误地移除了。"
|
||||
|
||||
elif bug_type == "appended_random_values":
|
||||
# 在每行末尾添加随机值
|
||||
if random.choice([True, False]): # 随机选择添加到行还是列
|
||||
# 添加到行
|
||||
n_rows, n_cols = buggy_df.shape
|
||||
new_col = f"random_col_{n_cols}"
|
||||
buggy_df[new_col] = [random.randint(1, 100) for _ in range(n_rows)]
|
||||
bug_description = f"表格中每行末尾添加了一个随机值列 '{new_col}'。"
|
||||
else:
|
||||
# 添加到列
|
||||
random_row = {}
|
||||
for col in buggy_df.columns:
|
||||
if pd.api.types.is_numeric_dtype(buggy_df[col]):
|
||||
random_row[col] = random.randint(1, 100)
|
||||
else:
|
||||
random_row[col] = f"random_value_{random.randint(1, 10)}"
|
||||
buggy_df = pd.concat([buggy_df, pd.DataFrame([random_row])], ignore_index=True)
|
||||
bug_description = "表格末尾添加了一行随机值。"
|
||||
|
||||
elif bug_type == "merged_rows":
|
||||
# 每两行合并成一行
|
||||
merged_data = []
|
||||
n_rows = len(buggy_df)
|
||||
|
||||
for i in range(0, n_rows, 2):
|
||||
if i + 1 < n_rows:
|
||||
merged_row = {}
|
||||
for col in buggy_df.columns:
|
||||
# 合并两行的值,用逗号分隔
|
||||
val1 = str(buggy_df.iloc[i][col]) if pd.notna(buggy_df.iloc[i][col]) else "null"
|
||||
val2 = str(buggy_df.iloc[i + 1][col]) if pd.notna(buggy_df.iloc[i + 1][col]) else "null"
|
||||
merged_row[col] = f"{val1},{val2}"
|
||||
merged_data.append(merged_row)
|
||||
else:
|
||||
# 如果是奇数行,最后一行单独保留
|
||||
merged_row = {}
|
||||
for col in buggy_df.columns:
|
||||
val = str(buggy_df.iloc[i][col]) if pd.notna(buggy_df.iloc[i][col]) else "null"
|
||||
merged_row[col] = val
|
||||
merged_data.append(merged_row)
|
||||
|
||||
buggy_df = pd.DataFrame(merged_data)
|
||||
bug_description = "表格中每两行被错误地合并成了一行,值之间用逗号分隔。"
|
||||
|
||||
elif bug_type == "rotated_data":
|
||||
# 旋转数据
|
||||
if random.choice([True, False]): # 随机选择旋转行还是列
|
||||
# 旋转行: 将最后一行移到第一行
|
||||
rows = buggy_df.values.tolist()
|
||||
if len(rows) > 1:
|
||||
rows = [rows[-1]] + rows[:-1]
|
||||
buggy_df = pd.DataFrame(rows, columns=buggy_df.columns)
|
||||
bug_description = "表格的行被旋转了,最后一行被移到了第一行。"
|
||||
else:
|
||||
# 旋转列: 将最后一列移到第一列
|
||||
cols = list(buggy_df.columns)
|
||||
if len(cols) > 1:
|
||||
new_order = [cols[-1]] + cols[:-1]
|
||||
buggy_df = buggy_df[new_order]
|
||||
bug_description = "表格的列被旋转了,最后一列被移到了第一列。"
|
||||
|
||||
elif bug_type == "replaced_values":
|
||||
# 随机替换一些值为"ERROR"
|
||||
n_rows, n_cols = buggy_df.shape
|
||||
replacements = min(random.randint(1, n_rows), 5) # 最多替换5个值
|
||||
|
||||
for _ in range(replacements):
|
||||
row_idx = random.randint(0, n_rows - 1)
|
||||
col_idx = random.randint(0, n_cols - 1)
|
||||
col_name = buggy_df.columns[col_idx]
|
||||
buggy_df[col_name] = buggy_df[col_name].astype(object)
|
||||
buggy_df.loc[row_idx, col_name] = "ERROR"
|
||||
|
||||
bug_description = f"表格中有{replacements}个值被错误地替换为'ERROR'。"
|
||||
|
||||
return buggy_df, bug_description
|
||||
|
||||
def _generate_condition(self, df: pd.DataFrame) -> Tuple[str, Dict]:
|
||||
"""生成查询条件"""
|
||||
numeric_cols = [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])]
|
||||
if not numeric_cols:
|
||||
return "", {}
|
||||
|
||||
selected_col = random.choice(numeric_cols)
|
||||
condition_type = random.choice(["gt", "lt", "eq", "between"])
|
||||
|
||||
values = df[selected_col].dropna().values
|
||||
if len(values) == 0:
|
||||
return "", {}
|
||||
|
||||
min_val, max_val = int(np.min(values)), int(np.max(values))
|
||||
|
||||
condition_str = ""
|
||||
condition_dict = {}
|
||||
|
||||
if condition_type == "gt":
|
||||
threshold = random.randint(min_val, max_val)
|
||||
condition_str = f"{selected_col} > {threshold}"
|
||||
condition_dict = {"column": selected_col, "operator": ">", "value": threshold}
|
||||
elif condition_type == "lt":
|
||||
threshold = random.randint(min_val, max_val)
|
||||
condition_str = f"{selected_col} < {threshold}"
|
||||
condition_dict = {"column": selected_col, "operator": "<", "value": threshold}
|
||||
elif condition_type == "eq":
|
||||
if len(values) > 0:
|
||||
value = int(random.choice(values))
|
||||
condition_str = f"{selected_col} == {value}"
|
||||
condition_dict = {"column": selected_col, "operator": "==", "value": value}
|
||||
elif condition_type == "between":
|
||||
val1 = random.randint(min_val, max_val)
|
||||
val2 = random.randint(min_val, max_val)
|
||||
low, high = min(val1, val2), max(val1, val2)
|
||||
condition_str = f"{low} <= {selected_col} <= {high}"
|
||||
condition_dict = {"column": selected_col, "operator": "between", "low": low, "high": high}
|
||||
|
||||
return condition_str, condition_dict
|
||||
|
||||
def _generate_query(self, df: pd.DataFrame) -> Tuple[str, Dict]:
|
||||
"""生成查询信息"""
|
||||
numeric_cols = [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])]
|
||||
if not numeric_cols:
|
||||
return "", {}
|
||||
|
||||
query_type = random.choice(self.query_types)
|
||||
selected_col = random.choice(numeric_cols)
|
||||
|
||||
condition_str, condition_dict = self._generate_condition(df)
|
||||
|
||||
query_str = f"{query_type}({selected_col})"
|
||||
if condition_str:
|
||||
query_str += f" where {condition_str}"
|
||||
|
||||
query_dict = {
|
||||
"type": query_type,
|
||||
"column": selected_col
|
||||
}
|
||||
|
||||
if condition_dict:
|
||||
query_dict["condition"] = condition_dict
|
||||
|
||||
return query_str, query_dict
|
||||
|
||||
def generate_example(self) -> Dict:
|
||||
"""生成一个新的示例"""
|
||||
# 生成原始表格
|
||||
n_rows = random.randint(5, 15)
|
||||
n_cols = random.randint(3, 6)
|
||||
clean_table = self._generate_clean_table(n_rows, n_cols)
|
||||
|
||||
# 选择一个bug类型
|
||||
bug_type = random.choice(self.bug_types)
|
||||
|
||||
# 应用bug
|
||||
buggy_table, bug_description = self._apply_bug(clean_table, bug_type)
|
||||
|
||||
# 生成查询
|
||||
query_str, query_dict = self._generate_query(clean_table)
|
||||
|
||||
# 创建示例
|
||||
example = {
|
||||
"input": {
|
||||
"table": buggy_table.fillna("null").to_dict(orient="records"),
|
||||
"bug_description": bug_description,
|
||||
"query": query_str
|
||||
},
|
||||
"clean_table": clean_table.fillna("null").to_dict(orient="records"),
|
||||
"query_info": query_dict
|
||||
}
|
||||
|
||||
return example
|
||||
|
||||
def generate_dataset(self, n_examples: int) -> List[Dict]:
|
||||
"""生成指定数量的示例"""
|
||||
dataset = []
|
||||
for _ in range(n_examples):
|
||||
example = self.generate_example()
|
||||
dataset.append(example)
|
||||
|
||||
return dataset
|
||||
|
||||
def save_dataset(self, dataset: List[Dict], file_path: str) -> None:
|
||||
"""保存数据集到文件"""
|
||||
try:
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({"examples": dataset}, f, ensure_ascii=False, indent=2)
|
||||
print(f"数据集已保存至 {file_path}")
|
||||
except Exception as e:
|
||||
print(f"保存数据集失败: {e}")
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
generator = BBEHBuggyTablesGenerator("task.json")
|
||||
new_dataset = generator.generate_dataset(5) # 生成5个新示例
|
||||
generator.save_dataset(new_dataset, "generated_dataset.json")
|
||||
print("生成了5个新的BBEH Buggy Tables示例")
|
||||
|
||||
|
|
@ -0,0 +1,369 @@
|
|||
import json
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import re
|
||||
pd.set_option('future.no_silent_downcasting', True)
|
||||
from typing import Dict, List, Any, Union, Tuple
|
||||
|
||||
|
||||
class BBEHBuggyTablesSolver:
|
||||
def __init__(self):
|
||||
"""初始化求解器"""
|
||||
pass
|
||||
|
||||
def _fix_missing_null_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||||
"""修复缺失的null值问题,使用统计方法推断和恢复缺失值"""
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
# 对每一列进行处理
|
||||
for col in df.columns:
|
||||
# 检查是否为数值列
|
||||
numeric_col = pd.to_numeric(df[col], errors='coerce')
|
||||
if numeric_col.notna().sum() > 0: # 如果列包含数值
|
||||
# 计算列的统计特征
|
||||
mean_val = numeric_col.mean()
|
||||
std_val = numeric_col.std()
|
||||
|
||||
# 检测异常值的阈值
|
||||
lower_bound = mean_val - 2 * std_val
|
||||
upper_bound = mean_val + 2 * std_val
|
||||
|
||||
# 找出可能缺失null值的位置(数据异常断点)
|
||||
values = numeric_col.dropna().values
|
||||
gaps = np.diff(sorted(values))
|
||||
median_gap = np.median(gaps)
|
||||
|
||||
# 如果发现异常大的间隔,在这些位置插入null
|
||||
large_gaps = np.where(gaps > 3 * median_gap)[0]
|
||||
if len(large_gaps) > 0:
|
||||
df[col] = df[col].astype(object)
|
||||
for gap_idx in large_gaps:
|
||||
# 在异常间隔处插入null值
|
||||
df.loc[len(df)] = {col: np.nan for col in df.columns}
|
||||
|
||||
return df.sort_index()
|
||||
|
||||
def _fix_appended_random_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||||
"""修复添加的随机值"""
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
if "每行末尾添加了一个随机值列" in bug_description:
|
||||
# 随机值被添加为新列,识别并删除它
|
||||
match = re.search(r"'(.+?)'", bug_description)
|
||||
if match:
|
||||
random_col = match.group(1)
|
||||
if random_col in df.columns:
|
||||
df = df.drop(columns=[random_col])
|
||||
else:
|
||||
# 如果没有找到列名,尝试删除最后一列
|
||||
df = df.iloc[:, :-1]
|
||||
|
||||
elif "表格末尾添加了一行随机值" in bug_description:
|
||||
# 随机值被添加为新行,删除最后一行
|
||||
df = df.iloc[:-1]
|
||||
|
||||
return df
|
||||
|
||||
def _fix_merged_rows(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||||
"""修复合并的行"""
|
||||
df = pd.DataFrame(table)
|
||||
unmerged_data = []
|
||||
|
||||
for _, row in df.iterrows():
|
||||
for col in df.columns:
|
||||
value = row[col]
|
||||
# 检查是否是字符串并包含逗号
|
||||
if isinstance(value, str) and "," in value:
|
||||
# 这一行是合并的,需要拆分
|
||||
split_values = {}
|
||||
for c in df.columns:
|
||||
vals = str(row[c]).split(",")
|
||||
split_values[c] = [v if v != "null" else None for v in vals]
|
||||
|
||||
# 确定每行有多少个值
|
||||
max_vals = max(len(split_values[c]) for c in df.columns)
|
||||
|
||||
# 创建拆分后的行
|
||||
for i in range(max_vals):
|
||||
new_row = {}
|
||||
for c in df.columns:
|
||||
vals = split_values[c]
|
||||
if i < len(vals):
|
||||
new_row[c] = vals[i]
|
||||
else:
|
||||
new_row[c] = None
|
||||
unmerged_data.append(new_row)
|
||||
|
||||
break
|
||||
else:
|
||||
# 这行没有合并值,直接添加
|
||||
unmerged_data.append(row.to_dict())
|
||||
|
||||
return pd.DataFrame(unmerged_data)
|
||||
|
||||
def _fix_rotated_data(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||||
"""修复旋转的数据"""
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
if "行被旋转" in bug_description and "最后一行被移到了第一行" in bug_description:
|
||||
# 将第一行移到最后一行
|
||||
rows = df.values.tolist()
|
||||
if len(rows) > 1:
|
||||
rows = rows[1:] + [rows[0]]
|
||||
df = pd.DataFrame(rows, columns=df.columns)
|
||||
|
||||
elif "列被旋转" in bug_description and "最后一列被移到了第一列" in bug_description:
|
||||
# 将第一列移到最后一列
|
||||
cols = list(df.columns)
|
||||
if len(cols) > 1:
|
||||
new_order = cols[1:] + [cols[0]]
|
||||
df = df[new_order]
|
||||
|
||||
return df
|
||||
|
||||
def _fix_replaced_values(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
for col in df.columns:
|
||||
df[col] = df[col].astype(object)
|
||||
|
||||
error_mask = df[col] == "ERROR"
|
||||
if error_mask.any():
|
||||
numeric_values = pd.to_numeric(df[~error_mask][col], errors='coerce')
|
||||
|
||||
if numeric_values.notna().sum() > 0:
|
||||
mean_val = numeric_values.mean()
|
||||
median_val = numeric_values.median()
|
||||
|
||||
for idx in df[error_mask].index:
|
||||
prev_vals = df.loc[:idx - 1, col]
|
||||
next_vals = df.loc[idx + 1:, col]
|
||||
|
||||
# 使用 concat 替代 append
|
||||
nearby_values = pd.to_numeric(
|
||||
pd.concat([prev_vals.tail(2), next_vals.head(2)]),
|
||||
errors='coerce'
|
||||
).dropna()
|
||||
|
||||
if len(nearby_values) > 0:
|
||||
df.loc[idx, col] = nearby_values.mean()
|
||||
else:
|
||||
df.loc[idx, col] = median_val
|
||||
else:
|
||||
df[col] = df[col].replace("ERROR", np.nan)
|
||||
|
||||
return df
|
||||
|
||||
def fix_table(self, table: List[Dict], bug_description: str) -> pd.DataFrame:
|
||||
"""根据bug描述修复表格"""
|
||||
if "空值(null)被错误地移除" in bug_description:
|
||||
return self._fix_missing_null_values(table, bug_description)
|
||||
elif "添加了一个随机值列" in bug_description or "添加了一行随机值" in bug_description:
|
||||
return self._fix_appended_random_values(table, bug_description)
|
||||
elif "每两行被错误地合并" in bug_description:
|
||||
return self._fix_merged_rows(table, bug_description)
|
||||
elif "行被旋转" in bug_description or "列被旋转" in bug_description:
|
||||
return self._fix_rotated_data(table, bug_description)
|
||||
elif "值被错误地替换为'ERROR'" in bug_description:
|
||||
return self._fix_replaced_values(table, bug_description)
|
||||
else:
|
||||
# 未识别的bug类型,返回原始表格
|
||||
return pd.DataFrame(table)
|
||||
|
||||
def _convert_values(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""改进的值转换方法,更好地处理混合类型数据"""
|
||||
for col in df.columns:
|
||||
# 保存原始数据类型
|
||||
original_dtype = df[col].dtype
|
||||
|
||||
# 尝试转换为数值类型
|
||||
try:
|
||||
numeric_values = pd.to_numeric(df[col], errors='coerce')
|
||||
non_numeric = df[col][numeric_values.isna() & df[col].notna()]
|
||||
|
||||
# 检查是否主要是数值类型
|
||||
if numeric_values.notna().sum() / len(df) > 0.5:
|
||||
df[col] = numeric_values
|
||||
|
||||
# 特殊处理非数值项
|
||||
if len(non_numeric) > 0:
|
||||
for idx in non_numeric.index:
|
||||
if df.loc[idx, col] == 'null':
|
||||
df.loc[idx, col] = np.nan
|
||||
elif isinstance(df.loc[idx, col], str):
|
||||
try:
|
||||
# 尝试清理并转换字符串
|
||||
cleaned_val = df.loc[idx, col].strip().replace(',', '')
|
||||
df.loc[idx, col] = float(cleaned_val)
|
||||
except:
|
||||
df.loc[idx, col] = np.nan
|
||||
except:
|
||||
# 如果转换失败,保持原始类型
|
||||
df[col] = df[col].astype(original_dtype)
|
||||
|
||||
# 统一处理null值
|
||||
df[col] = df[col].replace(['null', 'NULL', 'None', ''], np.nan)
|
||||
|
||||
return df
|
||||
|
||||
def apply_condition(self, df: pd.DataFrame, condition: Dict) -> pd.DataFrame:
|
||||
"""应用查询条件"""
|
||||
if not condition:
|
||||
return df
|
||||
|
||||
column = condition.get('column')
|
||||
operator = condition.get('operator')
|
||||
value = condition.get('value')
|
||||
|
||||
if not column or not operator or column not in df.columns:
|
||||
return df
|
||||
|
||||
if operator == ">":
|
||||
return df[df[column] > value]
|
||||
elif operator == "<":
|
||||
return df[df[column] < value]
|
||||
elif operator == "==":
|
||||
return df[df[column] == value]
|
||||
elif operator == "between":
|
||||
low = condition.get('low')
|
||||
high = condition.get('high')
|
||||
if low is not None and high is not None:
|
||||
return df[(df[column] >= low) & (df[column] <= high)]
|
||||
|
||||
return df
|
||||
|
||||
def execute_query(self, df: pd.DataFrame, query_info: Dict) -> float:
|
||||
query_type = query_info.get('type')
|
||||
column = query_info.get('column')
|
||||
condition = query_info.get('condition', {})
|
||||
|
||||
if not query_type or not column or column not in df.columns:
|
||||
return np.nan
|
||||
|
||||
try:
|
||||
# 创建副本避免 SettingWithCopyWarning
|
||||
filtered_df = self.apply_condition(df, condition).copy()
|
||||
filtered_df.loc[:, column] = pd.to_numeric(filtered_df[column], errors='coerce')
|
||||
|
||||
values = filtered_df[column].dropna()
|
||||
|
||||
if len(values) == 0:
|
||||
return np.nan
|
||||
|
||||
# 执行查询
|
||||
if query_type == "count":
|
||||
return len(values)
|
||||
elif query_type == "sum":
|
||||
return float(values.sum())
|
||||
elif query_type == "mean":
|
||||
return float(values.mean())
|
||||
elif query_type == "stdev":
|
||||
return float(values.std()) if len(values) > 1 else np.nan
|
||||
elif query_type == "median":
|
||||
return float(values.median())
|
||||
except Exception as e:
|
||||
print(f"Query execution error: {str(e)}")
|
||||
return np.nan
|
||||
|
||||
return np.nan
|
||||
|
||||
def parse_query_string(self, query_string: str) -> Dict:
|
||||
"""从查询字符串解析查询信息"""
|
||||
query_info = {}
|
||||
|
||||
# 匹配查询类型和列名
|
||||
query_match = re.match(r"(\w+)\((.+?)\)(\s+where\s+(.+))?", query_string)
|
||||
if not query_match:
|
||||
return query_info
|
||||
|
||||
query_type, column = query_match.group(1), query_match.group(2)
|
||||
query_info["type"] = query_type
|
||||
query_info["column"] = column
|
||||
|
||||
# 提取条件(如果有)
|
||||
if query_match.group(4):
|
||||
condition_str = query_match.group(4)
|
||||
|
||||
# 处理不同类型的条件
|
||||
between_match = re.match(r"(\d+)\s*<=\s*(.+?)\s*<=\s*(\d+)", condition_str)
|
||||
gt_match = re.match(r"(.+?)\s*>\s*(\d+)", condition_str)
|
||||
lt_match = re.match(r"(.+?)\s*<\s*(\d+)", condition_str)
|
||||
eq_match = re.match(r"(.+?)\s*==\s*(\d+)", condition_str)
|
||||
|
||||
if between_match:
|
||||
low, col, high = between_match.groups()
|
||||
query_info["condition"] = {
|
||||
"column": col,
|
||||
"operator": "between",
|
||||
"low": int(low),
|
||||
"high": int(high)
|
||||
}
|
||||
elif gt_match:
|
||||
col, val = gt_match.groups()
|
||||
query_info["condition"] = {
|
||||
"column": col,
|
||||
"operator": ">",
|
||||
"value": int(val)
|
||||
}
|
||||
elif lt_match:
|
||||
col, val = lt_match.groups()
|
||||
query_info["condition"] = {
|
||||
"column": col,
|
||||
"operator": "<",
|
||||
"value": int(val)
|
||||
}
|
||||
elif eq_match:
|
||||
col, val = eq_match.groups()
|
||||
query_info["condition"] = {
|
||||
"column": col,
|
||||
"operator": "==",
|
||||
"value": int(val)
|
||||
}
|
||||
|
||||
return query_info
|
||||
|
||||
def solve(self, example: Dict) -> float:
|
||||
"""解决一个示例问题"""
|
||||
# 提取输入数据
|
||||
table = example['input']['table']
|
||||
bug_description = example['input']['bug_description']
|
||||
query = example['input']['query']
|
||||
|
||||
# 修复表格
|
||||
fixed_df = self.fix_table(table, bug_description)
|
||||
|
||||
# 转换数据类型
|
||||
fixed_df = self._convert_values(fixed_df)
|
||||
|
||||
# 解析查询信息
|
||||
if 'query_info' in example:
|
||||
query_info = example['query_info']
|
||||
else:
|
||||
query_info = self.parse_query_string(query)
|
||||
|
||||
# 执行查询
|
||||
result = self.execute_query(fixed_df, query_info)
|
||||
|
||||
# 根据查询类型做适当的四舍五入
|
||||
if result is not None:
|
||||
if query_info.get('type') in ['mean', 'stdev']:
|
||||
result = round(result, 2)
|
||||
elif query_info.get('type') in ['sum', 'median']:
|
||||
result = round(result, 1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 从文件加载测试数据
|
||||
with open("generated_dataset.json", 'r', encoding='utf-8') as f:
|
||||
test_data = json.load(f)
|
||||
|
||||
solver = BBEHBuggyTablesSolver()
|
||||
|
||||
# 解决每个示例
|
||||
for i, example in enumerate(test_data['examples']):
|
||||
result = solver.solve(example)
|
||||
print(f"示例 {i + 1} 的计算结果: {result}")
|
||||
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
import json
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Union
|
||||
from internbootcamp.libs.bbeh_buggy_tables.bbeh_buggy_tables_solver import BBEHBuggyTablesSolver
|
||||
|
||||
|
||||
class BBEHBuggyTablesValidator:
|
||||
def __init__(self):
|
||||
"""初始化验证器"""
|
||||
self.solver = BBEHBuggyTablesSolver()
|
||||
|
||||
def _get_expected_result(self, example: Dict) -> float:
|
||||
"""计算期望的结果"""
|
||||
# 如果示例中包含干净的表格和查询信息,使用这些来计算参考答案
|
||||
if 'clean_table' in example and 'query_info' in example:
|
||||
clean_df = pd.DataFrame(example['clean_table'])
|
||||
|
||||
# 转换数据类型
|
||||
for col in clean_df.columns:
|
||||
clean_df[col] = clean_df[col].replace('null', np.nan)
|
||||
try:
|
||||
clean_df[col] = pd.to_numeric(clean_df[col], errors='coerce')
|
||||
except:
|
||||
pass
|
||||
|
||||
# 使用求解器执行查询
|
||||
expected_result = self.solver.execute_query(clean_df, example['query_info'])
|
||||
|
||||
# 根据查询类型做适当的四舍五入
|
||||
if expected_result is not None:
|
||||
if example['query_info'].get('type') in ['mean', 'stdev']:
|
||||
expected_result = round(expected_result, 2)
|
||||
elif example['query_info'].get('type') in ['sum', 'median']:
|
||||
expected_result = round(expected_result, 1)
|
||||
|
||||
return expected_result
|
||||
|
||||
return None
|
||||
|
||||
def validate_example(self, example: Dict) -> Dict:
|
||||
"""验证单个示例"""
|
||||
# 使用求解器计算结果
|
||||
actual_result = self.solver.solve(example)
|
||||
|
||||
# 获取期望的结果
|
||||
expected_result = self._get_expected_result(example)
|
||||
|
||||
# 检查结果是否匹配
|
||||
is_correct = abs(
|
||||
actual_result - expected_result) < 1e-6 if actual_result is not None and expected_result is not None else False
|
||||
|
||||
# 创建验证报告
|
||||
validation_report = {
|
||||
"input": example['input'],
|
||||
"expected_result": expected_result,
|
||||
"actual_result": actual_result,
|
||||
"is_correct": is_correct,
|
||||
"error_details": None
|
||||
}
|
||||
|
||||
# 如果结果不匹配,添加错误详情
|
||||
if not is_correct:
|
||||
validation_report["error_details"] = {
|
||||
"difference": abs(
|
||||
actual_result - expected_result) if actual_result is not None and expected_result is not None else None,
|
||||
"error_type": self._get_error_type(actual_result, expected_result)
|
||||
}
|
||||
|
||||
return validation_report
|
||||
|
||||
def _get_error_type(self, actual: float, expected: float) -> str:
|
||||
"""确定错误类型"""
|
||||
if actual is None:
|
||||
return "求解器未返回结果"
|
||||
if expected is None:
|
||||
return "无法确定期望结果"
|
||||
|
||||
diff = actual - expected
|
||||
if abs(diff) < 1e-6:
|
||||
return "结果正确"
|
||||
elif diff > 0:
|
||||
return "计算结果过大"
|
||||
else:
|
||||
return "计算结果过小"
|
||||
|
||||
def validate_dataset(self, dataset: Dict) -> Dict:
|
||||
"""验证整个数据集"""
|
||||
examples = dataset.get('examples', [])
|
||||
validation_results = []
|
||||
error_statistics = {
|
||||
"求解器未返回结果": 0,
|
||||
"无法确定期望结果": 0,
|
||||
"计算结果过大": 0,
|
||||
"计算结果过小": 0
|
||||
}
|
||||
|
||||
for example in examples:
|
||||
result = self.validate_example(example)
|
||||
validation_results.append(result)
|
||||
|
||||
# 统计错误类型
|
||||
if not result['is_correct'] and result['error_details']:
|
||||
error_type = result['error_details']['error_type']
|
||||
error_statistics[error_type] = error_statistics.get(error_type, 0) + 1
|
||||
|
||||
# 计算统计信息
|
||||
total = len(validation_results)
|
||||
correct = sum(1 for r in validation_results if r['is_correct'])
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
|
||||
# 创建验证摘要
|
||||
validation_summary = {
|
||||
"total_examples": total,
|
||||
"correct_examples": correct,
|
||||
"accuracy": accuracy,
|
||||
"error_statistics": error_statistics,
|
||||
"detailed_results": validation_results
|
||||
}
|
||||
|
||||
return validation_summary
|
||||
|
||||
def save_validation_report(self, validation_summary: Dict, file_path: str) -> None:
|
||||
"""保存验证报告到文件"""
|
||||
try:
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(validation_summary, f, ensure_ascii=False, indent=2)
|
||||
print(f"验证报告已保存至 {file_path}")
|
||||
except Exception as e:
|
||||
print(f"保存验证报告失败: {e}")
|
||||
|
||||
def print_validation_summary(self, validation_summary: Dict) -> None:
|
||||
"""打印验证摘要"""
|
||||
print("\n=== 验证摘要 ===")
|
||||
print(f"总示例数: {validation_summary['total_examples']}")
|
||||
print(f"正确示例数: {validation_summary['correct_examples']}")
|
||||
print(f"准确率: {validation_summary['accuracy']:.2%}")
|
||||
|
||||
print("\n错误统计:")
|
||||
for error_type, count in validation_summary['error_statistics'].items():
|
||||
if count > 0:
|
||||
print(f"- {error_type}: {count}次")
|
||||
|
||||
def analyze_errors(self, validation_summary: Dict) -> Dict:
|
||||
"""分析错误模式"""
|
||||
error_patterns = {}
|
||||
|
||||
for result in validation_summary['detailed_results']:
|
||||
if not result['is_correct']:
|
||||
# 获取bug类型
|
||||
bug_description = result['input']['bug_description']
|
||||
bug_type = self._extract_bug_type(bug_description)
|
||||
|
||||
# 统计每种bug类型的错误
|
||||
if bug_type not in error_patterns:
|
||||
error_patterns[bug_type] = {
|
||||
"count": 0,
|
||||
"examples": []
|
||||
}
|
||||
|
||||
error_patterns[bug_type]["count"] += 1
|
||||
error_patterns[bug_type]["examples"].append({
|
||||
"expected": result['expected_result'],
|
||||
"actual": result['actual_result'],
|
||||
"error_details": result['error_details']
|
||||
})
|
||||
|
||||
return error_patterns
|
||||
|
||||
def _extract_bug_type(self, bug_description: str) -> str:
|
||||
"""从bug描述中提取bug类型"""
|
||||
if "空值(null)被错误地移除" in bug_description:
|
||||
return "missing_null_values"
|
||||
elif "添加了一个随机值列" in bug_description or "添加了一行随机值" in bug_description:
|
||||
return "appended_random_values"
|
||||
elif "每两行被错误地合并" in bug_description:
|
||||
return "merged_rows"
|
||||
elif "行被旋转" in bug_description or "列被旋转" in bug_description:
|
||||
return "rotated_data"
|
||||
elif "值被错误地替换为'ERROR'" in bug_description:
|
||||
return "replaced_values"
|
||||
else:
|
||||
return "unknown_bug_type"
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 从文件加载测试数据
|
||||
with open("generated_dataset.json", 'r', encoding='utf-8') as f:
|
||||
test_data = json.load(f)
|
||||
|
||||
# 创建验证器实例
|
||||
validator = BBEHBuggyTablesValidator()
|
||||
|
||||
# 验证数据集
|
||||
validation_summary = validator.validate_dataset(test_data)
|
||||
|
||||
# 打印验证摘要
|
||||
validator.print_validation_summary(validation_summary)
|
||||
|
||||
# 分析错误模式
|
||||
error_patterns = validator.analyze_errors(validation_summary)
|
||||
|
||||
# 保存验证报告
|
||||
validator.save_validation_report(validation_summary, "validation_report.json")
|
||||
|
||||
# 打印错误模式分析
|
||||
print("\n=== 错误模式分析 ===")
|
||||
for bug_type, pattern in error_patterns.items():
|
||||
print(f"\n{bug_type}:")
|
||||
print(f"错误次数: {pattern['count']}")
|
||||
if pattern['examples']:
|
||||
print("示例错误:")
|
||||
for i, example in enumerate(pattern['examples'][:3], 1): # 只显示前3个示例
|
||||
print(f" {i}. 期望值: {example['expected']}, 实际值: {example['actual']}")
|
||||
print(f" 错误类型: {example['error_details']['error_type']}")
|
||||
1
internbootcamp/libs/bbeh_buggy_tables/task.json
Normal file
1
internbootcamp/libs/bbeh_buggy_tables/task.json
Normal file
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue