自学内容网 自学内容网



一、 TimestepEmbedSequential

代码中class ResBlock继承自TimestepBlock,需要执行时间步嵌入操作,其他不需要。

class TimestepBlock(nn.Module):
    Any module where forward() takes timestep embeddings as a second argument.

    def forward(self, x, emb):
        Apply the module to `x` given `emb` timestep embeddings.

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.

    def forward(self, x, emb):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
                x = layer(x)
        return x
class ResBlock(TimestepBlock):


def checkpoint(func, inputs, params, flag):
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.

    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
        return func(*inputs)

checkpoint 是在 torch.no_grad() 模式下计算的目标操作的前向函数,这并不会修改原本的叶子结点的状态,有梯度的还会保持。只是关联这些叶子结点的临时生成的中间变量会被设置为不需要梯度,因此梯度链式关系会被断开。


class AttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=1, use_checkpoint=False):
        self.channels = channels
        self.num_heads = num_heads
        self.use_checkpoint = use_checkpoint

        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
        self.attention = QKVAttention()
        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

    def forward(self, x):
        return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)

    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.qkv(self.norm(x))
        qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
        h = self.attention(qkv)
        h = h.reshape(b, -1, h.shape[-1])
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)

class QKVAttention(nn.Module):
    A module which performs QKV attention.

    def forward(self, qkv):
        Apply QKV attention.

        :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x C x T] tensor after attention.
        ch = qkv.shape[1] // 3
        q, k, v = th.split(qkv, ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        return th.einsum("bts,bcs->bct", weight, v)

    def count_flops(model, _x, y):
        b, c, *spatial = y[0].shape
        num_spatial = int(np.prod(spatial))
        # We perform two matmuls with the same number of ops.
        # The first computes the weight matrix, the second computes
        # the combination of the value vectors.
        matmul_ops = 2 * b * (num_spatial ** 2) * c
        model.total_ops += th.DoubleTensor([matmul_ops])


    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)-》输入转换为(b,c,N)
        qkv = self.qkv(self.norm(x))-》通过卷积转换为(b,3*c,N)
        qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
        h = self.attention(qkv)
        h = h.reshape(b, -1, h.shape[-1])
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)

class QKVAttention的forward就是下面的公式:


    def _forward(self, x, emb):
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h

在深度学习中,特别是在处理如扩散模型(Diffusion Models)或任何需要精细控制输出特征的神经网络时,use_scale_shift_norm引入一种灵活的变换,这种变换通过缩放(scale)和平移(shift)来调整网络层的输出。

提取缩放和平移参数:接下来,从emb_out(可能是嵌入层的输出或其他某种特征表示)中提取缩放(scale)和平移(shift)参数。这里假设emb_out的维度被设计为包含这两组参数,通过th.chunk(emb_out, 2, dim=1)沿着第二维(dim=1)将其分割成两部分,分别代表缩放和平移参数。
应用缩放和平移:然后,将h(可能是之前某个层的输出)通过out_norm层进行变换,之后使用从emb_out中提取的缩放和平移参数对结果进行调整。调整的方式是将out_norm(h)的输出乘以(1 + scale)并加上shift。这个步骤实质上是在对out_norm(h)的输出进行线性变换,以引入额外的灵活性和控制。

