自学内容网 自学内容网

【ViT】对图片进行分类(论文复现)

【ViT】对图片进行分类(论文复现)

本文所涉及所有资源均在传知代码平台可获取

概述

Transformer架构虽然已经成为自然语言处理任务的标准,但是它在计算机视觉的应用仍然有限,先前的视觉任务中,注意力大多与卷积结合使用。ViT模型的出现,证明了对CNN的依赖是不必要的,直接应用于图像补丁序列的纯Transformer架构可以在图像分类任务中表现良好

模型结构
模型总体框架

在这里插入图片描述

上述是ViT模型的基本框架,可以大致分为三个主要部分

  • Patch_embed(将图片分成一系列的patches)
  • Transformer Encoder(建模不同序列之间的相关性)
  • MLP Head(用于最终的分类结构)
Patch_embed

在标准的Transformer模块中,输入的格式为二维矩阵 [num_token,token_dim] ,但对于图像数据而言,其输入数据的格式为[H,W,C] 的三维矩阵,明显不是Transformer架构需要的。所以需要Patch_embed结构将其转换为Transformer架构的输入。

针对于ViT-B/16而言,将输入图片(224x224)按照大小为(16x16) 的Patch进行划分,生成196个Patch。此时通过线性映射将每个Patch映射到一个长度为768 (16x16x3) 一维向量中。这一步可以通过卷积核大小为16x16,步距为 16 的卷积来实现。最后将长宽进行展平,则得到Transformer需要的输入格式。具体的维度变换如下所示:
[224,224,3] -> [14,14,768] -> [196,768]

在输入到Transformer Encoder之前还需要加上 [class]token(z00=xclass*z*00=*x**c**l**a**s**s*),它在Transformer 编码器 zL0z**L0 输出处的状态用作图像表示 yy , 在预训练和微调过程中,zL0z**L0 处都具有一个分类头。

同时需要将Position Embeddin[197,768]叠加(add)到上述的token上

在这里插入图片描述

如上图所示,第一行第一列的位置编码上与其自身的余弦相似度最高,其次是与第一行和第一列的余弦相似度更高,这符合常理

Transformer Encoder

Transformer Encoder 本身是堆叠Encoder Block L 次,ViT-B/16是12次。主要有以下几部分组成:

  • Layer Norm: 针对NLP领域提出,因为在RNN这类时序网络中,时序的长度并不一定是一个定值,Layer Norm在每个样本的每个特征维度上进行归一化,使得每个特征的均值为0,方差为1,从而有助于提高模型的训练效果和泛化能力。
  • Multi-head Attention: 使用多头注意力机制能够联合来自不同head部分学习到的信息。
  • MLP Block:由全连接+GELU激活函数+Dropout组成,在ViT-B/16的模型结构中,第一个全连接层将输入节点的个数翻4倍,第二个全连接层键还原节点的个数
MLP Head

通过Transfomer Encoder后输入的shape和输出的shape保持不变,由于我们只需要分类信息,因此只需要提取[class]token 的结果 zL0z**L0 ,之后通过MLP Head得到最后的分类结果。

模型的公式如下,其中E表示token的个数

在这里插入图片描述

演示效果

可视化输入图片的形式

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

可视化模型运行结果

在这里插入图片描述

核心逻辑

对输入图片进行分块处理

class PatchEmbed(nn.Module):
    def __init__(self,img_size=224,patch_size=16,in_c=3,embed_dim=768,norm_layer=None):
        super(PatchEmbed,self).__init__()
        img_size = (img_size,img_size)
        patch_size = (patch_size,patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size[0]//patch_size[0])*(img_size[1]//patch_size[1])
        self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,
                               stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim) if norm_layer else nn.Identity()
    
    def forward(self,x):
        # 首先需要判断输入图片的大小符合我们的预期
        B,C,H,W=x.shape
        assert H==self.img_size[0] and W==self.img_size[1],\
            f"input image{H}x{W} does not model {self.img_size[0]}x{self.img_size[1]}"
        # [N,in_c,H,W]->[N,embed_dim,H//16,W//16]->[N,embed_dim,H//16*W//16]
        x = self.proj(x).flatten(2).transpose(1,2)
        
        x = self.norm(x)
        
        return x

多头注意力机制

