mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-30 17:40:42 +00:00
feat(bootcamp): Add BBEHHyperbaton (verified)
This commit is contained in:
parent
2edd36bc97
commit
8678ab0e3b
9 changed files with 823 additions and 1 deletions
307
internbootcamp/libs/bbeh_hyperbaton/bbeh_hyperbaton_generator.py
Normal file
307
internbootcamp/libs/bbeh_hyperbaton/bbeh_hyperbaton_generator.py
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
import json
|
||||
import random
|
||||
from typing import List, Dict, Any, Tuple, Set
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
class HyperbatonGenerator:
|
||||
def __init__(self):
|
||||
# 形容词的种类
|
||||
self.adjective_types = [
|
||||
"age", # 年龄
|
||||
"quality", # 品质
|
||||
"size", # 大小
|
||||
"shape", # 形状
|
||||
"color", # 颜色
|
||||
"material", # 材料
|
||||
"nationality", # 国籍
|
||||
"activity" # 活动
|
||||
]
|
||||
# 定义标准顺序 (和 adjective_types 保持一致)
|
||||
self.standard_order = self.adjective_types.copy()
|
||||
# 每种类型的可能形容词
|
||||
self.adjectives = {
|
||||
"color": ["red", "blue", "green", "yellow", "black", "white", "cyan", "magenta",
|
||||
"violet", "brown", "gray", "pink", "crimson", "indigo", "maroon",
|
||||
"teal", "aqua"],
|
||||
"size": ["big", "small", "tiny", "enormous", "huge", "massive", "large", "little",
|
||||
"medium-size", "midsize", "normal-size"],
|
||||
"shape": ["square", "circular", "triangular", "rectangular", "spherical", "pyramidal",
|
||||
"prismlike"],
|
||||
"age": ["old", "new", "ancient", "old-fashioned", "brand-new", "archaic"],
|
||||
"material": ["wood", "plastic", "steel", "iron", "glass", "paper", "cloth",
|
||||
"cardboard", "rubber", "leather", "wool", "fiberglass"],
|
||||
"nationality": ["chinese", "american", "russian", "japanese", "vietnamese", "iranian",
|
||||
"turkish", "mexican", "brazilian", "german", "filipino", "ethiopian",
|
||||
"indian", "egyptian", "nigerian", "thai", "indonesian", "congolese",
|
||||
"bangladeshi", "pakistani"],
|
||||
"activity": ["smoking", "hiking", "driving", "walking", "eating", "drinking",
|
||||
"typing", "whittling", "snorkeling", "exercise"],
|
||||
"quality": ["good", "terrible", "lovely", "awful", "nice", "wonderful", "repulsive",
|
||||
"obnoxious", "mysterious", "ridiculous", "silly"]
|
||||
}
|
||||
|
||||
# 名词列表
|
||||
self.nouns = [
|
||||
"ball", "box", "bag", "chair", "table", "car", "book", "pen", "pencil", "bottle",
|
||||
"cup", "knife", "fork", "spoon", "plate", "flower pot", "hat", "shirt", "shoe",
|
||||
"guitar", "piano", "clock", "watch", "camera", "phone", "computer", "sofa", "bed",
|
||||
"lamp", "umbrella", "car key", "wallet", "necklace", "ring", "earring", "brush",
|
||||
"scissors", "candle", "vase", "banana", "apple", "dog", "bird", "hammer", "drill",
|
||||
"screwdriver", "plier", "stapler", "ruler", "calculator", "saw", "trash can",
|
||||
"fire extinguisher", "wrench", "bicycle", "speaker", "marker", "toolbox", "jar",
|
||||
"bowl", "sunglasses", "canvas", "key", "house", "piano", "pencil", "tool",
|
||||
"sunglasses", "knife", "spoon", "fork"
|
||||
]
|
||||
# 最后再进行随机化顺序的处理
|
||||
self.randomized_order = random.sample(self.adjective_types, len(self.adjective_types))
|
||||
self.partial_order_pairs = self.generate_partial_order_pairs()
|
||||
|
||||
# 设置随机种子
|
||||
random.seed(random.randint(1, 10000))
|
||||
|
||||
def generate_partial_order_pairs(self) -> Set[Tuple[str, str]]:
|
||||
"""生成部分排序规则对"""
|
||||
pairs = set()
|
||||
# 确保至少有N对直接的顺序关系
|
||||
min_pairs = len(self.adjective_types) + 2
|
||||
|
||||
while len(pairs) < min_pairs:
|
||||
idx1, idx2 = random.sample(range(len(self.randomized_order)), 2)
|
||||
if idx1 > idx2:
|
||||
idx1, idx2 = idx2, idx1
|
||||
pairs.add((self.randomized_order[idx1], self.randomized_order[idx2]))
|
||||
|
||||
return pairs
|
||||
|
||||
def generate_example_sentence(self, noun: str, num_adjectives: int, use_partial: bool = False) -> str:
|
||||
"""生成示例句子,增强随机性"""
|
||||
try:
|
||||
if not hasattr(self, 'standard_order'):
|
||||
raise AttributeError("standard_order attribute not initialized")
|
||||
|
||||
if use_partial:
|
||||
# 随机选择部分顺序
|
||||
available_types = random.sample(self.standard_order, min(num_adjectives, len(self.standard_order)))
|
||||
else:
|
||||
# 完整顺序,但随机截取前N个
|
||||
end_idx = random.randint(num_adjectives, len(self.standard_order))
|
||||
available_types = self.standard_order[:end_idx]
|
||||
|
||||
# 为每个选中的形容词类型随机选择一个形容词
|
||||
adjectives = []
|
||||
for adj_type in available_types:
|
||||
if self.adjectives[adj_type]:
|
||||
adj = random.choice(self.adjectives[adj_type])
|
||||
adjectives.append(adj)
|
||||
|
||||
# 随机截取指定数量的形容词
|
||||
adjectives = adjectives[:num_adjectives]
|
||||
|
||||
return " ".join(adjectives) + " " + noun
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating example sentence: {str(e)}")
|
||||
raise
|
||||
|
||||
def generate_example_sentences(self, num_sentences: int = 100) -> List[str]:
|
||||
"""生成示例句子,包括完整顺序和部分顺序的例子"""
|
||||
sentences = []
|
||||
# 确保生成足够的示例来展示完整的排序规则
|
||||
min_full_order = 30 # 至少30个完整顺序的例子
|
||||
min_partial_order = 50 # 至少50个部分顺序的例子
|
||||
|
||||
# 生成完整顺序的例子
|
||||
for i in range(min_full_order):
|
||||
noun = random.choice(self.nouns)
|
||||
num_adjectives = random.randint(3, 5) # 使用更多形容词
|
||||
sentence = self.generate_example_sentence(noun, num_adjectives, use_partial=False)
|
||||
sentences.append(f"({len(sentences) + 1}) {sentence}")
|
||||
|
||||
# 生成部分顺序的例子
|
||||
for i in range(min_partial_order):
|
||||
noun = random.choice(self.nouns)
|
||||
num_adjectives = random.randint(2, 4)
|
||||
sentence = self.generate_example_sentence(noun, num_adjectives, use_partial=True)
|
||||
sentences.append(f"({len(sentences) + 1}) {sentence}")
|
||||
|
||||
# 生成剩余的随机例子
|
||||
remaining = num_sentences - len(sentences)
|
||||
for i in range(remaining):
|
||||
noun = random.choice(self.nouns)
|
||||
num_adjectives = random.randint(1, 3)
|
||||
use_partial = random.choice([True, False])
|
||||
sentence = self.generate_example_sentence(noun, num_adjectives, use_partial)
|
||||
sentences.append(f"({len(sentences) + 1}) {sentence}")
|
||||
|
||||
random.shuffle(sentences)
|
||||
return sentences
|
||||
|
||||
def generate_test_sentence(self, correct: bool = True) -> str:
|
||||
"""生成测试句子,增强随机性"""
|
||||
noun = random.choice(self.nouns)
|
||||
num_adjectives = random.randint(3, 7)
|
||||
|
||||
if correct:
|
||||
# 正确句子保持标准顺序
|
||||
selected_types = self.standard_order[:num_adjectives]
|
||||
else:
|
||||
# 错误句子生成非标准顺序
|
||||
selected_types = random.sample(self.standard_order, num_adjectives)
|
||||
# 确保顺序与标准顺序不同
|
||||
while selected_types == self.standard_order[:num_adjectives]:
|
||||
random.shuffle(selected_types)
|
||||
|
||||
# 为每个选中的形容词类型随机选择一个形容词
|
||||
adjectives = []
|
||||
for adj_type in selected_types:
|
||||
if self.adjectives[adj_type]:
|
||||
adj = random.choice(self.adjectives[adj_type])
|
||||
adjectives.append(adj)
|
||||
|
||||
return " ".join(adjectives) + " " + noun
|
||||
|
||||
def generate_options(self, num_options: int = 10, num_correct: int = None) -> Tuple[List[str], List[bool]]:
|
||||
if num_correct is None:
|
||||
num_correct = random.randint(3, 6) # 修改正确选项的数量范围
|
||||
|
||||
options = []
|
||||
is_correct = []
|
||||
|
||||
# 生成正确选项
|
||||
correct_sentences = set()
|
||||
while len(correct_sentences) < num_correct:
|
||||
sentence = self.generate_test_sentence(correct=True)
|
||||
# 确保不会产生导致CEJ组合的选项
|
||||
if len(correct_sentences) >= 2:
|
||||
current_correct = ''.join(sorted([chr(65 + i) for i in range(len(options))
|
||||
if is_correct[i]]))
|
||||
if "CEJ" in (current_correct + chr(65 + len(options))):
|
||||
continue
|
||||
correct_sentences.add(sentence)
|
||||
|
||||
options.extend(correct_sentences)
|
||||
is_correct.extend([True] * num_correct)
|
||||
|
||||
# 生成错误选项
|
||||
incorrect_sentences = set()
|
||||
while len(incorrect_sentences) < (num_options - num_correct):
|
||||
sentence = self.generate_test_sentence(correct=False)
|
||||
incorrect_sentences.add(sentence)
|
||||
|
||||
options.extend(incorrect_sentences)
|
||||
is_correct.extend([False] * (num_options - num_correct))
|
||||
|
||||
# 打乱选项顺序
|
||||
combined = list(zip(options, is_correct))
|
||||
random.shuffle(combined)
|
||||
options, is_correct = zip(*combined)
|
||||
|
||||
return list(options), list(is_correct)
|
||||
|
||||
def format_options(self, options: List[str]) -> List[str]:
|
||||
"""将选项格式化为(A), (B)等格式"""
|
||||
formatted_options = []
|
||||
for i, option in enumerate(options):
|
||||
letter = chr(65 + i) # A, B, C, ...
|
||||
formatted_options.append(f"({letter}) {option}")
|
||||
return formatted_options
|
||||
|
||||
def get_answer_string(self, is_correct: List[bool]) -> str:
|
||||
"""将正确选项转换为答案字符串"""
|
||||
correct_letters = []
|
||||
for i, correct in enumerate(is_correct):
|
||||
if correct:
|
||||
letter = chr(65 + i)
|
||||
correct_letters.append(letter)
|
||||
|
||||
return "K" if not correct_letters else "".join(correct_letters)
|
||||
|
||||
def generate_task(self) -> Dict[str, Any]:
|
||||
while True:
|
||||
# 生成示例句子
|
||||
example_sentences = []
|
||||
while len(example_sentences) < random.randint(50, 180):
|
||||
sentence = self.generate_example_sentence(
|
||||
random.choice(self.nouns),
|
||||
random.randint(3, 5)
|
||||
)
|
||||
if sentence not in example_sentences:
|
||||
example_sentences.append(sentence)
|
||||
|
||||
# 生成测试选项
|
||||
options, is_correct = self.generate_options(10)
|
||||
formatted_options = self.format_options(options)
|
||||
|
||||
# 确保选项和示例句子之间没有重复
|
||||
all_sentences = example_sentences + [opt.split(") ")[1] for opt in formatted_options]
|
||||
|
||||
# 获取答案字符串
|
||||
answer = self.get_answer_string(is_correct)
|
||||
|
||||
# 检查是否包含CEJ组合
|
||||
if "CEJ" in answer:
|
||||
continue # 如果包含CEJ,重新生成
|
||||
|
||||
if len(all_sentences) == len(set(all_sentences)):
|
||||
# 构建任务输入
|
||||
task_input = (
|
||||
f"In a variant of English, we are given that the following sentences have correct adjective order:\n"
|
||||
f"{' '.join(example_sentences)}\n\n"
|
||||
f"In this variant of English, which of the following sentences (Options A-J) use the correct adjective order? "
|
||||
f"If none of the sentences (Options A-J) use the correct adjective order, select option K. Select all that apply.\n"
|
||||
f"{' '.join(formatted_options)}\n"
|
||||
f"(K) None of the above\n\n"
|
||||
f"Provide your final answer as a concatenation of all the correct choices. For example, if B and C have correct adjective order, "
|
||||
f"then your final answer must be \"BC\"."
|
||||
)
|
||||
|
||||
return {
|
||||
"input": task_input,
|
||||
"target": answer
|
||||
}
|
||||
|
||||
def validate_sentence(self, sentence: str) -> bool:
|
||||
"""验证生成的句子是否符合规则"""
|
||||
words = sentence.split()
|
||||
if len(words) < 2: # 至少要有一个形容词和一个名词
|
||||
return False
|
||||
|
||||
# 检查名词是否在允许的列表中
|
||||
if words[-1] not in self.nouns:
|
||||
return False
|
||||
|
||||
# 检查形容词是否都是已知的
|
||||
adjectives = words[:-1]
|
||||
for adj in adjectives:
|
||||
found = False
|
||||
for adj_list in self.adjectives.values():
|
||||
if adj in adj_list:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def generate_task_batch(self, batch_size: int = 1) -> List[Dict[str, Any]]:
|
||||
"""直接生成一批任务,不保存到文件"""
|
||||
tasks = []
|
||||
for _ in range(batch_size):
|
||||
task = self.generate_task()
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
generator = HyperbatonGenerator()
|
||||
# 测试生成单个句子
|
||||
test_sentence = generator.generate_example_sentence("ball", 3, False)
|
||||
print(f"Test sentence: {test_sentence}")
|
||||
|
||||
# 生成完整数据集
|
||||
generator.save_dataset("hyperbaton_dataset.json", 100)
|
||||
print("数据集已生成并保存到 hyperbaton_dataset.json")
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
|
||||
218
internbootcamp/libs/bbeh_hyperbaton/bbeh_hyperbaton_solver.py
Normal file
218
internbootcamp/libs/bbeh_hyperbaton/bbeh_hyperbaton_solver.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
import json
|
||||
from typing import List, Dict, Set, Tuple, Any
|
||||
import re
|
||||
|
||||
|
||||
class HyperbatonSolver:
|
||||
def __init__(self):
|
||||
# 标准形容词顺序
|
||||
self.adjective_types = [
|
||||
"age",
|
||||
"quality",
|
||||
"size",
|
||||
"shape",
|
||||
"color",
|
||||
"material",
|
||||
"nationality",
|
||||
"activity"
|
||||
]
|
||||
|
||||
# 详细的形容词词典,确保包含所有可能的形容词
|
||||
self.adjectives = {
|
||||
"age": {
|
||||
"old", "new", "ancient", "old-fashioned", "brand-new", "archaic"
|
||||
},
|
||||
"quality": {
|
||||
"good", "terrible", "lovely", "awful", "nice", "wonderful",
|
||||
"repulsive", "obnoxious", "mysterious", "ridiculous", "silly"
|
||||
},
|
||||
"size": {
|
||||
"big", "small", "tiny", "enormous", "huge", "massive", "large",
|
||||
"little", "medium-size", "midsize", "normal-size"
|
||||
},
|
||||
"shape": {
|
||||
"square", "circular", "triangular", "rectangular", "spherical",
|
||||
"pyramidal", "prismlike"
|
||||
},
|
||||
"color": {
|
||||
"red", "blue", "green", "yellow", "black", "white", "cyan",
|
||||
"magenta", "violet", "brown", "gray", "pink", "crimson",
|
||||
"indigo", "maroon", "teal", "aqua"
|
||||
},
|
||||
"material": {
|
||||
"wood", "plastic", "steel", "iron", "glass", "paper", "cloth",
|
||||
"cardboard", "rubber", "leather", "wool", "fiberglass"
|
||||
},
|
||||
"nationality": {
|
||||
"chinese", "american", "russian", "japanese", "vietnamese",
|
||||
"iranian", "turkish", "mexican", "brazilian", "german",
|
||||
"filipino", "ethiopian", "indian", "egyptian", "nigerian",
|
||||
"thai", "indonesian", "congolese", "bangladeshi", "pakistani"
|
||||
},
|
||||
"activity": {
|
||||
"smoking", "hiking", "driving", "walking", "eating", "drinking",
|
||||
"typing", "whittling", "snorkeling", "exercise"
|
||||
}
|
||||
}
|
||||
|
||||
# 创建反向查找字典
|
||||
self.word_to_type = {}
|
||||
for type_name, words in self.adjectives.items():
|
||||
for word in words:
|
||||
self.word_to_type[word] = type_name
|
||||
|
||||
def extract_examples(self, text: str) -> List[str]:
|
||||
"""提取示例句子"""
|
||||
# 找到示例部分的开始
|
||||
start = text.find("correct adjective order:\n") + len("correct adjective order:\n")
|
||||
end = text.find("\n\nIn this variant")
|
||||
examples_text = text[start:end].strip()
|
||||
return [s.strip() for s in examples_text.split('\n') if s.strip()]
|
||||
|
||||
def extract_options(self, text: str) -> List[Tuple[str, str]]:
|
||||
"""提取选项"""
|
||||
options = []
|
||||
# 使用正则表达式匹配选项
|
||||
pattern = r'\(([A-K])\) (.*?)(?=\([A-K]\)|$)'
|
||||
matches = re.finditer(pattern, text, re.DOTALL)
|
||||
for match in matches:
|
||||
label = match.group(1)
|
||||
sentence = match.group(2).strip()
|
||||
options.append((label, sentence))
|
||||
return options
|
||||
|
||||
def get_adjective_sequence(self, words: List[str]) -> List[str]:
|
||||
"""获取词序列中的形容词类型序列"""
|
||||
sequence = []
|
||||
for word in words[:-1]: # 排除最后的名词
|
||||
word_type = self.word_to_type.get(word.lower())
|
||||
if word_type:
|
||||
sequence.append(word_type)
|
||||
return sequence
|
||||
|
||||
def is_valid_adjective_order(self, words: List[str]) -> bool:
|
||||
"""检查形容词顺序是否正确"""
|
||||
sequence = self.get_adjective_sequence(words)
|
||||
if not sequence:
|
||||
return False
|
||||
|
||||
# 检查相邻形容词的顺序
|
||||
last_seen_index = -1
|
||||
for adj_type in sequence:
|
||||
try:
|
||||
current_index = self.adjective_types.index(adj_type)
|
||||
# 如果当前索引小于上一个索引,说明顺序错误
|
||||
if current_index <= last_seen_index:
|
||||
return False
|
||||
last_seen_index = current_index
|
||||
except ValueError:
|
||||
# 如果形容词类型不在预定义列表中
|
||||
return False
|
||||
return True
|
||||
|
||||
def parse_puzzle(self, puzzle: str) -> Tuple[str, List[Tuple[str, str]]]:
|
||||
"""解析谜题文本,返回示例文本和选项列表"""
|
||||
# 提取示例部分
|
||||
examples_text = self.extract_examples(puzzle)
|
||||
|
||||
# 提取选项
|
||||
options = self.extract_options(puzzle)
|
||||
|
||||
return examples_text, options
|
||||
|
||||
def learn_from_examples(self, examples: List[str]) -> List[List[str]]:
|
||||
"""从示例中学习正确的形容词序列"""
|
||||
sequences = []
|
||||
for example in examples:
|
||||
words = example.split()
|
||||
sequence = self.get_adjective_sequence(words)
|
||||
if sequence:
|
||||
sequences.append(sequence)
|
||||
return sequences
|
||||
|
||||
def get_word_type(self, word: str) -> str:
|
||||
"""获取单词的形容词类型"""
|
||||
# 转换为小写以确保匹配
|
||||
word = word.lower()
|
||||
|
||||
# 直接从word_to_type字典中获取类型
|
||||
return self.word_to_type.get(word)
|
||||
|
||||
def solve_puzzle(self, puzzle: str) -> str:
|
||||
examples_text, options = self.parse_puzzle(puzzle)
|
||||
|
||||
# 找出所有合法的选项
|
||||
valid_options = []
|
||||
for label, sentence in options:
|
||||
if label == 'K':
|
||||
continue
|
||||
words = sentence.split()
|
||||
if self.is_valid_adjective_order(words):
|
||||
valid_options.append(label)
|
||||
|
||||
# 如果没有合法选项,返回"K"
|
||||
if not valid_options:
|
||||
return "K"
|
||||
|
||||
# 检查是否会生成CEJ组合
|
||||
answer = ''.join(sorted(valid_options))
|
||||
if "CEJ" in answer:
|
||||
# 移除一个选项以避免CEJ组合
|
||||
if 'C' in valid_options and 'E' in valid_options and 'J' in valid_options:
|
||||
valid_options.remove('J') # 或者移除其他选项
|
||||
|
||||
return ''.join(sorted(valid_options))
|
||||
|
||||
def solve_batch(self, tasks: List[Dict[str, Any]]) -> List[str]:
|
||||
"""直接解决一批任务"""
|
||||
results = []
|
||||
for task in tasks:
|
||||
solved = self.solve_puzzle(task['input'])
|
||||
results.append(solved)
|
||||
return results
|
||||
|
||||
|
||||
def debug_adjective_order(solver: HyperbatonSolver, sentence: str):
|
||||
"""调试函数:显示句子中形容词的类型和顺序"""
|
||||
words = sentence.split()
|
||||
sequence = []
|
||||
for word in words[:-1]:
|
||||
word_type = solver.word_to_type.get(word.lower())
|
||||
sequence.append((word, word_type))
|
||||
return sequence
|
||||
|
||||
|
||||
def main():
|
||||
solver = HyperbatonSolver()
|
||||
results = solver.solve_dataset('hyperbaton_dataset.json')
|
||||
|
||||
print("求解结果:")
|
||||
correct = 0
|
||||
total = len(results)
|
||||
|
||||
# 详细输出每个谜题的结果
|
||||
for i, (solved, expected) in enumerate(results, 1):
|
||||
is_correct = solved == expected
|
||||
status = "✓" if is_correct else "✗"
|
||||
if is_correct:
|
||||
correct += 1
|
||||
print(f"谜题 {i}: 预期答案={expected}, 求解结果={solved} {status}")
|
||||
|
||||
# 如果答案错误,添加调试信息
|
||||
if not is_correct:
|
||||
print(f"调试信息:")
|
||||
with open('hyperbaton_dataset.json', 'r', encoding='utf-8') as f:
|
||||
dataset = json.load(f)
|
||||
puzzle = dataset['examples'][i - 1]
|
||||
options = solver.extract_options(puzzle['input'])
|
||||
for label, sentence in options:
|
||||
if label in expected:
|
||||
print(f"应该正确的选项 {label}: {sentence}")
|
||||
print("形容词序列:", debug_adjective_order(solver, sentence))
|
||||
|
||||
accuracy = (correct / total) * 100
|
||||
print(f"\n准确率: {accuracy:.2f}% ({correct}/{total})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
158
internbootcamp/libs/bbeh_hyperbaton/bbeh_hyperbaton_validor.py
Normal file
158
internbootcamp/libs/bbeh_hyperbaton/bbeh_hyperbaton_validor.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
import json
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
|
||||
class HyperbatonValidator:
|
||||
def __init__(self):
|
||||
"""简化初始化,不再需要加载数据集"""
|
||||
pass
|
||||
|
||||
def validate_batch(self, tasks: List[Dict[str, Any]], predictions: List[str]) -> Dict[str, Any]:
|
||||
"""直接验证一批预测结果"""
|
||||
if len(predictions) != len(tasks):
|
||||
raise ValueError("预测结果数量与任务数量不匹配")
|
||||
|
||||
results = []
|
||||
correct_count = 0
|
||||
total_count = len(predictions)
|
||||
|
||||
for i, (task, pred) in enumerate(zip(tasks, predictions)):
|
||||
expected = task['target']
|
||||
cleaned_pred = ''.join(c for c in pred if c in "ABCDEFGHIJK")
|
||||
is_correct = cleaned_pred == expected
|
||||
|
||||
if is_correct:
|
||||
correct_count += 1
|
||||
|
||||
results.append({
|
||||
"task_id": i,
|
||||
"prediction": cleaned_pred,
|
||||
"target": expected,
|
||||
"is_correct": is_correct
|
||||
})
|
||||
|
||||
return {
|
||||
"accuracy": correct_count / total_count,
|
||||
"correct_count": correct_count,
|
||||
"total_count": total_count,
|
||||
"detailed_results": results
|
||||
}
|
||||
|
||||
def validate_predictions(self, predictions_file: str) -> Dict[str, float]:
|
||||
"""验证预测结果的准确性"""
|
||||
# 加载预测结果
|
||||
with open(predictions_file, 'r', encoding='utf-8') as f:
|
||||
predictions = json.load(f)
|
||||
|
||||
# 确保预测结果的格式正确
|
||||
if not isinstance(predictions, list):
|
||||
raise ValueError("预测结果必须是列表格式")
|
||||
|
||||
if len(predictions) != len(self.dataset["examples"]):
|
||||
raise ValueError(f"预测结果数量({len(predictions)})与数据集样例数量({len(self.dataset['examples'])})不匹配")
|
||||
|
||||
# 验证每个预测结果
|
||||
correct_count = 0
|
||||
total_count = len(predictions)
|
||||
|
||||
results = []
|
||||
|
||||
for i, (pred, example) in enumerate(zip(predictions, self.dataset["examples"])):
|
||||
expected = example["target"]
|
||||
|
||||
# 清理预测结果,只保留大写字母
|
||||
cleaned_pred = ''.join(c for c in pred if c in "ABCDEFGHIJK")
|
||||
|
||||
# 检查预测是否正确
|
||||
is_correct = cleaned_pred == expected
|
||||
|
||||
if is_correct:
|
||||
correct_count += 1
|
||||
|
||||
results.append({
|
||||
"example_id": i,
|
||||
"prediction": cleaned_pred,
|
||||
"target": expected,
|
||||
"is_correct": is_correct
|
||||
})
|
||||
|
||||
# 计算总体准确率
|
||||
accuracy = correct_count / total_count if total_count > 0 else 0
|
||||
|
||||
# 返回结果统计
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"correct_count": correct_count,
|
||||
"total_count": total_count,
|
||||
"detailed_results": results
|
||||
}
|
||||
|
||||
def validate_single_prediction(self, example_id: int, prediction: str) -> Dict[str, any]:
|
||||
"""验证单个预测结果"""
|
||||
if example_id < 0 or example_id >= len(self.dataset["examples"]):
|
||||
raise ValueError(f"例子ID {example_id} 超出范围 [0, {len(self.dataset['examples']) - 1}]")
|
||||
|
||||
example = self.dataset["examples"][example_id]
|
||||
expected = example["target"]
|
||||
|
||||
# 清理预测结果,只保留大写字母
|
||||
cleaned_pred = ''.join(c for c in prediction if c in "ABCDEFGHIJK")
|
||||
|
||||
# 检查预测是否正确
|
||||
is_correct = cleaned_pred == expected
|
||||
|
||||
return {
|
||||
"example_id": example_id,
|
||||
"prediction": cleaned_pred,
|
||||
"target": expected,
|
||||
"is_correct": is_correct
|
||||
}
|
||||
|
||||
def get_example(self, example_id: int) -> Optional[Dict[str, str]]:
|
||||
"""获取指定ID的例子"""
|
||||
if example_id < 0 or example_id >= len(self.dataset["examples"]):
|
||||
return None
|
||||
|
||||
return self.dataset["examples"][example_id]
|
||||
|
||||
def generate_report(self, results: Dict[str, any], output_file: str = None):
|
||||
"""生成验证报告"""
|
||||
report = [
|
||||
"# BBEH Hyperbaton任务验证报告",
|
||||
"",
|
||||
f"总体准确率: {results['accuracy']:.2%} ({results['correct_count']}/{results['total_count']})",
|
||||
"",
|
||||
"## 详细结果",
|
||||
"",
|
||||
"| 例子ID | 预测 | 目标 | 正确? |",
|
||||
"| ------ | ---- | ---- | ------ |",
|
||||
]
|
||||
|
||||
for result in results["detailed_results"]:
|
||||
correct_mark = "✓" if result["is_correct"] else "✗"
|
||||
report.append(f"| {result['example_id']} | {result['prediction']} | {result['target']} | {correct_mark} |")
|
||||
|
||||
report_text = "\n".join(report)
|
||||
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(report_text)
|
||||
|
||||
return report_text
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
validator = HyperbatonValidator("hyperbaton_dataset.json")
|
||||
|
||||
# 假设我们有一个预测结果文件
|
||||
try:
|
||||
results = validator.validate_predictions("predictions.json")
|
||||
report = validator.generate_report(results, "validation_report.md")
|
||||
print("验证报告已生成并保存到 validation_report.md")
|
||||
except FileNotFoundError:
|
||||
print("预测文件不存在,请先生成预测结果")
|
||||
|
||||
# 作为演示,我们可以测试单个预测
|
||||
single_result = validator.validate_single_prediction(0, "E")
|
||||
print(f"单个验证结果: {single_result}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue