Diffusion Transformer模型结构解析(DiT、SD3、Flux)
Diffusion Transformer模型结构解析(DiT、SD3、Flux)
本文将通过 DiT、SD3、Flux 三个 DiT 相关工作,介绍 Diffusion 中的 Transformer 结构的应用与演进。注意 SD3 和 Flux 采用的 Flow Matching 的扩散模型形式化当然是很关键的改进,但是本文主要聚焦于它们在模型结构方面的改进,因此不对 Flow Matching 作过多介绍。
DiT
Diffusion 仅要求其去噪网络是一个输入输出等尺寸的 image-to-image 模型,基于 CNN 的 UNet 是一个很自然的选择。但近些年来,凭借着更强的全局理解能力和 scaling 能力,Transformer 在视觉领域也大放异彩。DiT 正是使用 Transformer 替换掉 UNet,验证了 Transformer 在生图模型上的 scaling 能力。并试验了 in-context、cross-attention、adaLN 等不同的条件注入方式。
patchify
Transformer 模型处理的都是一维的 token 序列。要将二维图像送入 Transformer 处理,首先要对图片进行 patchify。所谓 patchify,其实就是将图片拆分为小图块,通过线性层对齐模型的维度,再加上位置编码用于表征各 token 的位置信息。记图像的尺寸为 I × I × C I\times I\times C I×I×C,取 patch size 为 p p p,那么就会得到 T = ( I / p ) 2 T=(I/p)^2 T=(I/p)2 个图块,再经过线性层,得到 T T T 个维度为 d d d 的视觉 token。之后,DiT 中加入了不可学习的 sin-cos 位置编码来表达各 token 的空间位置信息。
conditioning
在 patchify 之后,就可以把各个 patch token 送到 DiT 中进行处理,进行无条件图片生成了。但我们一般需要做条件生成(如文生图),因此还需要将条件进行编码并输入到模型中。这里 DiT 考虑的是相对简单的将类别作为条件,而不是将文本作为条件,但其实从模型结构的角度来看影响不大,都是将条件提取 embedding,然后注入到模型中。在 LDM 中,是通过 cross-attention 的方式输入到 UNet 模型中的。那么在 Transformer 结构中,如何把条件注入到模型中呢?
DiT 尝试了以下几类方式:
- In-Context Conditioning:将两个 embedding 作为两个 special tokens 拼接到图像块 token 后,类似 ViT 中的 cls token,实现起来比较简单,基本没有额外的计算量。
- Cross-Attention:将两个 embedding 拼接起来,然后在 transformer block 中插入一个 cross attention,将 embedding 作为 cross attention的 K 和 V;这也是,该方法需要引入的额外计算量最大,约增加 15%。
- Adaptive Layer Norm (adaLN):adaLN 在 GAN 这类生成模型中的应用非常广泛。将常规的 LN 替换为 adaLN,回归 scale 和 shift 两个参数,这种方式也基本不增加计算量。
- adaLN-Zero:即采用零初始化,将 adaLN 的线性层参数初始化为零,网络初始化时 transformer block 的残差模块就是一个 identity 函数。除了回归scale 和 shift,还在每个残差模块结束之前回归一个 scale。
实验结果显示,adaLN-Zero 是最优的条件方式。
SD3
整体上,SD3 的模型结构还是延续了 SD 系列 “LDM” 的思路,即先通过 VAE 将像素空间的图片压缩到隐层空间,在隐层空间进行去噪训练。生成时,也是先生成隐层 latent,在解码回像素空间。从整体模型结构上来看,SD3 还是由扩散模型、VAE、文本编码器三大部分组成,但各部分都有一定的改进。其中 VAE 是增加了通道数,提升了细节还原能力。而文本编码器和扩散模型,以及它们的搭配方式,是 SD3 改进的重点,我们接下来详细介绍。
文本编码器
作为文生图模型,对文本条件的理解和遵循是评价效果的关键指标之一。在之前的 SD 中,文本条件是使用预训练的 CLIP 模型进行特征提取,然后通过 cross attention 的形式将文本 embedding 注入到 UNet 的各层中,这个过程中的两步都还有可以提升的空间:
-
首先,可以使用更强的语言模型进行文本特征的编码,Imagen 早在 2022 就通过指出了更大更强的文本编码器对于文生图模型在语义理解和画面质量上的重要作用,但是 SD 系列模型迟迟没有采用足够强大的大语言模型来编码文本特征,在 SD3 中,终于用上了 T5,与两个 CLIP 模型搭配,共同编码文本条件。
-
其次,注入文本条件的方式也有改进的空间。之前 cross attention 的形式比较适合于 UNet 这类 CNN 模型,在切换到 Transformer 之后,SD3 文本条件的注入方式也有所改变。
其实直觉上一种更直接的条件注入方式是直接将文本 embedding 与 latent 拼接到一起。
在 SD3 中,一共使用了三个预训练的文本编码器,分别是 CLIP ViT-L (~124M)、OpenCLIP ViT-bigG (~695M)、T5-XXL (~4.7B)。其中,三个模型都是输出 77 个 tokens。
首先两个 CLIP 分别对文本编码得到 77 × 768 77\times768 77×768 和 77 × 1280 77\times 1280 77×1280 的特征,T5 则得到 77 × 4096 77\times 4096 77×4096 的特征,这三组文本特征通过不同的方式组合,得到两个文本特征,它们分别会在 MM-DiT 中通过不同的方式应用。具体来说,
- 一方面,两组 CLIP 特征分别在 token 维度经过池化,得到特征向量,并拼接起来,得到 1 × 2048 ( = 1280 + 768 ) 1\times 2048(=1280+768) 1×2048(=1280+768) 的特征向量,这就是图中的文本特征,该特征会与时间步向量加和后得到 y y y;
- 另一方面,两组 CLIP 特征直接拼接后得到 77 × 2048 77\times 2048 77×2048 的特征,经过 zero padding 后,与 T5 的特征形状相同为 77 × 4096 77\times 4096 77×4096,将 CLIP 特征和 T5 特征再拼接,得到形状为 154 ( = 77 + 77 ) × 4096 154(=77+77)\times4096 154(=77+77)×4096 的文本特征 c c c(注:代码里实际上把 CLIP 和 T5 的特征直接拼起来了,并没有限制 T5 的 seqlen 只有 77,即 333 ( = 77 + 256 ) × 4096 333(=77+256)\times4096 333(=77+256)×4096)。
MM DiT
然后就是我们的重点:MM-DiT。一个 MM-DiT Block 的详细结构如上图右侧所示,看起来非常复杂,我们拆解开来看。
实际上,每个 block 只有 y , c , x y,c,x y,c,x 三个输入,其中 y , c y,c y,c 是我们上面刚介绍的 CLIP 和 T5 编码出的两个文本特征,而 x x x 就是噪声图经过 patchify 得到的 token 序列。具体的 patchify 方式与上面 DiT 类似,不再赘述。
我们先看 y y y,图中看着线很乱,但其实 y y y 并没有与 block 的核心结构发生作用,而是只在左右两侧处理。 y y y 实际的作用就是进行 DiT 中提到过的 adaLN modulation,用于计算一共 α , β , γ ; δ , ϵ , ξ \alpha,\beta,\gamma;\delta,\epsilon,\xi α,β,γ;δ,ϵ,ξ 六组参数,每组两个分别对应 c c c 和 x x x,每个 block 内共 12 个参数,该参数会用于 c , x c,x c,x 各自 modulation 的计算。除了计算这 12 个参数, y y y 不参与 block 中的任何其他操作了,仔细对比 DiT 和 MM-DiT 的图示,这一部分其实与 DiT 中的 adaLN 完全一致,所谓 modulation (Mod) 就是 DiT 中的 scale,shift。
再看 c c c 和 x x x,分别是一个文本输入和一个图片输入,所以称为 MM (MultiModal) DiT。它们以双流的形式在 block 中分别进行处理,但是一起做 Attention。以 x x x 为例( c c c 的处理是对称的),在每个 block 中,首先经过一个 LN,然后根据 y y y 计算出的参数进行 modulation 控制,之后经过一个线性层之后,与 c c c 对应拼接起来共同计算 QKV Attention。之后再依次经过 linear、scale、residual、layernorm、scale+shift (mod)、mlp、scale、residual 就完成了一个 block 的完整计算。
在 SD3 对模型结构的改进上,还有两处细节需要注意。一是在 Attention 的 Q 和 K 的处理时,有一个 RMS Norm,这称为 QK-Normalization,是为了在混合精度训练时提升训练稳定性。避免在模型变大、分辨率变高时 attention logit 出现 nan。二是 SD3 采用了扩展+插值的 2D 位置编码方式来提升训练分辨率并适应可变长宽比。
Flux
Flux 由 SD3 原班人马出走后创立的 BFL 公司打造。在模型结构上,Flux 整体上延续了 SD3 MM DiT 的设计(如下图所示,图源),但是在几处细节上又有改进,我们着重介绍这些 Flux 与 SD3 的不同点:
- 文本编码器部分,Flux 取消了 SD3 中的 1 个 CLIP 模型,只使用了 1 个 CLIP 模型和 1 个 T5 模型来编码文本条件;
- Flux 采用了更先进的旋转位置编码 RoPE,详情可参考苏剑林大佬的博客;
- Flux 的在 MM-DiT(DoubleStreamTransformer)之后,将文本和图像拼接,在送入到 SingleStreamTransformer 中进行处理。这样能够降低单层的参数量,增大网络深度;
- 为了进行 CFG 蒸馏,Flux dev 版本的 DiT 需要显式地直接接受 guidance scale 作为条件。这个条件与 timestep 条件类似,分别经过正弦 embedding 后加在一起
总结
从 UNet 迁移到 DiT,可以利用 Transformer 模型的 scaling 能力,通过增大参数量来提升出图的质量。在这个迁移过程中,我们需要考虑 Transformer 应用于扩散模型时的 patchify、positional encoding、conditioning 等几个重要环节。DiT 首先提出,确立了 patchify 的方法,并实验得到 adaLN 是比较适合 Diffusion Transformer 的 conditioning 方式;SD3 中引入了多个强大的文本编码器,参考 DiT 中 adaLN 的 conditioning 方式,结合文本特征设计出了多模态的 MM-DiT;Flux 则进一步引入了 RoPE,并通过将 Transformer 的后几层改为 SingleStream 的 block 来提升参数效率。
原文地址:https://blog.csdn.net/weixin_44966641/article/details/143772347
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!