自学内容网 自学内容网

ResNeSt

paper:ResNeSt: Split-Attention Networks

official implementation:https://github.com/zhanghang1989/ResNeSt

third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnest.py

ResNeXt(具体介绍见https://blog.csdn.net/ooooocj/article/details/122394742)引入了"cardinality"超参,将特征图沿通道分为多个cardinal group,具体实现时对应组卷积的group数。本文又引入了一个"radix"超参,即在每个cardinal group内又将特征图进一步split成多个radix group,下图是ResNeSt block和SE-Net block以及SK-Net block的对比,关于这两篇文章的介绍见https://blog.csdn.net/ooooocj/article/details/122485342https://blog.csdn.net/ooooocj/article/details/122683493。其中各个网络的split attention的具体实现方式不同。

下图是一个cardinal group内的split attention,首先通过卷积将特征图沿通道划分为多个splits,然后相加并进行全局平均池化,经过两个dense linear层并取softmax后每个radix split得到一个对应的权重,然后加权相加后得到最终输出。这样看起来其实和SENet比较像,区别在于SENet对整个输入特征图进行操作,然后特征图的每个通道得到一个对应的权重。而ResNeSt将特征图先分成多个cardinal group然后又将每个cardinal group划分成多个radix group,然后split attention是在每个cardinal group内进行的,最后每个cardinal group得到一个对应的权重。 

ResNeXt有三种等价的实现方式,其中组卷积的实现最简洁。这里为了实现简便,作者也将原始的先cardinal group后radix group的方式等价转换成了下图的方式,其中将不同cardinal group中同一radix索引的group放到一起,变成了先radix后cardinal的分组形式,这样就可以通过组卷积来实现。

这里以timm中的实现为例,对照代码和图4讲解下具体实现。模型选择"resnest50d_4s2x40d",即cardinality=2,radix=4。输入大小为(1, 3, 224, 224)。 

首先看ResNestBottleneck类的forward函数,这里经过前面的stem处理后输入shape=(1, 64, 56, 56)。其中self.conv1为Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False),对应的就是图4每个radix group里第一个1x1卷积,按照图4实际应该有2x4=8列,即8个输出通道数为80/2/4=10的1x1卷积,这里将8个卷积合并到一起了。

接下来self.conv2就是split attention,代码如下。forward函数中self.conv对应的是图4中第二行蓝框的3x3卷积,这里每个卷积的输出通道数为c'/k=80/2=40,这里也是通过组卷积将8个卷积合并到一起了。然后当radix>1时先reshape将radix维度分离出来,然后沿radix维度相加,对应图4中间那个+。然后x_gap.mean就是global pooling。self.fc1和self.fc2对应两个dense层,具体通过groups=k=2的1x1组卷积实现。然后通过radixsoftmax沿radix维度取softmax,这样每个radix group都得到一个对应的权重,最后x * x_attn加权求和得到split attention的最终输出。

class RadixSoftmax(nn.Module):
    def __init__(self, radix, cardinality):
        super(RadixSoftmax, self).__init__()
        self.radix = radix  # 2
        self.cardinality = cardinality  # 1

    def forward(self, x):  # (1,320,1,1)
        batch = x.size(0)
        if self.radix > 1:
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)  # (1,2,4,40)->(1,4,2,40)
            x = F.softmax(x, dim=1)
            x = x.reshape(batch, -1)  # (1,320)
        else:
            x = torch.sigmoid(x)
        return x


class SplitAttn(nn.Module):
    """Split-Attention (aka Splat)
    """
    def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
                 dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
                 act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs):
        super(SplitAttn, self).__init__()
        out_channels = out_channels or in_channels  # 80
        self.radix = radix
        mid_chs = out_channels * radix  # 80*4=320
        if rd_channels is None:  # None
            attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)  # 80*4*0.25, 8; 80
        else:
            attn_chs = rd_channels * radix

        padding = kernel_size // 2 if padding is None else padding
        self.conv = nn.Conv2d(
            in_channels, mid_chs, kernel_size, stride, padding, dilation,
            groups=groups * radix, bias=bias, **kwargs)  # 2*4
        self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
        self.drop = drop_layer() if drop_layer is not None else nn.Identity()
        self.act0 = act_layer(inplace=True)
        self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)  # 80,80,2
        self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
        self.act1 = act_layer(inplace=True)
        self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)  # 80,320,2
        self.rsoftmax = RadixSoftmax(radix, groups)  # 4,2

    def forward(self, x):  # (1,80,56,56)
        # Conv2d(80, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
        x = self.conv(x)  # (1,320,56,56), c'/k=80/2=40, 这里通过group=8的组卷积每组的输入输出为c'/k=40
        x = self.bn0(x)
        x = self.drop(x)
        x = self.act0(x)

        B, RC, H, W = x.shape  # 320
        if self.radix > 1:
            x = x.reshape((B, self.radix, RC // self.radix, H, W))  # (1,4,80,56,56)
            x_gap = x.sum(dim=1)  # (1,80,56,56)
        else:
            x_gap = x
        x_gap = x_gap.mean((2, 3), keepdim=True)  # (1,80,1,1), c'=80
        # Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1), groups=2)
        x_gap = self.fc1(x_gap)  # (1,80,1,1), c''=80
        x_gap = self.bn1(x_gap)
        x_gap = self.act1(x_gap)
        # Conv2d(80, 320, kernel_size=(1, 1), stride=(1, 1), groups=2)
        x_attn = self.fc2(x_gap)  # (1,320,1,1), c'r=80x4=320

        x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)  # (1,320)->(1,320,1,1)
        if self.radix > 1:
            out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)  # (1,4,80,56,56)*(1,4,80,1,1)->(1,4,80,56,56)->(1,80,56,56)
        else:
            out = x * x_attn
        return out.contiguous()


原文地址:https://blog.csdn.net/ooooocj/article/details/140610836

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!