自学内容网 自学内容网

深度学习之基于Pytorch框架的Unet分割模型搭建学习

1. U-Net模型框架简介

论文地址:U-Net: Convolutional Networks for Biomedical Image Segmentation
代码地址:https://github.com/laisimiao/Unet

自2015年诞生以来,U-Net便以其卓越的性能在生物医学图像分割领域崭露头角。作为FCN的一种变体,U-Net凭借其Encoder-Decoder的精巧结构,不仅在医学图像分析中大放异彩,更在卫星图像分割、工业瑕疵检测等多个领域展现出强大的应用能力。U-Net是一种常用于图像分割的卷积神经网络架构,其特点在于其U型结构,包括一个收缩路径(下采样/编码器)和一个扩展路径(上采样/解码器)。这种结构使得U-Net能够在捕获上下文信息的同时,也能精确地定位到目标边界。

(1)编码器Encoder:通过连续的卷积和池化操作,逐步减小特征图的尺寸,从而捕获到图像的上下文信息。

(2)解码器Decoder:通过上采样操作逐步恢复特征图的尺寸,并与Encoder中对应尺度的特征图进行拼接(concatenate),以融合不同尺度的特征信息。
(3)跳跃连接:U-Net中的跳跃连接使得Decoder能够利用到Encoder中的高分辨率特征,从而提高了分割的精度。

(4)输出层:U-Net的输出层通常是一个1x1的卷积层,用于将特征图转换为与输入图像相同尺寸的分割图。

在这里插入图片描述

2. 安装pytorch并导入所需要库

这里采用的是pytorch框架,需要安装相应的pytorch版本,电脑支持cuda的话安装cuda版本的(注意:如果安装cuda版的pytorch需要先注意pytorch、torchvision、cuda以及python的版本兼容)。安装完成后导入需要用到的模块。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

3. 数据集准备与加载

准备好模型训练及测试的数据集,并定义相关的数据集加载程序

3.1 数据集准备

用于模型训练及测试的数据集一方面可以从开源网站上下载,另外也可以自己制作数据集。用于分割的数据集包含:原始图片+对应的mask图像。自己制作数据集时,mask图像通常采用labelme标记感兴趣区域。
在这里插入图片描述

3.2 定义数据集加载程序

数据准备好后,需要自定义数据集类用于读取数据。
Dataset和DataLoader都是用来帮助我们加载数据集的两个重要工具类。Dataset和DataLoader是一起使用的,在模型训练的过程中不断为模型提供数据,同时,使用Dataset加载出来的数据集也是DataLoader的第一个参数。所以,DataLoader本质上就是用来将已经加载好的数据以模型能够接收的方式输入到即将训练的模型中去。

Dataset简介及用法
Dataset本质上就是一个抽象类,可以把数据封装成Python可以识别的数据结构。
Dataset类不能实例化,所以在使用Dataset的时候,我们需要定义自己的数据集类,也是Dataset的子类,来继承Dataset类的属性和方法。
Dataset可作为DataLoader的参数传入DataLoader,实现基于张量的数据预处理。

Dataset主要有两种类型,分别为Map-style datasets和Iterable-style datasets
Map-style datasets类型 实现了__getitem__()和__len__()方法,它代表数据的索引到真正数据样本的映射。读取的数据并非直接把所有数据读取出来,而是读取的数据的索引或者键值这种类型是使用最多的类型,采用这种访问数据的方式可以大大节约训练时需要的内存数量,提高模型的训练效率 。
Iterable-styledatasets类型 实现了__iter__()方法,与上述类型不同之处在于,他会将真实的数据全部载入,然后在整个数据集上进行迭代, 这种读取数据的方式比较适合处理流数据

自己定义数据集子类
上面我们提到,Dataset作为一个抽象类,需要定义其子类来实例化。所以我们需要自己定义其子类或者使用已经定义好的子类。必须要继承已经内置的抽象类dataset 必须要重写其中的__init__()方法、getitem()方法和__len__()方法 其中__getitem__()方法实现通过给定的索引遍历数据样本,len()方法实现返回数据的条数

load_data.py

from torch.utils.data import Dataset
import os
import cv2
import numpy as np


class MyDataset(Dataset):
    def __init__(self, train_path, transform=None):
        self.images = os.listdir(train_path + '/last')
        self.labels = os.listdir(train_path + '/last_msk')
        assert len(self.images) == len(self.labels), 'Number does not match'
        self.transform = transform
        self.images_and_labels = []    # 存储图像和标签路径
        for i in range(len(self.images)):
            self.images_and_labels.append((train_path + '/last/' + self.images[i], train_path + '/last_msk/' + self.labels[i]))

    def __getitem__(self, item):
        img_path, lab_path = self.images_and_labels[item]
        img = cv2.imread(img_path)
        img = cv2.resize(img, (224, 224))
        lab = cv2.imread(lab_path, 0)
        lab = cv2.resize(lab, (224, 224))
        lab = lab / 255    # 转换成0和1
        lab = lab.astype('uint8')    # 不为1的全置为0
        lab = np.eye(2)[lab]    # one-hot编码
        lab = np.array(list(map(lambda x: abs(x-1), lab))).astype('float32')   # 将所有0变为1(1对应255, 白色背景),所有1变为0(黑色,目标)
        lab = lab.transpose(2, 0, 1)  # [224, 224, 2] => [2, 224, 224]
        if self.transform is not None:
            img = self.transform(img)
        return img, lab

    def __len__(self):
        return len(self.images)


