自学内容网 自学内容网

PixArt--alpha笔记

PixArt-α 是华为发布的文生图模型。

  • 训练策略分解:设计三个不同训练步骤,分别优化像素依赖、文本图像对齐和图像审美质量。
  • 高效T2I transformer:将 cross-attention融入 Diffusion Transformer (DiT)注入文本条件,简化计算密集的class条件分支。
  • 高信息量数据:强调文本图像对中概念密度的重要性,利用LLaVA自动标注密集伪标题辅助学习。

具体的介绍可参考 文生图模型之PixArt-α华为·文生图:PixArt模型系列。本文主要看一下代码。

PixArt-α 使用了T5-xxl来作为文本编码器,T5模型比较大,因此使用PixArt-α github提供的8GB GPU VRAM方法。
先使用T5计算text embedding,然后卸载模型,再进行后面的生成阶段。
按照程序执行,加载模型的时候会报错 “does not have a parameter or a buffer named y_embedding.”。
这是因为模型是分三阶段训练的,第一阶段使用了类别条件,类别条件在预测的时候并没有用,但是参数却保存在模型中,导致不能正常加载。需要使用github中提供的方法,手动转换模型。
Convert .pth checkpoint into diffusers version

python tools/convert_pixart_alpha_to_diffusers.py --image_size your_img_size --multi_scale_train (True if you use PixArtMS else False) --orig_ckpt_path path/to/pth --dump_path path/to/diffusers --only_transformer=True

#转换中去掉了这两组参数
state_dict.pop("pos_embed")  #由get_2d_sincos_pos_embed计算
state_dict.pop("y_embedder.y_embedding")

Running the PixArtAlphaPipeline in under 8GB GPU VRAM

from diffusers import PixArtAlphaPipeline
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
import torch
from transformers import T5EncoderModel

import gc 

def flush():
    gc.collect()
    torch.cuda.empty_cache()

quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)

text_encoder_2_8bit = T5EncoderModel.from_pretrained(
    "~/.cache/modelscope/hub/AI-ModelScope/t5-v1_1-xxl",
    # subfolder="text_encoder_2",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

pipe = PixArtAlphaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-1024-MS",
    text_encoder=text_encoder_2_8bit,
    transformer=None,
    device_map="balanced"
)

with torch.no_grad():
    prompt = "A cute cat"
    prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)

del text_encoder_2_8bit
del pipe
flush()

pipe = PixArtAlphaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-1024-MS",
    text_encoder=None,
    torch_dtype=torch.float16,
).to("cuda")

latents = pipe(
    negative_prompt=None, 
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_embeds,
    prompt_attention_mask=prompt_attention_mask,
    negative_prompt_attention_mask=negative_prompt_attention_mask,
    num_images_per_prompt=1,
    output_type="latent",
).images

del pipe.transformer
flush()

with torch.no_grad():
    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")
image[0].save("cat1.png")
# image[0]

PixArt-α模型结构

在这里插入图片描述

PixArt-α 仍然是一个diffusion模型,使用DiffusionPipeline的流程,下面只列出PixArt-α改动的部分。
在每次预测noise的时候,timestep由AdaLayerNormSingle计算得到,然后在每个PixArt-α的BasicTransformerBlock计算的时候,
shift、scale、gates由正太分布随机数加上timestep得到。

timestep, embedded_timestep = self.adaln_single(
            timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
        )

self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)

 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
     self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
 ).chunk(6, dim=1)

在AdaLayerNormSingle中,timestep由PixArtAlphaCombinedTimestepSizeEmbeddings计算得到。

class AdaLayerNormSingle(nn.Module):
    r"""
    Norm layer adaptive layer norm single (adaLN-single).

    As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).

    Parameters:
        embedding_dim (`int`): The size of each embedding vector.
        use_additional_conditions (`bool`): To use additional conditions for normalization or not.
    """

    def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
        super().__init__()

        self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
            embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
        )

        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)

    def forward(
        self,
        timestep: torch.Tensor,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        batch_size: Optional[int] = None,
        hidden_dtype: Optional[torch.dtype] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # No modulation happening here.
        added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
        embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
        return self.linear(self.silu(embedded_timestep)), embedded_timestep

在PixArtAlphaCombinedTimestepSizeEmbeddings 中,timesteps 由 TimestepEmbedding计算得到,并且在默认情况下,使用use_additional_conditions,加上了resolution 和aspect_ratio。

class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
    """
    For PixArt-Alpha.

    Reference:
    https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
    """

    def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
        super().__init__()

        self.outdim = size_emb_dim
        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)

        self.use_additional_conditions = use_additional_conditions
        if use_additional_conditions:
            self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
            self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
            self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)

    def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
        timesteps_proj = self.time_proj(timestep)
        timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D) (2,1152)

        if self.use_additional_conditions:
            resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
            resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
            aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
            aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
            conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
        else:
            conditioning = timesteps_emb

        return conditioning
class TimestepEmbedding(nn.Module):
    def __init__(
        self,
        in_channels: int,
        time_embed_dim: int,
        act_fn: str = "silu",
        out_dim: int = None,
        post_act_fn: Optional[str] = None,
        cond_proj_dim=None,
        sample_proj_bias=True,
    ):
        super().__init__()

        self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)

        if cond_proj_dim is not None:
            self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
        else:
            self.cond_proj = None

        self.act = get_activation(act_fn)

        if out_dim is not None:
            time_embed_dim_out = out_dim
        else:
            time_embed_dim_out = time_embed_dim
        self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)

        if post_act_fn is None:
            self.post_act = None
        else:
            self.post_act = get_activation(post_act_fn)

    def forward(self, sample, condition=None):
        if condition is not None:
            sample = sample + self.cond_proj(condition)
        sample = self.linear_1(sample)

        if self.act is not None:
            sample = self.act(sample)

        sample = self.linear_2(sample)

        if self.post_act is not None:
            sample = self.post_act(sample)
        return sample

在这里插入图片描述
重参数化,为了使用预训练的权重,但是又不使用训练DiT所需要的类别c,所以用一个全局MLP和特定的可训练的Embedding 来替换。初始化的时候选择 t=500 的值。


原文地址:https://blog.csdn.net/zhilaizhiwang/article/details/145215320

免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!