自学内容网 自学内容网

w~Transformer~合集11

我自己的原文哦~    https://blog.51cto.com/whaosoft/12472192

#LightSeq

最高加速9倍!字节跳动开源8比特混合精度Transformer引擎,近年来,Transformer 已经成为了 NLP 和 CV 等领域的主流模型,但庞大的模型参数限制了它的高效训练和推理。于是字节跳动在 2019 年 12 月和 2021 年 6 月分别推出了高效推理和训练引擎 LightSeq,大大加速了 Transformer 系列模型的训练和推理,也打通了 Transformer 从训练到推理的整个流程,极大优化了用户使用体验。最近,LightSeq 训练引擎相关论文[1],被录用难度极高的超算领域国际顶会 SC22 接收,得到了学术界的广泛认可!

  • SC22 接收论文:https://sc22.supercomputing.org/presentation/?id=pap211&sess=sess154
  • 代码地址:https://github.com/bytedance/lightseq

如何继续提升速度?降低计算精度是比较直接的方法。2017 年以来,fp16 混合精度技术 [2] 获得了广泛应用。在对模型效果无损的前提下,将模型训练和推理的速度提升了 50% 以上。而为了维持模型效果,更低精度的方法(例如 int8)通常需要使用如下传统方案:

  1. 首先使用 fp16 混合精度将模型训练至收敛;
  2. 然后在模型计算密集型算子的权重、输入和输出位置处,插入伪量化结点,进行量化感知训练;
  3. 最后将带有伪量化结点的模型计算图转换到专用的 int8 推理引擎中,进行服务部署和模型推理。

虽然在多数任务上,上述方案可以实现模型效果无损,但还是存在以下问题:

  1. 使用方法复杂。例如要多一次量化感知训练 [4] 的过程,并且带有伪量化节点的计算图转换复杂。
  2. 训练速度慢。由于目前流行的深度学习框架不支持 int8 精度,所以量化感知训练需要插入 fp16 的伪量化结点来模拟 int8 量化,导致量化感知训练反而比 fp16 混合精度训练慢 2-3 倍。
  3. 推理部署难且加速比低。对比 fp32、fp16 等类型,int8 硬件和底层软件库优化相对滞后。例如在 NVIDIA GPU 上,int8 矩阵乘法加速受限于硬件架构和特定 shape,实际加速比远远低于理论值。

在下文中,如无特殊说明,量化都是指的 int8 精度的量化

针对这些问题,字节跳动推出了全新版本的 LightSeq GPU 量化训练与推理引擎。支持 Transformer 系列模型的量化训练与推理,并做到了开箱即用,用户友好。LightSeq 快准狠地实现了 int8 精度的量化训练和推理:

  1. 快:A100 多卡训练最高加速 5.2 倍,T4 单卡推理最高加速 8.9 倍。
  2. 准:训练和推理效果基本无损。
  3. 狠:相同数据量下,显存占用最高减少 68%,模型存储空间减少 75%。

总体来说,LightSeq 新版量化训练与推理引擎具有如下几个优点:

1. 丰富的支持

支持完整的 Transformer 模块和多种解码算法,支持 Transformer、BERT、GPT、BART、ViT 等多种模型结构,支持 Fairseq、Hugging Face、NeurST 等多种训练框架接入量化训练、导出模型以及量化推理,提供了丰富的样例供用户参考。

2. 卓越的性能

相比于 fp16 精度的 LightSeq 推理引擎,int8 量化还可以进一步加速最高 70%,相比于 PyTorch 推理更是达到了最高 8.9 倍的加速比。同时显存占用相比 fp16 推理引擎降低了 30% 左右,模型存储空间只需要原来的四分之一。最后经过多个任务的验证,推理效果几乎无损。

3. 便捷的使用

LightSeq 已经针对多个训练库进行了量化支持,可以一键开启量化训练,然后轻松导出为 LightSeq 支持的模型格式,最后实现量化推理。除此之外,LightSeq 还支持训练后量化,无需额外训练即可体验量化推理。

如上图所示,为了最大程度减小量化带来的损失,首先需要用 fp16 精度训练一个浮点数模型,将模型效果训到最好。然后开启量化进行 finetune,得到微调过的量化模型,此时模型效果已经基本恢复到浮点数模型的水平。接着将量化模型转换为 LightSeq 支持的 PB 或者 HDF5 模型格式,最后用 LightSeq 进行量化推理。

安装方法

LightSeq 安装非常简单,只需要一行命令即可:

pip install lightseq

量化训练

LightSeq 支持 Fairseq、Hugging Face、NeurST 等训练框架的量化接入,同时也可以自定义模型并开启量化训练。以 encoder 层为例,只需要先定义浮点数模型,然后开启量化即可:

from lightseq.training import LSTransformerEncoderLayer
from lightseq.training.ops.pytorch.quantization import enable_quant
config = LSTransformerEncoderLayer.get_config(
    model="bert-base",
    max_batch_tokens=4096,
    max_seq_len=512,
    fp16=True,
    local_rank=0,
)
layer = LSTransformerEncoderLayer(config)
# 开启量化
layer.apply(enable_quant)

量化推理

LightSeq 提供了便捷的 python 推理接口,只需要三行代码即可实现快速的量化推理:

import lightseq.inference as lsi
model = lsi.QuantTransformer(pb_path, batch_size)
result = model.infer(input)

此外 LightSeq 还提供了 BERT、GPT、ViT 等模型的 python 接口,分别调用 QuantBert、QuantGpt 和 QuanVit 即可体验。

梯度通信量化

LightSeq 支持 Transformer 模型的梯度通信量化[5],使用 Fairseq 或者 Hugging Face 即可轻松开启分布式量化训练,并同时支持浮点数模型和量化模型。在构建模型后,只需要为模型注册一个 communication hook 即可开启梯度通信量化,再开始训练过程。

from lightseq.training.gradient_comm_quantization import encode_and_decode, GCQState
from torch.nn.parallel import DistributedDataParallel 
# model could be from Fairseq or Hugging Face, wrapped by DDP
model = DistributedDataParallel(model)
state =  GCQState(process_group)
# register hook
model.register_comm_hook(state=state, hook=encode_and_decode)

性能测试

LightSeq 在多个任务上测试了量化训练、量化推理和梯度通信量化的速度,并且分析了显存占用情况和量化模型的效果。

量化训练速度

LightSeq 在 8 张 A100 显卡上进行了训练实验,主要对比对象是 Fairseq 的 Transformer、Hugging Face 的 BERT、GPT2 和 ViT。

可以看出,四种模型结构加速趋势都是类似的,加速比都会随着数据量的增大而减小,原因有三点:

  1. 随着数据量的增大,矩阵乘法 GEMM 的占比会明显增加,因此 PyTorch QAT 增加的额外的伪量化结点时间占比会逐渐减小,最后速度会和 PyTorch fp16 无限接近。
  2. 与此同时,随着 GEMM 占比升高,LightSeq fp16 自定义算子的提速效果也逐渐减小,因此时间上也会和 PyTorch fp16 无限接近。
  3. 由于 Ampere 架构显卡上 int8 GEMM 在 shape 较小时甚至不如 fp16 GEMM 快,在大 shape 下才能稍快一点,因此随着数据量增大,LightSeq int8 也会无限接近 LightSeq fp16 的速度。

量化推理速度

LightSeq 在单张 T4 显卡上进行了推理实验,主要对比对象是 Hugging Face 的 Transformer、BERT、GPT2 和 ViT。

可以看出,随着输入数据量的增大,LightSeq 与 PyTorch 的差距会逐渐减小,这也是 GEMM 占比升高造成的。比较 LightSeq fp16 和 LightSeq int8,可以看出随着数据量的增大,LightSeq int8 越来越快。这是因为在 T4 显卡上,int8 GEMM 的加速会随着 shape 的增大而有明显增加。因此在 T4 显卡上进行量化推理时,输入数据量越大,加速效果越好。

LightSeq 还针对机器翻译多个语向和多个测试集,测试了不同 batch size 下,LightSeq int8 推理相对于 LightSeq fp16 推理的加速比,实验同样是在单张 T4 显卡上进行的,采用的模型都是标准的 Transformer-Big。

可以得到和上文中相同的结论,随着 batch size 的增大,量化推理的加速比会逐渐升高。相比于 LightSeq fp16,最高还可以再加速近 70%,这极大地缩短了线上翻译模型的推理延时。

最后如上图所示,为了展示自动 GEMM 调优技术的效果,LightSeq 测试对比了 A100 显卡上 Transformer 和 BERT 模型 fp16、int8 调优前和 int8 调优后的延时。可以看出调优前某些 shape 的 int8 GEMM 速度甚至比 fp16 还要慢,而调优后全面超越了 fp16。

显存占用

LightSeq 分析了不同 batch size 下,量化模型相对于浮点数模型显存占用的加速比。可以看出随着 batch size 的增大,量化模型的显存占用优势更明显,最高可以减少 30% 左右。而 LightSeq fp16 引擎相对于 PyTorch 模型也极大程度减少了显存占用,因此 LightSeq int8 引擎最终能够减少最多 68% 左右的显存。

量化模型效果

 针对机器翻译多个语向和多个测试集,LightSeq 测试了量化模型推理相对于浮点数模型 BLEU 的损失,采用的模型都是标准的 Transformer-Big。

在数据量较大的语向 en2zh 上,LightSeq int8 相对 BLEU 损失较大些,最大达到了 - 0.4。而在数据量较小的语向 en2es 上,LightSeq int8 不仅没有任何效果损失,反而比浮点数模型更好。总体而言,int8 量化模型的平均 BLEU 相比浮点数模型基本无损。在 GLUE 和 SQuAD 等多个任务上,LightSeq 也验证了量化模型的效果。

梯度通信量化

由于在多机多卡场景下通信瓶颈更加明显,所以梯度通信量化主要应用在分布式训练场景。因此 LightSeq 在 2 机 8 卡的 A100 上进行了分布式训练的速度测试。

可以看出,梯度通信量化的训练加速效果整体上随着输入数据的增大而减弱。这主要是因为随着输入数据的增大,计算时间占比升高,梯度通信时间占比减少,梯度量化的收益也随之减小。

LightSeq 还额外增加了不同数量网卡(NIC)下的训练速度测试。可以看到使用梯度通信量化的分布式训练速度相比原始的 LightSeq fp16 有大幅度提升。

量化技术

int8 量化的加速收益主要来自如下几个方面:

  1. GEMM 精度从 fp16 降低到 int8 后,计算时间缩短;
  2. 自定义算子采用 int8 输入输出后,数据读写时间缩短;
  3. 梯度采用 int8 存储后,多机之间通信时间缩短。

以 Transformer 模型为例,经过 LightSeq fp16 引擎加速后,自定义算子时间大大缩短,而 GEMM 时间占比提升到了 90% 左右,因此优化的重点转移到了 GEMM 提速。将 fp16 GEMM 替换为 int8 GEMM 不仅可以缩短 GEMM 时间,还可以减小前后算子的输入输出位宽,从而减小读写数据的时间。最后多机训练的瓶颈主要在梯度的通信,将梯度量化为 int8 精度可以大大加快分布式训练的速度。

量化原理

为了弥补量化带来的精度损失,通常需要用量化感知训练来模拟量化过程。如上图所示,量化感知训练就是将 float GEMM 的两个 float 输入分别做一遍量化和反量化(称之为伪量化结点),离散化成分段的浮点数输入,然后进行 float GEMM 运算。得到结果后再次进行量化与反量化,得到最终的浮点数结果。而量化的过程是不可导的,因此需要用 STE 方法来估计量化参数的梯度。之所以量化感知训练中需要插入伪量化结点,然后用 float GEMM 去模拟量化过程,是因为 TensorFlow 和 PyTorch 等训练框架不支持 int8 GEMM。

而 LightSeq 量化训练直接采用 int8 GEMM 来真实还原量化过程,因此相比传统的实现要更快,且更加节省显存。在推理的时候,同样采用离散化后的整数进行 int8 GEMM 运算,最后再反量化回浮点数结果。量化推理过程和量化训练完全一致,并且和传统的量化感知训练是完全等价的。

量化位置

整个量化 Transformer 的网络结构如上图所示,红色箭头表示需要加上量化和反量化结点的位置。

首先所有 int8 GEMM 的输入和输出都需要进行量化。由于 int8 GEMM 的 shape 限制,部分 GEMM(例如注意力分数的计算)仍然采用 float GEMM。此外第二层 FFN 的 GEMM 采用的是 int32 的输出,因为它的 GEMM 输入是 ReLU 激活函数的输出结果,只包含正数,非对称,因此如果采用 int8 输出的 GEMM,将无法反量化为正确的浮点数结果。

然后所有的模型权重 weight 都需要存储为 int8 类型,因此需要对 weight 做量化。而权重 bias 参数量较小,无需量化,保留 float 精度反而可以提升模型效果。

最后需要对 decoder 端的 cache 进行量化。因为在推理时,decoder 端的 cache 需要频繁进行读写,因此将 cache 量化为 int8 可以大大加快解码的速度。

量化策略

将一个浮点数矩阵量化为 int8 整数矩阵有很多方法,LightSeq 采用的是对称量化,即将正负数范围对称的浮点数区间等比例地映射到整数区间 [-127, 127] 上。

而实际上浮点数矩阵的数值范围通常并不对称,存在极少的离群值。如果直接按照离群值的范围来量化矩阵,会影响到量化后的精度,所以需要先对矩阵进行数值截断。

LightSeq 采用 PACT 方法进行截断[6],将截断的范围当作模型可学习的参数,然后利用 STE 算法去估计参数的梯度,并进行反向传播优化。根据实践经验,权重 weight 的初始截断范围设为[-1, 1],中间结果的初始截断范围设为[-16, 16],可以在大部分任务上达到最好的效果。最后经过截断范围和其他模型参数的联合优化,量化模型的效果可以达到基本无损。

梯度通信量化

针对分布式训练场景,LightSeq 推出了梯度量化压缩技术。即对浮点精度的梯度进行 int8 量化,以减少梯度通信的时间消耗,从而加速训练,这就是梯度通信量化(GCQ)。

如上图所示,梯度通信量化的主要流程如下:

  1. 计算每张卡上各自梯度的截断范围;
  2. 对截断范围执行 all-reduce max 操作;
  3. 每张卡使用统一的截断范围对各自梯度进行 int8 量化;
  4. 对 int8 梯度执行 all-reduce sum 操作;
  5. 每张卡对 all-reduce 后的梯度进行反量化,还原为浮点数梯度,并进行参数更新。

为了解决 int8 梯度在 all-reduce 过程中溢出的问题,LightSeq 首先将每张卡上的浮点数梯度除以卡数,再使用除之前的截断范围进行量化,最后进行 all-reduce 操作。这样每张卡上量化后的 int8 整数 all-reduce 完就不会溢出,但是单卡实际用于量化的比特数也因此而减少,所以目前方案在 2 机 8 卡效果几乎无损,但随着卡数的上涨,训练效果会有所下降。以 en2de 和 en2fr 翻译任务为例,在 4 机 8 卡上进行分布式量化训练,BLEU 值分别会下降 0.4 和 1.5 左右。未来 LightSeq 将会持续探索更好的方法来解决这一问题。

通用技术

除了上一章节中提到的量化技术以外,此次更新 LightSeq 还提出了几种通用的优化技术,不仅可以应用在量化模型中,也适用于其它所有精度模型的训练与推理。

算子融合

上图是 encoder 模块量化训练的计算图,LightSeq 将两次 GEMM 运算之间的所有操作融合成一个算子[7],减少了 kernel 调用的次数,因此减少了总的计算时间。

图中黄色矩形表示 int8 GEMM,绿色矩形表示 float GEMM。这里采用 float GEMM 是由于 shape 的限制,不适合使用 int8 GEMM 加速。红色箭头表示流动数据的类型是 int8,绿色箭头表示第二层 FFN 的 GEMM 输出是 int32 数据类型。int8 GEMM 输入输出的量化与反量化操作都被融合到了前后 kernel 里,这不仅可以减少数据搬运,还可以减小显存占用。

在推理时,LightSeq 还针对 decoder 做了优化。如上图所示,在计算 self-attention 时,注意力得分的维度是(batch size, 1, sequence length)。因此在计算 value 乘积时,可以不采用 GEMM 运算,而直接手写加权求和的算子,从而将图中虚线框中的计算融合成一个 kernel。

自动显存管理

模型量化引入了更复杂的张量类型和张量依赖关系,这给显存管理带来新的挑战。为此,LightSeq 设计了新的显存管理机制。如上图所示,主要包括以下过程:

  1. 训练启动前,根据每个算子的拓扑依赖关系,自动计算每个张量的生命周期及显存空间大小。其中,包含动态维度的张量按照此维度的最大量进行计算,例如机器翻译任务中的最大句长和最大 batch 句子数量。这些最大量在训练前已被指定;
  2. 张量确定生命周期和大小后,分析显存复用关系。其中,无生命周期重合的张量可以共用一片显存空间,所有显存空间都是无数据类型的,可以被分配到任意数据类型的张量上;
  3. 根据张量显存复用关系,申请多段显存空间,为每个张量分配实际的显存起止地址。

张量显存复用的分析,LightSeq 借鉴了论文 [3] 中提出的 Greedy by Size for Offset Calculation 方法,做了三个改进:      

  1. 支持了整个训练过程的显存复用(forward/backward);
  2. 不同数据类型能做到显存复用(int8/fp16/fp32);
  3. 在多段显存空间上容纳所有张量,而非一段非常大的显存空间,这样能有效提升显存利用率。

自动 GEMM 调优

LightSeq 的 int8 GEMM 采用了 NVIDIA 的 cuBLASLt 库,这也是目前 NVIDIA 显卡上最为高效的矩阵运算库。但是输入数据的 shape 或者显卡不同的话,GEMM 所采用的最优配置(例如数据排布、GEMM 算法等等)也可能不同,因此需要进行自动选取。LightSeq 采取的自动调优方案如下:

  1. 在多种型号显卡上(例如 T4 和 A100)进行不同 shape 的 GEMM 最优配置搜索,并将结果保存到配置文件中,用户只需要下载即可;
  2. 模型初始化时,加载对应型号显卡的配置文件,解析并保存到键值对为 (shape, 最优配置) 的字典中。如果没有对应型号显卡的配置文件,或者没有需要的 GEMM shape,那么用户可以选择自己搜索并保存,或者直接使用默认配置;
  3. 模型前向或后向计算时,根据输入的 shape 在字典中寻找最优配置,然后进行 GEMM 计算。如果没有找到对应的 shape,那么直接采用默认的配置。

未来工作

未来 LightSeq 还将继续探索移动端的低精度量化、反向传播中梯度的量化、大模型量化等方向。

#SCTNet

SCTNet,一种带有transformer语义信息的单分支CNN用于实时分割。借助于提出的transformer类CNN块CFBlock和语义信息对齐模块,SCTNet可以在训练中从transformer分支捕获丰富的语义信息。 80.5mIoU+62.8FPS! 华科与美团联合提出单分支推理分割架构

最新的实时语义分割方法通常采用额外的语义分支来追求丰富的长距离上下文。然而,额外的分支会带来不必要的计算开销,并减缓推理速度。为了消除这一困境,我们提出了SCTNet,一种带有transformer语义信息的单分支CNN用于实时分割

​https://arxiv.org/abs/2312.17071​

​https://github.com/xzz777/SCTNet​

SCTNet在保留轻量级单分支CNN高效性的同时,还拥有语义分支的丰富语义表示。考虑到transformer提取长距离上下文的卓越能力,SCTNet将transformer作为仅用于训练的语义分支。借助于提出的transformer类CNN块CFBlock和语义信息对齐模块,SCTNet可以在训练中从transformer分支捕获丰富的语义信息。在推理过程中,只需要部署单分支CNN。我们在Cityscapes,ADE20K和COCO-Stuff-10K上进行了广泛的实验,结果表明,我们的方法达到了新的最先进水平。

