自学内容网 自学内容网

YOLO11改进|注意力机制篇|引入注意力机制Shuffle Attention

在这里插入图片描述

一、【Shuffle Attention】注意力机制

1.1【Shuffle Attention】注意力介绍

在这里插入图片描述

下图是【Shuffle Attention】的结构图,让我们简单分析一下运行过程和优势

处理过程

  • 输入和分组(Group):
  • 输入特征 𝑋大小为 𝑤×ℎ×𝑐(即宽度、高度和通道数),首先进行 通道分组(Group) 操作,将输入特征按通道划分为 𝑔组,每组包含 𝑐/𝑔个通道。这一步通过分组降低了单个操作的计算复杂度,并提高了并行处理的能力。
  • 通道分裂(Split):
  • 每个分组进一步被分为两部分,每部分包含 𝑐/2𝑔个通道,分别进入两个不同的特征处理路径(上方绿色路径和下方蓝色路径)。这一步通过分裂通道,使得每一组的通道可以分别进行特征处理,增加特征多样性。
  • 特征融合(Fuse):
  • 对于每个分支,特征会经过一个自适应的加权处理,生成加权后的特征图。这些加权操作通过 元素乘积 和 Sigmoid 函数进行,实现特征的选择性增强。绿色路径使用了全局池化 𝐹𝑔𝑝 ,而蓝色路径则通过组归一化(GN)实现更细粒度的特征处理。
  • 通道融合与打乱(Channel Shuffle):
  • 两条路径的特征在完成各自的加权融合后,通过 通道拼接(Concatenate) 操作整合为一个新的特征图,并使用 通道打乱(Channel Shuffle) 技术重新排列通道顺序。通道打乱有助于消除分组带来的通道隔离问题,增强不同通道之间的信息交互。
  • 特征聚合(Aggregate):
  • 最终,将重新排列的特征图进行聚合(Aggregate)操作,得到输出特征 𝑆。这一过程将所有处理后的特征进行整合,输出经过优化的特征图,准备用于后续任务。
    优势
  • 计算效率高:
  • 通过通道分组(Group)操作,该模块有效降低了每个卷积和操作的计算成本。这使得网络能够在更高效的计算框架下工作,尤其适用于大规模数据和高维度特征的处理。
  • 增强特征表达能力:
  • 通过将通道分裂为不同的路径(绿色和蓝色路径),网络能够从不同维度对特征进行加工和处理,增加了特征的多样性。此外,融合(Fuse)操作通过自适应加权的方式进一步提升了关键特征的表达能力。
  • 通道打乱改善信息流动:
  • 通道打乱(Channel Shuffle)通过重新排列通道,打破了原本通道分组导致的信息孤立问题,增强了不同通道之间的信息交互。这使得分组操作不再限制通道之间的联系,提高了特征的共享和传递效率。
  • 自适应的特征选择:
  • 通过 Sigmoid 激活和加权操作,该模块能够自适应地调整特征的重要性,增强了模型的选择性关注,使得网络能够更加聚焦于有用的特征,减少了不必要的信息干扰。
    在这里插入图片描述

1.2【Shuffle Attention】核心代码

import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter


class ShuffleAttention(nn.Module):

    def __init__(self, channel=512, reduction=16, G=8):
        super().__init__()
        self.G = G
        self.channel = channel
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sigmoid = nn.Sigmoid()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.size()
        # group into subfeatures
        x = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w

        # channel_split
        x_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w

        # channel attention
        x_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1
        x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1
        x_channel = x_0 * self.sigmoid(x_channel)

        # spatial attention
        x_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,w
        x_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,w
        x_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w

        # concatenate along channel axis
        out = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,w
        out = out.contiguous().view(b, -1, h, w)

        # channel shuffle
        out = self.channel_shuffle(out, 2)
        return out



二、添加【Shuffle Attention】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个ShuffleAttention.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【Shuffle Attention】注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128,3,2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256,3,2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512,3,2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024,3,2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1,1,ShuffleAttention,[]]
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 14], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 11], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【Shuffle Attention】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述


原文地址:https://blog.csdn.net/A1983Z/article/details/142894214

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!