BERT的中文问答系统45
确保 search_360_baike 函数能够从 360百科 的页面中提取描述信息,并将其显示在 Text 组件中。
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)
def setup_logging():
log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
setup_logging()
# 数据集类
class XihuaDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self.load_data(file_path)
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {i + 1}: {e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {file_path}: {e}")
return data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item['question']
human_answer = item['human_answers'][0]
chatgpt_answer = item['chatgpt_answers'][0]
try:
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
except Exception as e:
logging.warning(f"跳过无效项 {idx}: {e}")
return self.__getitem__((idx + 1) % len(self.data))
return {
'input_ids': inputs['input_ids'].squeeze(),
'attention_mask': inputs['attention_mask'].squeeze(),
'human_input_ids': human_inputs['input_ids'].squeeze(),
'human_attention_mask': human_inputs['attention_mask'].squeeze(),
'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
'human_answer': human_answer,
'chatgpt_answer': chatgpt_answer
}
# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
dataset = XihuaDataset(file_path, tokenizer, max_length)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 模型定义
class XihuaModel(torch.nn.Module):
def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
super(XihuaModel, self).__init__()
self.bert = BertModel.from_pretrained(pretrained_model_name)
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
model.train()
total_loss = 0.0
num_batches = len(data_loader)
for batch_idx, batch in enumerate(data_loader):
try:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
human_input_ids = batch['human_input_ids'].to(device)
human_attention_mask = batch['human_attention_mask'].to(device)
chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)
optimizer.zero_grad()
human_logits = model(human_input_ids, human_attention_mask)
chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)
human_labels = torch.ones(human_logits.size(0), 1).to(device)
chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)
loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
if progress_var:
progress_var.set((batch_idx + 1) / num_batches * 100)
except Exception as e:
logging.warning(f"跳过无效批次: {e}")
return total_loss / len(data_loader)
# 主训练函数
def main_train(retrain=False):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'使用设备: {device}')
tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)
if retrain:
model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
logging.info("加载现有模型")
else:
logging.info("没有找到现有模型,将使用预训练模型")
optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = torch.nn.BCEWithLogitsLoss()
train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)
num_epochs = 30
for epoch in range(num_epochs):
train_loss = train(model, train_data_loader, optimizer, criterion, device)
logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.10f}')
torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
logging.info("模型训练完成并保存")
# 网络搜索函数
def search_baidu(query):
url = f"https://www.baidu.com/s?wd={query}"
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}
response = requests.get(url, headers=headers)
soup = BeautifulSoup(response.text, 'html.parser')
results = soup.find_all('div', class_='c-abstract')
if results:
return results[0].get_text().strip()
return "没有找到相关信息"
# 百度百科搜索函数
def search_baidu_baike(query):
url = f"https://baike.baidu.com/item/{query}"
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}
response = requests.get(url, headers=headers)
soup = BeautifulSoup(response.text, 'html.parser')
meta_description = soup.find('meta', attrs={'name': 'description'})
if meta_description:
return meta_description['content']
return "没有找到相关信息"
# 360百科搜索函数
def search_360_baike(query):
url = f"https://baike.so.com/doc/{query}.html"
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}
response = requests.get(url, headers=headers)
soup = BeautifulSoup(response.text, 'html.parser')
# 查找 meta description 标签
meta_description = soup.find('meta', attrs={'name': 'description'})
if meta_description:
return meta_description['content']
# 如果 meta description 不存在,查找正文内容
main_content = soup.find('div', class_='lemma-main-content')
if main_content:
paragraphs = main_content.find_all('p')
if paragraphs:
return '\n'.join([p.get_text().strip() for p in paragraphs[:3]]) # 取前三个段落的内容
return "没有找到相关信息"
# GUI界面
class XihuaChatbotGUI:
def __init__(self, root):
self.root = root
self.root.title("羲和聊天机器人")
self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
self.load_model()
self.model.eval()
# 加载训练数据集以便在获取答案时使用
self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))
# 历史记录
self.history = []
self.create_widgets()
def create_widgets(self):
# 设置样式
style = ttk.Style()
style.theme_use('clam')
# 顶部框架
top_frame = ttk.Frame(self.root)
top_frame.pack(pady=10)
self.question_label = ttk.Label(top_frame, text="问题:", font=("Arial", 12))
self.question_label.grid(row=0, column=0, padx=10)
self.question_entry = ttk.Entry(top_frame, width=50, font=("Arial", 12))
self.question_entry.grid(row=0, column=1, padx=10)
self.answer_button = ttk.Button(top_frame, text="获取回答", command=self.get_answer, style='TButton')
self.answer_button.grid(row=0, column=2, padx=10)
# 中部框架
middle_frame = ttk.Frame(self.root)
middle_frame.pack(pady=10)
self.chat_text = tk.Text(middle_frame, height=20, width=100, font=("Arial", 12), wrap='word')
self.chat_text.grid(row=0, column=0, padx=10, pady=10)
self.chat_text.tag_configure("user", justify='right', foreground='blue')
self.chat_text.tag_configure("xihua", justify='left', foreground='green')
# 底部框架
bottom_frame = ttk.Frame(self.root)
bottom_frame.pack(pady=10)
self.correct_button = ttk.Button(bottom_frame, text="准确", command=self.mark_correct, style='TButton')
self.correct_button.grid(row=0, column=0, padx=10)
self.incorrect_button = ttk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, style='TButton')
self.incorrect_button.grid(row=0, column=1, padx=10)
self.train_button = ttk.Button(bottom_frame, text="训练模型", command=self.train_model, style='TButton')
self.train_button.grid(row=0, column=2, padx=10)
self.retrain_button = ttk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), style='TButton')
self.retrain_button.grid(row=0, column=3, padx=10)
self.progress_var = tk.DoubleVar()
self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200, mode='determinate')
self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)
self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))
self.log_text.grid(row=2, column=0, columnspan=4, pady=10)
self.evaluate_button = ttk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, style='TButton')
self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)
self.history_button = ttk.Button(bottom_frame, text="查看历史记录", command=self.view_history, style='TButton')
self.history_button.grid(row=3, column=1, padx=10, pady=10)
self.save_history_button = ttk.Button(bottom_frame, text="保存历史记录", command=self.save_history, style='TButton')
self.save_history_button.grid(row=3, column=2, padx=10, pady=10)
def get_answer(self):
question = self.question_entry.get()
if not question:
messagebox.showwarning("输入错误", "请输入问题")
return
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
with torch.no_grad():
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
logits = self.model(input_ids, attention_mask)
if logits.item() > 0:
answer_type = "羲和回答"
else:
answer_type = "零回答"
specific_answer = self.get_specific_answer(question, answer_type)
self.chat_text.insert(tk.END, f"用户: {question}\n", "user")
self.chat_text.insert(tk.END, f"羲和: {specific_answer}\n", "xihua")
# 添加到历史记录
self.history.append({
'question': question,
'answer_type': answer_type,
'specific_answer': specific_answer,
'accuracy': None, # 初始状态为未评价
'baidu_baike': None, # 初始状态为无百度百科结果
'360_baike': None # 初始状态为无360百科结果
})
def get_specific_answer(self, question, answer_type):
# 使用模糊匹配查找最相似的问题
best_match = None
best_ratio = 0.0
for item in self.data:
ratio = SequenceMatcher(None, question, item['question']).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_match = item
if best_match:
if answer_type == "羲和回答":
return best_match['human_answers'][0]
else:
return best_match['chatgpt_answers'][0]
return "这个我也不清楚,你问问零吧"
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {i + 1}: {e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {file_path}: {e}")
return data
def load_model(self):
model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
if os.path.exists(model_path):
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
logging.info("加载现有模型")
else:
logging.info("没有找到现有模型,将使用预训练模型")
def train_model(self, retrain=False):
file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
if not file_path:
messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
return
try:
dataset = XihuaDataset(file_path, self.tokenizer)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
# 加载已训练的模型权重
if retrain:
self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))
self.model.to(self.device)
self.model.train()
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
criterion = torch.nn.BCEWithLogitsLoss()
num_epochs = 30
for epoch in range(num_epochs):
train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)
logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.10f}')
self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.10f}\n')
self.log_text.see(tk.END)
torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
logging.info("模型训练完成并保存")
self.log_text.insert(tk.END, "模型训练完成并保存\n")
self.log_text.see(tk.END)
messagebox.showinfo("训练完成", "模型训练完成并保存")
except Exception as e:
logging.error(f"模型训练失败: {e}")
self.log_text.insert(tk.END, f"模型训练失败: {e}\n")
self.log_text.see(tk.END)
messagebox.showerror("训练失败", f"模型训练失败: {e}")
def evaluate_model(self):
# 这里可以添加模型评估的逻辑
messagebox.showinfo("评估结果", "模型评估功能暂未实现")
def mark_correct(self):
if self.history:
self.history[-1]['accuracy'] = True
messagebox.showinfo("评价成功", "您认为这次回答是准确的")
def mark_incorrect(self):
if self.history:
self.history[-1]['accuracy'] = False
question = self.history[-1]['question']
self.show_reference_options(question)
def show_reference_options(self, question):
reference_window = tk.Toplevel(self.root)
reference_window.title("参考答案")
reference_label = ttk.Label(reference_window, text="请选择参考答案来源:", font=("Arial", 12))
reference_label.pack(pady=10)
baidu_button = ttk.Button(reference_window, text="百度百科", command=lambda: self.get_reference_answer(question, 'baidu_baike'), style='TButton')
baidu_button.pack(pady=5)
so_button = ttk.Button(reference_window, text="360百科", command=lambda: self.get_reference_answer(question, '360_baike'), style='TButton')
so_button.pack(pady=5)
def get_reference_answer(self, question, source):
if source == 'baidu_baike':
baike_answer = self.search_baidu_baike(question)
self.chat_text.insert(tk.END, f"百度百科结果: {baike_answer}\n", "xihua")
self.history[-1]['baidu_baike'] = baike_answer
elif source == '360_baike':
baike_answer = self.search_360_baike(question)
self.chat_text.insert(tk.END, f"360百科结果: {baike_answer}\n", "xihua")
self.history[-1]['360_baike'] = baike_answer
messagebox.showinfo("参考答案", f"已获取{source}的结果")
def search_baidu_baike(self, query):
return search_baidu_baike(query)
def search_360_baike(self, query):
url = f"https://baike.so.com/doc/{query}.html"
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}
response = requests.get(url, headers=headers)
soup = BeautifulSoup(response.text, 'html.parser')
# 查找 meta description 标签
meta_description = soup.find('meta', attrs={'name': 'description'})
if meta_description:
return meta_description['content']
# 如果 meta description 不存在,查找正文内容
main_content = soup.find('div', class_='lemma-main-content')
if main_content:
paragraphs = main_content.find_all('p')
if paragraphs:
return '\n'.join([p.get_text().strip() for p in paragraphs[:3]]) # 取前三个段落的内容
return "没有找到相关信息"
def view_history(self):
history_window = tk.Toplevel(self.root)
history_window.title("历史记录")
history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))
history_text.pack(padx=10, pady=10)
for entry in self.history:
history_text.insert(tk.END, f"问题: {entry['question']}\n")
history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
if entry['accuracy'] is None:
history_text.insert(tk.END, "评价: 未评价\n")
elif entry['accuracy']:
history_text.insert(tk.END, "评价: 准确\n")
else:
history_text.insert(tk.END, "评价: 不准确\n")
if entry['baidu_baike']:
history_text.insert(tk.END, f"百度百科结果: {entry['baidu_baike']}\n")
if entry['360_baike']:
history_text.insert(tk.END, f"360百科结果: {entry['360_baike']}\n")
history_text.insert(tk.END, "-" * 50 + "\n")
def save_history(self):
RECORDS_DIR = os.path.join(PROJECT_ROOT, 'records')
os.makedirs(RECORDS_DIR, exist_ok=True)
file_name = datetime.now().strftime('%Y-%m-%d_%H-%M-%S.txt')
file_path = os.path.join(RECORDS_DIR, file_name)
with open(file_path, 'w', encoding='utf-8') as f:
for entry in self.history:
f.write(f"用户: {entry['question']}\n")
f.write(f"羲和: {entry['specific_answer']}\n")
if entry['baidu_baike']:
f.write(f"百度百科结果: {entry['baidu_baike']}\n")
if entry['360_baike']:
f.write(f"360百科结果: {entry['360_baike']}\n")
f.write("-" * 50 + "\n")
# 保存为JSON格式
json_records = []
for entry in self.history:
record = {
"question": entry['question'],
"human_answers": [entry['specific_answer']] if entry['answer_type'] == "羲和回答" else [],
"chatgpt_answers": [entry['specific_answer']] if entry['answer_type'] == "零回答" else [],
"baidu_baike": entry['baidu_baike'],
"360_baike": entry['360_baike']
}
json_records.append(record)
json_file_name = datetime.now().strftime('%Y-%m-%d_%H-%M-%S.json')
json_file_path = os.path.join(RECORDS_DIR, json_file_name)
with open(json_file_path, 'w', encoding='utf-8') as f:
json.dump(json_records, f, ensure_ascii=False, indent=4)
messagebox.showinfo("保存成功", f"历史记录已保存到 {file_path} 和 {json_file_path}")
# 主函数
if __name__ == "__main__":
# 启动GUI
root = tk.Tk()
app = XihuaChatbotGUI(root)
root.mainloop()
主要改进点
search_360_baike 函数:
首先尝试从 标签中提取描述信息。
如果没有找到 标签,则查找页面中的主要内容区域(通常在 div 标签中,类名为 lemma-main-content)。
从主要内容区域中提取前三个段落的内容,并返回这些内容。
get_reference_answer 方法:
调用 search_360_baike 函数获取 360 百科的参考答案,并将其插入到 Text 组件中。
将获取到的参考答案保存到历史记录中。
通过这些修改,360百科 的内容将能够正确地显示在 Text 组件中。希望这对你有帮助!
原文地址:https://blog.csdn.net/weixin_54366286/article/details/144016326
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!