PixArt--alpha笔记
PixArt-α 是华为发布的文生图模型。
- 训练策略分解:设计三个不同训练步骤,分别优化像素依赖、文本图像对齐和图像审美质量。
- 高效T2I transformer:将 cross-attention融入 Diffusion Transformer (DiT)注入文本条件,简化计算密集的class条件分支。
- 高信息量数据:强调文本图像对中概念密度的重要性,利用LLaVA自动标注密集伪标题辅助学习。
具体的介绍可参考 文生图模型之PixArt-α, 华为·文生图:PixArt模型系列。本文主要看一下代码。
PixArt-α 使用了T5-xxl来作为文本编码器,T5模型比较大,因此使用PixArt-α github提供的8GB GPU VRAM方法。
先使用T5计算text embedding,然后卸载模型,再进行后面的生成阶段。
按照程序执行,加载模型的时候会报错 “does not have a parameter or a buffer named y_embedding.”。
这是因为模型是分三阶段训练的,第一阶段使用了类别条件,类别条件在预测的时候并没有用,但是参数却保存在模型中,导致不能正常加载。需要使用github中提供的方法,手动转换模型。
Convert .pth checkpoint into diffusers version
python tools/convert_pixart_alpha_to_diffusers.py --image_size your_img_size --multi_scale_train (True if you use PixArtMS else False) --orig_ckpt_path path/to/pth --dump_path path/to/diffusers --only_transformer=True
#转换中去掉了这两组参数
state_dict.pop("pos_embed") #由get_2d_sincos_pos_embed计算
state_dict.pop("y_embedder.y_embedding")
Running the PixArtAlphaPipeline in under 8GB GPU VRAM
from diffusers import PixArtAlphaPipeline
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
import torch
from transformers import T5EncoderModel
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
text_encoder_2_8bit = T5EncoderModel.from_pretrained(
"~/.cache/modelscope/hub/AI-ModelScope/t5-v1_1-xxl",
# subfolder="text_encoder_2",
quantization_config=quant_config,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=text_encoder_2_8bit,
transformer=None,
device_map="balanced"
)
with torch.no_grad():
prompt = "A cute cat"
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
del text_encoder_2_8bit
del pipe
flush()
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=None,
torch_dtype=torch.float16,
).to("cuda")
latents = pipe(
negative_prompt=None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_images_per_prompt=1,
output_type="latent",
).images
del pipe.transformer
flush()
with torch.no_grad():
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")
image[0].save("cat1.png")
# image[0]
PixArt-α模型结构
PixArt-α 仍然是一个diffusion模型,使用DiffusionPipeline的流程,下面只列出PixArt-α改动的部分。
在每次预测noise的时候,timestep由AdaLayerNormSingle计算得到,然后在每个PixArt-α的BasicTransformerBlock计算的时候,
shift、scale、gates由正太分布随机数加上timestep得到。
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
在AdaLayerNormSingle中,timestep由PixArtAlphaCombinedTimestepSizeEmbeddings计算得到。
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
在PixArtAlphaCombinedTimestepSizeEmbeddings 中,timesteps 由 TimestepEmbedding计算得到,并且在默认情况下,使用use_additional_conditions,加上了resolution 和aspect_ratio。
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) (2,1152)
if self.use_additional_conditions:
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
else:
conditioning = timesteps_emb
return conditioning
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
重参数化,为了使用预训练的权重,但是又不使用训练DiT所需要的类别c,所以用一个全局MLP和特定的可训练的Embedding 来替换。初始化的时候选择 t=500 的值。
原文地址:https://blog.csdn.net/zhilaizhiwang/article/details/145215320
免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!