本文贡献主要包含以下三点:

  • 我们提出了一种新的单支实时分割网络SCTNet。通过学习从Transformer到CNN的语义信息对齐来提取丰富的语义信息,SCTNet在保持轻量级单支CNN快速推理速度的同时,具有Transformer的高准确性。
  • 为了缓解CNN特征和Transformer特征之间的语义鸿沟,我们设计了CFBlock(ConvFormer Block),它可以仅使用卷积操作捕获长距离上下文。此外,我们提出了SIAM(语义信息对齐模块),以更有效地对齐特征。
  • 在Cityscapes、ADE20K和COCO-Stuff-10K上的大量实验结果表明,所提的SCTNet在实时语义分割方面优于现有的最新方法. SCTNet为提高实时语义切分的速度和性能提供了一个新的视角

本文方案

降低计算成本的同时,获得丰富的语义信息,我们将现在流行的两个分支架构拆解为:

  • 一个CNN分支进行推断;
  • 一个Transformer分支用于训练阶段语义对齐。

Backbone 为了提高推理速度,SCTNet采用了典型的分层CNN骨干。SCTNet的Stem模块由两个3×3卷积构成;前两个阶段是由堆叠的残积模块组成的;后两个阶段则是由所提CFBlock构成。CFBlock采用了几个精心设计的卷积操作来执行类似于Transformer块的远程上下文捕获功能。

Decoder Head 解码头由DAPPM与分割头构成,为进一步丰富上下文信息,作者在Stage4后面添加了DAPPM。然后,作者将S2和S4输出进行拼接并送入分割头。

Training Phase 众所周知,Transformer在捕获全局语义上下文方面表现出色。另一方面,CNN已被证明比变换器更适合于对分层局部信息进行建模。受Transformer和CNN优点的启发,我们探索配备一个具有这两种优点的实时分割网络。我们提出了一个单分支CNN,它学习将其特征与强大的Transformer的特征对齐。这种特征对齐使单分支CNN能够提取丰富的全局上下文和详细的空间信息。具体而言,SCTNet采用了一个仅作用在训练阶段的Transformer作为语义分支来提取强大的全局语义上下文,语义信息对齐模块监督卷积分支以对齐来自Transformer的高质量全局上下文

Inference Phase 为了避免两个分支的巨大计算成本,在推理阶段只部署了CNN分支。利用transformer对齐的语义信息,单分支CNN可以生成准确的分割结果,而无需额外的语义或昂贵的密集融合。更具体地说,输入图像被送入到单分支层次卷积主干中,解码器头拾取主干中的特征并进行简单的拼接进行像素分类.   

本文实验

上图与表为Cityscapes语义分割上不同方案的性能对比,从中可以看到:

  • 所提SCTNet以大幅优势优于其他实时分割方案,取得了最佳的速度-精度均衡;
  • 所提SCTNet-B-Seg100去的了80.5%mIoU且速度达62.8FPS,达成实时分割新SOTA
  • 所提SCTNet-B-Seg75取得了79.8%mIoU,比RTFormer-B与DDRnet-23精度更高,同时速度快两倍;
  • 在所有输入分辨率下,所提SCTNet-B均比其他方案指标更优;此外,SCTNet-S同样取得了比STDC2、RTFormer-S、SeaFormer-B、TopFormer-B更优的性能均衡。

上表为ADE20K与COCO-Stuff-10K两个数据集上不同分割方案的性能对比,很明显:所提SCTNet同样取得了更优的速度-精度均衡。 

#STKET

作者提出了一种基于时空知识嵌入的 Transformer(STKET)将先验时空知识纳入多头交叉注意机制中,从而学习更多有代表性的视觉关系表示。 基于时空知识的视频场景图生成

视频场景图生成(VidSGG)旨在识别视觉场景中的对象并推断它们之间的视觉关系。该任务不仅需要全面了解分散在整个场景中的每个对象,还需要深入研究它们在时序上的运动和交互。为此,我们进行了相关的探索,并发现:每对物体组合及其它们之间的关系在每个图像内具有空间共现相关性,并且在不同图像之间具有时间一致性/转换相关性。基于这些先验知识,我们提出了一种基于时空知识嵌入的 Transformer(STKET)将先验时空知识纳入多头交叉注意机制中,从而学习更多有代表性的视觉关系表示。具体来说,我们首先以统计方式学习空间共现和时间转换相关性。然后,我们设计了时空知识嵌入层对视觉表示与知识之间的交互进行充分探索,分别生成空间和时间知识嵌入的视觉关系表示。最后,我们聚合这些特征,以预测最终的语义标签及其视觉关系。大量实验表明,我们所提出的框架大幅优于当前竞争算法。目前,该论文已经被人工智能顶级期刊 IEEE T-IP接收。

论文链接:https://arxiv.org/abs/2309.13237

代码链接:https://github.com/HCPLab-SYSU/STKET

1. 概述

随着场景理解领域的快速发展,许多研究者们开始尝试利用各种框架解决场景图生成(Scene Graph Generation, SGG)任务,并已取得了不俗的进展。但是,这些方法往往只考虑单张图像的情况,忽略了时序中存在着的大量的上下文信息,导致现有大部分场景图生成算法在无法准确地识别所给定的视频中包含的动态视觉关系。因此,许多研究者致力于开发视频场景图生成(Video Scene Graph Generation, VidSGG)算法来解决这个问题。

目前的工作主要关注从空间和时间角度聚合对象级视觉信息,以学习对应的视觉关系表示。然而,由于各类物体与交互动作的视觉外表方差大以及视频收集所导致的视觉关系显著的长尾分布,单纯的仅用视觉信息容易导致模型预测错误的视觉关系。

针对上述问题,我们做了以下两方面的工作:

  • 首先,我们提出挖掘训练样本中包含的先验时空知识用以促进视频场景图生成领域。其中,先验时空知识包括:1)空间共现相关性:某些对象类别之间的关系倾向于特定的交互。2)时间一致性/转换相关性:给定对的关系在连续视频剪辑中往往是一致的,或者很有可能转换到另一个特定关系。
  • 其次,我们提出了一种新颖的基于时空知识嵌入的 Transformer (Spatial-Temporal Knowledge-Embedded Transformer, STKET) 框架。该框架将先验时空知识纳入多头交叉注意机制中,从而学习更多有代表性的视觉关系表示。根据在测试基准上得到的比较结果,我们发现我们所提出的 STKET 框架优于以前的最先进方法。

图 1. 由于视觉外表多变和视觉关系的长尾分布,导致视频场景图生成充满挑战

2. 基于时空知识嵌入的 Transformer

2.1. 时空知识表示

在推断视觉关系时,人类不仅利用视觉线索,还利用积累的先验知识 [1, 2]。受此启发,我们提出直接从训练集中提取先验时空知识,以促进视频场景图生成任务。其中,空间共现相关性具体表现为当给定物体组合后其视觉关系分布将高度倾斜(例如,“人”与“杯子”之间的视觉关系的分布明显不同于“狗”与“玩具”之间的分布)和时间转移相关性具体表现为当给定前一时刻的视觉关系后各个视觉关系的转换概率将大幅变化(例如,当已知前一时刻的视觉关系为“吃”时,下一时刻视觉关系转移为“书写”的概率大幅下降)。如图 2 所示,我们可以直观地感受到给定物体组合或之前的视觉关系后,预测空间可以被大幅的缩减。       

图 2. 视觉关系的空间共现概率 [3] 与时间转移概率

图 3. 学习空间 (a) 和时间 (b) 知识表示的过程

2.2. 知识嵌入注意力层

空间知识通常包含有关实体之间的位置、距离和关系的信息。另一方面,时间知识涉及动作之间的顺序、持续时间和间隔。鉴于它们独特的属性,单独处理它们可以允许专门的建模更准确地捕获固有模式。因此,我们设计了时空知识嵌入层,彻底探索视觉表示与时空知识之间的相互作用。

图 4. 空间 (左侧) 和时间 (右侧) 知识嵌入层

2.3. 时空聚合模块

如前所述,空间知识嵌入层探索每个图像内的空间共现相关性,时间知识嵌入层探索不同图像之间的时间转移相关性,以此充分探索了视觉表示和时空知识之间的相互作用。尽管如此,这两层忽略了长时序的上下文信息,而这对于识别大部分动态变化的视觉关系具有帮助。为此,我们进一步设计了时空聚合(STA)模块来聚合每个对象对的这些表示,以预测最终的语义标签及其关系。它将不同帧中相同主客体对的空间和时间嵌入关系表示作为输入。具体来说,我们将同一对象对的这些表示连接起来以生成上下文表示。然后,为了在不同帧中找到相同的主客体对,我们采用预测的对象标签和 IoU(即并集交集)来匹配帧中检测到的相同主客体对 。最后,考虑到帧中的关系在不同批次中有不同的表示,我们选择滑动窗口中最早出现的表示。

3. 实验结果

为了全面评估所提出的框架的性能,我们除了对比现有的视频场景图生成方法(STTran, TPI, APT)外,我们也选取了先进的图像场景图生成方法(KERN, VCTREE, ReIDN, GPS-Net)进行比较。其中,为确保对比的公平,图像场景图生成方法通过对每一帧图像进行识别,从而达到对所给定视频生成对应场景图的目标。

#SeTformer

这里提出了SeTformer,一种新的transformer,其中DPSA完全被Self-optimal Transport (SeT)取代,以实现更好的性能和计算效率。在小型和基准尺寸模型下,SeTformer在ImageNet-1K上实现了令人印象深刻的84.7%和86.2%的top-1准确率。 

论文链接:https://arxiv.org/pdf/2401.03540.pdf

Transformer(变压器)最初是用于自然语言处理(NLP)的技术,在视觉领域得到了显著的流行,这要归功于Vision Transformer(ViT)的开创性工作它的优势已经在各种视觉任务中得到了证明,包括图像分类、目标检测、分割等。对于捕获长距离依赖关系,点积自注意力(DPSA)与softmax归一化在transformer中起着至关重要的作用。然而,该模型的计算导致了二次时间和内存复杂度,使得训练长序列模型变得困难。

简介

本文提出了SeTformer,一种新的transformer,其中DPSA完全被Self-optimal Transport (SeT)取代,以实现更好的性能和计算效率。SeT基于两个基本的softmax属性:保持非负的注意力矩阵和使用非线性的重新加权机制来强调输入序列中重要的标记。通过引入一个用于最优传输的核成本函数,SeTformer有效地满足了这些属性。特别是,在小型和基准尺寸模型下,SeTformer在ImageNet-1K上实现了令人印象深刻的84.7%和86.2%的top-1准确率。在目标检测方面, SeTformer-base相比FocalNet同类产品超出2.2 mAP, 使用的参数和浮点运算数分别减少了38%和29%。在语义分割方面, 我们的基准模型相比NAT超出了3.5 mIoU,并且参数减少了33%。SeTformer在GLUE基准测试中也取得了最先进的语言建模结果。这些发现凸显了SeTformer在视觉和语言任务中的适用性。

方法与模型

我们的目标是开发一种强大而高效的自注意力模型,尤其注重简单性。我们不添加任何复杂模块,如卷积、平移窗口或注意力偏置,以提高视觉任务的性能。事实上,我们采用了不同的策略。SeT利用了softmax的重要性质,包括非负性和重新加权机制,同时在设计中也注重了效率。使用具有正定(PD)核的RKHS避免了聚合负相关信息。SeT通过OT引入了非线性的重新加权方案。这涉及在RKHS中计算输入和参考集之间的对齐得分。这个过程引入了对齐得分的非线性,给元素分配权重以突出它们的重要性。这有助于模型捕捉复杂关系并强调局部相关性。

SeTformer 架构首先是一个下采样的卷积层,然后是包含多个 SeT 块的四个序列阶段。连续的阶段通过降采样层相连,降低空间尺寸同时加倍深度。在右边,我们展示了我们的注意力计算:将 x 和 y 元素映射到RKHS,然后通过 x 和 y 之间的 OT 计算聚合 x 特征,如果它们与相应的参考对齐良好。

我们使用Swin作为我们的基线模型,用我们的SeT模块替换其自注意力。我们的模型由四个阶段组成,每个阶段具有不同的空间分辨率,结果是输入图像的1/4大小。输入使用两层3×3卷积和2×2步幅进行嵌入。在每个阶段之后,除了最后一层外,都有一个通过3×3卷积和2×2步幅进行下采样的模块。这与Swin不同,Swin使用的是非重叠的2×2卷积。

1 Representing local image neighborhoods in an RKHS

为 了 保 持 线 性 计 算, 我 们 将 输 入 特 征 向 量 嵌 入到 一 个RKHS中, 其 中 点 评 估 采 用 线 性 函 数 的 形式。核方法使我们能够通过一个正定核函数K,将数据从其原始空间X映射到一个高维希尔伯特空间(特征空间)F中。对于函数u:X → F(特征映射),正定核函数表示为K(x, x′) = 〈u(x), u(x′)〉F。鉴于u(x)可以是无穷维的,核技术允 许 从Rk中 导 出 一 个 有 限 维 度 的 表 示v(x),其 中 内 积〈v(xi), v(x′j)〉表 示K(x, x′)。正 如所示,如果K是正定的,对于任意的x和x′,我们有K(x, x′) ≥ 0,这与softmax算子的非负性质一致。

2 Optimal transport (OT)

我们模型中的一个基本作用是通过学习它们之间的映射将相关令牌进行聚合。我们的加权聚合依赖于被视为不同测度或加权点云的元素x和x′之间的输运计划。OT在对齐问题中得到了广泛应用,并且具有捕捉数据几何形状的出色能力。在本文中,我们专注于Kantorovich形式的OT ´ ,其中使用熵正则化来平滑输运计划

3 Self-optimal Transport (SeT)

对 于 一 个 输 入 特 征 向 量x和 一 个 位 于X中 的 参考ym,我们进行以下步骤:(i)将特征向量x和y表示为RKHSF中 的 元 素, (ii)使 用OT将x的 元 素 与y对 齐,(iii)对x的元素进行加权聚合,得到一个对齐矩阵A。我们使用参考y来实现高效的元素聚合。参考集合中的每个元素都作为一个”对齐单元”,输入特征通过加权求和在这些单元中进行聚合。这些权重指示了输入和参考之间的对应关系,通过OT计算得出。假设我们有一个输入特征向量x = {x1, . . . , xn},其中x属于X ∈ Rd,是从输入图像中随机提取的。在Nystrom¨ 近似方法的背景下,y的样本是通过对训练集X中的特征向量进行K-means聚类来获得的质心,从而我们得到y = {y1, . . . , ym},其中m ≤ n。使用参考集合有助于优化计算过程,并使模型能够有效地处理更长的输入序列。设k是一个正定的核函数,如定义在RKHS上的高斯核函数,以及映射u : Rd → F。我们创建一个大小为n × m的矩阵k,用于存储比较k(xi, yj )的结果。接下来,我们根据公式(2)计算x和y之间的传输计划,得到大小为n × m的矩阵T(x, y)。传输计划找到将输入特征与参考元素对齐的最佳方法,同时最小化对齐成本。

4 Projecting onto a linear subspace

当处理有限维度的u(x)时, Ay(x)可 以 直 接 计 算, 而 不 会 引 起 重 大的计算开销。对于无限维或高维的u(x),Nystrom¨ 算法提 供了 一 种 有 效 的 近 似 方 法 来 嵌 入væRd → Rk。Nystrom¨ 算 法 通 过 对 列 和 行 进 行 采 样, 并 将 输 入从 特 征 空 间F投 影 到 线 性 子 空 间F1上 来 近 似 计 算传 输 计 划, 从 而 得 到 嵌 入〈v(xi), v(x′j)〉F1。子 空间F1由k个中心u(z1), . . . , u(zk)张成。显式公式v(xi) =k(z, z)−1/2k(z, xi)表示将z = z1, . . . , zk作为中心来进行新的嵌入。这种高效的方法只需要执行K-means聚类并计算逆平方根矩阵。

5 Linear positional encoding

为了将位置信息融入我们的模型中,我们采用了的方法,在输入集和参考集之间的相似性上应用了指数惩罚,基于它们的位置距离。这涉及到对T(v(x), y)与一个距离矩阵M进行乘法运算,其中Mij = e(− 1τ2)(α−β)2,其中α = i/n,β = j/m,τ表示平滑参数。我们考虑了内容和位置信息的相似性权重与其他位置编码方法相比取得了优秀的性能。

实验与结果

我 们 在 图 像 和 语 言 领 域 进 行 了 实 验, 包括ImageNet、COCO和ADE20K,以及GLUE,以展示我们的模型的影响。我们对超参数进行了微调,例如参考数量(m),OT中的熵正则化ϵ,以及位置嵌入中的τ。我们观察到ϵ和τ在任务之间表现稳定,但对于值m的选

224x224分辨率ImageNet-1K的分类准确率

SeTformer模型以较小的模型大小、Flops和吞吐量稳定优于ConvNeXt。我们的Mini模型的准确率超过Swin-T模型0.4%,参数量减少40%(28M → 16M),Flops减少37%。我们的Tiny模型(83.9%)在性能上超过CSWin 1.2%,并具有类似的模型大小,速度提升12%(从701/s到785/s)。与FocalNet-T模型相比,它在性能上表现更优,提高了1.6%。使用更大的模型,我们在较少的参数和较低的计算成本下实现了最先进的性能。例如,SeTformer-B模型在超过24%和36%的Flops和参数减少的情况下,将NAT-B模型(84.4%)的准确率提高了1.8%。我们还注意到,吞吐量是在V100 GPU上测量的。 

COCO数据集上Mask R-CNN目标检测结果

SeTformer在卷积神经网络(如ResNet)和Transformer骨 干 网 络 (如CSWin、 NAT、 MViTv2)方 面 表 现 优 异。例 如, SeTformer-T的APb为49.3,APm为44.0, 相 较 于NAT-T增 加 了1.6和1.4个 百 分 点,同时计算量更小, 模型尺寸更小。在扩展规模方面,SeTformer-B的APb为51.9,相比于CSWin-B的50.8,增加了1.1个百分点,同时参数减少28%,计算量减少33%。

ADE20K数据集上的语义分割结果

语义分割任务上我们的模型优于现有最先进的方法;例如,相比于CSWin的对应模型,SeTformer-T和SeTformer-S的mIoU(SS/MS)分别提高了+1.3 / +0.7和+0.7 / +0.4,同时具有更轻、更低复杂度的优势。

#Find+Replace Transforme

论文新提出了一种名为“Find+Replace Transformer”的多 Transformer 架构,并证明了通过集成多个Transformer,能够解决单一 Transformer 无法胜任的任务。 

ICLR 匿名研究:单一 Transformer 不具备图灵完备性,但多 Transformer 可以。

Transformer 自 2017 年出世以来就在 AI 领域高举高打,ChatGPT 引发全球大型语言模型热潮后更是在 NLP 领域被赋予了神话般的地位。

但近日,一篇正在审核中的 ICLR 2023 投稿论文(如下)经研究后提出一个观点:单一 Transformer 并不具备图灵完备性,其计算能力存在理论上的局限性,在圈内引起关注。

由于该论文正在审核中,作者信息没有被公开。

论文链接:https://openreview.net/pdf?id=MGWsPGogLH

与此同时,该论文新提出了一种名为“Find+Replace Transformer”的多 Transformer 架构,并证明了通过集成多个Transformer,能够解决单一 Transformer 无法胜任的任务。

这项研究直接对标并超越了当前最先进的GPT-4模型,在一系列极具挑战性的基准测试中展现了显著的优势和潜力。

1 被神化的 Transformer 局限在哪里?

图灵完备性是评判一个计算系统强大与否的关键指标。如果一个系统被确认为图灵完备,则理论上只要赋予其充足的运行时间和内存资源,即可以执行任何可计算的算法。

在实际应用中,尽管 Transformer 模型在诸多自然语言处理任务上表现卓越,但其能力受到设计上的固有限制,例如固定的上下文窗口长度和有限的词汇表大小。这意味着 Transformer 模型并不具备解决所有类型计算问题的能力,特别是那些需要无限存储空间或无限制迭代过程的问题。