class Attention(nn.Module):
    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0,
                 proj_drop_ratio=0):
        super(Attention,self).__init__()
        self.head_dim = dim//num_heads
        self.num_heads = num_heads
        self.dim = dim
        self.scale = qk_scale or self.head_dim**(0.5)
        
        
        self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim,dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)
        
    def forward(self,x):
        # [batch_size,num_patches+class_token,channel:HxW]
        B,N,C = x.shape
        
        # 将其进行投影,也就是多头自注意力机制所说的矩阵相乘
        # reshape [B,N,C]->[B,N,3,heads,head_dim]->[3,B,heads,N,head_dim]
        qkv = self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)
        # [B,heads,N,head_dim]
        q,k,v=qkv[0],qkv[1],qkv[2]
        # [B,heads,N,N]
        attn = (q@k.transpose(-2,-1))*self.scale
        
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        # [B,heads,N,head_dim]->[B,N,heads,head_dim]->[B,N,heads*head_dim]
        # x = attn@v.permute(0,2,1,3).flatten(2)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        
        x = self.proj_drop(self.proj(x))
        return x

MLP 模块

class MLP(nn.Module):
    def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
        super(MLP,self).__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features
        self.fc1 = nn.Linear(in_features,hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features,out_features)
        self.drop = nn.Dropout(drop)
    
    # 根据流程图确定其中的结构,注意是先激活函数之后才是dropout操作   
    def forward(self,x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

Block 结构

class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 drop_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm
                 ):
        super(Block,self).__init__()
        
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim=dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio,proj_drop_ratio=drop_ratio)
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        # Linear都需要是int的类型数据
        self.mlp = MLP(dim,int(dim*mlp_ratio),dim,act_layer,drop_ratio)
        
        self.norm2 = norm_layer(dim)
        
    def forward(self,x):
        x = x+self.drop_path(self.attn(self.norm1(x)))
        x = x+self.drop_path(self.mlp(self.norm2(x)))
        
        return x

ViT 模块

class VisionTransformer(nn.Module):
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_c=3,
                 num_classes=1000,
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4.0,
                 qkv_bias=True,
                 qk_scale=None,
                 representation_size=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 embed_layer=PatchEmbed,
                 norm_layer=None,
                 act_layer=None,
                 ):
        super(VisionTransformer,self).__init__()
        # 首先需要进行初始化操作,还可以对权重进行初始化操作
        self.num_classes = num_classes
        self.embed_dim = self.num_features = embed_dim
        self.num_tokens = 1 
        act_layer = act_layer or nn.GELU
        norm_layer = norm_layer or partial(nn.LayerNorm,eps=1e-6)
        
        self.patch_embed = embed_layer(img_size=img_size,patch_size=patch_size,in_c=in_c,embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        
        # cls_token是针对于每个embed_dim确定一个class
        # pos_embed除了channel 还要针对于每一个patch确定结果
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)
        
        dpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,qk_scale=qk_scale,
                  drop_ratio=drop_ratio,attn_drop_ratio=attn_drop_ratio,drop_path_ratio=dpr[i],
                  norm_layer=norm_layer,act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)
        
        # Pre_logits layer 相当于多添加了一个全连接层
        if representation_size:
            self.has_logits=True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc',nn.Linear(embed_dim,representation_size))
                ('out',nn.Tanh())]
            ))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()
            
        self.head = nn.Linear(self.num_features,self.num_classes) if num_classes>0 else nn.Identity()
        
        # 开始对所有的权重进行初始化操作
        nn.init.trunc_normal_(self.pos_embed,std=0.02)
        nn.init.trunc_normal_(self.cls_token,std=0.02)
        self.apply(_init_vit_weights)
        
    def forward(self,x):
        B,C,H,W = x.shape
        #[B,C,H,W]->[B,N,H*W]
        x = self.patch_embed(x)
        # 每次都需要进行操作,所以不能对其本身进行expand操作
        cls_token = self.cls_token.expand(B,-1,-1)
        
        # 注意到后续一个是cat操作一个是add操作,且位置的先后关系
        x = torch.cat((cls_token,x),dim=1)
        
        # self.pos_embed中针对于一个batch值共享
        x = self.pos_drop(x+self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        
        
        x = self.pre_logits(x[:,0])
        x = self.head(x)
        return x  

文章代码资源点击附件获取


原文地址:https://blog.csdn.net/weixin_62765017/article/details/142815457

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!