自学内容网 自学内容网

昇思学习打卡营第32天|基于ResNet50的中药炮制饮片质量判断模型

背景介绍

        中药炮制是根据中医药理论,依照临床用药需求,通过调剂和制剂要求,将中药材制备成中药饮片的过程。老百姓日常使用的中药饮片,是中药炮制技术的成果。中药炮制过程中,尤其是涉及到水火处理时,必须注重“程度适中”。如果炮制火候不足,则无法发挥最好的药效;而火候过度则会使药效丧失。因此,判断炮制程度的准确性直接影响中药的质量和疗效。

        传统上,中药炮制程度主要依赖于经验丰富的老药工判断。然而,随着老药工的减少,经验传承面临挑战。人工智能的兴起为这一问题提供了解决方案,通过图像分类技术,尤其是使用深度学习中的ResNet50模型,我们能够有效判断饮片的炮制状态,智能化再现药工的经验。

ResNet50网络简介

        ResNet50网络由何恺明等人在2015年提出,是ILSVRC 2015年图像分类竞赛的冠军模型。传统的卷积神经网络随着层数加深会出现退化问题,而ResNet网络通过引入残差结构,成功训练了数百甚至上千层的深度神经网络。ResNet50则是基于Bottleneck残差块的50层深度网络,在多种图像分类任务中展现了优异的性能。

准备阶段
配置实验环境

        本实验基于MindSpore框架和华为Ascend平台进行。以下是环境配置的必要步骤:

!pip install mindspore==2.3.0
数据集介绍

        我们使用的中药炮制饮片数据集由成都中医药大学提供,包含三类药材(蒲黄、山楂、王不留行)的不同炮制程度图片:生品、不及、适中、太过。每种状态下包含500张图片,总共12类5000张图像。

数据预处理

        我们将原始4K的图片缩放到1000x1000像素,以适应ResNet50的输入需求。

from PIL import Image
import os

def resize_images(data_dir, target_size=(1000, 1000)):
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.endswith('.jpg'):
                img = Image.open(os.path.join(root, file))
                img = img.resize(target_size)
                img.save(os.path.join(root, file))

resize_images('dataset/zhongyiyao/')
数据加载与划分

        为了训练和验证模型,我们将数据集分为训练集、验证集和测试集。通过使用sklearn中的train_test_split函数,我们将数据按比例分配,并保证每类样本均匀分布。

from sklearn.model_selection import train_test_split

def split_dataset(data_dir):
    classes = os.listdir(data_dir)
    for class_name in classes:
        images = os.listdir(os.path.join(data_dir, class_name))
        train, test = train_test_split(images, test_size=0.2, random_state=42)
        # 进一步划分验证集
        train, val = train_test_split(train, test_size=0.2, random_state=42)
        # 保存划分后的数据
        # ...

split_dataset('dataset1/zhongyiyao/')
ResNet50模型构建

        在处理完数据后,我们选择了ResNet50作为基础网络,并对其进行微调。我们将最后的全连接层调整为输出12个类别,以适应中药饮片的分类任务。

from mindspore import nn
from mindspore import Model

def build_resnet50(num_classes=12):
    network = resnet50(pretrained=True)
    in_channels = network.fc.in_channels
    network.fc = nn.Dense(in_channels, num_classes)
    return network

network = build_resnet50()
数据加载函数定义

        为了训练模型,我们需要定义一个数据加载器。此函数加载图片并执行图像增强等预处理步骤。

from mindspore.dataset import GeneratorDataset
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore import dtype as mstype

class Iterable:
    def __init__(self,data_path):
        self._data = []
        self._label = []
        if data_path.endswith(('JPG','jpg','png','PNG')):
            # 用作推理,所以没有label
            image = Image.open(data_path)
            self._data.append(image)
            self._label.append(0)
        else:
            classes = os.listdir(data_path)
            if '.ipynb_checkpoints' in classes:
                classes.remove('.ipynb_checkpoints')
            for (i,class_name) in enumerate(classes):
                new_path =  data_path+"/"+class_name
                for image_name in os.listdir(new_path):
                    try:
                        image = Image.open(new_path + "/" + image_name)
                        self._data.append(image)
                        self._label.append(i)
                    except:
                        pass
                
    def __getitem__(self, index):
        return self._data[index], self._label[index]

    def __len__(self):
        return len(self._data)