在论文中,研究团队特别指出,基础的语言模型工作原理在于根据前 k 个词语的概率来预测下一个词语。在 NLP 领域,通常会构建一些专门针对固定长度输入输出序列设计的模型集合或框架,并将这类模型归入 MF_SMF 类别。

Transformer 作为 MF_SMF 这一框架下的具体实例,其图灵完备性的缺失得到了该研究团队的理论论证。他们基于以下逻辑:

首先,回顾计算理论的基础:图灵停机问题是不可判定的,意味着不存在一个通用的方法来判断任意给定程序何时终止运行,就如同无法找到一把万能钥匙预测每一场棋局结束时间一样。这一原理同样适用于评估模型是否会在执行过程中陷入无尽循环而无法自拔。

研究者进而分析了 MF_S(这里假设 MF_S 代表 MF_SMF 中的子集)集合中的模型:

  • 假设可以构建一个算法H,它可以准确判断MF_S中任意模型m是否终止。
  • 假设MF_S集合中存在一个模型m’,它足够强大以至于能够模拟任何图灵机的计算过程,包括那些永远不会停止的图灵机。
  • 根据算法H的假设能力,如果MF_S集合中的模型m’能够模拟那些不会停止的图灵机,那么算法H应该能够预测m’在模拟这些图灵机时是否会停止。
  • 然而,根据图灵的停机问题不可判定定理,我们知道实际上不可能存在这样一个算法H,因为它会与图灵的定理相矛盾。
  • 因此,MF_S集合中不可能存在能够模拟所有图灵机行为的模型m’,也就是说,MF_S中没有任何模型是图灵完备的。

Transformer便属于 MF_SMF,所以 Transformer 不具备图灵完备性。

研究人员指出,Transformer在处理自然语言任务,尤其是在机器翻译方面,有明显的优势。这类模型能够通过递归的方式输入序列并生成更新后的序列,从而逐个预测下一个符号。

但是,尽管Transformer模型能够基于之前的字符序列连续生成新的字符序列,每次接收一段输入字符后产出相应的输出字符,并利用新产生的字符序列进行迭代计算,它还是受到了上下文长度k和词汇表大小v的限制。这意味着它能够处理的不同字符组合的数量不会超过v^k种。

例如,当 Transformer 遇到重复输入时,由于它的无状态特性(这有利于并行训练多个序列),模型必须保证对同一输入产生一致的输出结果。这可能导致在某些情况下,模型陷入无限循环的模式,即只能生成有限数量的、最多为v^k种不同的输出序列,或者在自我复制的过程中无法停止。

与Transformer相比,图灵在1936年提出的图灵机概念具有无限的计算潜力,不受这些结构性的限制,能够模拟任何可计算的过程,确保不会陷入类似的有限循环困境。

2 如何超越 GPT-4?

实验结果显示,单个 Transformer 架构并不具备图灵完备性,而多 Transformer 则有能力实现图灵完备(如论文中所提出的 Find+Replace Transformer)、并执行如 GPT-4 等最先进的 Transformer 模型所无法解决的问题。

论文中创新性地将 Find Transformer 与 Replace Transformer 相结合,构建了Find+Replace Transformer体系结构——这是一个能在任意长度序列上运行的多Transformer系统,在论文中被形象地比喻为“磁带”(Tape)。

该系统由 Find Transformer、Replace Transformer 以及 Map 三部分组成,其中 Map 是一个从 Replace Transformer 到 Find Transformer 所涉及的有序集合的函数映射关系。

具体运作时,Find Transformer 会在输入序列中定位并标识出需要由 Replace Transformer 处理的部分内容。这两个组件各自具有固定的上下文长度 k,并依次对“磁带”上的每个长度为k的子序列进行分析,Find Transformer 会选择那些在最终层产生最高激活值的特定子序列。

随后,Replace Transformer 会接收 Find Transformer 标识出的子序列作为输入,并基于此生成一个新的长度为k的输出序列,这个过程利用了 Map 关联的 f∈Map(r) 规则,确保了两个 Transformer 之间的协同工作及信息传递。

那这个 Find+Replace Transformer 的多 Transformer 系统是如何可以实现图灵完备的呢?

简单来说,Find+Replace Transformer 是一个学习简化的机器。在编程语言的基石 λ 演算 中,有三条被称为“归约”(Reduction)的规则:   

  • Alpha Reduction:这是一个绑定变量的重命名。它被用来避免命名冲突。例如,在λ 演算的项 λx.x,我们可以化简成 λy.y,且不改变其意思。
  • Beta Reduction:这是将函数应用于其参数的过程。例如,在λ项(λx.x)y(表示将函数λx.x作用于参数y),我们可以化简成 y。
  • Eta Reduction:这是对函数和参数的简化。如果你有一个函数比如λx.(fx),而x不出现在f中,那么这个就可以化简为f。

Find+Replace Transformer 的多Transformer 系统之所以能够实现图灵完备性,关键在于其架构设计和训练方式允许模型通过一系列组合操作模拟类似于 λ 演算中的归约规则。尽管单个 Transformer 受限于上下文长度、词汇表大小等因素,但通过构建一个多 Transformer 协作的框架,并结合特定的学习机制,这些简单且局部的“查找与替换”操作得以在更复杂的计算任务中累积并形成强大的综合效应。

具体来说,在Find+Replace Transformer中,多个 Transformer 可能被专门设计来分别或协同地处理不同类型的简化(归约)任务,例如模拟 Alpha Reduction 进行变量重命名、模拟 Beta Reduction 执行函数应用以及模拟 Eta Reduction进行函数简化等。每个 Transformer 可能专注于理解和学习如何执行这类简单的转换操作,并将结果传递给下一个Transformer,从而逐步构建起复杂问题的解决方案。

虽然单个 Transformer 不具备图灵完备性,但当它们以特定的方式组织起来并协同工作时,可以模拟通用图灵机的逻辑行为,进而实现对任意可计算问题的解决能力。这样的体系结构让Find+Replace Transformer在处理大规模、多层次的复杂问题时展现出超越传统单一Transformer的性能表现,实现了更高阶的计算能力。

2023年当OpenAI 发布GPT-4时,微软研究院的研究人员发表了一篇题为“Sparks of Artificial General Intelligence(Bubeck et al., 2023)”的论文,阐述了早期AGI所面临的局限性。

研究者们以汉诺塔问题为例进行了说明。汉诺塔是一个经典的递归问题,要求玩家将按照大小顺序堆叠的圆盘从一根柱子移动到另一根柱子上,期间只能移动一个圆盘且任何时候大盘不能位于小盘之上,借助第三根柱子作为中转。

GPT-4无法解决这个复杂的推理问题,从而突显了当前Transformer在推理过程中缺乏规划能力。

研究者对比了几种模型在解决完整汉诺塔问题上的表现。随着问题规模增大,其难度呈指数级上升:规模为n的问题其解决方案需要2^n - 1步操作。Find+Replace Transformer在此任务上表现出色,甚至能生成比GPT-4至少长18倍的正确解决方案。

除了在汉诺塔这个GPT-4都难以解决的问题上表现优越之外,在其他AI任务,如创作满足特定条件的诗歌等,Find+Replace Transformer都能超越GPT-4,这反映了其在泛化能力上的优势。

3 结语

Find+Replace Transformer模型通过创新性地结合多个Transformer单元,并模拟λ演算中的归约规则,在处理如汉诺塔问题等复杂组合任务时展现出了超越传统单个Transformer的优越性能。

这一研究成果揭示了多Transformer系统在实现图灵完备性方面的潜力,也证明了在面对特定计算难题时,提高模型的逻辑推理和抽象表达能力的重要性。

而纵观整个人工智能技术的发展,从深度学习兴起到大模型浪潮来袭,每一次技术迭代,人们都对于新技术报以极大的热情与崇拜。

然而,无论是深度学习还是Transformer架构,亦或是如今新出现Find+Replace Transformer架构,所带给我们的启示是,在研究和应用深度学习技术时,都需要避免过分神化任何技术,应该理性地看待每一项技术,关注其优势和局限,并结合实际问题来选择和调整合适的技术。只有这样,才能不断地在通往人工通用智能(AGI)的道路上迈进。

#Soft MoE

本文提出了一种可微的稀疏混合专家 Transformer 模型 (fully-differentiable sparse Transformer) Soft MoE 来解决端到端训练困难的问题,同时也能够保持 MoE 方法的优势,即以较低的推理成本更大的模型容量。 

Soft MoE 提出了一种新的可微稀疏混合专家模型,稀疏混合专家 (Sparse Mixture of Experts, MoE) 是一种在保证模型训练和推理的成本不显著增加的情况下,大幅度提升模型容量的方法。

MoE 方法已经有很长的一段历史了,是一种扩大模型容量的经典高效的做法,但是它的缺点是:

  1. 训练不稳定
  2. Token Dropping 的问题
  3. 较难扩展 Expert 的数量
  4. 低效率的微调

造成以上问题的一个原因是 MoE 的端到端训练困难,因此,本文提出了一种可微的稀疏混合专家 Transformer 模型 (fully-differentiable sparse Transformer) Soft MoE 来解决端到端训练困难的问题,同时也能够保持 MoE 方法的优势,即以较低的推理成本更大的模型容量。Soft MoE 的特点是给每个专家输入不同 token 的权重混合。

视觉实验结果证明,Soft MoE 大大优于标准 ViT 和流行的 MoE 方法,比如 128 个 Expert,16 个 MoE 层的 Soft MoE-Huge/14 模型参数比 ViT-Huge/14 多 40 倍,但推理时间成本仅增长 2%,同时性能要好得多。

1 Soft MoE:一种完全可微的稀疏 Transformer

论文名称: From Sparse to Soft Mixtures of Experts

论文地址:​​ https://arxiv.org/pdf/2308.00951.pdf​​

  • 1 Soft MoE 论文解读:

1.1 背景:把离散优化问题变为可微的优化问题

稀疏混合专家 (Sparse Mixture of Experts, MoE) 是一种在保证模型训练和推理的成本不显著增加的情况下,大幅度提升模型容量的方法。在视觉,语言和多模态任务中都取得了成功,代表像视觉的 V-MoE[1],文本的 Switch Transformer[2]和多模态的 LIMoE[3]。

如下图1左所示,稀疏 MoE Transformer 的核心是一个离散优化问题,即:模型需要决定每个输入 token 应该输入哪些 Expert 里面,这些 Expert 一般是 MLP 模块。输入 token 和 Expert 之间的匹配 (token-to-expert match) 是 MoE 中要考虑的很重要的问题之一,之前也有各种各样的方法尝试解决此问题,比如基于线性规划的[4],比如基于 RL 算法的[5],比如基于固定规则的[6],比如基于最优传输理论的[7],和基于贪婪匹配的[8]。总之,解决好稀疏 MoE 的这个离散优化问题的确是件不容易的事情。稀疏 MoE 的缺点有:

  1. 训练不稳定
  2. Token Dropping 的问题
  3. 较难扩展 Expert 的数量
  4. 低效率的微调

图1:Sparse MoE 和 Soft MoE 的区别:左:Sparse MoE,给每个 Expert 分配一定的输入 token。右:Soft MoE,给每个 Expert 分配的是所有输入 token 的加权平均值

如下图1右所示,Soft MoE 把稀疏 MoE Transformer 的这个离散优化问题变成了可微的优化问题。Soft MoE 觉得不必一定要 "hard" 地找到输入 token 和 Expert 之间的一一匹配,而是可以 "Soft" 地混合输入 token 并且分给每一个 Expert。Soft MoE 给每个 Expert 分配的不是某几个输入 token,而是所有输入 token 的加权平均值 (权重取决于 token 和 Expert),然后由这个对应的 Expert 去处理这个加权平均值。

1.2 变为可微的优化问题之后,解决了之前稀疏 MoE 的什么问题?

问题1: 精心设计的 Expert-to-token 的路由机制通常并不比随机固定路由好。

Soft MoE 可以避免这个问题,因为每个路由的参数都是基于每个输入 token 直接更新的。

问题2: 训练不稳定 (LIMoE[3]这个工作观察到在训练期间,可能有大部分 token 改变路由,给训练带来一定挑战) 导致很多稀疏 MoE 方法的 Expert 都不可以设置得很多。

Soft MoE 可以避免这个问题,扩展到数千个 Expert。

1.3 Soft MoE 算法描述

参数配置:

整个过程如下图2所示。

图2:Soft MoE 算法流程图

遵循稀疏 MoE 的常用设计思想,作者用 Soft MoE 块替换了 Transformer 的一部分 MoE 块。slot 的总数是 Soft MoE 的关键超参数,因为时间复杂度取决于 slot 的数量,而不是 Expert 的数量。比如,可以设置等于输入序列长度的 slot 数以匹配等效密集 Transformer 的 FLOP。

Soft MoE 的 JAX 代码:

def soft_moe_layer(X, Phi, experts):
    # Compute the dispatch and combine weights.
    logits = jnp.einsum('md,dnp->mnp', X, Phi)
    D = jax.nn.softmax(logits, axis=(0,))
    C = jax.nn.softmax(logits, axis=(1, 2))
    # The input slots are a weighted average of all the input tokens,
    # given by the dispatch weights.
    Xs = jnp.einsum('md,mnp->npd', X, D)
    # Apply the corresponding expert function to each input slot.
    Ys = jnp.stack([
    f_i(Xs[i, :, :]) for i, f_i in enumerate(experts)],
    axis=0)
    # The output tokens are a weighted average of all the output slots,
    # given by the combine weights.
    Y = jnp.einsum('npd,mnp->md', Ys, C)
    return Y

全部代码:

​https://github.com/google-research/vmoegithub.com/google-research/vmoe​

1.4 Soft MoE 的一些关键性质

1) 完全可微:

Sparse MoE 算法的通病是 token 和 Expert 之间存在的分配问题,有时精心设计的 Expert-to-token 的路由机制通常并不比随机固定路由好。输入 token 和 Expert 之间的匹配 (token-to-expert match) 是 MoE 中要考虑的很重要的问题之一,之前也有各种各样的方法尝试解决此问题,比如基于线性规划的[4],比如基于 RL 算法的[5],比如基于固定规则的[6],比如基于最优传输理论的[7],和基于贪婪匹配的[8][9]。所有这些方法本质上都是离散,不可微的。

Soft MoE 可以避免这个问题,因为每个路由的参数都是基于每个输入 token 直接更新的。

2) 可以避免掉 Token Dropping 和 Expert Unbalance 的问题

MoE 算法里面每个 Expert 都会处理一些 token,很自然地就会带来 Token Dropping (有的 token 不会分配给任何一个 Expert) 和 Expert Unbalance (一些 Expert 会比另一些 Expert 分配到更多 token) 的问题。

Soft MoE 可以避免这个问题,因为每个 slot 的输入都是所有 token 的加权平均值。

3) 运算速度快

Soft MoE 的主要优点是完全避免了之前算法中的 token 排序或 top-k 操作,因为这些操作的速度慢,而且不太适合硬件加速器。因此,Soft MoE 明显快于大多数 Sparse MoE 算法。

4) Soft MoE 算法是密集的 MoE 算法还是稀疏的 MoE 算法?

要回答这个问题我们需要首先搞明白为什么 Sparse MoE 算法是稀疏的。Sparse MoE 是稀疏的这件事的根本原因是每个 Expert 的输入特征仅仅是一部分的 token,而 Soft MoE 的输入是所有输入 token 的加权平均值,因此不能算作是稀疏的。

Soft MoE 也不能算作是 Dense MoE 算法,因为每个 Expert 仅仅会处理输入 token 的子集。

5) Soft MoE 算法需要归一化

Transformers 中,MoE 层通常用于替换每个编码器块中的 FFN 层,因此如果去遵循大部分 Transformer 架构的 Pre-Normalization 方法,就需要使用归一化,这里 Soft MoE 针对  的操作是:

l2_normalize(X, axis=1)

scale * l2_normalize(Phi, axis=0)

其中,scale 是可学习的参数,l2_normalize 的定义是:

def l2_normalize(x, axis, eps=1e-6):
    norm = jnp.sqrt(jnp.square(x).sum(axis=axis, keepdims=True))
    return x * jnp.reciprocal(norm + eps)

6) Soft MoE 算法和注意力机制 (Multi-Head Self-Attention) 的区别和联系?

1.5 Soft MoE 算法的局限性

  • 自回归解码 (Auto-regressive decoding):

因为 Soft MoE 算法要在运行过程中合并所有的输入 token,因此很难实现自回归。因为自回归必须在训练期间保留过去的 token 和未来 token 之间的因果关系 (Causality)。

Self-Attention 解决这个问题的手段是依赖于注意力的掩码 (Mask) 机制。如果想在 Soft MoE 中实现这一点就需要特别小心 token 之间的依赖和相关关系。总之研究 Soft MoE 算法的自回归解码是个很有价值的方向。

  • 内存消耗

Soft MoE 倾向于利用大量 Expert,而其成本和 Dense Backbone 类似,使得模型的内存需求可能变大。

1.6 图像分类实验结果

训练数据集

预训练数据集: JFT-4B:一个私有数据集,其最新版本包含超过 4B 张图像,涵盖超过 29k 个类。预训练的过程中评价指标是 JFT-4B 上的上游验证精度 Precision-at-1 和 ImageNet 10-shot 精度 (冻结模型权重,并用一个新的权重来计算的,该数据集仅在包含来自 ImageNet-1K 的每个类包含 10 张图像的数据集上进行训练)。

微调数据集: ImageNet-1K 训练集。

验证集: ImageNet-1K 验证集。

模型尺寸:

ViT-S/8, ViT-S/16, ViT-S/32, ViT-B/16, ViT-B/32, ViT-L/16, ViT-L/32, ViT-H/14。

方法:

Token Choice, Expert Choice 和本文的 Soft MoE。

训练策略:

300k steps, Batch Size 4096

Pareto Model 实验结果:

如下图3所示是四种方法 Soft MoE, Experts Choice, Tokens Choice, Dense 在预训练过程中的 JFT-4B Precision-at-1 的结果和 ImageNet 10-shot 的精度的训练成本/性能帕累托边界。Soft MoE 算法在这两个指标上都优于之前的方法。

图3:四种方法在预训练过程中的 JFT-4B Precision-at-1 的结果和 ImageNet 10-shot 的精度的训练成本/性能帕累托边界

更长的训练结果:

本文还测试在更长的训练 step 下模型的性能如何,把从 Small 到 Huge 的模型训练了 4K steps,用 128 个 Expert 的 Soft MoE 替换 ViT S/16、B/16、L/16 和 H/14 中的最后一半 Block 中的 FFN,每个 Expert 使用一个 slot。

由于模型并行性所需的额外数据传输,Large Soft MoE 模型产生的 wall-clock time overhead 很小。所有变体都训练了 4M 步,除了 H/14,出于成本原因训练了 2M 步,实验结果如下图4和5所示。

如下图4所示是 Soft MoE 和 ViT 的 JFT-4B 精度、ImageNet 10-shot 精度和 ImageNet 微调精度与 ExaFLOPS 的训练成本。

图4:不同模型更长的训练 step 下的 JFT-4B 精度

如下图5所示是所有结果。对于给定的计算预算,Soft MoE 模型大大优于 Vision Transformer 模型。比如 Soft MoE-S/16 在 JFT-4B 和 ImageNet 10-shot 上的表现优于 ViT-B/16,它还提高了完整 ImageNet 数据的微调分数,即使它的训练 (和推理) 成本要小得多。同样,Soft MoE-B/16 在上游任务 JFT-4B 和 ImageNet 10 shot 的表现优于 ViT-L/16,微调后仅落后 0.5,同时速度快 3 倍,所需的 FLOP 减少了近 4 倍。最后,Soft MoE-L/16 模型优于 Dense H/14 模型,同时在推理速度又快 3 倍左右。

图5:不同模型更长的训练 step 下的实验结果

根据前面的实验结果,较小的 Soft MoE 的性能可以匹配较大的视觉 Transformer,作者因此继续训练小模型 Backbone,希望以非常低的训练成本获得更高质量的模型。

