ViT(Vision Transformer)简介
ViT(Vision Transformer)简介
ViT,全称Vision Transformer,是一种将Transformer模型应用于计算机视觉任务的创新方法。Transformer模型最初是为自然语言处理任务设计的,但ViT成功地将其扩展到图像分类等视觉任务中。本文将详细介绍ViT的结构、原理,并结合医学图像分类的具体应用场景,提供一个完整的代码示例。
ViT的结构和原理
ViT的核心思想是将图像分割成小块(patches),然后将这些小块的序列输入到Transformer模型中进行处理。这个过程可以分为以下几个步骤:
图像分割(Image Patching)
首先,将输入图像分割成固定大小的小块。例如,将224x224的图像分割成16x16的小块,共有(224/16)²=196个图像块。
线性投影(Linear Projection of Flattened Patches)
每个图像块展平为一个向量,然后通过一个线性变换映射到更高维的向量空间。假设每个图像块展平成一个长度为N的向量,通过线性投影将其映射到D维空间。这个过程可以表示为:
z
0
i
=
x
p
i
E
,
其中
x
p
i
∈
R
N
,
E
∈
R
N
×
D
z_0^i = x_p^iE, \quad \text{其中} \quad x_p^i \in \mathbb{R}^{N}, \quad E \in \mathbb{R}^{N \times D}
z0i=xpiE,其中xpi∈RN,E∈RN×D
其中,
x
p
i
x_p^i
xpi是第i个图像块的向量,
E
E
E是可学习的线性投影矩阵。
添加位置编码(Positional Encoding)
由于Transformer模型不具备位置感知能力,需要为每个图像块添加位置信息。位置编码可以使用固定的正弦和余弦函数,也可以是可学习的。公式如下:
z
0
i
=
x
p
i
E
+
E
p
o
s
i
,
其中
E
p
o
s
i
∈
R
D
z_0^i = x_p^iE + E_{pos}^i, \quad \text{其中} \quad E_{pos}^i \in \mathbb{R}^{D}
z0i=xpiE+Eposi,其中Eposi∈RD
其中,
E
p
o
s
i
E_{pos}^i
Eposi是第i个图像块的位置信息。
Transformer编码器(Transformer Encoder)
带有位置信息的图像块序列输入到标准的Transformer编码器中。Transformer编码器包含多层多头自注意力机制和前馈神经网络。每一层的自注意力机制可以表示为:
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
Attention(Q,K,V)=softmax(dkQKT)V
其中,
Q
Q
Q、
K
K
K、
V
V
V分别是查询、键、值矩阵,
d
k
d_k
dk是缩放因子。
分类头(Classification Head)
Transformer编码器的输出经过一个分类头,用于图像分类任务。通常会在图像块序列的开头添加一个分类令牌(CLS token),最后的分类结果基于该分类令牌的输出。
公式总结
- 线性投影:
z 0 i = x p i E z_0^i = x_p^iE z0i=xpiE - 位置编码:
z 0 i = x p i E + E p o s i z_0^i = x_p^iE + E_{pos}^i z0i=xpiE+Eposi - 自注意力机制:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
ViT的使用场景
ViT主要用于图像分类任务,但其应用范围远不止于此。ViT还可以用于目标检测、图像分割等其他计算机视觉任务,表现出比传统卷积神经网络(CNN)更优越的性能。
医学图像分类应用示例
为了更好地理解ViT的应用场景,我们以一个具体的应用场景为例:在医学图像分类中的应用。假设我们需要构建一个模型来自动识别医学图像中的不同类型的病变,如区分正常组织和癌变组织。使用ViT可以帮助我们利用自注意力机制,更好地捕捉图像中的细微差异,提高分类精度。
数据集介绍
我们使用的是Kaggle上的皮肤病变分类数据集,该数据集包含各种类型的皮肤病变图像,每个图像都标注了具体的病变类型。我们将使用ViT模型对这些图像进行分类。
代码示例
以下是一个详细的代码示例,展示如何使用ViT进行医学图像分类。
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split
from transformers import ViTForImageClassification, ViTFeatureExtractor
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
# 定义数据变换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# 自定义数据集类
class MedicalDataset(torch.utils.data.Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.annotations = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
img_path = f"{self.root_dir}/{self.annotations.iloc[idx, 0]}"
image = Image.open(img_path).convert("RGB")
label = int(self.annotations.iloc[idx, 1])
if self.transform:
image = self.transform(image)
return image, label
# 加载数据集
dataset = MedicalDataset(csv_file='path/to/annotations.csv', root_dir='path/to/images', transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 加载ViT模型和特征提取器
model = ViTForImageClassification.from_pretrained('path/to/vit-base-patch16-224-in21k', num_labels=2)
feature_extractor = ViTFeatureExtractor.from_pretrained('path/to/vit-base-patch16-224-in21k')
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# 训练模型
for epoch in range(10):
model.train()
for batch in train_loader:
images, labels = batch
inputs = feature_extractor(images, return_tensors="pt")['pixel_values']
outputs = model(inputs)
loss = criterion(outputs.logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
# 验证模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in val_loader:
images, labels = batch
inputs = feature_extractor(images, return_tensors="pt")['pixel_values']
outputs = model(inputs)
_, predicted = torch.max(outputs.logits, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Validation Accuracy: {100 * correct / total}%')
详细解释
-
数据集准备:
我们首先定义了一个自定义数据集类MedicalDataset
,用于加载医学图像数据。数据集的标注文件是一个CSV文件,其中包含图像文件名和对应的标签。 -
数据变换:
我们使用transforms.Compose
定义了一系列数据变换,包括将图像调整为224x224大小,并转换为张量。 -
数据加载:
使用random_split
将数据集分为训练集和验证集,并使用DataLoader
加载数据。 -
模型加载:
我们使用ViTForImageClassification
加载预训练的ViT模型,并设置输出类别数为2(正常和癌变)。同时,加载特征提取器ViTFeatureExtractor
。 -
训练和验证:
定义了损失函数和优化器,并在训练过程中通过反向传播更新模型参数。在每个epoch结束后,打印当前的损失值。验证阶段,我们计算模型在验证集上的准确率。
总结
ViT通过将图像分割成小块,并使用Transformer模型进行处理,在图像分类任务中取得了显著的效果。希望本文对你有所帮助!
原文地址:https://blog.csdn.net/mieshizhishou/article/details/140377305
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!