自学内容网 自学内容网

【Block总结】TAdaConv时序自适应卷积,轻量高效的时间建模卷积|即插即用

论文解读:Temporally-Adaptive Models for Efficient Video Understanding

论文信息

  • 标题:Temporally-Adaptive Models for Efficient Video Understanding

  • 发表时间:2023年

  • 作者:黄子渊等

  • 论文链接arXiv 论文

  • 代码链接GitHub - TAdaConv
    在这里插入图片描述

创新点

  1. 时序自适应卷积(TAdaConv):提出了一种新的卷积操作,通过动态调整卷积核权重,使空间卷积具备时序建模能力。这种方法能够有效捕捉视频中的复杂时间动态。

  2. 因子化卷积核权重:TAdaConv将卷积核分解为基础权重和校准权重,校准权重根据输入数据动态生成。这种设计不仅保留了预训练模型的权重,还减少了训练资源的消耗。

  3. 模块化设计:引入了TAdaBlock模块,能够灵活嵌入现有的卷积网络和视觉Transformer中,提升模型的时序理解能力。
    在这里插入图片描述

方法

  • 因子化卷积核调整:TAdaConv通过结合每一帧的局部和全局时序上下文,因子化调整卷积核的权重。具体来说,卷积核权重被分解为基础权重( W b W_b Wb)和校准权重( α t \alpha_t αt),公式为:
    W t = α t ⋅ W b W_t = \alpha_t \cdot W_b Wt=αtWb
    其中,校准权重根据输入特征的全局描述符生成。

  • 多头自注意力机制:在改进版本TAdaConvV2中,引入了多头自注意力机制,以增强全局时序信息的建模能力。

  • 时序特征聚合:提供了高效的时序特征聚合方案(如T-Pool),结合时序降采样进一步压缩计算量。

因子化卷积核调整

TAdaConv的因子化卷积核权重是通过将卷积核分解为基础权重和校准权重来实现的。这种方法旨在增强模型对时序信息的捕捉能力,同时保持计算效率。具体实现过程如下:

  1. 卷积核权重分解:在TAdaConv中,每一帧的卷积核权重被分解为两个部分:

    • 基础权重(Base Weight):这是一个固定的权重,对于所有帧都是相同的。
    • 校准权重(Calibration Weight):这是一个动态生成的权重,依据输入数据的时序特征进行调整。

    这种分解可以表示为:
    W t = α t ⋅ W b W_t = \alpha_t \cdot W_b Wt=αtWb
    其中, W t W_t Wt是每一帧的卷积核权重, W b W_b Wb是基础权重, α t \alpha_t αt是根据时序上下文动态生成的校准权重。

  2. 动态调整机制:TAdaConv设计了专门的时序自适应模块,用于生成和更新校准权重。这个模块能够根据输入数据的局部和全局时序特征进行动态调整,使得卷积核能够更好地适应不同时间尺度的特征。

  3. 低计算开销:通过这种因子化设计,TAdaConv在几乎不增加额外计算量的情况下,实现了时序推理能力的显著提升。这种低计算开销的特性使得TAdaConv在处理长序列数据时表现出色,同时保持了较高的计算效率。

  4. Tucker分解:TAdaConv还采用了Tucker分解技术来实现卷积权重的高效分解。Tucker分解可以将一个四维张量(即卷积核)分解为多个较小的张量,从而捕捉不同通道之间的交互信息,并动态调整卷积核的形状和参数,以适应不同的时序特征。

通过这些技术,TAdaConv能够有效地捕捉输入数据的时序特征,从而提高模型的性能和泛化能力。这种因子化卷积核权重的实现为传统卷积神经网络在处理时序数据时的局限性提供了新的解决方案。

多头自注意力机制