if __name__ == '__main__':
    img = cv2.imread('data/train/last_msk/150.jpg', 0)
    img = cv2.resize(img, (16, 16))
    img2 = img/255
    img3 = img2.astype('uint8')
    hot1 = np.eye(2)[img3]
    hot2 = np.array(list(map(lambda x: abs(x-1), hot1)))
    print(hot2.shape)
    print(hot2.transpose(2, 0, 1))

原文链接:https://blog.csdn.net/jiebaoshayebuhui/article/details/130439027

4. 模型搭建

模型结构可以从网上copy来直接用,也可以参考一下自己写。
model.py

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.encode1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.encode2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.encode3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.encode4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.encode5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.decode1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3,
                               stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True)
        )
        self.decode2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )
        self.decode3 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        self.decode4 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, 2, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(True)
        )
        self.decode5 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 2, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )
        self.classifier = nn.Conv2d(16, 2, kernel_size=1)

    def forward(self, x):           # b: batch_size
        out = self.encode1(x)       # [b, 3, 224, 224]  =>  [b, 64, 112, 112]
        out = self.encode2(out)     # [b, 64, 112, 112] =>  [b, 128, 56, 56]
        out = self.encode3(out)     # [b, 128, 56, 56]  =>  [b, 256, 28, 28]
        out = self.encode4(out)     # [b, 256, 28, 28]  =>  [b, 512, 14, 14]
        out = self.encode5(out)     # [b, 512, 14, 14]  =>  [b, 512, 7, 7]
        out = self.decode1(out)     # [b, 512, 7, 7]    =>  [b, 256, 14, 14]
        out = self.decode2(out)     # [b, 256, 14, 14]  =>  [b, 128, 28, 28]
        out = self.decode3(out)     # [b, 128, 28, 28]  =>  [b, 64, 56, 56]
        out = self.decode4(out)     # [b, 64, 56, 56]   =>  [b, 32, 112, 112]
        out = self.decode5(out)     # [b, 32, 112, 112] =>  [b, 16, 224, 224]
        out = self.classifier(out)  # [b, 16, 224, 224] =>  [b, 2, 224, 224]   2表示类别数,目标和非目标两类
        return out


if __name__ == '__main__':
    img = torch.randn(2, 3, 224, 224)
    net = Net()
    sample = net(img)
    print(sample.shape)

5. 模型训练

5.1 加载数据集

Dataset用来构造支持索引的数据集,在训练时需要使用DataLoader在全部样本中拿出小批量数据参与每次的训练。

DataLoader函数参数:
Dataset:通过上一节Dataset加载出来的数据集
batch_size:每个batch加载多少个样本
shuffle:是否打乱输入数据的顺序
例如:
Data_size=10,batch_size=3,一次Epoch需要四次Iteration,第一列为所有样本,第二列为打乱之后的所有样本,由于batch_size=3,所以通过DataLoader输入了4个batch,包括最后一个数量已经不够3个的Batch4,里边只包含sample3

data_loader = DataLoader(MyDataset(data_path), batch_size=1, shuffle=True)

5.2 实例化模型。

num_classes = 2  # +1是背景也为一类
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(num_classes).to(device)

另外,在初始化模型时,可以加载之前训练的参数进行迁移学习

if os.path.exists(weight_path):
    net.load_state_dict(torch.load(weight_path))
    print('successful load weight!')
else:
    print('not successful load weight')

5.3 定义损失函数。

loss_fun = nn.CrossEntropyLoss()

5.4 定义优化器。

opt = optim.Adam(net.parameters())

5.5 训练模型(循环迭代数据、计算损失、优化参数等)。

    epoch = 1
    while epoch < 100:
        for i, (image, segment_image) in enumerate(tqdm.tqdm(data_loader)):
            # plt.imshow(segment_image[0,:,:])
            # plt.show()
            image, segment_image = image.to(device), segment_image.to(device)
            out_image = net(image)
            train_loss = loss_fun(out_image, segment_image.long())
            opt.zero_grad()
            train_loss.backward()
            opt.step()

            if i % 1 == 0:
                print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')

            _image = image[0]
            _segment_image = torch.unsqueeze(segment_image[0], 0) * 255
            _out_image = torch.argmax(out_image[0], dim=0).unsqueeze(0) * 255

            img = torch.stack([_segment_image, _out_image], dim=0)
            save_image(img, f'{save_path}/{i}.png')
        if epoch % 20 == 0:
            torch.save(net.state_dict(), weight_path)
            print('save successfully!')
        epoch += 1

5.6 保存模型(模型+权重)

if epoch % 20 == 0:
      torch.save(net.state_dict(), weight_path)
      print('save successfully!')

train.py

import os

