BERT的中文问答系统28
为了使GUI界面更加人性化,我们从以下几个方面进行改进:
美化界面:使用更现代的样式和布局,增加图标和颜色。
用户反馈:增加更多的提示信息和反馈,让用户知道当前的操作状态。
功能增强:增加一些实用的功能,如历史记录搜索、导出日志等。
用户体验:优化交互流程,使操作更加流畅和直观。
下面是改进后的代码:
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)
def setup_logging():
log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
setup_logging()
# 数据集类
class XihuaDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self.load_data(file_path)
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {i + 1}: {e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {file_path}: {e}")
return data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item['question']
human_answer = item['human_answers'][0]
chatgpt_answer = item['chatgpt_answers'][0]
try:
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
except Exception as e:
logging.warning(f"跳过无效项 {idx}: {e}")
return self.__getitem__((idx + 1) % len(self.data))
return {
'input_ids': inputs['input_ids'].squeeze(),
'attention_mask': inputs['attention_mask'].squeeze(),
'human_input_ids': human_inputs['input_ids'].squeeze(),
'human_attention_mask': human_inputs['attention_mask'].squeeze(),
'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
'human_answer': human_answer,
'chatgpt_answer': chatgpt_answer
}
# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
dataset = XihuaDataset(file_path, tokenizer, max_length)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 模型定义
class XihuaModel(torch.nn.Module):
def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
super(XihuaModel, self).__init__()
self.bert = BertModel.from_pretrained(pretrained_model_name)
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
model.train()
total_loss = 0.0
num_batches = len(data_loader)
for batch_idx, batch in enumerate(data_loader):
try:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
human_input_ids = batch['human_input_ids'].to(device)
human_attention_mask = batch['human_attention_mask'].to(device)
chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)
optimizer.zero_grad()
human_logits = model(human_input_ids, human_attention_mask)
chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)
human_labels = torch.ones(human_logits.size(0), 1).to(device)
chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)
loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
if progress_var:
progress_var.set((batch_idx + 1) / num_batches * 100)
except Exception as e:
logging.warning(f"跳过无效批次: {e}")
return total_loss / len(data_loader)
# 主训练函数
def main_train(retrain=False):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device: {device}')
tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)
if retrain:
model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
logging.info("加载现有模型")
else:
logging.info("没有找到现有模型,将使用预训练模型")
optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = torch.nn.BCEWithLogitsLoss()
train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)
num_epochs = 30
for epoch in range(num_epochs):
train_loss = train(model, train_data_loader, optimizer, criterion, device)
logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')
torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
logging.info("模型训练完成并保存")
# GUI界面
class XihuaChatbotGUI:
def __init__(self, root):
self.root = root
self.root.title("羲和聊天机器人")
self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
self.load_model()
self.model.eval()
# 加载训练数据集以便在获取答案时使用
self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))
# 历史记录
self.history = []
self.create_widgets()
def create_widgets(self):
# 顶部框架
top_frame = tk.Frame(self.root)
top_frame.pack(pady=10)
self.question_label = tk.Label(top_frame, text="问题:", font=("Arial", 12))
self.question_label.grid(row=0, column=0, padx=10)
self.question_entry = tk.Entry(top_frame, width=50, font=("Arial", 12))
self.question_entry.grid(row=0, column=1, padx=10)
self.answer_button = tk.Button(top_frame, text="获取回答", command=self.get_answer, font=("Arial", 12))
self.answer_button.grid(row=0, column=2, padx=10)
# 中部框架
middle_frame = tk.Frame(self.root)
middle_frame.pack(pady=10)
self.answer_label = tk.Label(middle_frame, text="回答:", font=("Arial", 12))
self.answer_label.grid(row=0, column=0, padx=10)
self.answer_text = tk.Text(middle_frame, height=10, width=70, font=("Arial", 12))
self.answer_text.grid(row=1, column=0, padx=10)
# 底部框架
bottom_frame = tk.Frame(self.root)
bottom_frame.pack(pady=10)
self.correct_button = tk.Button(bottom_frame, text="准确", command=self.mark_correct, font=("Arial", 12))
self.correct_button.grid(row=0, column=0, padx=10)
self.incorrect_button = tk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, font=("Arial", 12))
self.incorrect_button.grid(row=0, column=1, padx=10)
self.train_button = tk.Button(bottom_frame, text="训练模型", command=self.train_model, font=("Arial", 12))
self.train_button.grid(row=0, column=2, padx=10)
self.retrain_button = tk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), font=("Arial", 12))
self.retrain_button.grid(row=0, column=3, padx=10)
self.progress_var = tk.DoubleVar()
self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200)
self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)
self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))
self.log_text.grid(row=2, column=0, columnspan=4, pady=10)
self.evaluate_button = tk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, font=("Arial", 12))
self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)
self.history_button = tk.Button(bottom_frame, text="查看历史记录", command=self.view_history, font=("Arial", 12))
self.history_button.grid(row=3, column=1, padx=10, pady=10)
self.save_history_button = tk.Button(bottom_frame, text="保存历史记录", command=self.save_history, font=("Arial", 12))
self.save_history_button.grid(row=3, column=2, padx=10, pady=10)
def get_answer(self):
question = self.question_entry.get()
if not question:
messagebox.showwarning("输入错误", "请输入问题")
return
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
with torch.no_grad():
input_ids = inputs['input_ids'].to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
logits = self.model(input_ids, attention_mask)
if logits.item() > 0:
answer_type = "羲和回答"
else:
answer_type = "零回答"
specific_answer = self.get_specific_answer(question, answer_type)
self.answer_text.delete(1.0, tk.END)
self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")
# 添加到历史记录
self.history.append({
'question': question,
'answer_type': answer_type,
'specific_answer': specific_answer,
'accuracy': None # 初始状态为未评价
})
def get_specific_answer(self, question, answer_type):
# 使用模糊匹配查找最相似的问题
best_match = None
best_ratio = 0.0
for item in self.data:
ratio = SequenceMatcher(None, question, item['question']).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_match = item
if best_match:
if answer_type == "羲和回答":
return best_match['human_answers'][0]
else:
return best_match['chatgpt_answers'][0]
return "这个我也不清楚,你问问零吧"
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {i + 1}: {e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {file_path}: {e}")
return data
def load_model(self):
model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
if os.path.exists(model_path):
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
logging.info("加载现有模型")
else:
logging.info("没有找到现有模型,将使用预训练模型")
def train_model(self, retrain=False):
file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
if not file_path:
messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
return
try:
dataset = XihuaDataset(file_path, self.tokenizer)
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
# 加载已训练的模型权重
if retrain:
self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))
self.model.to(self.device)
self.model.train()
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
criterion = torch.nn.BCEWithLogitsLoss()
num_epochs = 30
for epoch in range(num_epochs):
train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)
logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}\n')
self.log_text.see(tk.END)
torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
logging.info("模型训练完成并保存")
self.log_text.insert(tk.END, "模型训练完成并保存\n")
self.log_text.see(tk.END)
messagebox.showinfo("训练完成", "模型训练完成并保存")
except Exception as e:
logging.error(f"模型训练失败: {e}")
self.log_text.insert(tk.END, f"模型训练失败: {e}\n")
self.log_text.see(tk.END)
messagebox.showerror("训练失败", f"模型训练失败: {e}")
def evaluate_model(self):
# 这里可以添加模型评估的逻辑
messagebox.showinfo("评估结果", "模型评估功能暂未实现")
def mark_correct(self):
if self.history:
self.history[-1]['accuracy'] = True
messagebox.showinfo("评价成功", "您认为这次回答是准确的")
def mark_incorrect(self):
if self.history:
self.history[-1]['accuracy'] = False
messagebox.showinfo("评价成功", "您认为这次回答是不准确的")
def view_history(self):
history_window = tk.Toplevel(self.root)
history_window.title("历史记录")
history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))
history_text.pack(padx=10, pady=10)
for entry in self.history:
history_text.insert(tk.END, f"问题: {entry['question']}\n")
history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
if entry['accuracy'] is None:
history_text.insert(tk.END, "评价: 未评价\n")
elif entry['accuracy']:
history_text.insert(tk.END, "评价: 准确\n")
else:
history_text.insert(tk.END, "评价: 不准确\n")
history_text.insert(tk.END, "-" * 50 + "\n")
def save_history(self):
file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])
if not file_path:
return
with open(file_path, 'w') as f:
json.dump(self.history, f, ensure_ascii=False, indent=4)
messagebox.showinfo("保存成功", "历史记录已保存到文件")
# 主函数
if __name__ == "__main__":
# 启动GUI
root = tk.Tk()
app = XihuaChatbotGUI(root)
root.mainloop()
改进点总结
美化界面:
使用更大的字体和更现代的布局。
增加了顶部、中部和底部框架,使界面更加整洁。
用户反馈:
在每个按钮点击后增加了提示信息,让用户知道当前的操作状态。
使用messagebox显示操作结果,增加用户的反馈体验。
功能增强:
增加了历史记录查看和保存功能。
增加了日志输出,方便用户了解模型训练和评估的状态。
用户体验:
优化了交互流程,使操作更加流畅和直观。
增加了进度条,让用户了解模型训练的进度。
希望这些改进能让你的聊天机器人界面更加友好和易用!
原文地址:https://blog.csdn.net/weixin_54366286/article/details/143505132
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!