Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)
Mindspore框架循环神经网络RNN模型实现情感分类
Mindspore框架循环神经网络RNN模型实现情感分类|(一)IMDB影评数据集准备
Mindspore框架循环神经网络RNN模型实现情感分类|(二)预训练词向量
Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建
Mindspore框架循环神经网络RNN模型实现情感分类|(四)损失函数与优化器
Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练
Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)
Mindspore框架循环神经网络RNN模型实现情感分类|(七)模型导出ONNX与应用部署
一、模型资源下载
- RNN升级版LSTM模型:本项目训练好的情感分类模型-下载训练好的IMDB分类模型。
二、模型加载与推理
class RNN(nn.Cell):
def __init__(self, embeddings, hidden_dim, output_dim, n_layers,
bidirectional, pad_idx):
super().__init__()
vocab_size, embedding_dim = embeddings.shape
self.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings),
padding_idx=pad_idx)
self.rnn = nn.LSTM(embedding_dim,
hidden_dim,
num_layers=n_layers,
bidirectional=bidirectional,
batch_first=True)
weight_init = HeUniform(math.sqrt(5))
bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))
self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)
def construct(self, inputs):
embedded = self.embedding(inputs)
_, (hidden, _) = self.rnn(embedded)
hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)
output = self.fc(hidden)
return output
编写预测接口:test_interface
def predict_sentiment(model, vocab, sentence):
score_map = {
1: "Positive",
0: "Negative"
}
model.set_train(False)
tokenized = sentence.lower().split()
indexed = vocab.tokens_to_ids(tokenized)
tensor = ms.Tensor(indexed, ms.int32)
tensor = tensor.expand_dims(0)
prediction = model(tensor)
return score_map[int(np.round(ops.sigmoid(prediction).asnumpy()))]
def test_interface():
# train()
score_map = {
1: "Positive",
0: "Negative"
}
ckpt_file_name = './IMDB/IMDB/sentiment-analysis.ckpt'
# 预训练词向量表
glove_path = r"./IMDB/IMDB/glove.6B.zip"
vocab, embeddings = load_glove(glove_path) # 预定义词向量表
hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
pad_idx = vocab.tokens_to_ids('<pad>')
model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
param_dict = ms.load_checkpoint(ckpt_file_name)
ms.load_param_into_net(model, param_dict)
# 预测
while True:
try:
print("go on!")
sentence = input("请输入:")
res = predict_sentiment(model, vocab, sentence)
print("用户输入的内容为:", sentence, "评价结果是:", res)
except:
break
def load_glove(glove_path):
glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt') # 保存数据词典
if not os.path.exists(glove_100d_path):
glove_zip = zipfile.ZipFile(glove_path)
glove_zip.extractall(cache_dir)
embeddings = []
tokens = []
with open(glove_100d_path, encoding='utf-8') as gf:
for glove in gf:
word, embedding = glove.split(maxsplit=1)
tokens.append(word)
embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))
# 添加 <unk>, <pad> 两个特殊占位符对应的embedding
embeddings.append(np.random.rand(100))
embeddings.append(np.zeros((100,), np.float32))
vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)
embeddings = np.array(embeddings).astype(np.float32)
return vocab, embeddings
预测推理:
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import os
import zipfile
import numpy as np
test_interface()
预测结果。
原文地址:https://blog.csdn.net/beauthy/article/details/140675039
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!