作者观察到对于 Soft MoE 方法而言,较长的 cooldown (学习率线性减小到零的时期) 可以很好地适用于 Soft MoE,因此将 cooldown 从 50k steps 增加到 500k steps。

实验结果如下图6和7所示。Soft MoE-B/16 训练了 1k TPUv3 Days,优于在相似时间预算上训练的 ViT-H/14,而 Soft MoE-B 模型的 FLOPs 要低 10 倍,wall-clock time 低 5.7 倍。即使将 ViT-H/14 的训练代价加倍,Soft MoE-B 模型的性能也可以与之相匹配。Soft MoE-L/16 模型的在推断上比 ViT H/14 快近 2 倍的同时性能大大优于所有模型。

图6:不同训练代价和尺寸的 Soft MoE 模型和 ViT 的 JFT-4B Precision-at-1 性能和 ImageNet 10-shot 性能       

图7:Soft MoE 模型和 ViT 的实验结果

视觉-文本对比学习实验结果

作者还验证了 Soft MoE 得到的模型在其他任务的性能。具体而言作者探索了一种流行的范式,即图像语言对比学习,这里遵循的是 LiT[10] 方法,其中图像塔在图像分类任务上进行了预训练,然后在在图像-文本对数据集上训练文本编码器时冻结。

视觉编码器作者重用了在 JFT 上训练的模型,对比学习在 WebLI 上训练,这是一个专有数据集,由 10B 图像和从互联网上抓取的 ALT 文本组成。图像编码器被冻结,而文本编码器从头开始训练。实验结果如下图8所示,Soft MoE -L/16 在 Imagenet 和 Cifar-100 零样本上的性能分别比 ViT-L/16 高出 1% 和 2% 以上。

图8:对比学习实验结果

#Transformers18~ Diffusion

还是Transformers,来自 UC 伯克利的 William Peebles 以及纽约大学的谢赛宁撰文揭秘扩散模型中架构选择的意义,并为未来的生成模型研究提供经验基线。

近几年,在 Transformer 的推动下,机器学习正在经历复兴。过去五年中,用于自然语言处理、计算机视觉以及其他领域的神经架构在很大程度上已被 transformer 所占据。

不过还有许多图像级生成模型仍然不受这一趋势的影响,例如过去一年扩散模型在图像生成方面取得了惊人的成果,几乎所有这些模型都使用卷积 U-Net 作为主干。这有点令人惊讶!在过去的几年中,深度学习的大事件一直是跨领域的 Transformer 的主导地位。U-Net 或卷积是否有什么特别之处使它们在扩散模型中表现得如此出色?

将 U-Net 主干网络首次引入扩散模型的研究可追溯到 Ho 等人,这种设计模式继承了自回归生成模型 PixelCNN++,只是稍微进行了一些改动。而 PixelCNN++ 由卷积层组成,其包含许多的 ResNet 块。其与标准的 U-Net 相比,PixelCNN++ 附加的空间自注意力块成为 transformer 中的基本组件。不同于其他人的研究,Dhariwal 和 Nichol 等人消除了 U-Net 的几种架构选择,例如使用自适应归一化层为卷积层注入条件信息和通道计数。

本文中来自 UC 伯克利的 William Peebles 以及纽约大学的谢赛宁撰文《 Scalable Diffusion Models with Transformers 》,目标是揭开扩散模型中架构选择的意义,并为未来的生成模型研究提供经验基线。该研究表明,U-Net 归纳偏置对扩散模型的性能不是至关重要的,并且可以很容易地用标准设计(如 transformer)取代。

这一发现表明,扩散模型可以从架构统一趋势中受益,例如,扩散模型可以继承其他领域的最佳实践和训练方法,保留这些模型的可扩展性、鲁棒性和效率等有利特性。标准化架构也将为跨领域研究开辟新的可能性。

  • 论文地址:https://arxiv.org/pdf/2212.09748.pdf
  • 项目地址:https://github.com/facebookresearch/DiT
  • 论文主页:https://www.wpeebles.com/DiT

该研究专注于一类新的基于 Transformer 的扩散模型:Diffusion Transformers(简称 DiTs)。DiTs 遵循 Vision Transformers (ViTs) 的最佳实践,有一些小但重要的调整。DiT 已被证明比传统的卷积网络(例如 ResNet )具有更有效地扩展性。

具体而言,本文研究了 Transformer 在网络复杂度与样本质量方面的扩展行为。研究表明,通过在潜在扩散模型 (LDM) 框架下构建 DiT 设计空间并对其进行基准测试,其中扩散模型在 VAE 的潜在空间内进行训练,可以成功地用 transformer 替换 U-Net 主干。本文进一步表明 DiT 是扩散模型的可扩展架构:网络复杂性(由 Gflops 测量)与样本质量(由 FID 测量)之间存在很强的相关性。通过简单地扩展 DiT 并训练具有高容量主干(118.6 Gflops)的 LDM,可以在类条件 256 × 256 ImageNet 生成基准上实现 2.27 FID 的最新结果。

Diffusion Transformers

DiTs 是一种用于扩散模型的新架构,目标是尽可能忠实于标准 transformer 架构,以保留其可扩展性。DiT 保留了 ViT 的许多最佳实践,图 3 显示了完整 DiT 体系架构。

DiT 的输入为空间表示 z(对于 256 × 256 × 3 图像,z 的形状为 32 × 32 × 4)。DiT 的第一层是 patchify,该层通过将每个 patch 线性嵌入到输入中,以此将空间输入转换为一个 T token 序列。patchify 之后,本文将标准的基于 ViT 频率的位置嵌入应用于所有输入 token。

patchify 创建的 token T 的数量由 patch 大小超参数 p 决定。如图 4 所示,将 p 减半将使 T 翻四倍,因此至少能使 transformer Gflops 翻四倍。本文将 p = 2,4,8 添加到 DiT 设计空间。

DiT 块设计:在 patchify 之后,输入 token 由一系列 transformer 块处理。除了噪声图像输入之外,扩散模型有时还会处理额外的条件信息,例如噪声时间步长 t、类标签 c、自然语言等。本文探索了四种以不同方式处理条件输入的 transformer 块变体。这些设计对标准 ViT 块设计进行了微小但重要的修改。所有模块的设计如图 3 所示。

本文尝试了四种因模型深度和宽度而异的配置:DiT-S、DiT-B、DiT-L 和 DiT-XL。这些模型配置范围从 33M 到 675M 参数,Gflops 从 0.4 到 119 。

实验

研究者训练了四个最高 Gflop 的 DiT-XL/2 模型,每个模型使用不同的 block 设计 ——in-context(119.4Gflops)、cross-attention(137.6Gflops)、adaptive layer norm(adaLN,118.6Gflops)或 adaLN-zero(118.6Gflops)。然后在训练过程中测量 FID,图 5 为结果。

扩展模型大小和 patch 大小。图 2(左)给出了每个模型的 Gflops 和它们在 400K 训练迭代时的 FID 概况。可以发现,增加模型大小和减少 patch 大小会对扩散模型产生相当大的改进。

 图 6(顶部)展示了 FID 是如何随着模型大小的增加和 patch 大小保持不变而变化的。在四种设置中,通过使 Transformer 更深、更宽,训练的所有阶段都获得了 FID 的明显提升。同样,图 6(底部)展示了 patch 大小减少和模型大小保持不变时的 FID。研究者再次观察到,在整个训练过程中,通过简单地扩大 DiT 处理的 token 数量,并保持参数的大致固定,FID 会得到相当大的改善。 

 图 8 中展示了 FID-50K 在 400K 训练步数下与模型 Gflops 的对比: 

SOTA 扩散模型 256×256 ImageNet。在对扩展分析之后,研究者继续训练最高 Gflop 模型 DiT-XL/2,步数为 7M。图 1 展示了该模型的样本,并与类别条件生成 SOTA 模型进行比较,表 2 中展示了结果。

当使用无分类器指导时,DiT-XL/2 优于之前所有的扩散模型,将之前由 LDM 实现的 3.60 的最佳 FID-50K 降至 2.27。如图 2(右)所示,相对于 LDM-4(103.6 Gflops)这样的潜在空间 U-Net 模型来说,DiT-XL/2(118.6 Gflops)计算效率高得多,也比 ADM(1120 Gflops)或 ADM-U(742 Gflops)这样的像素空间 U-Net 模型效率高很多。

表 3 展示了与 SOTA 方法的比较。XL/2 在这一分辨率下再次胜过之前的所有扩散模型,将 ADM 之前取得的 3.85 的最佳 FID 提高到 3.04。

#Diffusion Transformers (DiTs)

本文探索了一类新的基于 Transformer 的扩散模型 Diffusion Transformers (DiTs)。本文训练 latent diffusion models 时,使用 Transformer 架构替换常用的 UNet 架构,且 Transformer 作用于 latent patches 上。

本文探索了一类新的基于 Transformer 的扩散模型 Diffusion Transformers (DiTs)。本文训练 latent diffusion models 时,使用 Transformer 架构替换常用的 UNet 架构,且 Transformer 作用于 latent patches 上。

作者探索了 DiT 的缩放性,发现具有较高 GFLOPs 的 DiT 模型,通过增加 Transformer 宽度或者深度或者输入 token 数量,始终有更好的 FID 值。最大的 DiT-XL/2 模型在 ImageNet 512×512 和 256×256 的测试中优于所有先前的扩散模型,实现了 2.27 的 FID 值。

做了什么工作

  1. 探索了一类新的基于 Transformer 的 Diffusion Model,称为 Diffusion Transformers (DiTs)。
  2. 研究了 DiT 对于模型复杂度 (GFLOPs) 和样本质量 (FID) 的缩放性。
  3. 证明了通过使用 Latent Diffusion Models (LDMs)[1]框架,Diffusion Model 中的 U-Net 架构可以被 Transformer 替换。

1 DiT:Transformer 构建扩散模型

论文名称:Scalable Diffusion Models with Transformers (ICCV 2023, Oral)

论文地址:https//arxiv.org/pdf/2212.09748.pdf

论文主页:https//www.wpeebles.com/DiT.html

1 DiT 论文解读:

1.1 把 Transformer 引入 Diffusion Models

机器学习正经历着 Transformer 架构带来的复兴:NLP,CV 等许多领域正在被 Transformer 模型覆盖。尽管 Transformer 在 Autoregressive Model 中得到广泛应用[2][3][4][5],但是这种架构在生成式模型中较少采用。比如,作为图像领域生成模型的经典方法,Diffusion Models[6][7]却一直使用基于卷积的 U-Net 架构作为骨干网络。