import tqdm
from torch import nn, optim
import torch
from torch.utils.data import DataLoader
from data import *
from model import *
from torchvision.utils import save_image
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = 'params/unet_metal.pth'
data_path = r'data'
save_path = 'train_image'
if __name__ == '__main__':
    num_classes = 2  # +1是背景也为一类
    data_loader = DataLoader(MyDataset(data_path), batch_size=1, shuffle=True)
    net = UNet(num_classes).to(device)
    if os.path.exists(weight_path):
        net.load_state_dict(torch.load(weight_path))
        print('successful load weight!')
        print(net)
    else:
        print('not successful load weight')

    opt = optim.Adam(net.parameters())
    loss_fun = nn.CrossEntropyLoss()

    epoch = 1
    while epoch < 100:
        for i, (image, segment_image) in enumerate(tqdm.tqdm(data_loader)):
            # plt.imshow(segment_image[0,:,:])
            # plt.show()
            image, segment_image = image.to(device), segment_image.to(device)
            out_image = net(image)
            train_loss = loss_fun(out_image, segment_image.long())
            opt.zero_grad()
            train_loss.backward()
            opt.step()

            if i % 1 == 0:
                print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')

            _image = image[0]
            _segment_image = torch.unsqueeze(segment_image[0], 0) * 255
            _out_image = torch.argmax(out_image[0], dim=0).unsqueeze(0) * 255

            img = torch.stack([_segment_image, _out_image], dim=0)
            save_image(img, f'{save_path}/{i}.png')
        if epoch % 20 == 0:
            torch.save(net.state_dict(), weight_path)
            print('save successfully!')
        epoch += 1

6. 模型测试

在训练阶段,每隔20个epoch保存一个模型,并计算每个模型的IOU,其中IOU越大,并不代表该模型就是最佳的,最终需要结合模型测试结果进行评估。也即选取训练比较高的IOU进行测试,并择优选取。

test.py

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch

from net import *
from utils import *
from data import *
from torchvision.utils import save_image
from PIL import Image
num_classes = 2
net = UNet(num_classes).cuda()

weights = 'params/unet_metal.pth'
if os.path.exists(weights):
    net.load_state_dict(torch.load(weights))
    print('successfully')
else:
    print('no loading')

_input = r'.\data\TestData\3.png'
img = keep_image_size_open_rgb(_input)
plt.imshow(img)
# plt.show()
img_data = transform(img).cuda()
img_data = torch.unsqueeze(img_data, dim=0)
net.eval()
out = net(img_data)
# out = torch.squeeze(out, dim=0)  # 维度压缩

out = torch.argmax(out, dim=1)
out = torch.squeeze(out, dim=0)
out = out.unsqueeze(dim=0)
out = (out).permute((1, 2, 0)).cpu().detach().numpy()

out = np.array(out, dtype='uint8')*255
img = np.array(img, dtype='uint8')
# out = cv2.cvtColor(out, cv2.COLOR_GRAY2BGR)
# out[:, :, 1] = 0
# out = cv2.addWeighted(out, 0.1, img, 0.9, 1)
plt.figure()
plt.imshow(out)
plt.show()
from PIL import Image


def keep_image_size_open(path, size=(256, 256)):
    img = Image.open(path)
    temp = max(img.size)
    mask = Image.new('P', (temp, temp))
    mask.paste(img, (0, 0))
    mask = mask.resize(size)
    return mask


def keep_image_size_open_rgb(path, size=(256, 256)):
    img = Image.open(path)
    temp = max(img.size)
    mask = Image.new('RGB', (temp, temp))
    mask.paste(img, (0, 0))
    mask = mask.resize(size)
    return mask

7. 模型部署(应用推理)

模型部署可以采用tensorRT,Caffe等推理框架,也可以采用OpenCV4来实现模型的推理。另外,在模型部署阶段,为了使模型轻量化,可以采用剪枝、知识蒸馏的手段对模型进行处理。
在这里插入图片描述
在模型部署时,通常需要将pth文件转换为onnx文件,pth转onnx的python代码如下:

import os
import torch
import torchvision
from net import *
from utils import *
from data import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
weight_path = 'params/unet.pth'
num_classes = 3
net = UNet(num_classes).to(device)
if os.path.exists(weight_path):
    net.load_state_dict(torch.load(weight_path))
    print('successful load weight!')
    # print(net)
else:
    print('not successful load weight')

img=r'2007_001834.jpg'  # 或者随机生成大小256*256的数据
img_data=transform(img).cuda()
img_data=torch.unsqueeze(img_data, dim=0)
# 将模型导出为 ONNX 格式
torch.onnx.export(net, img_data, r"params/net_model.onnx", verbose=True)

8. 参考文献

[1] pytorch搭建分类网络并进行训练和测试
[2] Pytorch搭建训练简单的图像分割模型
[3] 十七、完整神经网络模型训练步骤
【PyTorch 实战2:UNet 分割模型】10min揭秘 UNet 分割网络如何工作以及pytorch代码实现(详细代码实现)


原文地址:https://blog.csdn.net/qq_44924694/article/details/137028349

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!