自学内容网 自学内容网

于BERT的中文问答系统10

实现了一个基于BERT的中文问答系统,该系统能够区分给定问题的回答是由人类还是由ChatGPT生成的。以下是代码的主要功能和一些需要注意的事项:

主要功能

数据集加载:

XihuaDataset 类用于从 .json 或 .jsonl 文件中加载数据,并使用 BERT 的 tokenizer 对文本进行编码。
支持两种格式的数据文件:JSON 和 JSON Lines。
数据加载器:

get_data_loader 函数创建一个 DataLoader,用于批量加载数据集中的数据。
模型定义:

XihuaModel 类继承自 torch.nn.Module,使用预训练的 BERT 模型作为基础,并在其上添加了一个线性分类层,用于二分类任务(判断回答是人类的还是 ChatGPT 的)。
训练函数:

train 函数负责模型的训练过程,包括前向传播、损失计算、反向传播和参数更新。
使用 BCEWithLogitsLoss 作为损失函数,适用于二分类任务。
主训练函数:

main_train 函数初始化模型、优化器和数据加载器,并执行多个训练周期。
训练完成后,将模型的权重保存到文件中。
GUI 界面:

XihuaChatbotGUI 类使用 Tkinter 创建了一个简单的图形用户界面。
用户可以输入问题,点击“获取回答”按钮后,模型会判断回答是人类的还是 ChatGPT 的,并显示结果。
还提供了一个“训练模型”按钮,允许用户选择新的数据文件并重新训练模型。

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog
import logging

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for item in reader:
                    data.append(item)
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                data = json.load(f)
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True

原文地址:https://blog.csdn.net/weixin_54366286/article/details/142732313

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!