Diffusion Models 的开创性工作 DDPM [8]首次引入了基于 U-Net 骨干网络的扩散模型。U-Net 继承自 PixelCNN++[9][10],变化很少。与标准 U-Net[11]相比,额外的空间 Self-Attention 块 (Transformer 中必不可少的组件) 以较低分辨率穿插。[12]这个工作探索了 U-Net 的几种架构选择,例如自适应归一化层 (Adaptive Normalization Layer[13]为卷积层注入条件信息和通道计数。然而,DDPM 里面 U-Net 的高级设计在很大程度上都保持不变。

本文的目的是探索 Diffusion Models 架构选择的重要性,并为未来生成式模型的研究提供基线。本文的结论表明 U-Net 架构设计对 Diffusion Models 的性能并不重要,并且它们可以很容易地替换为 Transformers。

本文证明了 Diffusion Models 也可以受益于 Transformer 架构,受益于其训练方案,受益于其可扩展性,受益于其鲁棒性和效率等等。标准化架构还将为跨域研究开辟了新的可能性。

1.2 Diffusion Models 简介

DDPM

高斯扩散模型假设有一个前向的加噪过程 (Forward Noising Process),在这个过程中逐渐将噪声应用于真实数据:

这个优化的目标函数比较复杂,最后通过 variational lower bound 方法得到的结论是优化下式 (此处详细推导可以参考开创性工作 DDPM[8]):

1.3 DiT 架构介绍

1.3.1 Patchify 过程

图1:图片的 Patchify 操作。当 Patch 的大小 p 越小时,token 的数量 T 越大

1.3.2 DiT Block 设计

在 Patchify 之后,输入的 tokens 开始进入一系列 Transformer Block 中。除了噪声图像输入之外,Diffusion Model 有时会处理额外的条件信息,比如噪声时间步长 ttt , 类标签 ccc , 自然语言。

作者探索了4种不同类型的 Transformer Block,以不同的方式处理条件输入。这些设计都对标准 ViT Block 进行了微小的修改,所有 Block 的设计如下图2所示。

图2:Diffusion Transformer (DiT) 架构

  • In-Context Conditioning

作者将以上几种方法 In-Context Conditioning,Cross-Attention Block,Adaptive Layer Norm (adaLN) Block,adaLN-Zero Block 的做法列入 DiT 的设计空间中。

1.3.3 模型尺寸

图3:DiT 模型的详细配置

作者将以上几种配置列入了 DiT 的设计空间中。

1.3.4 Transformer Decoder

在最后一个 DiT Block 之后,需要将 image tokens 的序列解码为输出噪声以及对角的协方差矩阵的预测结果。

最终,完整 DiT 的设计空间是 Patch Size、DiT Block 的架构和模型大小。

1.4 DiT 训练策略

1.4.1 训练配方

作者在 ImageNet 数据集上训练了 class-conditional latent DiT 模型,标准的实验设置。

数据增强技术只使用 horizontal flips。

作者发现 learning rate warmup 和 regularization,对训练 DiT 模型而言不是必须的。

作者使用了 exponential moving average (EMA),参数为 0.9999 。

训练超参数基本都来自 ADM,不调学习率, decay/warm-up schedules, Adam 参数以及 weight decay.

1.4.2 扩散模型配置

作者保留了 ADM 中使用的超参数。   

1.5.1 DiT 架构设计

作者首先探索的是不同 Conditioning 策略的对比。对于一个 DiT-XL/2 模型,其计算复杂度分别是:in-context (119.4 Gflops), cross-attention (137.6 Gflops), adaptive layer norm (adaLN, 118.6 Gflops), adaLN-zero (118.6 Gflops)。实验结果如下图4所示。

adaLN-Zero 的 Block 架构设计取得了最低的 FID 结果,同时在计算量上也是最高效的。在 400K 训练迭代中,adaLN-Zero Block 架构得到的 FID 几乎是 In-Context 的一半,表明 Condition 策略会严重影响模型的质量。

初始化同样也重要:adaLN-Zero Block 架构在初始化时相当于恒等映射,其性能也大大优于 adaLN Block 架构。

因此,在后续实验中,DiT 将一直使用 adaLN-Zero Block 架构。

图4:不同 Conditioning 策略对比

1.5.2 缩放模型尺寸和 Patch Size

作者训练了12个 DiT 模型 (尺寸为 S, B, L, XL,Patch Size 为 8,4,2)。下图是不同 DiT 模型的尺寸和 FID-50K 性能。如下图5所示是不同大小 DiT 模型的 GFLOPs 以及在 400K 训练迭代中的 FID 值。可以发现,在增加模型大小或者减小 Batch Size 时可以显著改善 DiT 的性能。

图5:不同尺寸 DiT 模型的 GFLOPs 以及它们在 400K 训练迭代中的 FID

下图6上方是 Patch Size 不变,增加模型规模时 FID 的变化。当模型变深变宽时,FID 会下降。

下方是模型规模不变,减小 Patch Size 时 FID 的变化。当 Patch Size 下降时,FID 出现显著改善。

图5:缩放 DiT 模型可以改善训练各个阶段的 FID

1.5.3 GFLOPs 对性能很重要

上图5的结果表明,参数量并不能唯一确定 DiT 模型的质量。当 Patch Size 减小时,参数量仅仅是略有下降,只有 GFLOPs 明显增加。这些结果都表明了缩放模型的 GFLOPs 才是性能提升的关键。为了印证这一点,作者在下图6中绘制了不同 GFLOPs 模型在 400K 训练步骤时候的 FID-50K 结果。这些结果表明,当不同 DiT 模型的总 GFLOPs 相似时,它们的 FID 值也相似,比如 DiT-S/2 和 DiT-B/4。

作者还发现 DiT 模型的 GFLOPs 和 FID-50K 之间存在很强的负相关关系。

图6:GFLOPs 与 FID 密切相关

1.5.4 大模型更加计算高效

图7:大模型更加计算高效

1.5.5 缩放结果可视化

图8:缩放对于视觉质量的影响

1.6 DiT 实验结果

作者将 DiT 与最先进的生成模型进行了比较,结果如图9所示。DiT-XL/2 优于所有先前的扩散模型,将 LDM 实现的先前最佳 FID-50K 降低到 2.27。图5右侧显示 DiT-XL/2 (118.6 GFLOPs) 相对于 LDM-4 (103.6 GFLOPs) 等Latent Space U-Net 模型的计算效率很高,并且比 Pixel Space U-Net 模型更高效,例如 ADM (1120 GFLOPs) 或 ADM-U (742 GFLOPs)。

图9:ImageNet 256×256 图像生成结果

作者在 ImageNet 上训练了一个新的 DiT-XL/2,这次分辨率是 512×512,3M training iterations,超参数与 256×256 模型相同。这个模型 latent 的维度是 64×64×4,然后 Patch Size 为2,这样 Transformer 模型需要处理的 token 的数量就是 1024。如下图10所示是比较结果。DiT-XL/2 在此分辨率下再次优于所有先前的扩散模型,将 ADM 实现的先前最佳 FID 提高了 3.85 到 3.04。即使 token 的数量增加了,DiT-XL/2 的计算效率依然很高,比如 ADM 使用 1983 GFLOPs,ADM-U 使用 2813 GFLOPs,DiT-XL/2 仅仅使用 524.6 GFLOPs。

图10:ImageNet 512×512 图像生成结果

缩放模型大小还是采样次数?

Diffusion Model 的一个独特之处是它们可以通过在生成图像时增加采样步骤的数量来在训练期间使用额外的计算。也就是扩散模型的计算量既可以来自模型本身的缩放,也可以来自采样次数的增加。因此,作者在这里研究了通过使用更多的采样计算,较小的 DiT 模型是否可以胜过更大的模型。

作者计算了所有的 12 个 DiT 模型在 400K training iteration 时候的 FID 值,每张图分别使用 [16, 32, 64, 128, 256, 1000] sampling steps。

实验结果如下图11所示,考虑使用 1000 个采样步骤的 DiT-L/2 和使用 128 步的 DiT-XL/2。在这种情况下:

  • DiT-L/2 使用 80.7 TFLOPs 对每张图像进行采样。
  • DiT-XL/2 使用 15.2 TFLOPs 对每张图像进行采样。

但尽管如此,DiT-XL/2 具有更好的 FID-10K 结果。说明增加采样的计算量也无法弥补模型本身计算量的缺失。

图11:增加采样的计算量也无法弥补模型本身计算量的缺失

#eventful-transformers

如何降低视觉Transformer计算成本?时间冗余方法让人大吃一惊

在为语言领域带来变革之后,Transformer 正在进军视觉领域,但其也有着高计算成本的问题。近日,威斯康星大学麦迪逊分校一个研究团队提出了 Eventful Transformer,可通过在视觉 Transformer 中利用时间冗余来节省成本。

Transformer 一开始是为自然语言处理任务设计的,但现在却已经被广泛用于视觉任务。视觉 Transformer 在一系列视觉识别任务上实现了出色的准确度,并在图像分类、视频分类和目标检测等任务上取得了当前最优的表现。

视觉 Transformer 的一大缺点是计算成本高。典型的卷积网络(CNN)处理每张图像需要数十 GFlops,而视觉 Transformer 所需的往往会多上一个数量级,达到每张图像数百 GFlops。在处理视频时,由于数据量巨大,这个问题更为严重。高昂的计算成本让视觉 Transformer 难以被部署到资源有限或有严格延迟需求的设备上,这就限制了这项技术的应用场景,否则我们已经有一些激动人心的应用了。

在近期一篇论文中,威斯康星大学麦迪逊分校的三位研究者 Matthew Dutson、Yin Li 和 Mohit Gupta 首先提出可以在后续输入之间使用时间冗余来降低视觉 Transformer 在视频应用中的成本。他们也发布了模型代码,其中包含用于构建 Eventful Transformer 的 PyTorch 模块。

  • 论文地址:https://arxiv.org/pdf/2308.13494.pdf
  • 项目地址:http://wisionlab.com/project/eventful-transformers

时间冗余:首先假设有一个视觉 Transformer,其可以逐帧或逐视频片段地处理视频序列。这个 Transformer 可能是简单的逐帧处理的模型(如目标检测器)或是某个时空模型的中间步骤(如 ViViT 的分解式模型的第一步)。不同于一个输入就是一个完整序列的语言处理 Transformer,在这里,研究者的做法是随时间为 Transformer 提供多个不同的输入(帧或视频片段)。

自然视频包含显著的时间冗余,即后续帧之间的差异很小。尽管如此,包括 Transformer 在内的深度网络通常都会「从头开始」计算每一帧。该方法会丢弃之前推理获得的潜在相关信息,浪费极大。故而这三位研究者设想:是否可以复用之前计算步骤的中间计算结果来提升处理冗余序列的效率?

自适应推理:对于视觉 Transformer 以及一般意义上的深度网络而言,推理成本通常由架构决定。然而在现实应用中,可用的资源可能会随时间而变化,比如可能因为存在相竞争的进程或电源发生变化。如此一来,可能就存在运行时修改模型计算成本的需求。在这项新成果中,研究者设定的一大主要设计目标便是适应性 —— 其方法可实现对计算成本的实时控制。下图 1(底部)给出了在视频处理过程中修改计算预算的示例。

Eventful Transformer:本文提出了 Eventful Transformer,这类 Transformer 能利用输入之间的时间冗余来实现高效且自适应的推理。Eventful 这一术语的灵感来自事件相机(event camera),这种传感器能在场景变化时离散地记录影像。Eventful Transformer 会跟踪随时间发生的 token 层面的变化情况,并在每个时间步骤有选择性地更新 token 表征和自注意力映射图。Eventful Transformer 的模块中包含一种门控模块,用于控制运行时间被更新 token 的数量。

该方法可用于现成的模型(通常无需再训练)并且兼容许多视频处理任务。研究者也进行了实验论证,结果表明 Eventful Transformer 可用于现有的当前最佳模型,在极大降低它们的计算成本的同时还能维持其原有的准确度。

Eventful Transformer

这项研究的目标加速用于视频识别的视觉 Transformer。在这个场景中,视觉 Transformer 需要反复处理视频帧或视频片段,具体的任务包括视频目标检测和视频动作识别等。这里提出的关键思想是利用时间冗余,即复用之前时间步骤的计算结果。下面将详细描述如何通过修改 Transformer 模块来使其具备感知时间冗余的能力。

token 门控:检测冗余

这一小节将介绍研究者提出的两种新模块:token 门和 token 缓冲器。这些模块让模型可以识别和更新自上次更新后有明显变化的 token。

门模块:该门会从输入 token N 中选择一部分 M 发送给下游层执行计算。其记忆中维护着一个参照 token 集,记为 u。这种参照向量包含每个 token 在其最近一次更新时的值。在每个时间步骤,比较各个 token 与其对应的参照值,其中与参照值相差较大的 token 获得更新。

现在将该门的当前输入记为 c。在每个时间步骤,按照以下流程更新门的状态并决定其输出(见下图 2):

构建可感知冗余的 Transformer

为了利用上述时间冗余,研究者提出了一种对 Transformer 模块的修改方案。下图 4 展示了 Eventful Transformer 模块的设计。该方法可以加速针对各个 token 的运算(如 MLP)以及查询 - 键值和注意力 - 值乘法。

在针对各个 token 的运算 Transformer 模块中,很多运算都是针对各个 token 的,也就是说它们不涉及到 token 之间的信息交换,其中包括 MLP 和 MSA 中的线性变换。为了节省计算成本,研究者表示可以跳过未被门选取的 token 的面向 token 的运算。由于 token 之间的独立性,这不会改变对所选 token 的运算结果。参见图 3。

具体来说,针对各个 token 的运算(包括 W_qkv 变换、W_p 变换和 MLP)的连续序列,研究者使用了一对门 - 缓冲器。注意,他们还在 skip 连接之前添加了缓冲器以确保两个加法操作数的 token 是正确对齐的。

针对各个 token 的运算的成本正比于 token 的数量。门可将这个数量从 N 降至 M,也就将下游的针对各个 token 的运算的计算成本降低了 N/M 倍。

查询 - 键值的积:现在来看看查询 - 键值积 B = q k^T。

下图 5 展示了稀疏地更新查询 - 键值积 B 中一部分元素的方法。

这些更新的总体成本为 2NMD,相较而言,从头开始计算 B 的成本为 N^2D。注意,新方法的成本正比于 M,即门选取的 token 的数量。当 M < N/2 时(此时更新的 token 不到总量一半),可节省计算量。

注意力 - 值的积:研究者为此提出了一种基于增量 ∆ 的更新策略。

下图 6 展示了新提出的高效计算三个增量项的方法。

同样当 M < N/2 时,可节省计算量。

token 选取策略

Eventful Transformer 的一大重要设计是其 token 选取策略。给定一个门误差张量 e,这样一个策略的目标是生成一个掩码 m,其中指示了应当被更新的 token。具体的策略包括:

Top-r 策略:该策略选取 r 个误差 e 有最大范数的 token(这里使用的是 L2 范数)。

阈值策略:该策略选取误差 e 的范数超过一个阈值 h 的所有 token。

其它策略:更复杂精细的 token 选取策略可实现更好的准确度 - 成本权衡,比如可以使用一个轻量级策略网络来学习一个策略。但是,训练策略的决策机制的难度可能很大,因为二元掩码 m 一般是不可微分的。另一个思路是使用重要度分数作为选取的参考信息。但这些想法都还有待未来研究。

实验

研究者用实验评估了新提出的方法,具体使用的任务是视频目标检测和视频动作识别。

下图 7 展示了视频目标检测的实验结果。其中正轴是计算节省率,负轴是新方法的 mAP50 分数的相对减少量。可以看到,新方法用少量的准确度牺牲换来了显著的计算量节省。

 下图 8 给出了在视频目标检测任务上的方法比较和消融实验结果。

下图 9 给出了视频动作识别的实验结果。

下表 2 给出了在一台 CPU(Xeon Silver 4214, 2.2 GHz)和一台 GPU(NVIDIA RTX3090)上运行时间(毫秒)结果,可以看到时间冗余在 GPU 上带来的速度提升可达 1.74 倍,在 CPU 上带来的提升可达 2.47 倍。

#Llama~transformers搭建

本例从零开始基于transformers库逐模块搭建和解读Llama模型源码(中文可以翻译成羊驼)。

并且训练它来实现一个有趣的实例:两数之和。

输入输出类似如下:

输入:"12345+54321="

输出:"66666"

我们把这个任务当做一个文本生成任务来进行。输入是一个序列的上半部分,输出其下半部分.

这和文本生成的输入输出结构是类似的,所以可以用Llama来做。

目前大部分开源LLM模型都是基于transformers库来做的,它们的结构大部分都和Llama大同小异。

俗话说,魔鬼隐藏在细节中,深入理解Llama模型的的源码细节,将会帮助你打通和开源LLM模型相关的基础原理(如旋转位置编码以及长度外推),并让你熟悉各种参数的配置和使用(如past_key_value,attention_mask的使用等等)。

一,准备数据

import math  
from typing import List, Optional, Tuple, Union  
  
import torch  
import torch.nn.functional as F  
import torch.utils.checkpoint  
from torch import nn  
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss  
  
from transformers.activations import ACT2FN  
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast  
from transformers.modeling_utils import PreTrainedModel  
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings  
  
from transformers.models.llama.configuration_llama  import LlamaConfig  
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING,LLAMA_START_DOCSTRING  
  
logger = logging.get_logger('llama')  
  
config = LlamaConfig(  
    vocab_size=len(vocab),  
    hidden_size=512,  
    intermediate_size=2752,  
    num_hidden_layers=8,  
    num_attention_heads=16,  
    hidden_act='silu',  
    max_position_embeddings=128,  
    initializer_range=0.02,  
    rms_norm_eps=1e-06,  
    use_cache=True,  
    pad_token_id=0,  
    bos_token_id=1,  
    eos_token_id=2,  
    tie_word_embeddings=False  
)
<BOS>3914835626735057733+318829464988=3914835945564522721<EOS>
# 定义数据集  
class TwoSumDataset(torch.utils.data.Dataset):  
    def __init__(self,size = 100000, min_length=10,max_length=20):  
        super(Dataset, self).__init__()  
        self.size = size  
        self.min_length=min_length  
        self.max_length=max_length  
  
    def __len__(self):  
        return self.size  
  
    def __getitem__(self, i):  
        x,y = self.get(i)  
          
        # 编码成token  
        context_ids = [vocab[i] for i in x]  
        target_ids = [vocab[i] for i in y]  
          
        input_ids = context_ids + target_ids  
          
        #-100标志位后面会在计算loss时会被忽略不贡献损失,我们集中优化target部分生成的loss  
        labels = [-100]*len(context_ids)+ target_ids  
        masks = [0 if t==vocab['<PAD>'] else 1 for t in input_ids]  
          
        example = {'input_ids':input_ids,  
                  'labels':labels,'attention_mask':masks}  
          
        return example  
      
    def get(self,i):  
        return get_data(self.min_length,self.max_length)  
      
      
    def show_example(self,example):  
        input_ids,labels = example['input_ids'],example['labels']  
        x = ''.join([vocab_r[a] for a,b in zip(input_ids,labels) if b==-100])  
        y = ''.join([vocab_r[a] for a,b in zip(input_ids,labels) if b!=-100])  
        print(x+y)  
          
          
      
ds_train = TwoSumDataset(size = 100000,min_length=10,max_length=20)  
ds_val = TwoSumDataset(size = 10000,min_length=10,max_length=20)  
example = ds_train[0]  
ds_train.show_example(example)
<BOS>12878683929048906366+11274414130675477=12889958343179581843<EOS>
def data_collator(examples: list):  
    len_ids = [len(example["input_ids"]) for example in examples]  
    longest = max(len_ids) #之后按照batch中最长的input_ids进行padding  
      
    input_ids = []  
    labels_list = []  
    masks_list = []  
      
    for length, example in sorted(zip(len_ids, examples), key=lambda x: -x[0]):  
        ids = example["input_ids"]  
        labs = example["labels"]  
        masks = example['attention_mask']  
          
        ids = [vocab['<PAD>']] * (longest - length)+ids   
        labs = [-100] * (longest - length)+labs  
        masks = [0]*(longest - length)+masks  
          
        input_ids.append(torch.LongTensor(ids))  
        labels_list.append(torch.LongTensor(labs))  
        masks_list.append(torch.LongTensor(masks))  
            
    input_ids = torch.stack(input_ids)  
    labels = torch.stack(labels_list)  
    attention_mask = torch.stack(masks_list)  
    return {  
        "input_ids": input_ids,  
        "labels": labels,  
        "attention_mask":attention_mask  
    }  
  
# 数据加载器  
dl_train = DataLoader(dataset=ds_train,  
         batch_size=200,  
         drop_last=True,  
         shuffle=True,  
         collate_fn = data_collator          
        )  
  
dl_val = DataLoader(dataset=ds_val,  
         batch_size=200,  
         drop_last=True,  
         shuffle=False,  
         collate_fn = data_collator    
        )  
  
  
for batch in dl_train:  
    break
batch   

{'input_ids': tensor([[ 1, 11,  6,  ...,  7, 11,  2],         [ 0,  1,  6,  ...,  5,  4,  2],         [ 0,  1,  7,  ...,  8,  8,  2],         ...,         [ 0,  0,  0,  ..., 10, 11,  2],         [ 0,  0,  0,  ..., 12,  3,  2],         [ 0,  0,  0,  ..., 11, 12,  2]]), 'labels': tensor([[-100, -100, -100,  ...,    7,   11,    2],         [-100, -100, -100,  ...,    5,    4,    2],         [-100, -100, -100,  ...,    8,    8,    2],         ...,         [-100, -100, -100,  ...,   10,   11,    2],         [-100, -100, -100,  ...,   12,    3,    2],         [-100, -100, -100,  ...,   11,   12,    2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],         [0, 1, 1,  ..., 1, 1, 1],         [0, 1, 1,  ..., 1, 1, 1],         ...,         [0, 0, 0,  ..., 1, 1, 1],         [0, 0, 0,  ..., 1, 1, 1],         [0, 0, 0,  ..., 1, 1, 1]])}

二,定义模型

下面,我们会像搭积木建城堡那样从低往高地构建LLaMA模型。先构建4个基础组件:旋转位置编码,多头注意力、前馈网络、层归一化。类似用最基础的积木块搭建了 墙壁,房顶,房门,窗户 这样的模块。然后用这4个基础组件构建中间成品: 解码层。类似用基础组件构建了房间。接着用多个中间成品解码层的堆叠组装成了LlamaModel完整模型,相当于通过构建多个房间建成了城堡的主体结构。最后我们在LlamaModel基础上设计了两种不同的输出head,一种是语言模型Head,得到了LlamaForCausalLM,可用于文本生成。另外一种是分类head,得到了LlamaForSequenceClassification,可用于文本分类。相当于我们在城堡主体结构完成的基础上设计了两种不同的装修风格,一种是加装了一些游乐设施以便用于商业活动,另一种则是加装了一些武器以便用于军事活动。

1, 旋转位置编码: RoPE (使用旋转矩阵实现的绝对位置编码,可以起到相对位置编码的效果)

2, 多头注意力: LlamaAttention (用于融合不同token之间的信息)

3, 前馈网络: LlamaMLP (用于逐位置将多头注意力融合后的信息进行高维映射变换)

4, 层归一化: LlamaRMSNorm (用于稳定输入,相当于保持每个词向量的方向不变,但对模长标准化。)

5, Llama解码层: LlamaDecoderLayer (同时具备信息融合,信息转换功能的基本结构单元)

6, Llama解码器: LlamaModel (多个解码层的堆叠)7,Llama语言模型: LlamaForCausalLM (解码器加上语言模型head,可用于文本生成)8,Llama分类模型: LlamaForSequenceClassification (解码器加上分类head,可用于文本分类)

import math  
from typing import List, Optional, Tuple, Union  
  
import torch  
import torch.nn.functional as F  
import torch.utils.checkpoint  
from torch import nn  
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss  
  
from transformers.activations import ACT2FN  
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast  
from transformers.modeling_utils import PreTrainedModel  
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings  
  
from transformers.models.llama.configuration_llama  import LlamaConfig  
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING,LLAMA_START_DOCSTRING  
  
logger = logging.get_logger('llama')  
  
config = LlamaConfig(  
    vocab_size=len(vocab),  
    hidden_size=512,  
    intermediate_size=2752,  
    num_hidden_layers=8,  
    num_attention_heads=16,  
    hidden_act='silu',  
    max_position_embeddings=128,  
    initializer_range=0.02,  
    rms_norm_eps=1e-06,  
    use_cache=True,  
    pad_token_id=0,  
    bos_token_id=1,  
    eos_token_id=2,  
    tie_word_embeddings=False  
)

1,旋转位置编码 RoPE

旋转位置编码即使用旋转矩阵表示位置编码(Rotary Position Encoding),简称RoPE。

关于RoPE的3个核心要点知识如下:

RoPE的设计思想是使用绝对位置编码来达到相对位置编码的效果。

RoPE的实现方式是使用旋转矩阵来表示绝对位置编码。

使用NTK扩展方法可以让RoPE在短文本上训练并在长文本上做预测。

参考文章:

《博采众长的旋转式位置编码》https://kexue.fm/archives/8265

《RoPE是一种进制编码》https://kexue.fm/archives/9675

(1)绝对位置编码和相对位置编码

位置编码一般可以分成绝对位置编码和相对位置编码。

绝对位置编码的优点是计算简单高效,缺点是一般效果不如相对位置编码。

相对位置编码的优点是效果较好,缺点是计算效率不如绝对位置编码。

在相对位置编码中,注意力权重的结果仅仅和参与注意力计算的token向量的相对位置有关,不和绝对位置直接关联。

这符合NLP领域在序列长度方向上具有平移不变性的特点,所以相对位置编码一般效果会优于绝对位置编码。

不过绝对位置编码并非一无是处,绝对位置编码只需要初始化时对序列的每个位置(数量正比于序列长度)赋予位置编码即可,后续无需干预。

而相对位置编码要在计算过程中获取许多个(数量正比于序列长度平方)相对位置。

因此绝对位置编码更加简单高效。

(2)使用旋转矩阵表示位置编码

上述讨论可以看到,绝对位置编码和相对位置编码互有优劣,那么有没有什么办法能够对二者进行取长补短呢?

有的,这个方法就是RoPE,它的设计思想就是使用绝对位置编码来达到相对位置编码的效果。

那么旋转位置编码如何使用绝对位置编码来达到相对位置编码的效果的呢?答案是使用旋转矩阵来表示位置编码。

由于旋转矩阵是稀疏矩阵,直接使用乘法计算会很浪费算力,可以将旋转位置编码过程由矩阵乘法运算简化成两次向量的哈达玛积求和。 

(3)旋转位置编码的长度扩展

在LLM的应用中,有一个非常重要的参数,叫做LLM支持的上下文长度(max context length)。

更长的上下文长度允许我们进行更多轮次的对话,允许我们对更长的本文进行总结分析,也允许我们生成更长的文章。

但是在训练LLM的时候,我们的训练语料大部分是不够长的,许多LLM训练时候设计的最大文本长度都是只有2k,也就是最长2048个token。

那么,能否在训练的时候使用较短的文本,而在推理的时候扩展到长文本上呢?

是有可能的,我们可以对RoPE进行长度扩展。

我们介绍3种扩展方案。

第一种是直接外推:直接外推其实就是继续沿用现有的位置编码公式,不做任何修改。

在扩展长度不太长的时候,例如由2k扩展到2.5k时,这种方法可能对性能的影响并不大。

因为旋转位置编码只和相对位置m-n的大小有关,一般具有远程衰减性,即相对距离越大的两个token,其相关性一般越弱。

因此如果我们的模型已经从训练数据那里学习到了token之间的相关性相对于相对距离在0-2k的一个合适的衰减规律的时候,可以设想把这个规律应用到0-2.5k也是没有太大的问题的。

但是如果我们要扩展到更长的长度,例如从2k扩展到32k,这种直接外推的方案通常会严重地影响性能。因为我们学习到的衰减规律有可能在5k的那里就完全衰减截断基本降为0了,这样我们就无法捕捉相对距离长于5k的两个token之间的相互作用,外推就会导致性能下降。

总结一下,直接外推对衰减规律在长距离情况下的使用容易出现问题,导致性能下降。

为了减少长度外推对性能的影响,我们可以让训练好的模型在更长的上下文上做少许步骤的微调。

第二种是线性内插:线性内插需要改变位置编码公式,等效于将位置序号等比例缩小。

线性内插没有改变模型学习到的衰减规律的应用范围,不考虑微调的话,其效果一般好于直接外推方案。

但是,扩展倍数非常大的时候,例如从2k扩展到32k,其性能也会明显的受到影响。

因为在这种情况下,衰减规律在短距离情况下的使用会受到较严重的影响,本来距离为1的两个token,长度扩展后相当于变成了距离为1/16,衰减规律在短距离时可能具有非常大的变化率,因此对相关性的评估可能会极端地偏离合理值。

应用线性内插时,在长文本上做少许步骤的微调也能够明显地改善性能。

第三种是NTK扩展方式:这种方式综合了外推和内插的优点,做长度扩展后即使不微调也能够保持较好的性能。

前面的分析我们知道直接外推对衰减规律在长距离情况下的使用容易出问题,在短距离情况下的使用不受影响。

而线性内插对衰减规律在短距离情况下的使用容易出现问题,在长距离的情况下影响较小。

我们能否将它们综合起来,在短距离情况下具有外推特性(与扩展前基本一致),在长距离情况下具有内插特性(缩放到扩展前的范围),从而使得长距离情况下和短距离情况下衰减规律的使用都不太受到影响呢。

NTK扩展方式的要点是高频外推,低频内插,实现方法是直接对底数base进行缩放,类似进制编码转换。

采用NTK扩展到长文本,即使不做微调,性能会只会略有下降。

下面是RoPE以及三种长度扩展方式的实现。

class LlamaRotaryEmbedding(torch.nn.Module):  
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):  
        super().__init__()  
        self.dim = dim  
        self.max_position_embeddings = max_position_embeddings  
        self.base = base  
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))  
        self.register_buffer("inv_freq", inv_freq, persistent=False) #persistent=False将不会作为state_dict  
  
        # Build here to make `torch.jit.trace` work.  
        self._set_cos_sin_cache(  
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()  
        )  
  
    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  
  
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)  
  
    def forward(self, x, seq_len=None):  
        # x: [bs, num_attention_heads, seq_len, head_size]  
        #超过预设的max_position_embeddings则重新计算更大的Rope缓存,否则直接在缓存上切片  
        if seq_len > self.max_seq_len_cached:   
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)  
  
        return (  
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),  
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),  
        )  
  
      
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):  
    """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""  
  
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):  
        self.scaling_factor = scaling_factor  
        super().__init__(dim, max_position_embeddings, base, device)  
  
    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  
        t = t / self.scaling_factor #线性内插相当于将位置序号等比例缩小  
  
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)  
  
  
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):  
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""  
  
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):  
        self.scaling_factor = scaling_factor  
        super().__init__(dim, max_position_embeddings, base, device)  
  
    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  
  
        if seq_len > self.max_position_embeddings:  
            base = self.base * (  
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)  
            ) ** (self.dim / (self.dim - 2))  #NTK扩展方式直接对base进行缩放  
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))  
            self.register_buffer("inv_freq", inv_freq, persistent=False)  
  
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  
  
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
          
        #此处处理逻辑与原始的ROPE有差异,原始逻辑如下  
        #emb = torch.cat((freqs, freqs), dim=-1)  
        #emb[...,0::2]=freqs  
        #emb[...,1::2]=freqs  
          
          
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)  
          
          
def rotate_half(x):  
    """Rotates half the hidden dims of the input."""  
      
    #此处逻辑与原始的ROPE有所差异,原始逻辑如下  
    #x1 = x[..., 0::2]   
    #x2 = x[..., 1::2]  
    #res = torch.cat((x1, x2), dim=-1)  
    #res[...,0::2]=-x2  
    #res[...,1::2]=x1  
    #return res  
      
    x1 = x[..., : x.shape[-1] // 2]   
    x2 = x[..., x.shape[-1] // 2 :]  
    return torch.cat((-x2, x1), dim=-1)  
  
  
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):  
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.  
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]  
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]  
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]  
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]  
    q_embed = (q * cos) + (rotate_half(q) * sin)  
    k_embed = (k * cos) + (rotate_half(k) * sin)  
    return q_embed, k_embed  
  
x = torch.randn(1,8,4,2)  
rope = LlamaRotaryEmbedding(dim=8)  
cos,sin = rope.forward(x,seq_len=4)  
print(cos.shape)   
print(cos)  
torch.Size([1, 1, 4, 8])tensor([[[[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,            1.0000],          [ 0.5403,  0.9950,  0.9999,  1.0000,  0.5403,  0.9950,  0.9999,            1.0000],          [-0.4161,  0.9801,  0.9998,  1.0000, -0.4161,  0.9801,  0.9998,            1.0000],          [-0.9900,  0.9553,  0.9996,  1.0000, -0.9900,  0.9553,  0.9996,            1.0000]]]])

2,多头注意力 LlamaAttention

这里的LlamaAttention 基本上和《Attention Is All You Need》论文里的是一致的,主要差异有以下一些。

1,k和v的head数量可以是q的head数量的几分之一,类似分组卷积的思想,可以减少参数规模。

2,rope位置编码是每次做多头注意力时都进行一次,而不是原论文只在输入的时候进行一次。

3,允许传入key和value的states的缓存past_key_value,这在多轮对话中可以减少重复计算,起到加速效果。

4,attention_mask是通过加法形式作用到softmax之前的attention矩阵上的。

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:  
    """  
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,  
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)  
    """  
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape  
    if n_rep == 1:  
        return hidden_states  
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)  
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)  
  
  
class LlamaAttention(nn.Module):  
    """Multi-headed attention from 'Attention Is All You Need' paper"""  
  
    def __init__(self, config: LlamaConfig):  
        super().__init__()  
        self.config = config  
        self.hidden_size = config.hidden_size  
        self.num_heads = config.num_attention_heads  
        self.head_dim = self.hidden_size // self.num_heads  
        self.num_key_value_heads = config.num_key_value_heads  
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads  
        self.max_position_embeddings = config.max_position_embeddings  
  
        if (self.head_dim * self.num_heads) != self.hidden_size:  
            raise ValueError(  
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"  
                f" and `num_heads`: {self.num_heads})."  
            )  
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)  
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)  
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)  
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)  
        self._init_rope()  
  
    def _init_rope(self):  
        if self.config.rope_scaling is None:  
            self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)  
        else:  
            scaling_type = self.config.rope_scaling["type"]  
            scaling_factor = self.config.rope_scaling["factor"]  
            if scaling_type == "linear":  
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(  
                    self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor  
                )  
            elif scaling_type == "dynamic":  
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(  
                    self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor  
                )  
            else:  
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")  
  
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):  
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()  
  
    def forward(  
        self,  
        hidden_states: torch.Tensor,  
        attention_mask: Optional[torch.Tensor] = None,  
        position_ids: Optional[torch.LongTensor] = None,  
        past_key_value: Optional[Tuple[torch.Tensor]] = None,  
        output_attentions: bool = False,  
        use_cache: bool = False,  
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:  
        bsz, q_len, _ = hidden_states.size()  
  
        if self.config.pretraining_tp > 1:  
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp  
            query_slices = self.q_proj.weight.split(  
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0  
            )  
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)  
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)  
  
            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]  
            query_states = torch.cat(query_states, dim=-1)  
  
            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]  
            key_states = torch.cat(key_states, dim=-1)  
  
            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]  
            value_states = torch.cat(value_states, dim=-1)  
  
        else:  
            query_states = self.q_proj(hidden_states)  
            key_states = self.k_proj(hidden_states)  
            value_states = self.v_proj(hidden_states)  
  
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)  
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)  
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)  
  
        kv_seq_len = key_states.shape[-2]  
        if past_key_value is not None:  
            kv_seq_len += past_key_value[0].shape[-2]  
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)  
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)  
  
        if past_key_value is not None:  
            # reuse k, v, self_attention  
            key_states = torch.cat([past_key_value[0], key_states], dim=2)  
            value_states = torch.cat([past_key_value[1], value_states], dim=2)  
  
        past_key_value = (key_states, value_states) if use_cache else None  
  
        # repeat k/v heads if n_kv_heads < n_heads  
        key_states = repeat_kv(key_states, self.num_key_value_groups)  
        value_states = repeat_kv(value_states, self.num_key_value_groups)  
  
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)  
  
        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):  
            raise ValueError(  
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"  
                f" {attn_weights.size()}"  
            )  
  
        if attention_mask is not None:  
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):  
                raise ValueError(  
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"  
                )  
            attn_weights = attn_weights + attention_mask  
  
        # upcast attention to fp32  
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)  
        attn_output = torch.matmul(attn_weights, value_states)  
  
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):  
            raise ValueError(  
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"  
                f" {attn_output.size()}"  
            )  
  
        attn_output = attn_output.transpose(1, 2).contiguous()  
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)  
  
        if self.config.pretraining_tp > 1:  
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)  
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)  
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])  
        else:  
            attn_output = self.o_proj(attn_output)  
  
        if not output_attentions:  
            attn_weights = None  
  
        return attn_output, attn_weights, past_key_value

3,前馈网络 LlamaMLP

前馈网络是一个2层的感知机MLP。

先从hidden_size维度up_proj到intermediate_size维度,然后再down_proj还原为hidden_size维度。

这里的主要特色是引入了一个gate_proj配合激活函数来实现一个门控注意力的作用。

class LlamaMLP(nn.Module):  
    def __init__(self, config):  
        super().__init__()  
        self.config = config  
        self.hidden_size = config.hidden_size  
        self.intermediate_size = config.intermediate_size  
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)  
        self.act_fn = ACT2FN[config.hidden_act]  
  
    def forward(self, x):  
        if self.config.pretraining_tp > 1:  
            slice = self.intermediate_size // self.config.pretraining_tp  
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)  
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)  
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)  
  
            gate_proj = torch.cat(  
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1  
            )  
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)  
  
            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)  
            down_proj = [  
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)  
            ]  
            down_proj = sum(down_proj)  
        else:  
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))  
  
        return down_proj

4,层归一化 LlamaRMSNorm

这里的层归一化叫做RMSNorm,和标准的LayerNorm有少许差异。

首先是没有移除均值,直接除的RootMeanSquare,然后也没有加上bias。

这两个小的修正可以保证在层归一化不会改变hidden_states对应的词向量的方向,只会改变其模长。

在一定的意义上具有合理性。

class LlamaRMSNorm(nn.Module):  
    def __init__(self, hidden_size, eps=1e-6):  
        """  
        LlamaRMSNorm is equivalent to T5LayerNorm  
        """  
        super().__init__()  
        self.weight = nn.Parameter(torch.ones(hidden_size))  
        self.variance_epsilon = eps  
  
    def forward(self, hidden_states):  
        input_dtype = hidden_states.dtype  
        hidden_states = hidden_states.to(torch.float32)  
        variance = hidden_states.pow(2).mean(-1, keepdim=True)  
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)  
        return self.weight * hidden_states.to(input_dtype)

5,Llama解码层

解码层LlamaDecoderLayer由LlamaAttention,LlamaMLP,以及两个LlamaRMSNorm组成,并使用了两次残差结构。

class LlamaDecoderLayer(nn.Module):  
    def __init__(self, config: LlamaConfig):  
        super().__init__()  
        self.hidden_size = config.hidden_size  
        self.self_attn = LlamaAttention(config=config)  
        self.mlp = LlamaMLP(config)  
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  
  
    def forward(  
        self,  
        hidden_states: torch.Tensor,  
        attention_mask: Optional[torch.Tensor] = None,  
        position_ids: Optional[torch.LongTensor] = None,  
        past_key_value: Optional[Tuple[torch.Tensor]] = None,  
        output_attentions: Optional[bool] = False,  
        use_cache: Optional[bool] = False,  
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:  
        """  
        Args:  
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`  
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size  
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.  
            output_attentions (`bool`, *optional*):  
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under  
                returned tensors for more detail.  
            use_cache (`bool`, *optional*):  
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding  
                (see `past_key_values`).  
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states  
        """  
  
        residual = hidden_states  
  
        hidden_states = self.input_layernorm(hidden_states)  
  
        # Self Attention  
        hidden_states, self_attn_weights, present_key_value = self.self_attn(  
            hidden_states=hidden_states,  
            attention_mask=attention_mask,  
            position_ids=position_ids,  
            past_key_value=past_key_value,  
            output_attentions=output_attentions,  
            use_cache=use_cache,  
        )  
        hidden_states = residual + hidden_states  
  
        # Fully Connected  
        residual = hidden_states  
        hidden_states = self.post_attention_layernorm(hidden_states)  
        hidden_states = self.mlp(hidden_states)  
        hidden_states = residual + hidden_states  
  
        outputs = (hidden_states,)  
  
        if output_attentions:  
            outputs += (self_attn_weights,)  
  
        if use_cache:  
            outputs += (present_key_value,)  
  
        return outputs

6,Llama解码器

LlamaModel由多个Llama解码层堆叠而成。

有几个理解上的要点:

1,_make_causal_mask用于构造下三角这种mask结构以实现语言模型的单向注意力。

2,_expand_mask用于将传入的等特殊符号相关的mask信息展开成和attention矩阵相同的张量结构。

3,设置gradient_checkpointing=True可以节约显存。其主要应用了torch.utils.checkpoint.checkpoint方法。它的原理非常简单,在对decoder_layer进行forward时不保存中间激活值从而节约显存,backward时重新计算相关值,从而通过时间换取了空间。

4,gradient_checkpointing和use_cache不能同时设置为True,前者是为了节约显存时间换空间的,后者是为了节约时间空间换时间。

# Copied from transformers.models.bart.modeling_bart._make_causal_mask  
def _make_causal_mask(  
    input_ids_shape: torch.Size, dtype: torch.dtype,   
    device: torch.device, past_key_values_length: int = 0  
):  
    """  
    Make causal mask used for bi-directional self-attention.  
    """  
    bsz, tgt_len = input_ids_shape  
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)  
    mask_cond = torch.arange(mask.size(-1), device=device)  
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)  
    mask = mask.to(dtype)  
  
    if past_key_values_length > 0:  
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)  
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)  
  
  
# Copied from transformers.models.bart.modeling_bart._expand_mask  
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):  
    """  
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.  
    """  
    bsz, src_len = mask.size()  
    tgt_len = tgt_len if tgt_len is not None else src_len  
  
    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)  
    inverted_mask = 1.0 - expanded_mask  
  
    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)  
  
  
@add_start_docstrings(  
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",  
    LLAMA_START_DOCSTRING,  
)  
class LlamaPreTrainedModel(PreTrainedModel):  
    config_class = LlamaConfig  
    base_model_prefix = "model"  
    supports_gradient_checkpointing = True  
    _no_split_modules = ["LlamaDecoderLayer"]  
    _skip_keys_device_placement = "past_key_values"  
  
    def _init_weights(self, module):  
        std = self.config.initializer_range  
        if isinstance(module, nn.Linear):  
            module.weight.data.normal_(mean=0.0, std=std)  
            if module.bias is not None:  
                module.bias.data.zero_()  
        elif isinstance(module, nn.Embedding):  
            module.weight.data.normal_(mean=0.0, std=std)  
            if module.padding_idx is not None:  
                module.weight.data[module.padding_idx].zero_()  
  
    def _set_gradient_checkpointing(self, module, value=False):  
        if isinstance(module, LlamaModel):  
            module.gradient_checkpointing = value  
  
  
@add_start_docstrings(  
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",  
    LLAMA_START_DOCSTRING,  
)  
class LlamaModel(LlamaPreTrainedModel):  
    """  
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]  
  
    Args:  
        config: LlamaConfig  
    """  
  
    def __init__(self, config: LlamaConfig):  
        super().__init__(config)  
        self.padding_idx = config.pad_token_id  
        self.vocab_size = config.vocab_size  
  
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)  
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])  
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  
  
        self.gradient_checkpointing = False  
        # Initialize weights and apply final processing  
        self.post_init()  
  
    def get_input_embeddings(self):  
        return self.embed_tokens  
  
    def set_input_embeddings(self, value):  
        self.embed_tokens = value  
  
    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask  
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):  
        # create causal mask  
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]  
        combined_attention_mask = None  
        if input_shape[-1] > 1:  
            combined_attention_mask = _make_causal_mask(  
                input_shape,  
                inputs_embeds.dtype,  
                device=inputs_embeds.device,  
                past_key_values_length=past_key_values_length,  
            )  
  
        if attention_mask is not None:  
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]  
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(  
                inputs_embeds.device  
            )  
            combined_attention_mask = (  
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask  
            )  
  
        return combined_attention_mask  
  
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)  
    def forward(  
        self,  
        input_ids: torch.LongTensor = None,  
        attention_mask: Optional[torch.Tensor] = None,  
        position_ids: Optional[torch.LongTensor] = None,  
        past_key_values: Optional[List[torch.FloatTensor]] = None,  
        inputs_embeds: Optional[torch.FloatTensor] = None,  
        use_cache: Optional[bool] = None,  
        output_attentions: Optional[bool] = None,  
        output_hidden_states: Optional[bool] = None,  
        return_dict: Optional[bool] = None,  
    ) -> Union[Tuple, BaseModelOutputWithPast]:  
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions  
        output_hidden_states = (  
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states  
        )  
        use_cache = use_cache if use_cache is not None else self.config.use_cache  
  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict  
  
        # retrieve input_ids and inputs_embeds  
        if input_ids is not None and inputs_embeds is not None:  
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")  
        elif input_ids is not None:  
            batch_size, seq_length = input_ids.shape  
        elif inputs_embeds is not None:  
            batch_size, seq_length, _ = inputs_embeds.shape  
        else:  
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")  
  
        seq_length_with_past = seq_length  
        past_key_values_length = 0  
  
        if past_key_values is not None:  
            past_key_values_length = past_key_values[0][0].shape[2]  
            seq_length_with_past = seq_length_with_past + past_key_values_length  
  
        if position_ids is None:  
            device = input_ids.device if input_ids is not None else inputs_embeds.device  
            position_ids = torch.arange(  
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device  
            )  
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)  
        else:  
            position_ids = position_ids.view(-1, seq_length).long()  
  
        if inputs_embeds is None:  
            inputs_embeds = self.embed_tokens(input_ids)  
        # embed positions  
        if attention_mask is None:  
            attention_mask = torch.ones(  
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device  
            )  
        attention_mask = self._prepare_decoder_attention_mask(  
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length  
        )  
  
        hidden_states = inputs_embeds  
  
        if self.gradient_checkpointing and self.training:  
            if use_cache:  
                logger.warning_once(  
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."  
                )  
                use_cache = False  
  
        # decoder layers  
        all_hidden_states = () if output_hidden_states else None  
        all_self_attns = () if output_attentions else None  
        next_decoder_cache = () if use_cache else None  
  
        for idx, decoder_layer in enumerate(self.layers):  
            if output_hidden_states:  
                all_hidden_states += (hidden_states,)  
  
            past_key_value = past_key_values[idx] if past_key_values is not None else None  
  
            if self.gradient_checkpointing and self.training:  
  
                def create_custom_forward(module):  
                    def custom_forward(*inputs):  
                        # None for past_key_value  
                        return module(*inputs, output_attentions, None)  
  
                    return custom_forward  
  
                layer_outputs = torch.utils.checkpoint.checkpoint(  
                    create_custom_forward(decoder_layer),  
                    hidden_states,  
                    attention_mask,  
                    position_ids,  
                    None,  
                )  
            else:  
                layer_outputs = decoder_layer(  
                    hidden_states,  
                    attention_mask=attention_mask,  
                    position_ids=position_ids,  
                    past_key_value=past_key_value,  
                    output_attentions=output_attentions,  
                    use_cache=use_cache,  
                )  
  
            hidden_states = layer_outputs[0]  
  
            if use_cache:  
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)  
  
            if output_attentions:  
                all_self_attns += (layer_outputs[1],)  
  
        hidden_states = self.norm(hidden_states)  
  
        # add hidden states from the last decoder layer  
        if output_hidden_states:  
            all_hidden_states += (hidden_states,)  
  
        next_cache = next_decoder_cache if use_cache else None  
        if not return_dict:  
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)  
        return BaseModelOutputWithPast(  
            last_hidden_state=hidden_states,  
            past_key_values=next_cache,  
            hidden_states=all_hidden_states,  
            attentions=all_self_attns,  
        )

7,Llama语言模型

Llama语言模型 LlamaForCausalLM是在Llama解码器LlamaModel的基础上增加了一个lm_head作为Generator。

从而实现了一个完整的语言模型。

除此之外,Llama语言模型还实现了以下重要功能。

1,loss计算功能。当forward方法中传入labels时,会自动计算语言模型的交叉熵损失。注意labels中的-100会被忽略不参与计算。

2,文本生成generate方法。这个方法继承自PreTrainedModel,可以设置model.generation_config.num_beams选择束搜索的束宽度,默认为1即贪心搜索。

_CONFIG_FOR_DOC = "LlamaConfig"  
  
class LlamaForCausalLM(LlamaPreTrainedModel):  
    _tied_weights_keys = ["lm_head.weight"]  
  
    def __init__(self, config):  
        super().__init__(config)  
        self.model = LlamaModel(config)  
        self.vocab_size = config.vocab_size  
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)  
  
        # Initialize weights and apply final processing  
        self.post_init()  
  
    def get_input_embeddings(self):  
        return self.model.embed_tokens  
  
    def set_input_embeddings(self, value):  
        self.model.embed_tokens = value  
  
    def get_output_embeddings(self):  
        return self.lm_head  
  
    def set_output_embeddings(self, new_embeddings):  
        self.lm_head = new_embeddings  
  
    def set_decoder(self, decoder):  
        self.model = decoder  
  
    def get_decoder(self):  
        return self.model  
  
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)  
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)  
    def forward(  
        self,  
        input_ids: torch.LongTensor = None,  
        attention_mask: Optional[torch.Tensor] = None,  
        position_ids: Optional[torch.LongTensor] = None,  
        past_key_values: Optional[List[torch.FloatTensor]] = None,  
        inputs_embeds: Optional[torch.FloatTensor] = None,  
        labels: Optional[torch.LongTensor] = None,  
        use_cache: Optional[bool] = None,  
        output_attentions: Optional[bool] = None,  
        output_hidden_states: Optional[bool] = None,  
        return_dict: Optional[bool] = None,  
    ) -> Union[Tuple, CausalLMOutputWithPast]:  
  
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions  
        output_hidden_states = (  
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states  
        )  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict  
  
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)  
        outputs = self.model(  
            input_ids=input_ids,  
            attention_mask=attention_mask,  
            position_ids=position_ids,  
            past_key_values=past_key_values,  
            inputs_embeds=inputs_embeds,  
            use_cache=use_cache,  
            output_attentions=output_attentions,  
            output_hidden_states=output_hidden_states,  
            return_dict=return_dict,  
        )  
  
        hidden_states = outputs[0]  
        if self.config.pretraining_tp > 1:  
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)  
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]  
            logits = torch.cat(logits, dim=-1)  
        else:  
            logits = self.lm_head(hidden_states)  
        logits = logits.float()  
  
        loss = None  
        if labels is not None:  
            # Shift so that tokens < n predict n  
            shift_logits = logits[..., :-1, :].contiguous()  
            shift_labels = labels[..., 1:].contiguous()  
            # Flatten the tokens  
            loss_fct = CrossEntropyLoss()  
            shift_logits = shift_logits.view(-1, self.config.vocab_size)  
            shift_labels = shift_labels.view(-1)  
            # Enable model parallelism  
            shift_labels = shift_labels.to(shift_logits.device)  
            loss = loss_fct(shift_logits, shift_labels)  
  
        if not return_dict:  
            output = (logits,) + outputs[1:]  
            return (loss,) + output if loss is not None else output  
  
        return CausalLMOutputWithPast(  
            loss=loss,  
            logits=logits,  
            past_key_values=outputs.past_key_values,  
            hidden_states=outputs.hidden_states,  
            attentions=outputs.attentions,  
        )  
  
    def prepare_inputs_for_generation(  
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs  
    ):  
        if past_key_values:  
            input_ids = input_ids[:, -1:]  
  
        position_ids = kwargs.get("position_ids", None)  
        if attention_mask is not None and position_ids is None:  
            # create position_ids on the fly for batch generation  
            position_ids = attention_mask.long().cumsum(-1) - 1  
            position_ids.masked_fill_(attention_mask == 0, 1)  
            if past_key_values:  
                position_ids = position_ids[:, -1].unsqueeze(-1)  
  
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step  
        if inputs_embeds is not None and past_key_values is None:  
            model_inputs = {"inputs_embeds": inputs_embeds}  
        else:  
            model_inputs = {"input_ids": input_ids}  
  
        model_inputs.update(  
            {  
                "position_ids": position_ids,  
                "past_key_values": past_key_values,  
                "use_cache": kwargs.get("use_cache"),  
                "attention_mask": attention_mask,  
            }  
        )  
        return model_inputs  
  
    @staticmethod  
    def _reorder_cache(past_key_values, beam_idx):  
        reordered_past = ()  
        for layer_past in past_key_values:  
            reordered_past += (  
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),  
            )  
        return reordered_past

8,Llama分类模型

LlamaForSequenceClassification是一个序列分类模型。

这个分类模型可以用来训练RLHF流程中的Reward模型。

@add_start_docstrings(  
    """  
    The LLaMa Model transformer with a sequence classification head on top (linear layer).  
  
    [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models  
    (e.g. GPT-2) do.  
  
    Since it does classification on the last token, it requires to know the position of the last token. If a  
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If  
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the  
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in  
    each row of the batch).  
    """,  
    LLAMA_START_DOCSTRING,  
)  
class LlamaForSequenceClassification(LlamaPreTrainedModel):  
    def __init__(self, config):  
        super().__init__(config)  
        self.num_labels = config.num_labels  
        self.model = LlamaModel(config)  
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)  
  
        # Initialize weights and apply final processing  
        self.post_init()  
  
    def get_input_embeddings(self):  
        return self.model.embed_tokens  
  
    def set_input_embeddings(self, value):  
        self.model.embed_tokens = value  
  
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)  
    def forward(  
        self,  
        input_ids: torch.LongTensor = None,  
        attention_mask: Optional[torch.Tensor] = None,  
        position_ids: Optional[torch.LongTensor] = None,  
        past_key_values: Optional[List[torch.FloatTensor]] = None,  
        inputs_embeds: Optional[torch.FloatTensor] = None,  
        labels: Optional[torch.LongTensor] = None,  
        use_cache: Optional[bool] = None,  
        output_attentions: Optional[bool] = None,  
        output_hidden_states: Optional[bool] = None,  
        return_dict: Optional[bool] = None,  
    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:  
        r"""  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):  
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,  
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If  
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).  
        """  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict  
  
        transformer_outputs = self.model(  
            input_ids,  
            attention_mask=attention_mask,  
            position_ids=position_ids,  
            past_key_values=past_key_values,  
            inputs_embeds=inputs_embeds,  
            use_cache=use_cache,  
            output_attentions=output_attentions,  
            output_hidden_states=output_hidden_states,  
            return_dict=return_dict,  
        )  
        hidden_states = transformer_outputs[0]  
        logits = self.score(hidden_states)  
  
        if input_ids is not None:  
            batch_size = input_ids.shape[0]  
        else:  
            batch_size = inputs_embeds.shape[0]  
  
        if self.config.pad_token_id is None and batch_size != 1:  
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")  
        if self.config.pad_token_id is None:  
            sequence_lengths = -1  
        else:  
            if input_ids is not None:  
                sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(  
                    logits.device  
                )  
            else:  
                sequence_lengths = -1  
  
        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]  
  
        loss = None  
        if labels is not None:  
            labels = labels.to(logits.device)  
            if self.config.problem_type is None:  
                if self.num_labels == 1:  
                    self.config.problem_type = "regression"  
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):  
                    self.config.problem_type = "single_label_classification"  
                else:  
                    self.config.problem_type = "multi_label_classification"  
  
            if self.config.problem_type == "regression":  
                loss_fct = MSELoss()  
                if self.num_labels == 1:  
                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())  
                else:  
                    loss = loss_fct(pooled_logits, labels)  
            elif self.config.problem_type == "single_label_classification":  
                loss_fct = CrossEntropyLoss()  
                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))  
            elif self.config.problem_type == "multi_label_classification":  
                loss_fct = BCEWithLogitsLoss()  
                loss = loss_fct(pooled_logits, labels)  
        if not return_dict:  
            output = (pooled_logits,) + transformer_outputs[1:]  
            return ((loss,) + output) if loss is not None else output  
  
        return SequenceClassifierOutputWithPast(  
            loss=loss,  
            logits=pooled_logits,  
            past_key_values=transformer_outputs.past_key_values,  
            hidden_states=transformer_outputs.hidden_states,  
            attentions=transformer_outputs.attentions,  
        )

三,训练模型

下面,我们来训练一个LlamaForCausalLM 实现两数之和的任务。

config = LlamaConfig(  
    vocab_size=len(vocab),  
    hidden_size=512,  
    intermediate_size=2752,  
    num_hidden_layers=8,  
    num_attention_heads=16,  
    num_key_value_heads=4,  
    rope_scaling = None,  
    hidden_act='silu',  
    max_position_embeddings=128,  
    initializer_range=0.02,  
    rms_norm_eps=1e-06,  
    use_cache=True,  
    pad_token_id=0,  
    bos_token_id=1,  
    eos_token_id=2,  
    tie_word_embeddings=False,  
    pretraining_tp = 1,  
    max_new_tokens = 100  
)   
  
#试算一下  
model = LlamaForCausalLM(config)  
out = model.forward(**batch)  
print(out.loss)  
tensor(2.7630, grad_fn=)

from torchkeras import KerasModel   
from accelerate import Accelerator   
  
class StepRunner:  
    def __init__(self, net, loss_fn, accelerator=None, stage = "train", metrics_dict = None,   
                 optimizer = None, lr_scheduler = None  
                 ):  
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage  
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler  
        self.accelerator = accelerator if accelerator is not None else Accelerator()   
        if self.stage=='train':  
            self.net.train()   
        else:  
            self.net.eval()  
      
    def __call__(self, batch):  
          
        #loss  
        with self.accelerator.autocast():  
            loss = self.net(**batch).loss  
  
        #backward()  
        if self.stage=="train" and self.optimizer is not None:          
            self.accelerator.backward(loss)  
            if self.accelerator.sync_gradients:  
                self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0)  
            self.optimizer.step()  
            if self.lr_scheduler is not None:  
                self.lr_scheduler.step()  
            self.optimizer.zero_grad()  
              
        all_loss = self.accelerator.gather(loss).sum()  
          
        #losses (or plain metrics that can be averaged)  
        step_losses = {self.stage+"_loss":all_loss.item()}  
          
        #metrics (stateful metrics)  
        step_metrics = {}  
          
        if self.stage=="train":  
            if self.optimizer is not None:  
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']  
            else:  
                step_metrics['lr'] = 0.0  
        return step_losses,step_metrics  
      
KerasModel.StepRunner = StepRunner   
keras_model = KerasModel(model,loss_fn = None,  
        optimizer=torch.optim.AdamW(model.parameters(),lr=3e-5))  
  
  
#加载 之前训练过的权重  
ckpt_path = 'llama_twosum'  
  
keras_model.fit(train_data = dl_train,  
                val_data = dl_val,  
                epochs=100,patience=5,  
                monitor='val_loss',mode='min',  
                ckpt_path = ckpt_path,  
                mixed_precision='fp16'  
               )

四,使用模型 

from transformers.generation.utils import GenerationConfig  
model.generation_config = GenerationConfig.from_dict({'num_beams':1,  
                            'max_new_tokens':100,  
                            'max_length':200})  
model.generation_config.num_beams=1  
model.generation_config.max_new_tokens = 100   
model.generation_config.max_length=200  
def get_ans(tensor) ->"str":  
    s = "".join([vocab_r[i] for i in tensor.tolist()])  
    ans = s[s.find('=')+1:s.find('<EOS>')].replace('<BOS>','').replace('<EOS>','')  
    return ans  
x,y = get_data()   
print('x: '+''.join(x).replace('<BOS>',''))  
print('y: '+''.join(y).replace('<EOS>',''))

x: 3481340050+90157504501803=

y: 90160985841853

input_ids = torch.tensor([[vocab[i] for i in x]])   
out = model.generate(inputs=input_ids)  
out

tensor([[ 1,  5,  6, 10,  3,  5,  6, 12, 12,  7, 12, 13, 11, 12,  3,  7,  9,  7, 12,  6,  7, 12,  3, 10, 12,  5, 14, 11, 12,  3,  8, 12, 11, 10,  7, 10, 6,  3, 10,  7,  5,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2, 2,  2,  2,  2,  2,  2,  2,  2,  2,  2, 12,  2,  2,  2,  2,  2,  2,  2, 2, 12,  3, 12,  3]])

get_ans(out[0])

'90160985841853'

五,评估模型

from tqdm import tqdm   
loop = tqdm(range(1,201))  
correct = 0  
for i in loop:  
    x,y = get_data()   
    input_ids = torch.tensor([[vocab[i] for i in x]])   
    out = model.generate(inputs=input_ids)  
    pred = get_ans(out[0])  
    gt = ''.join(y).replace('<EOS>','')  
    if pred==gt:  
        correct+=1  
    loop.set_postfix(acc = correct/i)  
      
print("acc=",correct/len(loop))

acc= 0.99漂亮,我们的测试准确率达到了99%!

#Replacing softmax with ReLU in Vision Transformers

对于视觉 Transformer,将其 Self-Attention 中的 Softmax 操作替换为 ReLU/序列长度 (seqlen) 之后,性能的下降问题有所缓解。本文在 ImageNet-21K 上训练了从 Small 级别到 Large 级别的视觉 Transformer,证明了 ReLU-attention 可以在缩放性上接近或者匹配 Softmax-attention 的性能。Google出品,使用ReLU取代Softmax,ViT性能不退化

本文的研究结论是:对于视觉 Transformer,将其 Self-Attention 中的 Softmax 操作替换为 ReLU/序列长度 (seqlen) 之后,性能的下降问题有所缓解。本文在 ImageNet-21K 上训练了从 Small 级别到 Large 级别的视觉 Transformer,证明了 ReLU-attention 可以在缩放性上接近或者匹配 Softmax-attention 的性能。

1 在 ViT 中使用 ReLU 取代 Softmax

论文名称: Replacing softmax with ReLU in Vision Transformers (Arxiv 2023)

论文地址:https//arxiv.org/pdf/2309.08586.pdf

1.1 ReLU-attention 的新发现

Transformer 架构[1]在现代机器学习中无处不在。Attention 是 Transformer 的核心组件,包括一个 Softmax 操作,它在 token 上产生概率分布。Softmax 操作涉及到内部的计算所有输入的指数之和,它的计算代价相当昂贵,使得 Transformer 架构的并行化具有挑战性[2]。

本文作者探索了 Softmax 操作的 Point-wise 的替代方案,该操作不一定输出概率分布。本文的核心贡献是观察到:ReLU/序列长度(seqlen) ,可以在缩放性方面接近或匹配传统的 Softmax 操作。这一结果为并行化提供了新的机会,因为 ReLU-attention 相比传统的 Softmax-attention 可以使用更少的 gather 操作在序列长度维度实现并行化。

1.2 去掉 Softmax 的相关工作

替换 Softmax 的研究:

  • ReLU 和 squared ReLU:[3][4]把 Softmax 替换成了 ReLU,[5]把 Softmax 替换成了 squared ReLU。但是这些方法不会除以序列长度,本文通过实验发现对于达到与 Softmax 相当的准确度很重要。
  • [6]仍然需要对序列长度轴进行归一化,以确保注意力权重之和为1,这依然需要 gather。

去掉激活函数的研究:

1.3 ReLU-attention 方法

在进行 Self-attention 的操作时,首先计算注意力权重:

图1:Scaled point-wise attention 实验结果

Sequence length scaling

1.4 实验结果

作者在 ImageNet-21K 上训练了 30 Epochs,在 ImageNet-1K 上训练了 300 Epochs。作者使用了 ViT-22B[10]中提出的 qk-norm 技术,因为这个技术被验证在扩大视觉模型时有益于优化稳定性,但是作者发现在本文量级的模型这一技术没那么重要。 

如下图2所示说明了 ReLU-attention 与 ImageNet-21K 训练的 Softmax-attention 的缩放趋势相匹配。x 轴表示实验所需的总 core hours。ReLU-attention 的优势是能够以比 Softmax-attention 以更少的 gather 操作对序列长度维度进行并行化。

图2:Softmax 操作替换为 ReLU/seqlen 的缩放性能与传统带有 qk-layernorm 的 Transformer 的缩放性能匹配

1.5 qk-norm 实验结果

本文主要实验使用了 qk-norm,其中 query 和 key 在计算注意力权重之前通过 LayerNorm 传递,作者发现有必要在扩大模型大小时防止不稳定性。如图3所示是 qk-layernorm 的实验结果。结果表明,qk-norm 对这些模型没有很大的影响。

图3:qk-norm 实验结果

1.6 添加 gate 的影响

[11]这个工作删除了 Softmax 之后,添加了一个门控单元,并且不按序列长度缩放。具体而言,在门控注意力单元中,通过额外的投影层产生输出,该输出在输出映射之前与注意力的结果做 Element-wise 的乘法。

如图4所示是添加 gate 的影响实验结果。作者研究了 gate 的存在是否消除了序列长度缩放的需要。总体而言,作者观察到无论有没有 gate 的存在,使用序列长度缩放都实现了最佳精度。注意到对于带有 ReLU 的 S/8 模型,添加 gate 操作将实验所需的 core hour 增加了大约 9.3%。

图4:添加 gate 的影响

#Transformer~目标检测算法汇总

都到了13了 ~~ 还是基于这个的么办法 自从VIT横空出世以来,Transformer在CV界掀起了一场革新,各个上下游任务都得到了长足的进步,然后盘点一下基于Transformer的端到端目标检测算法!

原始Tranformer检测器

DETR(ECCV2020)

开山之作!DETR!

代码链接:https://github.com/facebookresearch/detr

论文提出了一种将目标检测视为直接集预测问题的新方法。DETR简化了检测流程,有效地消除了对许多人工设计组件的需求,如NMS或anchor生成。新框架的主要组成部分,称为DEtection TRansformer或DETR,是一种基于集合的全局损失,通过二分匹配强制进行一对一预测,以及一种transformer encoder-decoder架构。给定一组固定的学习目标查询,DETR分析了目标和全局图像上下文之间的关系,以直接并行输出最后一组预测。与许多其他检测器不同,新模型概念简单,不需要专门的库。DETR在具有挑战性的COCO目标检测数据集上展示了与成熟且高度优化的Faster RCNN基线相当的准确性和运行时间。此外,DETR可以很容易地推广到以统一的方式输出全景分割。

DETR的网络结构如下图所示,从图中可以看出DETR由四个主要模块组成:backbone,编码器,解码器以及预测头。主干网络是经典的CNN,输出降采样32倍的feature。

实验结果如下所示,性能上倒是还不错,就是训练太慢了,300 epochs。

DETR还展示了COCO上的全景分割结果,可以看出实例区分能力还是比较有限,中间的Bus。  

Pix2seq(谷歌Hinton)

代码链接:https://github.com/google-research/pix2seq

一句话总结:一个简单而通用的目标检测新框架,其将目标检测转换为语言建模任务,大大简化了pipeline,性能可比肩Faster R-CNN和DETR!还可扩展到其他任务。

论文提出Pix2Seq,一个简单而通用的目标检测框架!!!与显式集成关于任务的先验知识的现有方法不同,Pix2seq将目标检测作为一个基于观察到的像素输入的语言建模任务。目标描述(例如,边界框和类标签)表示为离散token,训练神经网络来感知图像并生成所需序列。Pix2seq主要基于这样一种直觉,即如果神经网络知道目标的位置和内容,我们只需要教它如何read them out。除了使用特定于任务的数据扩充,Pix2seq对任务的假设最少,但与高度专业化和优化的检测算法相比,它在具有挑战性的COCO数据集上取得了有竞争力的结果。

网络主要包含四个组件:

  • 图像增强:正如在训练计算机视觉模型中常见的那样,论文使用图像增强来丰富一组固定的训练示例(例如,使用随机缩放和裁剪);
  • 序列构造和扩充:由于图像的目标注释通常表示为一组边界框和类标签,论文将它们转换为一系列离散token;
  • 架构:使用编码器-解码器模型,其中编码器感知像素输入,解码器生成目标序列(一次一个token);
  • 目标/损失函数:对模型进行训练,以最大化基于图像和先前token的token的对数似然性(使用softmax cross-entropy loss)。

 序列构造示意图:

 训练300 epochs,实验结果:

稀疏注意力Deformable DETR(ICLR 2021)

代码链接:https://github.com/fundamentalvision/Deformable-DETR

最近提出了DETR,以消除在物体检测中对许多手动设计部件的需要,同时证明了良好的性能。然而,由于Transformer注意力模块在处理图像特征图时的限制,它存在收敛速度慢和特征空间分辨率有限的问题。为了缓解这些问题,论文提出了Deformable DETR,其注意力模块只关注参考周围的一小组关键采样点。Deformable DETR可以实现比DETR更好的性能(特别是在小目标上),训练时间减少10倍。COCO基准的大量实验证明了算法的有效性。

 DETR存在的问题

  • 训练周期长,相比faster rcnn慢10-20倍!
  • 小目标性能差!通常用多尺度特征来解小目标,然而高分辨率的特征图大大提高DETR复杂度!

- 存在上述问题的原因

  • 初始化时,attention model对于特征图上所有像素权重几乎是统一的(即一个query与所有的k相乘的贡献图比较均匀,理想状况是q与高度相关且稀疏的k相关性更强),因此需要长时间学习更好的attention map;
  • 处理高分辨率特征存在计算量过大,存储复杂的特点;

- Motivation

  • 让encoder初始化的权重不再是统一分布,即不再与所有key计算相似度,而是与更有意义的key计算相似度可变形卷积就是一种有效关注稀疏空间定位的方式;
  • 提出deformable DETR,融合deformable conv的稀疏空间采样与transformer相关性建模能力在整体feature map像素中,模型关注小序列的采样位置作为预滤波,作为key。

实验结果

End-to-End Object Detection with Adaptive Clustering Transformer(北大&港中文)代码链接:https://github.com/gaopengcuhk/SMCA-DETR/DETR 

本文的主要贡献如下:

  • 开发了一种称为自适应聚类Transformer(ACT)的新方法,该方法可以降低DETR的推理成本。ACT可以降低原始Transformer的二次复杂度,同时ACT与原始Transformer完全兼容;
  • 将DETR的FLOPS从73.4 Gflops减少到58.2 Gflops(不包括骨干Resnet FLOPS),而无需任何训练过程,而AP的损失仅为0.7%;
  • 通过多任务知识蒸馏(MTKD)进一步将AP的损失降低到0.2%,该技术实现了ACT和原始Transformer之间的无缝切换。

实验结果如下:

PnP-DETR(ICCV 2021)

论文链接:GitHub - twangnh/pnp-detr: Implementation of ICCV21 paper: PnP-DETR: Towards Efficient Visual Analysis with Transformers

DETR虽然有效,但由于在某些区域(如背景)上的冗余计算,转换完整的特征图可能代价高昂。在这项工作中,论文将减少空间冗余的思想封装到一个新的poll and pool(PnP)采样模块中,利用该模块构建了一个端到端PnP DETR架构,该架构自适应地在空间上分配其计算,以提高效率。具体地说,PnP模块将图像特征映射抽象为精细的前景目标特征向量和少量粗略的背景上下文特征向量。Transformer对精细-粗糙特征空间内的信息交互进行建模,并将特征转换为检测结果。此外,通过改变采样特征长度,PnP增强模型可以立即在单个模型的性能和计算之间实现各种期望的权衡,而不需要像现有方法那样训练多个模型。因此,它为具有不同计算约束的不同场景中的部署提供了更大的灵活性。论文进一步验证了PnP模块在全景分割上的泛化性以及最近基于Transformer的图像识别模型ViT[7],并显示出一致的效率增益。论文认为PnP-DETR为使用Transformer进行有效的视觉分析迈出了一步,其中通常观察到空间冗余。

本文的主要贡献如下:

  • 分析了DETR模型中图像特征图的空间冗余问题,该问题导致transformer网络计算量过大。因此,提出对特征映射进行抽象,以显著降低模型运算量;
  • 设计了一种新颖的两步轮询池采样模块提取特征。该算法首先利用poll采样器提取前景精细特征向量,然后利用pool采样器获取上下文粗特征向量;
  • 构建了PnP-DETR,该变换在抽象的细粗特征空间上进行操作,并自适应地将计算分布在空间域。通过改变精细特征集的长度,PnP-DETR算法效率更高,在单一模型下实现了即时计算和性能折衷。
  • PnP抽样模块是通用的,是端到端学习的,没有像RPN那样的明确监督。论文进一步在全景分割和最近的ViT模型上对其进行了验证,并显示出一致的效率增益。这种方法为未来研究使用transformer的视觉任务的有效解决方案提供了有用的见解。实验结果如下: 

Sparse DETR(ICLR 2022)

代码链接:https://github.com/kakaobrain/sparse-detr

Deformable DETR使用多尺度特征来改善性能,然而,与DETR相比,encoder tokens的数量增加了20倍,encoder注意力的计算成本仍然是一个瓶颈。在本文的初步实验中,发现即使只更新了encoder tokens的一部分,检测性能也几乎不会恶化。受这一观察的启发,论文提出了Sparse DETR,它只选择性地更新decoder预期引用的令牌,从而帮助模型有效地检测目标。此外,在encoder中对所选token应用辅助检测损失可以提高性能,同时最小化计算开销。本文验证了Sparse DETR即使在COCO数据集上只有10%的encoder tokens,也比Deformable DETR获得更好的性能。尽管只有encoder tokens被稀疏化,但与Deformable DETR相比,总计算成本降低了38%,FPS增加了42%。

论文的主要贡献如下:

  • 提出了一种有效的端到端目标检测器的编码器token稀疏化方法,通过该方法减轻了编码器中的注意力复杂性。这种效率使得能够堆叠比Deformable DETR更多的编码器层,从而在相同的计算量下提高性能;
  • 提出了两个新的稀疏化标准来从整个token集合中采样信息子集:Objectness Score(OS)和Decoder cross-Attention Map(DAM)。基于decoder cross-attention map标准,稀疏模型即使在仅使用整个token的10%时也保持了检测性能;
  • 仅对所选token采用编码器辅助损失。这种额外的损失不仅稳定了学习过程,而且大大提高了性能,只略微增加了训练时间。

 实验结果如下: 

空间先验Fast Convergence of DETR with Spatially Modulated Co-Attention(ICCV 2021)

DETR的收敛速度较慢。从头开始训练DETR[4]需要500个epoch才能获得高精度。为了加速其收敛,本文提出了一种简单而有效的改进DETR框架的方案,即Spatially Modulated Co-Attention(SMCA)机制。SMCA的核心思想是通过将co-attention响应限制在初始估计的边界框位置附近的较高区域,在DETR中进行regression-aware co-attention。本文提出的SMCA通过替换decoder中的原始co-attention,同时保持DETR中的其他操作不变,提高了DETR的收敛速度。此外,通过将multi-head和scale-selection注意力设计集成到SMCA中,与基于空洞卷积的主干的DETR相比,本文的SMCA可以实现更好的性能。论文对COCO数据集进行了广泛的消融研究,以验证所提出的SMCA的有效性。

主要贡献如下:

  • 提出了一种新的空间调制共同注意(SMCA),它可以通过进行位置约束目标回归来加速DETR的收敛。SMCA是原始DETR中的即插即用模块。没有多尺度特征和多头注意力的SMCA的基本版本已经可以在50个epoch达到41.0 mAP,在108个时期达到42.7 mAP。将SMCA的基本版本训练50个时期需要265个V100 GPU小时。
  • 完整SMCA进一步集成了多尺度特征和多头空间调制,这可以通过更少的训练迭代进一步显著改进和超越DETR。SMCA在50个epoch可达到43.7mAP,在108个时期可实现45.6mAP,而DETR-DC5在500个时期可获得43.3mAP。将完整的SMCA训练50个epoch需要600 V100 GPU小时。
  • 对COCO 2017数据集进行了广泛的消融研究,以验证所提出的SMCA模块和网络设计。

动机

为了加速DETR收敛,本文通过动态预测一个2D的空间高斯weight map,来跟co-attention feature maps相乘来达到加快收敛速度的目的。即插即用,让DETR涨点明显。性能优于可变形DETR、DETR等网络。实验结果如下:

Conditional DETR(ICCV 2021)

本文针对DETR训练收敛缓慢这一关键问题,提出了一种用于快速DETR训练的conditional cross-attention机制。动机是DETR中的cross-attention高度依赖内容嵌入来定位和预测box,这增加了对高质量内容嵌入的需求,从而增加了训练难度。

本文的方法称为Conditional DETR,从解码器嵌入中学习条件空间query,用于解码器multi-head cross-attention。好处在于,通过条件空间query,每个交叉注意力头能够关注包含不同区域的band,例如,一个目标末端或目标框内的区域。这缩小了用于定位目标分类和box回归的不同区域的空间范围,从而放松了对内容嵌入的依赖,并简化了训练。实验结果表明,对于主干R50和R101,Conditional DETR收敛速度快6.7倍,对于更强的主干DC5-R50和DC5-R101,收敛速度快10倍。

动机

为了分析 DETR 为什么收敛慢,论文对 DETR decoder cross-attention 中的 spatial attention map 进行了可视化。

每个 head 的 spatial attention map 都在尝试找物体的一个 extremity 区域。论文认为,DETR 在计算 cross-attention 时,query 中的 content embedding 要同时和 key 中的 content embedding 以及 key 中的 spatial embedding 做匹配,这就对 content embedding 的质量要求非常高。而训练了 50 epoch 的DETR,因为 content embedding 质量不高,无法准确地缩小搜寻物体的范围,导致收敛缓慢。所以用一句话总结 DETR 收敛慢的原因,就是DETR 高度依赖高质量的 content embedding 去定位物体的 extremity 区域,而这部分区域恰恰是定位和识别物体的关键

基于此,提出Conditional DETR!

实验结果如下:

Anchor DETR(AAAI 2022)

代码链接:https://github.com/megvii-research/AnchorDETR

本文提出了一种新的基于Transfomrer的目标检测查询机制。在以前的基于Transfomrer的检测器中,object query是一组学习的嵌入。然而,每个学习到的嵌入都没有明确的物理意义,我们无法解释它将集中在哪里。由于每个object query的预测slot没有特定的模式,因此很难进行优化。换句话说,每个object query都不会关注特定区域。为了解决这些问题,在本文的query设计中,object query基于anchor point,这在基于CNN的检测器中被广泛使用。因此,每个object query都集中在anchor附近的目标上。此外,本文的query设计可以在一个位置预测多个目标以解决困难:“一个区域,多个目标”。此外,本文设计了一种注意力变体,它可以降低内存成本,同时实现与DETR中的标准注意力相似或更好的性能。由于query设计和注意力变体,本文方法名为Anchor DETR,可以实现比DETR更好的性能,并且运行速度比DETR更快。

回顾基于CNN的检测器,anchor与位置高度相关,包含可解释的意义。受此启发,作者提出了一种基于锚点(anchor points)的查询设计,即将anchor points编码为目标查询。查询是锚点坐标的编码,因此每个目标查询都具有显式的物理意义。    

但是,这个解决方案还有一个限制:多个目标可能出现在一个位置 。在这种情况下,只有这个位置的一个查询不能预测多个目标,因此来自其他位置的查询必须协同预测这些目标。它将导致每个目标查询负责一个更大的区域。因此,作者通过向每个锚点添加多个模式(multiple patterns,即一个锚点可以检测多个目标)来改进目标查询设计,以便每个锚点都可以预测多个目标

除了查询设计之外,作者还设计了一个attention变体—行列解耦注意(Row-Column Decouple Attention,RCDA) 。它将二维key特征解耦为一维行特征和一维列特征,然后依次进行行注意力和列注意力。RCDA可以降低计算成本,同时实现与DETR中的标准注意力相似甚至更好的性能。

实验结果如下: 

Efficient DETR(旷视)

DETR和Deformable DETR,具有堆叠6个解码器层的级联结构,以迭代更新object query,否则它们的性能会严重下降。本文研究了目标容器(包括object query和reference point)的随机初始化主要负责多次迭代的需求。基于论文的发现提出了Efficient DETR,这是一种用于端到端目标检测的简单高效的管道。通过利用密集检测和稀疏集合检测,Efficient DETR在初始化目标容器之前利用密集先验,并消除了1解码器结构和6解码器结构之间的差距。在MS COCO上进行的实验表明,本文的方法仅具有3个编码器层和1个解码器层,与最先进的目标检测方法相比,可以获得具有竞争力的性能。Efficient DETR在拥挤的场景中也很强大。它在CrowdHuman数据集上大大优于当期检测器。

 实验结果如下:

Dynamic DETR(ICCV 2021)

本文提出了一种新的Dynamic DETR(Transfomrer检测)方法,将动态注意力引入DETR的编码器和解码器阶段,以打破其在小特征分辨率和训练收敛慢方面的两个限制。为了解决第一个限制,这是由于Transformer编码器中的自注意力模块的二次计算复杂性,论文提出了一种动态编码器,以使用具有各种注意力类型的基于卷积的动态编码器来近似Transformer编码器的注意力机制。这种编码器可以基于诸如尺度重要性、空间重要性和表示(即,特征维度)重要性的多个因素来动态调整注意力。为了减轻学习难度的第二个限制,论文引入了一个动态解码器,通过在Transformer解码器中使用基于ROI的动态注意力来替换交叉注意力模块。这种解码器有效地帮助Transfomrer从coarse-to-fine地关注ROI,并显著降低学习难度,从而实现更快的收敛。论文进行了一系列实验来证明我们的优势。Dynamic DETR显著缩短了训练时间(减少了14倍),但性能要好得多(mAP提升3.6)。

本文的主要贡献如下:

  • 提出了一种新的Dynamic DETR方法,它相干地结合了基于动态卷积的编码器和基于动态Transformer的解码器。该方法显著提高了目标检测头的表示能力和学习效率,而无需任何计算开销。
  • 与原始的DETR相比,Dynamic DETR大大减少了训练时间(减少了14倍),但却显著提高了性能(3.6 mAP),如图1所示;
  • 是第一个在标准1x设置中实现优于传统性能的端到端方法,采用ResNet-50主干,42.9mAP。

 实验结果如下: 

结构重新设计Rethinking Transformer-based Set Prediction for Object Detection(ICCV 2021)

代码链接:GitHub: Let’s build from hereEdward-Sun/TSP-Detection

DETR是最近提出的一种基于Transformer的方法,它将目标检测视为一个集合预测问题,并实现了最先进的性能,但需要额外的训练时间来收敛。本文研究了DETR训练中优化困难的原因,揭示了导致DETR缓慢收敛的几个因素,主要是匈牙利损失和Transformer中co-attention的问题。为了克服这些问题,本文提出了两种解决方案,即TSP-FCOS(使用FCOS的基于Transformer的集合预测)和TSP-RCNN(使用RCNN的基于Transformer集合预测)。实验结果表明,所提出的方法不仅比原始DETR收敛更快,而且在检测精度方面显著优于DETR和其他基线。

  • TSP-FCOS:在backbone和encoder之间加上了head;
  • TSP-RCNN:在backbone和encoder之间加上了RoIAlign;

实验结果如下:

You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection(NeurIPS 2021)

代码链接:GitHub - hustvl/YOLOS: You Only Look at One Sequence (NeurIPS 2021)

Transformer能否在对2D空间结构了解最少的情况下,从纯sequence-to-sequence的角度进行2D目标和区域级别的识别?为了回答这个问题,论文提出了“你只看一个序列”(YOLOS),这是一系列基于朴素视觉Transformer的目标检测模型,具有最少的可能修改、区域优先级以及目标任务的归纳偏差。论文发现只有在中型ImageNet-1k数据集上预训练的YOLOS才能在COCO目标检测基准上获得相当有竞争力的性能,例如,直接采用BERT-Base架构的YOLOS-Base可以在COCO值上获得42.0 box AP。论文还通过YOLOS讨论了当前预训练方案和Transformer模型缩放策略的影响和局限性。 

本文的主要贡献如下:

  • 使用中等大小的ImageNet-1k[51]作为唯一的预训练数据集,并表明可以成功地迁移到普通ViT[21],以执行复杂的目标检测任务,并在COCO[36]基准上以最少的可能修改(即,only looking at one sequence(YOLOS))输出有竞争力的结果;
  • 首次证明,通过将固定大小的非重叠图像块序列作为输入,可以以纯序列到序列的方式完成2D目标检测。在现有的物体检测器中,YOLOS利用最小的2D感应偏置。
  • 对于朴素ViT,论文发现目标检测结果对预训练方案非常敏感,并且检测性能远未饱和。因此,所提出的YOLOS也可以作为一项具有挑战性的基准任务,以评估不同的(标签监督和自监督)ViT预训练策略。

实验结果如下:

匹配优化DN-DETR(CVPR 2022)

代码链接:https://github.com/FengLi-ust/DN-DETR

本文提出了一种新的去噪训练方法,以加速DETR(DEtection TRansformer)训练,并加深了对类DETR方法的收敛慢问题的理解。本文认为收敛缓慢是由于二分匹配的不稳定性导致的,这在早期训练阶段导致了不一致的优化目标。为了解决这个问题,除了匈牙利损失外,论文还将带有噪声的GT框输入Transformer解码器,并训练模型以重建原始框,这有效地降低了二分匹配的难度,并可以更快的收敛。本文的方法是通用的,可以通过添加几十行代码轻松地插入到任何类DETR的方法中,以实现显著的改进。因此,DN-DETR在相同的设置下产生了显著的改进(+1.9AP)。与相同设置下的基线相比,DN-DETR在50%的训练时间内实现了可比的性能。

本文的主要贡献如下:

  • 设计了一种新的训练方法来加速DETR训练。实验结果表明,我们的方法不仅加快了训练收敛,而且导致了显著更好的训练结果—在12个epoch设置下,在所有检测算法中获得最佳结果。此外,我们的方法显示出比基线DAB-DETR显著的改进(+1.9AP),并且可以很容易地集成到其他类DETR的方法中;
  • 从一个新的角度分析了DETR的缓慢收敛,并对DETR训练有了更深入的理解。设计了一个度量来评估二分匹配的不稳定性,并验证了我们的方法可以有效地降低不稳定性;
  • 进行了一系列消融研究,以分析我们模型中不同组件的有效性,如噪声、标签嵌入和注意力mask。

实验结果后如下: 

DINO

代码链接:https://github.com/IDEACVR/DINO

本文提出DINO,这是一种先进的端到端目标检测器。DINO通过使用对比的去噪训练方法、anchor初始化的混合query选择方法和box预测的look forward twice方案,在性能和效率上改进了以前的类DETR模型。DINO在具有ResNet-50主干和多尺度特征的COCO上实现了12个epoch的49.4 AP和24个epoch的51.3AP,与之前最好的类DETR的模型DN-DETR相比,分别显著提高了+6.0 AP和+2.7 AP。DINO在模型大小和数据大小方面都具有很好的扩展性。没有任何trick,在使用SwinL主干的Objects365数据集上进行预训练后,DINO在COCO val 2017(63.2AP)和测试集(63.3AP)上都获得了最好的结果。与排行榜上的其他模型相比,DINO显著减少了其模型大小和预训练数据大小,同时获得了更好的结果。

本文的主要贡献如下:

  • 设计了一种新的端到端类DETR的目标检测器,采用了几种新技术,包括对比DN训练、混合查询选择,并对DINO模型的不同部分进行了两次前向。
  • 进行了深入的消融研究,以验证DINO中不同设计选择的有效性。因此,DINO通过ResNet-50和多尺度特征在12个epoch内达到49.4AP,在24个epoch内实现51.3AP,显著优于之前最好的类DETR的模型。特别是,在12个epoch训练的DINO在小目标上表现出更显著的改善,提高了+7.5AP。
  • 不用任何trick,DINO可以在公共基准上取得最好的成绩。在使用SwinL[23]主干对Objects365[33]数据集进行预训练后,DINO在COCO val2017(63.2AP)和测试集(63.3AP)基准上都取得了最好的结果。据我们所知,这是端到端Transformer检测首次在COCO排行榜上超过最先进(SOTA)模型[1]。实验结果如下:


原文地址:https://blog.csdn.net/weixin_49587977/article/details/145161322

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!