自学内容网 自学内容网

BERT的中文问答系统(羲和1.0)

确保项目目录结构清晰,我们可以通过以下步骤来组织代码和生成项目目录结构。我们将项目分为几个主要部分:数据、模型、日志、图标、源代码等。

项目目录结构
code

project_root/
├── data/
│   └── train_data.jsonl
├── models/
│   └── xihua_model.pth
├── logs/
│   └── <date_time>/
│       └── 羲和.txt
├── icons/
│   ├── xihe.png
│   └── ling.png
├── src/
│   ├── main.py
│   ├── xihua_dataset.py
│   ├── xihua_model.py
│   ├── xihua_gui.py
│   ├── utils.py
│   └── train.py
└── README.md

代码拆分
1.
main.py
主入口文件,负责启动GUI。

python

import os
import tkinter as tk

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.join(PROJECT_ROOT, 'src')

# 导入模块
from src.xihua_gui import XihuaChatbotGUI

if __name__ == "__main__":
    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

xihua_dataset.py
数据集类的定义。

python

import os
import json
import jsonlines
from transformers import BertTokenizer
import logging

class XihuaDataset:
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        if self.validate_item(item):
                            data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {
     i + 1}: {
     e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = [item for item in json.load(f) if self.validate_item(item)]
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {
     file_path}: {
     e}")
        return data

    def validate_item(self, item):
        required_keys = ['question', 'xihe_answers', 'ling_answers']
        if all(key in item for key in required_keys):
            return True
        logging.warning(f"跳过无效项: 缺少必要键 {
     required_keys}")
        return False

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']

        # 确保 xihe_answers 和 ling_answers 都有值
        xihe_answer = item.get('xihe_answers', [])
        ling_answer = item.get('ling_answers', [])

        if not xihe_answer and ling_answer:
            xihe_answer = ling_answer
        elif not ling_answer and xihe_answer:
            ling_answer = xihe_answer

        xihe_answer = xihe_answer[0] if xihe_answer else ""
        ling_answer = ling_answer[0] if ling_answer else ""

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            xihe_inputs = self.tokenizer(xihe_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            ling_inputs = self.tokenizer(ling_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {
     idx}: {
     e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
   
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            

原文地址:https://blog.csdn.net/weixin_54366286/article/details/142851138

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!