在TAdaConvV2中,多头自注意力机制(Multi-Head Self-Attention, MHSA)起到了关键的作用,主要体现在以下几个方面:

  1. 增强全局时序信息建模:TAdaConvV2通过引入多头自注意力机制,能够有效捕捉视频数据中的全局时序信息。传统的卷积操作在处理时序数据时往往局限于局部特征,而多头自注意力机制允许模型在不同的表示子空间中并行关注输入序列的不同部分,从而更全面地理解视频中的动态变化。

  2. 并行处理能力:多头自注意力机制的设计使得模型能够同时处理多个注意力头,每个头可以学习到不同的特征和关系。这种并行处理能力不仅提高了模型的表达能力,还加快了计算速度,适合处理复杂的时序数据。

  3. 捕捉多样化的依赖关系:通过多个注意力头,TAdaConvV2能够捕捉到视频帧之间的多样化依赖关系。例如,某些头可能专注于捕捉快速运动的特征,而其他头则可能关注于较慢的变化或背景信息。这种多样性使得模型在理解视频内容时更加灵活和准确。

  4. 提升模型的适应性:结合多头自注意力机制,TAdaConvV2能够更好地适应不同类型的视频数据和任务需求。模型可以根据输入的特征动态调整注意力分配,从而在不同场景下表现出更好的性能。

  5. 与其他模块的协同作用:在TAdaConvV2中,多头自注意力机制与其他模块(如动态卷积和时序特征聚合)协同工作,形成一个高效的时序建模框架。这种协同作用使得模型在处理复杂视频理解任务时,能够充分利用各个模块的优势,提升整体性能。

多头自注意力机制在TAdaConvV2中不仅增强了模型对时序信息的建模能力,还提升了其处理复杂视频数据的效率和准确性,为视频理解任务提供了强有力的支持。

试验结果

  • 消融实验:实验结果表明,动态校准权重显著提高了模型的时序建模能力。放松时间维度的不变性进一步增强了分类准确率。

  • 性能提升:在多个视频理解基准测试中,TAdaConv在动作识别和定位任务中表现出色,能够在多种基线模型上提升性能,且计算开销微乎其微。

  • 与传统模型比较:TAdaConv在保持较低计算成本的同时,提供了与传统复杂时序模型(如3D卷积网络、复杂Transformer)相当甚至更优的性能。
    在这里插入图片描述

总结

TAdaConv及其改进版本TAdaConvV2通过动态校准卷积核权重,结合局部和全局时序上下文,显著增强了模型的时序建模能力。其模块化设计使得TAdaConv能够灵活嵌入现代卷积网络和Transformer中,提升在动作识别和定位任务中的性能。这种创新设计不仅提高了模型的效率,还确保了在保持较低计算成本的同时,提供与传统复杂时序模型相当甚至更优的性能。代码:

import math
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _triple


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            if len(x.shape) == 5:
                x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
            elif len(x.shape) == 3:
                x = self.weight[:, None] * x + self.bias[:, None]
            return x


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class RouteFuncwTransformer(nn.Module):
    """
    The routing function for generating the calibration weights.
    """

    def __init__(self, c_in, ratio, kernels, with_bias_cal=False, bn_eps=1e-5, bn_mmt=0.1, zero_init_cal=True,
                 head_dim=64):
        """
        Args:
            c_in (int): number of input channels.
            ratio (int): reduction ratio for the routing function.
            kernels (list): temporal kernel size of the stacked 1D convolutions
        """
        super().__init__()
        self.c_in = c_in
        self.head_dim = head_dim
        self.with_bias_cal = with_bias_cal
        self.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1))
        self.globalpool = nn.AdaptiveAvgPool3d(1)
        self.a = nn.Conv3d(
            in_channels=c_in,
            out_channels=int(c_in // ratio),
            kernel_size=[kernels[0], 1, 1],
            padding=[kernels[0] // 2, 0, 0],
        )

        self.norm = LayerNorm(int(c_in // ratio), eps=1e-6, data_format="channels_first")
        self.norm_transformer = LayerNorm(int(c_in // ratio), eps=1e-6, data_format="channels_first")
        self.gelu = QuickGELU()

        self.scale = int(c_in // ratio) ** -0.5
        self.qkv_proj = nn.Conv3d(
            in_channels=int(c_in // ratio),
            out_channels=int(c_in // ratio) * 3,
            kernel_size=1,
            padding=0,
        )

        self.attn_out = nn.Conv3d(
            in_channels=int(c_in // ratio),
            out_channels=int(c_in // ratio),
            kernel_size=1,
            padding=0,
        )

        self.b = nn.Conv3d(
            in_channels=int(c_in // ratio),
            out_channels=c_in,
            kernel_size=[kernels[1], 1, 1],
            padding=[kernels[1] // 2, 0, 0],
            bias=False
        )
        self.zero_init_cal = zero_init_cal
        if zero_init_cal:
            self.b.skip_init = True
            self.b.weight.data.zero_()  # to make sure the initial values
            # for the output is 1.
        if with_bias_cal:
            self.b_bias = nn.Conv3d(
                in_channels=int(c_in // ratio),
                out_channels=c_in,
                kernel_size=[kernels[1], 1, 1],
                padding=[kernels[1] // 2, 0, 0],
                bias=False
            )
            if zero_init_cal:
                self.b_bias.skip_init = True
                self.b_bias.weight.data.zero_()  # to make sure the initial values
                # for the output is 1.

    def forward(self, x):
        x = self.avgpool(x)
        x = self.a(x)
        x = self.norm(x)
        x = self.gelu(x)

        x = x + self.forward_attention(self.norm_transformer(x))

        if self.with_bias_cal:
            if self.zero_init_cal:
                return [self.b(x) + 1, self.b_bias(x) + 1]
            else:
                return [self.b(x), self.b_bias(x)]
        else:
            if self.zero_init_cal:
                return self.b(x) + 1
            else:
                return self.b(x)

    def forward_attention(self, x):
        b, c, t, _, _ = x.shape
        qkv = self.qkv_proj(x)[:, :, :, 0, 0].view(b, 3, self.head_dim, c // self.head_dim, t).permute(1, 0, 3, 4, 2)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(-2, -1).reshape(b, c, t)[:, :, :, None, None]

        x = self.attn_out(x)

        return x
class TAdaConv2dV2(nn.Module):
    """
    Performs temporally adaptive 2D convolution.
    Currently, only application on 5D tensors is supported, which makes TAdaConv2d
        essentially a 3D convolution with temporal kernel size of 1.
    """

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True,
                 cal_dim="cin", num_frames=None, rf_r=4, rf_k=[3, 3], head_dim=64,
                 internal_rf_func=True, internal_temp_aggr=True):
        super().__init__()
        """
        Args:
            in_channels (int): number of input channels.
            out_channels (int): number of output channels.
            kernel_size (list): kernel size of TAdaConv2d. 
            stride (list): stride for the convolution in TAdaConv2d.
             padding (list): padding for the convolution in TAdaConv2d.
            dilation (list): dilation of the convolution in TAdaConv2d.
            groups (int): number of groups for TAdaConv2d. 
            bias (bool): whether to use bias in TAdaConv2d.
            cal_dim (str): calibrated dimension in TAdaConv2d. 
                Supported input "cin", "cout".
            head_dim (int): head dimension for MHA in the rourting function.
        """

        kernel_size = _triple(kernel_size)
        stride = _triple(stride)
        padding = _triple(padding)
        dilation = _triple(dilation)

        assert kernel_size[0] == 1
        assert stride[0] == 1
        assert padding[0] == 0
        assert dilation[0] == 1
        assert cal_dim in ["cin", "cout"]

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.cal_dim = cal_dim

        self.num_frames = num_frames

        if internal_rf_func:
            self.rf_func = RouteFuncwTransformer(
                c_in=out_channels,
                ratio=rf_r,
                kernels=rf_k,
                with_bias_cal=bias,
                zero_init_cal=False,
                head_dim=head_dim
            )

        if internal_temp_aggr:
            self.bn_a = nn.BatchNorm3d(out_channels)
            self.bn_b = nn.BatchNorm3d(out_channels)
            self.bn_b.skip_init = True
            self.bn_b.weight.data.zero_()
            self.bn_b.bias.data.zero_()

            self.avgpool = nn.AvgPool3d(kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0))

        # base weights (W_b)
        self.weight = nn.Parameter(
            torch.Tensor(1, 1, out_channels, in_channels // groups, kernel_size[1], kernel_size[2])
        )
        if bias:
            self.bias = nn.Parameter(torch.Tensor(1, 1, out_channels))
        else:
            self.register_parameter('bias', None)

        trunc_normal_(self.weight, std=.02)
        if self.bias is not None:
            nn.init.constant_(self.bias, 0)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv3d, nn.Linear)):
            if hasattr(m, "skip_init") and m.skip_init:
                return
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, feat, reshape_required=True, alpha=None):
        """
        Args:
            feat (tensor): feature to perform convolution on.
            reshape_required (bool): True if intput feat is the shape of (L, N, C),
                where N=B*T
        """
        if reshape_required:
            assert self.num_frames is not None
            h = w = int(math.sqrt(feat.shape[0]))
            # L, N, C -> H, W, B, T, C
            feat = feat.reshape(h, w, -1, self.num_frames, feat.shape[-1]).permute(2, 4, 3, 0, 1)

        # generate calibration factors
        if alpha is None:
            alpha = self.rf_func(feat)

        if isinstance(alpha, list):
            w_alpha, b_alpha = alpha[0], alpha[1]
        else:
            w_alpha = alpha
            b_alpha = None

        _, _, c_out, c_in, kh, kw = self.weight.size()
        b, c_in, t, h, w = feat.size()
        feat = feat.permute(0, 2, 1, 3, 4).reshape(1, -1, h, w)

        if self.cal_dim == "cin":
            # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, 1, C, H(1), W(1)
            # corresponding to calibrating the input channel
            weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(2) * self.weight).reshape(-1, c_in // self.groups, kh,
                                                                                         kw)
        elif self.cal_dim == "cout":
            # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C, 1, H(1), W(1)
            # corresponding to calibrating the input channel
            weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(3) * self.weight).reshape(-1, c_in // self.groups, kh,
                                                                                         kw)

        bias = None
        if self.bias is not None:
            if b_alpha is not None:
                # b_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C
                bias = (b_alpha.permute(0, 2, 1, 3, 4).squeeze() * self.bias).reshape(-1)
            else:
                bias = self.bias.repeat(b, t, 1).reshape(-1)
        output = F.conv2d(
            feat, weight=weight, bias=bias, stride=self.stride[1:], padding=self.padding[1:],
            dilation=self.dilation[1:], groups=self.groups * b * t)

        output = output.view(b, t, c_out, h, w).permute(0, 2, 1, 3, 4)
        if hasattr(self, "bn_a") and hasattr(self, "bn_b"):
            output = self.bn_a(output) + self.bn_b(self.avgpool(output))
        if reshape_required:
            output = output.permute(3, 4, 0, 2, 1).reshape(h * w, b * t, c_out)

        return output

    def __repr__(self):
        return f"TAdaConv2dV2({self.in_channels}, {self.out_channels}, kernel_size={self.kernel_size}, " + \
            f"stride={self.stride}, padding={self.padding}, bias={self.bias is not None}, cal_dim=\"{self.cal_dim}\")"


if __name__ == '__main__':
    B, C,T, H, W = 1, 64, 8, 40, 40
    input_tensor = torch.randn(B, C,T, H, W)  # 随机生成输入张量
    alpha_tensor = torch.rand(1, 64, 8, 1, 1)
    # 初始化 CBlock
    dim = C  # 输入和输出通道数
    # 创建 CBlock 实例
    block = TAdaConv2dV2(in_channels=dim,out_channels=dim,kernel_size=[1,3,5],stride=[1,1,1],padding=[0,1,2])
    # 如果GPU可用将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    block = block.to(device)
    print(block)
    input_tensor = input_tensor.to(device)
    alpha_tensor=alpha_tensor.to(device)
    # 执行前向传播
    output = block(input_tensor,False,alpha_tensor)
    # 打印输入和输出的形状
    print(f"Input: {input_tensor.shape}")
    print(f"Output: {output.shape}")

在这里插入图片描述


原文地址:https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/145264444

免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!