BERT的中文问答系统(羲和1.0)
确保项目目录结构清晰,我们可以通过以下步骤来组织代码和生成项目目录结构。我们将项目分为几个主要部分:数据、模型、日志、图标、源代码等。
项目目录结构
code
project_root/
├── data/
│ └── train_data.jsonl
├── models/
│ └── xihua_model.pth
├── logs/
│ └── <date_time>/
│ └── 羲和.txt
├── icons/
│ ├── xihe.png
│ └── ling.png
├── src/
│ ├── main.py
│ ├── xihua_dataset.py
│ ├── xihua_model.py
│ ├── xihua_gui.py
│ ├── utils.py
│ └── train.py
└── README.md
代码拆分
1.
main.py
主入口文件,负责启动GUI。
python
import os
import tkinter as tk
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.join(PROJECT_ROOT, 'src')
# 导入模块
from src.xihua_gui import XihuaChatbotGUI
if __name__ == "__main__":
# 启动GUI
root = tk.Tk()
app = XihuaChatbotGUI(root)
root.mainloop()
xihua_dataset.py
数据集类的定义。
python
import os
import json
import jsonlines
from transformers import BertTokenizer
import logging
class XihuaDataset:
def __init__(self, file_path, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self.load_data(file_path)
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
if self.validate_item(item):
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {
i + 1}: {
e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = [item for item in json.load(f) if self.validate_item(item)]
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {
file_path}: {
e}")
return data
def validate_item(self, item):
required_keys = ['question', 'xihe_answers', 'ling_answers']
if all(key in item for key in required_keys):
return True
logging.warning(f"跳过无效项: 缺少必要键 {
required_keys}")
return False
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item['question']
# 确保 xihe_answers 和 ling_answers 都有值
xihe_answer = item.get('xihe_answers', [])
ling_answer = item.get('ling_answers', [])
if not xihe_answer and ling_answer:
xihe_answer = ling_answer
elif not ling_answer and xihe_answer:
ling_answer = xihe_answer
xihe_answer = xihe_answer[0] if xihe_answer else ""
ling_answer = ling_answer[0] if ling_answer else ""
try:
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
xihe_inputs = self.tokenizer(xihe_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
ling_inputs = self.tokenizer(ling_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
except Exception as e:
logging.warning(f"跳过无效项 {
idx}: {
e}")
return self.__getitem__((idx + 1) % len(self.data))
return {
'input_ids': inputs['input_ids'].squeeze(),
'attention_mask': inputs['attention_mask'].squeeze(),
原文地址:https://blog.csdn.net/weixin_54366286/article/details/142851138
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!