def create_dataset_zhongyao(dataset_dir,usage,resize,batch_size,workers):
    data = Iterable(dataset_dir)
    data_set = GeneratorDataset(data,column_names=['image','label'])
    trans = []
    if usage == "train":
        trans += [
            vision.RandomCrop(700, (4, 4, 4, 4)),
            vision.RandomHorizontalFlip(prob=0.5)
        ]

    trans += [
        vision.Resize((resize,resize)),
        vision.Rescale(1.0 / 255.0, 0.0),
        vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ]

    target_trans = transforms.TypeCast(mstype.int32)
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)

    data_set = data_set.map(
        operations=target_trans,
        input_columns='label',
        num_parallel_workers=workers)

    data_set = data_set.batch(batch_size,drop_remainder=True)
    return data_set
模型训练

      我们采用交叉熵作为损失函数,Momentum优化器进行模型参数优化。通过MindSpore的Model接口进行训练。

from mindspore import Model
from mindspore import nn

loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.001, momentum=0.9)
model = Model(network, loss_fn=loss_fn, optimizer=optimizer, metrics={"accuracy"})

dataset_train = create_dataset_zhongyao('dataset1/zhongyiyao/train', 'train', 224, 32, 4)
dataset_val = create_dataset_zhongyao('dataset1/zhongyiyao/valid', 'valid', 224, 32, 4)

model.train(epochs=50, train_dataset=dataset_train, valid_dataset=dataset_val)
模型评估与推理

        在训练过程中,我们可以实时评估模型的性能,并保存训练效果最好的模型。

from mindspore import save_checkpoint, load_checkpoint, load_param_into_net

def evaluate_and_save_best_model(model, dataset_val, best_ckpt_path):
    best_acc = 0
    for epoch in range(50):
        acc = model.eval(dataset_val, dataset_sink_mode=False)['accuracy']
        print(f"Epoch {epoch}, Accuracy: {acc}")
        if acc > best_acc:
            best_acc = acc
            save_checkpoint(network, best_ckpt_path)
            print("Best model saved.")

evaluate_and_save_best_model(model, dataset_val, 'best_model.ckpt')

        推理部分代码如下,加载训练好的最佳模型,并对新的图片进行分类:

best_ckpt_path = 'best_model.ckpt'
net = resnet50(num_classes=12)
param_dict = load_checkpoint(best_ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

def predict_one(input_img):
    dataset_one = create_dataset_zhongyao(input_img, 'test', 224, 1, 1)
    data = next(dataset_one.create_tuple_iterator())
    output = model.predict(ms.Tensor(data[0]))
    pred = output.asnumpy().argmax(axis=1)
    return pred

print(predict_one('dataset1/zhongyiyao/test/sz_tg/IMG_0001.JPG'))
结果可视化

        我们可以通过可视化训练过程中准确率和损失的变化,直观展示模型的训练效果。

import matplotlib.pyplot as plt

def plot_training_results(acc_list, loss_list):
    epochs = range(1, len(acc_list) + 1)
    plt.subplot(1, 2, 1)
    plt.plot(epochs, acc_list, label="Accuracy")
    plt.title("Accuracy over Epochs")
    plt.subplot(1, 2, 2)
    plt.plot(epochs, loss_list, label="Loss")
    plt.title("Loss over Epochs")
    plt.show()

plot_training_results(acc_list, loss_list)
结语

        通过本次实验,我们成功构建并应用了ResNet50模型,对中药炮制饮片的质量进行了精准的智能化分类判断。中药炮制作为中医药的重要组成部分,其炮制火候的判断历来依赖老药工的丰富经验。然而,随着人工智能技术的迅速发展,我们借助深度学习模型有效地实现了这一经验的传承和智能化,解决了传统经验判断可能失传的问题。通过数据集的准备、网络的构建、模型的训练与验证,我们发现ResNet50在中药饮片分类任务中展现了出色的表现,准确率极高,进一步验证了其在图像分类领域的优势。

        未来,我们将继续探索更多深度学习模型的应用与改进,优化网络结构和算法,进一步提升模型在多样化炮制饮片中的判断能力。同时,也将尝试引入其他先进的算法,例如Transformer等新兴模型,探索它们在中药智能化领域的应用潜力。希望在这个过程中,能够与大家一起不断学习和进步,共同推动中医药智能化发展的新前景,助力中药现代化与人工智能的深度融合,实现更广泛的创新应用。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!


原文地址:https://blog.csdn.net/ljd939952281/article/details/142702335

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!