自学内容网 自学内容网

机器学习周报(12.2-12.8)

摘要

本周学习了Vision Transformer (ViT) 的基本原理及其实现,并完成了基于PyTorch的模型训练、验证和预测任务。深入理解了ViT如何将图像分割成patch作为输入序列,并结合Transformer Encoder处理。通过迁移学习在花类数据集上训练模型,并验证了模型在预测任务中的优越性能。

Abstract

This week, I studied the fundamental principles and implementation of Vision Transformer (ViT) and completed model training, validation, and prediction tasks using PyTorch. I gained a deep understanding of how ViT splits an image into patches as input sequences and processes them using the Transformer Encoder. By leveraging transfer learning, I trained the model on a flower dataset and validated its superior performance in prediction tasks.

Vision Transformer

1 原理

  • 数据处理

我认为ViT的关键在于理解怎么将图片当作一个序列输入进模型之中。我们先看看ViT整体结构图,如下图所示

在这里插入图片描述
论文中提到将 224x224x3 的图像作为输入,将图像分为 16x16x3 大小的patch,也就是说将输入图像分为了 224 × 224 × 3 16 × 16 × 3 = 196 \frac{224×224×3}{16×16×3}=196 16×16×3224×224×3=196 个patch。其中每个patch拉直之后的维度为 16×16×3=768维,也就是Linear Projection of Flattened Patches层下面分割的小图像。

在具体实现中,使用卷积核大小为 16x16x3 、步距为16、卷积核个数为768的卷积层,就能将3维图像转换为Transformer所需要的输入token[组数,维度]。

  • 全连接层
    上述[196,768]的token将传入Linear Projection of Flattened Patches层,该层是 768x768 的全连接层,该层输出认为 196x768 。

  • 位置编码
    将经过全连接层后的输出进行位置编码,其位置编码和Transformer中的时序编码有异曲同工之妙,前者可以通过位置编码表示出token之间关于原输入图像的一些位置信息,后者可以表示输入先后的时序信息。
    该模型位置编码通过类似于坐标的形式表达,直接于输入相加,不改变维度大小。如下图所示:
    在这里插入图片描述

进行位置编码后,还需要加上一个特殊字符(最左输入0*),输入总组数从之前的196变为197,传入Transformer Encoder的token为[197,768]。

  • Transformer Encoder
    在这里插入图片描述
    ViT采用的是Transformer中编码器进行叠加,但其中的参数数量有所不同。
    经过位置编码和加入特殊字符的token[197,768]传入编码器,首先经过层归一化,再经过多头自注意力。这里的多头自注意力是采用12个头,也就是将768维分为12份,每份(Q、K、V)64维度,计算之后再进行合并为768维。

ViT中的编码器仍是采用残差连接,再经过一次层归一化后,就进入单个Transformer Encoder的最后一层MLP(多层感知机)。MLP将经过多头自注意力的输出维度升高4倍,即从768变为3072,最后再将维度降至768维

ViT中的编码器输入和输出都是768维,也就是在硬件运行的情况下一直叠加,论文中也是将该模块叠加了L块。

  • 输出
    ViT中的编码器输入和输出都是768维,也就是在硬件运行的情况下一直叠加,论文中也是将该模块叠加了L块。

在这里插入图片描述

最后,通过全连接层和softmax进行概率输出即可

2 代码

在理解完ViT的原理之后,我们来看看PyTorch代码如何实现。这里以ViT-base模型,输入图像 224x224x3,patch大小 16x16x3 为例

花类数据集:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzy

训练模型代码如下,需要自行更改数据集路径和权重路径。

import os
import math
import argparse
 
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
 
from my_dataset import MyDataSet
from vit_model import vit_base_patch16_224_in21k as create_model
from utils import read_split_data, train_one_epoch, evaluate
 
 
def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
 
    if os.path.exists("../weights") is False:
        os.makedirs("../weights")
 
    tb_writer = SummaryWriter()
 
    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
 
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
 
    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])
 
    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])
 
    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)
 
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)
 
    model = create_model(num_classes=args.num_classes, has_logits=False).to(device)
 
    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        weights_dict = torch.load(args.weights, map_location=device)
        # 删除不需要的权重
        # del_keys = ['head.weight', 'head.bias'] if model.has_logits \
        #     else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']
        del_keys = ['head.weight', 'head.bias']
        for k in del_keys:
            del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))
 
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head, pre_logits外,其他权重全部冻结
            if "head" not in name and "pre_logits" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))
 
    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
 
    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch)
 
        scheduler.step()
 
        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)
 
        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
 
        torch.save(model.state_dict(), "../weights/model-{}.pth".format(epoch))
 
 
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lrf', type=float, default=0.01)
 
    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data_path', type=str, default='../data/flower_photos', help='path to dataset')
    parser.add_argument('--model-name', default='', help='create model name')
 
    # 预训练权重路径,如果不想载入就设置为空字符
    parser.add_argument('--weights', type=str, default='../weights/vit_base_patch16_224.pth', help='path to initial weights')
    # 是否冻结权重
    parser.add_argument('--freeze-layers', type=bool, default=True)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
 
    opt = parser.parse_args()
 
    main(opt)

训练结果如下:

在这里插入图片描述

因为是迁移学习的原因,只需要进行微调即可,所以9epoch之后准确率就达到97.9%了。

每训练一个epoch,就会将训练模型保存至weights文件夹,如下图所示

在这里插入图片描述

通过上述代码的训练之后,我们可以将保存的模型model-9.pth引入预测代码进行预测啦!需自行更改权重路径,以及需要测试的图片路径。

import os
import json
 
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
 
from vit_model import vit_base_patch16_224_in21k as create_model
 
 
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
    data_transform = transforms.Compose(
        [transforms.Resize(254),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 
    # load image
    img_path = "../data/Image/flower.png"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    img2 = img
    plt.imshow(img)
    plt.show()
 
    img = img.convert('RGB')
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)  # [1, 3, 224, 224]
 
    # read class_indict
    json_path = 'class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
 
    with open(json_path, "r") as f:
        class_indict = json.load(f)
 
    # create model
    model = create_model(num_classes=5, has_logits=False).to(device)  # num_classes=5:表示模型将被训练来识别5个不同的类别;has_logits=False:模型不直接输出logits,在实际应用中,这通常意味着模型的输出层之后可能会跟随一个softmax激活函数
    # load model weights
    model_weight_path = "../weights/model-9.pth"  # 采用第10轮训练的参数
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
 
    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))
    plt.imshow(img2)
    plt.show()
 
 
if __name__ == '__main__':
    main()

模型预测结果如下所示:
在这里插入图片描述

在这里插入图片描述

模型预测结果几乎100%为sunflowers


原文地址:https://blog.csdn.net/weixin_51923997/article/details/144315913

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!