自学内容网 自学内容网

从代码角度深入浅出讲解潜在扩散模型(LDM)

1. 前言:AI图像生成的革命

假如你输入一句话:“在宇宙中漂浮的巨型熊猫正抱着地球”,不出几秒钟,AI就给你生成了这样一张图。这种看似魔法般的能力,正是近年来人工智能领域一个突破性技术——扩散模型(Diffusion Model)带来的革命。

在这里插入图片描述

接下来,我们将从头开始,一步步拆解潜在扩散模型的原理、流程和实现细节,帮你理解这项技术的巨大潜力以及为什么它能成为当前AI生成领域的核心引擎。准备好了吗?让我们开始探索吧!


2. 什么是扩散模型?

在未来的某一天,你走进一个画展,墙上展示的每幅画都是由AI生成的,细腻逼真到你几乎无法分辨。突然间,你脑中冒出了一个问题:这些神奇的画作是怎么来的? 从随机的噪声到这样充满创意的图像,这中间到底发生了什么?答案就藏在今天要探讨的主题中——扩散模型

扩散模型的核心可以简单地理解为“破坏”和“创造”两个过程。想象我们拿到一幅画,一个冲动的画家决定把它撕成碎片,而另一个细心的修复师将它一点一点拼回完整,甚至还能补充一些未曾有过的细节,使它看起来更美。扩散模型正是依据类似的过程,从完全随机的噪声中生成高清的图像。

接下来,让我们通过两个核心问题带你走进扩散模型的世界:

  1. 如何一步步地将图片变成随机噪声? (前向扩散)
  2. 如何反过来从随机噪声中创造清晰的图片? (反向扩散)

2.1 先破坏:从图片到噪声

我们生活在一个耳熟能详的高科技时代,拍摄的照片清晰高分辨率,但如果给它逐步增加噪声,直到图片完全被破坏,会是什么样的场景?你也许会发现,它最终变成了屏幕上的“小雪花”——也就是所谓的白噪声。这就是扩散模型的第一个步骤:前向扩散

接下来,让我们想象一个场景,展示逐步破坏的过程:

  1. 你拿到了一张清晰明亮的风景照;
  2. 你有意地往图片上“泼些墨水”(噪声),比例很小(0.1),图片轻微模糊;
  3. 然后增加更多噪声(0.5),背景几乎看不清了;
  4. 最后100%的噪声,让图片完全变成了一个随机的“无序世界”。

在这里插入图片描述

视觉效果

输出的图片应该从最开始的清晰内容逐渐模糊,最后完全被破坏,像一场实验的全过程。你能够清晰地看到:

  • 0% 噪声 → 清晰无比的原版图片;
  • 25% 噪声 → 色彩丰富,但细节开始丢失;
  • 70% 噪声 → 基本看不清内容,大多数区域充满纹理点;
  • 100% 噪声 → 画面彻底变成了无规则粒子构成的白噪声。

这个过程不仅揭示了人们对图像渐进变化的好奇心,更让读者思考,这一片毫无意义的噪声究竟还能做些什么?


2.2 再恢复:从噪声到图片

好了,现在画家已经亲手毁掉了这幅画,难道就再也没有办法修复了吗?扩散模型说:“别急,从这些碎片里,我们可以用一定的方法把它的本来面目恢复出来,甚至让它变得更加艺术化!” 于是,我们进入了扩散的第二阶段 反向扩散:从完全随机的噪声变回清晰、美丽的图片。

这一过程好比你得到了拼图游戏的一个框架,尽管碎片完全随机,你却可以依靠模式逐步放回对的位置。反向扩散的有趣之处就在于:它既在还原内容,同时也在悄悄“创造新东西”。那么它到底是怎么做到的?

想象一下我们现在做了如下还原尝试:

  1. 从完全白噪声开始,画面毫无内容(初始状态);
  2. 第一步,噪声减少了一点,隐约出现物体轮廓;
  3. 第二步,更多噪声被移除,主体清晰度提升;
  4. 最终,一个完整的世界呈现出来,甚至比你以前看到的照片更生动!

在这里插入图片描述

视觉效果

这个过程将是整个扩散模型最令人震撼的部分。你会看到:

  • 100% 噪声 → 随机无序,噪声完全占领了画面;
  • 初步去噪 → 开始有模糊的线条出现;
  • 中途去噪 → 主体物体的定义更加清晰,背景模糊;
  • 完全恢复 → 一个完整的清晰图片出现在眼前。

通过这种方式,扩散模型不仅可以还原“已经存在过的图片”,更可以从完全随机的初始噪声中,生成全新的高清图像。


2.3 探索扩散背后的逻辑

现在,让我们回顾这个过程:从图片到噪声,再从噪声到图片。扩散模型的核心其实是一次“破坏后的修复”,它包含两个紧密连接的阶段:

  1. 前向扩散(破坏):不断加入随机噪声,让图片逐步变模糊,最终成为完全无序的噪声;
  2. 反向扩散(生成):从随机噪声出发,逐步还原主体内容,甚至生成全新的内容。

我来用简单的Python代码解释扩散模型的基本原理。

  1. 首先是前向扩散过程(Forward Diffusion):
import numpy as np

def forward_diffusion(image, timesteps=1000):
    # 初始化一个噪声系数表,范围从接近0到接近1
    betas = np.linspace(0.0001, 0.02, timesteps)
    
    # 保存每个时间步的图像
    noisy_images = []
    current_image = image.copy()
    
    # 逐步添加噪声
    for t in range(timesteps):
        # 生成随机噪声
        noise = np.random.normal(size=image.shape)
        
        # 将噪声按比例添加到图像中
        # 图像 = √(1-β) * 原图 + √β * 噪声
        current_image = np.sqrt(1 - betas[t]) * current_image + np.sqrt(betas[t]) * noise
        noisy_images.append(current_image)
    
    return noisy_images
  1. 然后是反向扩散过程(Reverse Diffusion):
def reverse_diffusion(noise, model, timesteps=1000):
    # 初始化同样的噪声系数表
    betas = np.linspace(0.0001, 0.02, timesteps)
    
    # 从纯噪声开始
    current_image = noise.copy()
    
    # 逐步去除噪声
    for t in reversed(range(timesteps)):
        # 使用训练好的模型预测噪声
        predicted_noise = model.predict_noise(current_image, t)
        
        # 根据预测的噪声来恢复图像
        # 图像 = (当前图像 - β * 预测噪声) / √(1-β)
        current_image = (current_image - betas[t] * predicted_noise) / np.sqrt(1 - betas[t])
    
    return current_image

关键点说明:

# 在实际应用中,模型结构会更复杂,通常是U-Net架构
class DiffusionModel:
    def predict_noise(self, noisy_image, timestep):
        # 这里是模型的核心:预测给定时间步的图像中的噪声
        # 实际实现会使用深度神经网络
        pass

    def train(self, image_dataset):
        # 训练过程:
        # 1. 对原始图片加入已知的随机噪声
        # 2. 让模型学习预测这个已知的噪声
        # 3. 通过多次迭代优化模型参数
        pass

模型通过大量训练数据学习如何"猜测"和移除噪声,最终能够从随机噪声生成符合特定分布的新图像。


2.4 扩散模型的魔力

如果我们要用一句话来总结扩散模型,那就是:从无序中创造有序的艺术。通过对噪声的精准控制,它既可以还原已经存在过的影像,也可以生成全新的、不曾存在的内容。正因为此,扩散模型被应用于图像生成、修复、编辑、艺术创造等多个领域,让AI具备了更灵活更强大的能力。

这一部分让我们清楚地了解了扩散模型的基础逻辑。接下来,我们将深入潜在扩散模型(LDM)的优化机制,看看它是如何加速这个过程并将其效率提升上百倍的。


3. 为什么需要潜在扩散模型 (LDM)?

  • 扩散模型的局限性:
    • 高分辨率图片如何导致计算量爆炸。
    • 无法直接在像素空间高效训练。
  • LDM 的核心创新:
    • 引入潜在空间,压缩图片到低维度表示。
  • 类比解释潜在空间:从“高清地图”到“简单草图”。

4. 完整工作原理:从输入到生成

  • 步骤解析:
    1. 将高清图片映射到潜在空间。
    2. 在潜在空间中添加噪声,并利用扩散模型还原。
    3. 解码潜在表示,生成高清图像。
  • 对比普通像素空间扩散模型和LDM执行效率的优劣。

在这里插入图片描述

5. LDM的核心架构

  • 分解LDM架构的三大模块:
    1. 自编码器:如何压缩图像成为潜在表示?
    2. U-Net扩散模型:在潜在空间中进行加噪和去噪。
    3. 条件编码器:如何通过文字、情景等条件控制生成的结果?

在这里插入图片描述


6. 代码实现:拆解LDM的关键组件

6.1 LDM的核心组件以及完整实现

LDM主要由以下几个部分组成:

  1. VAE (AutoencoderKL):将图像压缩到潜空间
  2. CLIP Text Encoder:处理文本提示词
  3. UNet:核心的噪声预测网络
  4. Scheduler:控制去噪过程的调度器

让我们先看看基于diffusers的整体架构:

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

class TimeEmbedding(nn.Module):
    """
    时间步嵌入模块:将时间步转换为特征向量
    """
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
        # 创建时间嵌入的线性层
        self.linear_1 = nn.Linear(dim, dim * 4)
        self.linear_2 = nn.Linear(dim * 4, dim)
        
    def forward(self, time: torch.Tensor) -> torch.Tensor:
        # 计算位置嵌入
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        
        # 通过MLP处理
        embeddings = self.linear_1(embeddings)
        embeddings = F.silu(embeddings)
        embeddings = self.linear_2(embeddings)
        return embeddings

class CrossAttention(nn.Module):
    """
    交叉注意力模块:用于处理条件信息(如文本特征)
    """
    def __init__(self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8):
        super().__init__()
        inner_dim = query_dim
        context_dim = context_dim if context_dim is not None else query_dim
        
        self.heads = heads
        self.scale = (query_dim // heads) ** -0.5
        
        # 定义注意力的Q,K,V转换矩阵
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim)
        
    def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
        context = context if context is not None else x
        
        # 计算Q,K,V
        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)
        
        # 重塑张量以进行多头注意力计算
        q = q.reshape(x.shape[0], -1, self.heads, q.shape[-1] // self.heads).permute(0, 2, 1, 3)
        k = k.reshape(context.shape[0], -1, self.heads, k.shape[-1] // self.heads).permute(0, 2, 1, 3)
        v = v.reshape(context.shape[0], -1, self.heads, v.shape[-1] // self.heads).permute(0, 2, 1, 3)
        
        # 计算注意力
        attention = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attention = attention.softmax(dim=-1)
        
        # 应用注意力
        out = torch.matmul(attention, v)
        out = out.permute(0, 2, 1, 3).reshape(x.shape[0], -1, q.shape[-1] * self.heads)
        return self.to_out(out)

class ResnetBlock(nn.Module):
    """
    残差块:包含时间条件和可选的交叉注意力
    """
    def __init__(self, in_channels: int, out_channels: int, temb_channels: int, 
                 groups: int = 32, use_attention: bool = False):
        super().__init__()
        self.norm1 = nn.GroupNorm(groups, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(groups, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        
        # 时间嵌入投影
        self.temb_proj = nn.Linear(temb_channels, out_channels)
        
        # 残差连接
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()
            
        if use_attention:
            self.attn = CrossAttention(out_channels)
        else:
            self.attn = None
            
    def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
        h = x
        h = self.norm1(h)
        h = F.silu(h)
        h = self.conv1(h)
        
        # 添加时间嵌入
        temb = self.temb_proj(F.silu(temb))[:, :, None, None]
        h = h + temb
        
        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)
        
        # 应用注意力(如果有)
        if self.attn is not None:
            h = h.reshape(h.shape[0], h.shape[1], -1).transpose(1, 2)
            h = self.attn(h)
            h = h.transpose(1, 2).reshape(x.shape[0], -1, x.shape[2], x.shape[3])
        
        return h + self.shortcut(x)

class UNet2DConditionModel(nn.Module):
    """
    条件UNet模型:用于生成扩散模型
    """
    def __init__(
        self,
        in_channels: int = 4,
        out_channels: int = 4,
        model_channels: int = 320,
        time_embed_dim: int = 1280,
        context_dim: int = 768,
        attention_levels: Tuple[bool, ...] = (False, True, True, True),
    ):
        super().__init__()
        
        # 时间嵌入
        self.time_embed = TimeEmbedding(model_channels)
        self.time_embed_mlp = nn.Sequential(
            nn.Linear(model_channels, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )
        
        # 初始卷积
        self.input_blocks = nn.ModuleList([
            nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1)
        ])
        
        # 下采样块
        current_channels = model_channels
        channel_multipliers = [1, 2, 4, 4]
        for level, use_attention in enumerate(attention_levels):
            # 添加残差块
            for _ in range(2):
                layers = [ResnetBlock(
                    current_channels,
                    current_channels * channel_multipliers[level],
                    time_embed_dim,
                    use_attention=use_attention
                )]
                current_channels = current_channels * channel_multipliers[level]
                self.input_blocks.append(nn.ModuleList(layers))
            
            # 添加下采样
            if level != len(attention_levels) - 1:
                self.input_blocks.append(nn.Conv2d(
                    current_channels, current_channels,
                    kernel_size=3, stride=2, padding=1
                ))
        
        # 中间块
        self.middle_block = nn.ModuleList([
            ResnetBlock(current_channels, current_channels, time_embed_dim, use_attention=True),
            ResnetBlock(current_channels, current_channels, time_embed_dim, use_attention=False),
        ])
        
        # 上采样块
        self.output_blocks = nn.ModuleList([])
        for level, use_attention in enumerate(reversed(attention_levels)):
            for _ in range(3):
                layers = [ResnetBlock(
                    current_channels + self.input_blocks[-1].out_channels,
                    current_channels // 2,
                    time_embed_dim,
                    use_attention=use_attention
                )]
                current_channels = current_channels // 2
                self.output_blocks.append(nn.ModuleList(layers))
            
            # 添加上采样
            if level != len(attention_levels) - 1:
                self.output_blocks.append(nn.Upsample(
                    scale_factor=2, mode='nearest'
                ))
        
        # 最终输出层
        self.out = nn.Sequential(
            nn.GroupNorm(32, current_channels),
            nn.SiLU(),
            nn.Conv2d(current_channels, out_channels, kernel_size=3, padding=1),
        )
        
    def forward(
        self,
        x: torch.Tensor,
        timesteps: torch.Tensor,
        context: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # 1. 时间嵌入
        temb = self.time_embed(timesteps)
        temb = self.time_embed_mlp(temb)
        
        # 2. 下采样路径
        h = x
        hs = []
        for module in self.input_blocks:
            if isinstance(module, nn.ModuleList):
                for layer in module:
                    h = layer(h, temb)
            else:
                h = module(h)
            hs.append(h)
        
        # 3. 中间块
        for module in self.middle_block:
            h = module(h, temb)
        
        # 4. 上采样路径
        for module in self.output_blocks:
            if isinstance(module, nn.ModuleList):
                h = torch.cat([h, hs.pop()], dim=1)
                for layer in module:
                    h = layer(h, temb)
            else:
                h = module(h)
        
        # 5. 输出
        return self.out(h)

# 使用示例
def test_unet():
    # 创建模型
    model = UNet2DConditionModel(
        in_channels=4,
        out_channels=4,
        model_channels=320,
        attention_levels=(False, True, True, True)
    )
    
    # 创建测试输入
    batch_size = 4
    x = torch.randn(batch_size, 4, 64, 64)
    timesteps = torch.randint(0, 1000, (batch_size,))
    context = torch.randn(batch_size, 77, 768)  # 文本特征
    
    # 前向传播
    output = model(x, timesteps, context)
    print(f"输入形状: {x.shape}")
    print(f"输出形状: {output.shape}")

if __name__ == "__main__":
    test_unet()

6.2 预训练模块 vs 需要训练的模块

6.2.1 预训练模块(参数冻结)

  1. VAE (AutoencoderKL)

    • 作用:将高维图像压缩到低维潜空间
    • 为什么冻结:VAE已经在大量图像上预训练,具有良好的压缩和重建能力
    • 使用预训练模型:CompVis/stable-diffusion-v1-4
  2. CLIP Text Encoder

    • 作用:将文本转换为向量表示
    • 为什么冻结:CLIP模型已经学习到了强大的文本理解能力
    • 使用预训练模型:openai/clip-vit-large-patch14

6.2.2 需要训练的模块

  1. UNet
    • 作用:预测噪声,是整个系统的核心
    • 为什么需要训练:需要学习特定领域的去噪能力
    • 训练目标:准确预测加入的噪声

6.3 手动实现核心模块

为了便于理解,手动实现其中的关键模块PNDM和UNet2DConditionModel

  1. PNDM调度器
    调度器控制着去噪过程,下面是一个简化实现:
import torch
import numpy as np
from typing import List, Optional, Tuple, Union

class PNDMScheduler:
    """
    PNDM (Pseudo Numerical Methods for Diffusion Models) 调度器
    这是一个改进的采样方法,结合了数值方法来提高采样质量和速度
    """
    
    def __init__(
        self,
        num_train_timesteps: int = 1000,        # 训练时的总时间步数
        beta_start: float = 0.00085,            # β初始值
        beta_end: float = 0.012,                # β结束值
        beta_schedule: str = "scaled_linear",    # β调度方式
        skip_prk_steps: bool = True,            # 是否跳过Runge-Kutta步骤
    ):
        """
        初始化PNDM调度器
        
        参数:
            num_train_timesteps: 训练过程中的总时间步数
            beta_start: β范围的起始值
            beta_end: β范围的结束值
            beta_schedule: β调度类型 ("linear" 或 "scaled_linear")
            skip_prk_steps: 是否跳过Plms步骤中的Runge-Kutta步骤
        """
        self.num_train_timesteps = num_train_timesteps
        self.skip_prk_steps = skip_prk_steps
        
        # 1. 创建beta序列
        if beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps)
        elif beta_schedule == "scaled_linear":
            # Stable Diffusion论文中使用的缩放线性调度
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps) ** 2
        else:
            raise ValueError(f"未知的beta调度类型: {beta_schedule}")
        
        # 2. 计算扩散过程中的关键变量
        # alphas表示保留原始图像信息的比例
        self.alphas = 1.0 - self.betas
        # alphas_cumprod是alphas的累积乘积
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        
        # 3. 计算用于采样的重要参数
        # sqrt_alphas_cumprod用于前向过程
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        # sqrt_one_minus_alphas_cumprod用于添加噪声
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        
        # 4. 初始化PNDM特定的变量
        self.cur_model_output = None  # 存储当前模型输出
        self.counter = 0              # 步数计数器
        self.prk_timesteps = None     # Runge-Kutta时间步
        self.plms_timesteps = None    # PLMS时间步
        self.timesteps = None         # 所有时间步
        
        # 用于初始化噪声的系数
        self.init_noise_sigma = self.sqrt_one_minus_alphas_cumprod.max()
        
    def set_timesteps(self, num_inference_steps: int):
        """
        设置推理过程的时间步
        
        参数:
            num_inference_steps: 推理时要使用的步数
        """
        # 1. 计算时间步
        step_ratio = self.num_train_timesteps // num_inference_steps
        timesteps = torch.arange(0, num_inference_steps) * step_ratio
        timesteps = timesteps.flip(0)  # 反转顺序,从最大噪声开始
        
        # 2. 设置PNDM特定的时间步
        if not self.skip_prk_steps:
            # Runge-Kutta步骤的时间步(前4步)
            self.prk_timesteps = timesteps[:4]
            # PLMS步骤的时间步(剩余步骤)
            self.plms_timesteps = timesteps[4:]
        else:
            # 跳过Runge-Kutta步骤,直接使用PLMS
            self.prk_timesteps = None
            self.plms_timesteps = timesteps
        
        self.timesteps = timesteps
        self.cur_model_output = None
        self.counter = 0
        
    def scale_model_input(self, sample: torch.Tensor, timestep: int) -> torch.Tensor:
        """
        根据时间步缩放模型输入
        
        参数:
            sample: 输入样本
            timestep: 当前时间步
        """
        return sample
    
    def _get_prev_sample(self, sample: torch.Tensor, timestep: int, 
                        model_output: torch.Tensor) -> torch.Tensor:
        """
        计算前一个时间步的样本
        
        参数:
            sample: 当前样本
            timestep: 当前时间步
            model_output: 模型预测的噪声
        """
        alpha = self.alphas[timestep]
        alpha_prev = self.alphas_cumprod[timestep-1] if timestep > 0 else torch.tensor(1.0)
        
        # 计算预测的前一步样本
        pred_sample = (sample - (1 - alpha).sqrt() * model_output) / alpha.sqrt()
        pred_sample_coef = alpha_prev.sqrt() * (1 - alpha) / (1 - alpha_prev)
        current_sample_coef = alpha.sqrt() * (1 - alpha_prev) / (1 - alpha_prev)
        
        prev_sample = pred_sample_coef * pred_sample + current_sample_coef * sample
        return prev_sample
    
    def step(self, model_output: torch.Tensor, timestep: int, 
            sample: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        执行PNDM采样步骤
        
        参数:
            model_output: 模型预测的噪声
            timestep: 当前时间步
            sample: 当前样本
        
        返回:
            prev_sample: 前一个时间步的样本
            model_output: 当前模型输出
        """
        # 1. 存储当前模型输出用于PLMS
        if self.counter < 3 or self.cur_model_output is None:
            self.cur_model_output = model_output
        
        # 2. 根据计数器执行不同的更新策略
        if self.counter < 3:
            # 前几步使用普通的去噪步骤
            prev_sample = self._get_prev_sample(sample, timestep, model_output)
        else:
            # PLMS更新步骤
            model_outputs = [model_output]
            for i in range(3):
                model_outputs.append(self.cur_model_output[i])
            
            # 使用PLMS公式计算下一步
            next_model_output = model_outputs[0] + (model_outputs[0] - model_outputs[1]) + \
                              (model_outputs[2] - model_outputs[3])
            prev_sample = self._get_prev_sample(sample, timestep, next_model_output)
        
        # 3. 更新存储的模型输出
        if self.counter < 3:
            self.cur_model_output = [model_output] if self.cur_model_output is None \
                else [model_output] + self.cur_model_output
        
        # 4. 更新计数器
        self.counter += 1
        
        return prev_sample, model_output


#使用示例:

# 创建调度器实例
scheduler = PNDMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear"
)

# 设置推理步数
scheduler.set_timesteps(num_inference_steps=50)

# 在生成过程中使用
# 假设我们有一个latent和一个UNet模型
latents = torch.randn(1, 4, 64, 64)  # 初始随机噪声
latents = latents * scheduler.init_noise_sigma

# 去噪循环
for t in scheduler.timesteps:
    # 1. 预测噪声
    noise_pred = model(latents, t)
    
    # 2. 计算前一步的样本
    latents, _ = scheduler.step(noise_pred, t, latents)

  1. UNet条件模型
    UNet是整个系统的核心,负责噪声预测:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

class TimeEmbedding(nn.Module):
    """
    时间步嵌入模块:将时间步转换为特征向量
    """
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
        # 创建时间嵌入的线性层
        self.linear_1 = nn.Linear(dim, dim * 4)
        self.linear_2 = nn.Linear(dim * 4, dim)
        
    def forward(self, time: torch.Tensor) -> torch.Tensor:
        # 计算位置嵌入
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        
        # 通过MLP处理
        embeddings = self.linear_1(embeddings)
        embeddings = F.silu(embeddings)
        embeddings = self.linear_2(embeddings)
        return embeddings

class CrossAttention(nn.Module):
    """
    交叉注意力模块:用于处理条件信息(如文本特征)
    """
    def __init__(self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8):
        super().__init__()
        inner_dim = query_dim
        context_dim = context_dim if context_dim is not None else query_dim
        
        self.heads = heads
        self.scale = (query_dim // heads) ** -0.5
        
        # 定义注意力的Q,K,V转换矩阵
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim)
        
    def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
        context = context if context is not None else x
        
        # 计算Q,K,V
        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)
        
        # 重塑张量以进行多头注意力计算
        q = q.reshape(x.shape[0], -1, self.heads, q.shape[-1] // self.heads).permute(0, 2, 1, 3)
        k = k.reshape(context.shape[0], -1, self.heads, k.shape[-1] // self.heads).permute(0, 2, 1, 3)
        v = v.reshape(context.shape[0], -1, self.heads, v.shape[-1] // self.heads).permute(0, 2, 1, 3)
        
        # 计算注意力
        attention = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attention = attention.softmax(dim=-1)
        
        # 应用注意力
        out = torch.matmul(attention, v)
        out = out.permute(0, 2, 1, 3).reshape(x.shape[0], -1, q.shape[-1] * self.heads)
        return self.to_out(out)

class ResnetBlock(nn.Module):
    """
    残差块:包含时间条件和可选的交叉注意力
    """
    def __init__(self, in_channels: int, out_channels: int, temb_channels: int, 
                 groups: int = 32, use_attention: bool = False):
        super().__init__()
        self.norm1 = nn.GroupNorm(groups, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(groups, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        
        # 时间嵌入投影
        self.temb_proj = nn.Linear(temb_channels, out_channels)
        
        # 残差连接
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()
            
        if use_attention:
            self.attn = CrossAttention(out_channels)
        else:
            self.attn = None
            
    def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
        h = x
        h = self.norm1(h)
        h = F.silu(h)
        h = self.conv1(h)
        
        # 添加时间嵌入
        temb = self.temb_proj(F.silu(temb))[:, :, None, None]
        h = h + temb
        
        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)
        
        # 应用注意力(如果有)
        if self.attn is not None:
            h = h.reshape(h.shape[0], h.shape[1], -1).transpose(1, 2)
            h = self.attn(h)
            h = h.transpose(1, 2).reshape(x.shape[0], -1, x.shape[2], x.shape[3])
        
        return h + self.shortcut(x)

class UNet2DConditionModel(nn.Module):
    """
    条件UNet模型:用于生成扩散模型
    """
    def __init__(
        self,
        in_channels: int = 4,
        out_channels: int = 4,
        model_channels: int = 320,
        time_embed_dim: int = 1280,
        context_dim: int = 768,
        attention_levels: Tuple[bool, ...] = (False, True, True, True),
    ):
        super().__init__()
        
        # 时间嵌入
        self.time_embed = TimeEmbedding(model_channels)
        self.time_embed_mlp = nn.Sequential(
            nn.Linear(model_channels, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )
        
        # 初始卷积
        self.input_blocks = nn.ModuleList([
            nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1)
        ])
        
        # 下采样块
        current_channels = model_channels
        channel_multipliers = [1, 2, 4, 4]
        for level, use_attention in enumerate(attention_levels):
            # 添加残差块
            for _ in range(2):
                layers = [ResnetBlock(
                    current_channels,
                    current_channels * channel_multipliers[level],
                    time_embed_dim,
                    use_attention=use_attention
                )]
                current_channels = current_channels * channel_multipliers[level]
                self.input_blocks.append(nn.ModuleList(layers))
            
            # 添加下采样
            if level != len(attention_levels) - 1:
                self.input_blocks.append(nn.Conv2d(
                    current_channels, current_channels,
                    kernel_size=3, stride=2, padding=1
                ))
        
        # 中间块
        self.middle_block = nn.ModuleList([
            ResnetBlock(current_channels, current_channels, time_embed_dim, use_attention=True),
            ResnetBlock(current_channels, current_channels, time_embed_dim, use_attention=False),
        ])
        
        # 上采样块
        self.output_blocks = nn.ModuleList([])
        for level, use_attention in enumerate(reversed(attention_levels)):
            for _ in range(3):
                layers = [ResnetBlock(
                    current_channels + self.input_blocks[-1].out_channels,
                    current_channels // 2,
                    time_embed_dim,
                    use_attention=use_attention
                )]
                current_channels = current_channels // 2
                self.output_blocks.append(nn.ModuleList(layers))
            
            # 添加上采样
            if level != len(attention_levels) - 1:
                self.output_blocks.append(nn.Upsample(
                    scale_factor=2, mode='nearest'
                ))
        
        # 最终输出层
        self.out = nn.Sequential(
            nn.GroupNorm(32, current_channels),
            nn.SiLU(),
            nn.Conv2d(current_channels, out_channels, kernel_size=3, padding=1),
        )
        
    def forward(
        self,
        x: torch.Tensor,
        timesteps: torch.Tensor,
        context: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # 1. 时间嵌入
        temb = self.time_embed(timesteps)
        temb = self.time_embed_mlp(temb)
        
        # 2. 下采样路径
        h = x
        hs = []
        for module in self.input_blocks:
            if isinstance(module, nn.ModuleList):
                for layer in module:
                    h = layer(h, temb)
            else:
                h = module(h)
            hs.append(h)
        
        # 3. 中间块
        for module in self.middle_block:
            h = module(h, temb)
        
        # 4. 上采样路径
        for module in self.output_blocks:
            if isinstance(module, nn.ModuleList):
                h = torch.cat([h, hs.pop()], dim=1)
                for layer in module:
                    h = layer(h, temb)
            else:
                h = module(h)
        
        # 5. 输出
        return self.out(h)

# 使用示例
def test_unet():
    # 创建模型
    model = UNet2DConditionModel(
        in_channels=4,
        out_channels=4,
        model_channels=320,
        attention_levels=(False, True, True, True)
    )
    
    # 创建测试输入
    batch_size = 4
    x = torch.randn(batch_size, 4, 64, 64)
    timesteps = torch.randint(0, 1000, (batch_size,))
    context = torch.randn(batch_size, 77, 768)  # 文本特征
    
    # 前向传播
    output = model(x, timesteps, context)
    print(f"输入形状: {x.shape}")
    print(f"输出形状: {output.shape}")

if __name__ == "__main__":
    test_unet()


7. LDM的实际应用案例

下面是一个基于diffusers的LDM完整训练代码示例,包含数据集构建、训练循环等完整流程

7.1 完整的代码

import os
import torch
import requests
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm.auto import tqdm
from accelerate import Accelerator
import logging
import numpy as np

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ImageTextDataset(Dataset):
    """
    自定义图文对数据集
    """
    def __init__(self, data_pairs, tokenizer, image_size=512):
        """
        参数:
            data_pairs: List[Dict] 图文对列表,格式如:
                       [{"image_url": "http://...", "text": "描述文本"}, ...]
            tokenizer: CLIP tokenizer
            image_size: 图片大小
        """
        self.data_pairs = data_pairs
        self.tokenizer = tokenizer
        
        # 图像预处理
        self.image_transforms = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        
    def __len__(self):
        return len(self.data_pairs)
    
    def __getitem__(self, idx):
        item = self.data_pairs[idx]
        
        # 下载并处理图片
        try:
            response = requests.get(item["image_url"])
            image = Image.open(BytesIO(response.content)).convert("RGB")
            image = self.image_transforms(image)
        except Exception as e:
            logger.error(f"Error loading image {item['image_url']}: {e}")
            # 返回一个随机噪声图片作为替代
            image = torch.randn(3, 512, 512)
            
        # 处理文本
        text = item["text"]
        encoded_text = self.tokenizer(
            text,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            "image": image,
            "input_ids": encoded_text.input_ids[0],
            "attention_mask": encoded_text.attention_mask[0]
        }

class LDMTrainer:
    def __init__(
        self,
        pretrained_model_name="CompVis/stable-diffusion-v1-4",
        device="cuda",
        train_batch_size=1,
        eval_batch_size=2,  # 新增验证batch size
        gradient_accumulation_steps=4,
        mixed_precision="fp16",
        learning_rate=1e-5,
        max_train_steps=50000,
        num_warmup_steps=100,
        save_steps=1000,    # 新增保存步数
        eval_steps=500,     # 新增验证步数
        output_dir="./ldm_model",  # 新增输出目录
    ):
        self.pretrained_model_name = pretrained_model_name
        self.device = device
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.mixed_precision = mixed_precision
        self.learning_rate = learning_rate
        self.max_train_steps = max_train_steps
        self.num_warmup_steps = num_warmup_steps
        self.save_steps = save_steps
        self.eval_steps = eval_steps
        self.output_dir = output_dir
        
        # 创建输出目录
        os.makedirs(output_dir, exist_ok=True)
        
        # 初始化加速器
        self.accelerator = Accelerator(
            gradient_accumulation_steps=gradient_accumulation_steps,
            mixed_precision=mixed_precision,
        )
        
        # 初始化模型组件
        self._init_models()
        
    def _init_models(self):
        """初始化所有模型组件"""
        # 1. VAE
        self.vae = AutoencoderKL.from_pretrained(
            self.pretrained_model_name,
            subfolder="vae"
        )
        self.vae.requires_grad_(False)
        
        # 2. Text Encoder
        self.tokenizer = CLIPTokenizer.from_pretrained(
            "openai/clip-vit-large-patch14"
        )
        self.text_encoder = CLIPTextModel.from_pretrained(
            "openai/clip-vit-large-patch14"
        )
        self.text_encoder.requires_grad_(False)
        
        # 3. UNet
        self.unet = UNet2DConditionModel.from_pretrained(
            self.pretrained_model_name,
            subfolder="unet"
        )
        
    
    def prepare_data(self, train_pairs, val_pairs):
        """准备训练和验证数据加载器"""
        train_dataset = ImageTextDataset(train_pairs, self.tokenizer)
        val_dataset = ImageTextDataset(val_pairs, self.tokenizer)
        
        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=4
        )
        
        self.val_dataloader = DataLoader(
            val_dataset,
            batch_size=self.eval_batch_size,
            shuffle=False,
            num_workers=4
        )
        
    def prepare_training(self):
        """准备训练所需的优化器和调度器"""
        # 优化器
        self.optimizer = torch.optim.AdamW(
            self.unet.parameters(),
            lr=self.learning_rate
        )
        
        # 学习率调度器
        self.lr_scheduler = get_scheduler(
            "cosine",
            optimizer=self.optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=self.max_train_steps,
        )
        
        # 使用accelerator准备模型和数据加载器
        self.unet, self.optimizer, self.train_dataloader, self.val_dataloader, self.lr_scheduler = \
            self.accelerator.prepare(
                self.unet, self.optimizer, self.train_dataloader, 
                self.val_dataloader, self.lr_scheduler
            )
     
        # 将其他模型移到设备上
        self.vae = self.vae.to(self.accelerator.device)
        self.text_encoder = self.text_encoder.to(self.accelerator.device)

def validate(self):
        """验证步骤"""
        self.unet.eval()
        val_loss = 0
        val_steps = 0
        
        for batch in self.val_dataloader:
            with torch.no_grad():
                # 1. 将图像转换到潜空间
                latents = self.vae.encode(
                    batch["image"].to(dtype=self.vae.dtype)
                ).latent_dist.sample()
                latents = latents * 0.18215
                
                # 2. 获取文本嵌入
                encoder_hidden_states = self.text_encoder(
                    batch["input_ids"].to(self.accelerator.device)
                )[0]
                
                # 3. 添加噪声
                noise = torch.randn_like(latents)
                timesteps = torch.randint(
                    0, self.noise_scheduler.config.num_train_timesteps,
                    (latents.shape[0],), device=latents.device
                )
                noisy_latents = self.noise_scheduler.add_noise(
                    latents, noise, timesteps
                )
                
                # 4. 预测噪声
                noise_pred = self.unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states
                ).sample
                
                # 5. 计算损失
                loss = torch.nn.functional.mse_loss(
                    noise_pred.float(),
                    noise.float(),
                    reduction="mean"
                )
                
                val_loss += loss.detach().item()
                val_steps += 1
                
        avg_val_loss = val_loss / val_steps
        self.unet.train()
        return avg_val_loss

    def save_checkpoint(self, step, val_loss):
        """保存检查点"""
        if self.accelerator.is_main_process:
            checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
            os.makedirs(checkpoint_dir, exist_ok=True)
            
            # 保存UNet模型
            unwrapped_unet = self.accelerator.unwrap_model(self.unet)
            unwrapped_unet.save_pretrained(os.path.join(checkpoint_dir, "unet"))
            
            # 保存优化器和调度器状态
            torch.save({
                "step": step,
                "optimizer_state_dict": self.optimizer.state_dict(),
                "lr_scheduler_state_dict": self.lr_scheduler.state_dict(),
                "val_loss": val_loss
            }, os.path.join(checkpoint_dir, "optimizer_state.pt"))
            
            logger.info(f"Saved checkpoint at step {step} to {checkpoint_dir}")
        
    def train(self):
        """训练循环"""
        # 设置进度条
        progress_bar = tqdm(
            total=self.max_train_steps,
            disable=not self.accelerator.is_local_main_process
        )
        
        global_step = 0
        
        # 开始训练循环
        while global_step < self.max_train_steps:
            for batch in self.train_dataloader:
                with self.accelerator.accumulate(self.unet):
                    # 1. 将图像转换到潜空间
                    with torch.no_grad():
                        latents = self.vae.encode(
                            batch["image"].to(dtype=self.vae.dtype)
                        ).latent_dist.sample()
                        latents = latents * 0.18215
                    
                    # 2. 获取文本嵌入
                    with torch.no_grad():
                        encoder_hidden_states = self.text_encoder(
                            batch["input_ids"].to(self.accelerator.device)
                        )[0]
                    
                    # 3. 添加噪声
                    noise = torch.randn_like(latents)
                    timesteps = torch.randint(
                        0, self.noise_scheduler.config.num_train_timesteps,
                        (latents.shape[0],), device=latents.device
                    )
                    noisy_latents = self.noise_scheduler.add_noise(
                        latents, noise, timesteps
                    )
                    
                    # 4. 预测噪声
                    noise_pred = self.unet(
                        noisy_latents,
                        timesteps,
                        encoder_hidden_states
                    ).sample
                    
                    # 5. 计算损失
                    loss = torch.nn.functional.mse_loss(
                        noise_pred.float(),
                        noise.float(),
                        reduction="mean"
                    )
                    
                    # 6. 反向传播
                    self.accelerator.backward(loss)
                    if self.accelerator.sync_gradients:
                        self.accelerator.clip_grad_norm_(self.unet.parameters(), 1.0)
                    self.optimizer.step()
                    self.lr_scheduler.step()
                    self.optimizer.zero_grad()
                
                # 更新进度条
                if self.accelerator.sync_gradients:
                    progress_bar.update(1)
                    global_step += 1
                    
                    if global_step % 100 == 0:
                        logs = {"loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]}
                        progress_bar.set_postfix(**logs)
                  
                        # 验证
                        if global_step % self.eval_steps == 0:
                            val_loss = self.validate()
                            logger.info(f"Step {global_step}: Validation Loss: {val_loss:.4f}")
                            
                            # 如果是最佳模型,保存检查点
                            if val_loss < best_val_loss:
                                best_val_loss = val_loss
                                self.save_checkpoint(global_step, val_loss)
                        
                        # 定期保存检查点
                        elif global_step % self.save_steps == 0:
                            self.save_checkpoint(global_step, None)
                        
                        if global_step % 100 == 0:
                            logs = {
                                "loss": loss.detach().item(),
                                "lr": self.lr_scheduler.get_last_lr()[0]
                            }
                            progress_bar.set_postfix(**logs)
                        
                        if global_step >= self.max_train_steps:
                            break
        
        # 保存最终模型
        if self.accelerator.is_main_process:
            self.unet = self.accelerator.unwrap_model(self.unet)
            self.save_model()
            
    def save_model(self, output_dir="./ldm_model"):
        """保存模型"""
        os.makedirs(output_dir, exist_ok=True)
        self.unet.save_pretrained(os.path.join(output_dir, "unet"))
        logger.info(f"Model saved to {output_dir}")

# 使用示例
if __name__ == "__main__":
    # 准备训练和验证数据
    train_pairs = [
        {
            "image_url": "https://example.com/train_image1.jpg",
            "text": "一只可爱的猫咪在草地上玩耍"
        },
        # ... 更多训练数据
    ]
    
    val_pairs = [
        {
            "image_url": "https://example.com/val_image1.jpg",
            "text": "日落时分的海滩风景"
        },
        # ... 更多验证数据
    ]
    
    # 初始化训练器
    trainer = LDMTrainer(
        train_batch_size=1,
        eval_batch_size=2,
        gradient_accumulation_steps=4,
        learning_rate=1e-5,
        max_train_steps=10000,
        save_steps=1000,
        eval_steps=500,
        output_dir="./ldm_model"
    )
    
    # 准备数据
    trainer.prepare_data(train_pairs, val_pairs)
    
    # 准备训练
    trainer.prepare_training()
    
    # 开始训练
    trainer.train()

要使用这个代码,你需要:

  1. 准备数据集
data_pairs = [
    {
        "image_url": "实际的图片URL",
        "text": "对应的描述文本"
    },
    # ... 更多图文对
]
  1. 安装依赖
pip install diffusers transformers accelerate torch torchvision pillow requests tqdm
  1. 调整超参数
  • train_batch_size:根据显存大小调整
  • gradient_accumulation_steps:如果显存不足,可以增加这个值
  • learning_rate:根据具体任务调整
  • max_train_steps:训练步数

这个实现包含了完整的训练流程,你可以根据具体需求进行修改和扩展。记得在训练时监控显存使用情况,并相应调整批次大小和梯度累积步数。


8. 完整基于刚刚训练的模型实现LDM推理

8.1 推理代码:

import torch
from diffusers import (
    AutoencoderKL,
    UNet2DConditionModel,
    LMSDiscreteScheduler,
    DPMSolverMultistepScheduler,
    EulerDiscreteScheduler
)
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import numpy as np
import logging
from typing import Optional, Union, List, Dict
import os
from tqdm.auto import tqdm

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class CustomLDMInference:
    def __init__(
        self,
        checkpoint_dir: str,
        base_model_path: str = "runwayml/stable-diffusion-v1-5",
        scheduler_type: str = "dpm++",
        device: str = "cuda",
        torch_dtype: torch.dtype = torch.float16,
    ):
        """
        初始化自定义LDM推理类
        
        参数:
            checkpoint_dir: 训练好的模型检查点目录
            base_model_path: 基础模型路径
            scheduler_type: 调度器类型
            device: 设备类型
            torch_dtype: 模型数据类型
        """
        self.device = device
        self.torch_dtype = torch_dtype
        
        # 加载模型组件
        self._load_models(checkpoint_dir, base_model_path)
        
        # 设置调度器
        self._setup_scheduler(scheduler_type)
        
        # 将模型移至指定设备
        self._to_device()

    def _load_models(self, checkpoint_dir: str, base_model_path: str):
        """加载模型组件"""
        try:
            # 加载VAE
            self.vae = AutoencoderKL.from_pretrained(
                base_model_path,
                subfolder="vae",
                torch_dtype=self.torch_dtype
            )
            
            # 加载文本编码器
            self.text_encoder = CLIPTextModel.from_pretrained(
                base_model_path,
                subfolder="text_encoder",
                torch_dtype=self.torch_dtype
            )
            
            # 加载分词器
            self.tokenizer = CLIPTokenizer.from_pretrained(
                base_model_path,
                subfolder="tokenizer"
            )
            
            # 加载训练好的UNet
            unet_path = os.path.join(checkpoint_dir, "unet")
            self.unet = UNet2DConditionModel.from_pretrained(
                unet_path,
                torch_dtype=self.torch_dtype
            )
            
            logger.info("Successfully loaded all model components")
            
        except Exception as e:
            logger.error(f"Error loading models: {str(e)}")
            raise

    def _setup_scheduler(self, scheduler_type: str):
        """设置调度器"""
        config = {
            "num_train_timesteps": 1000,
            "beta_start": 0.00085,
            "beta_end": 0.012,
            "beta_schedule": "scaled_linear"
        }
        
        if scheduler_type == "dpm++":
            self.scheduler = DPMSolverMultistepScheduler.from_config(
                config,
                algorithm_type="dpmsolver++",
                solver_order=2
            )
        elif scheduler_type == "euler":
            self.scheduler = EulerDiscreteScheduler.from_config(config)
        elif scheduler_type == "lms":
            self.scheduler = LMSDiscreteScheduler.from_config(config)
        else:
            raise ValueError(f"Unsupported scheduler type: {scheduler_type}")

    def _to_device(self):
        """将模型移动到指定设备"""
        self.vae = self.vae.to(self.device)
        self.text_encoder = self.text_encoder.to(self.device)
        self.unet = self.unet.to(self.device)

    @torch.no_grad()
    def generate(
        self,
        prompt: Union[str, List[str]],
        negative_prompt: Union[str, List[str]] = "",
        height: int = 512,
        width: int = 512,
        num_inference_steps: int = 30,
        guidance_scale: float = 7.5,
        num_images_per_prompt: int = 1,
        seed: Optional[int] = None,
        output_path: Optional[str] = None,
    ) -> List[Image.Image]:
        """
        生成图像

        参数:
            prompt: 正向提示词
            negative_prompt: 负向提示词
            height: 输出图像高度
            width: 输出图像宽度
            num_inference_steps: 推理步数
            guidance_scale: 分类器引导scale
            num_images_per_prompt: 每个提示词生成的图像数量
            seed: 随机种子
            output_path: 输出路径
        """
        # 设置随机种子
        if seed is not None:
            torch.manual_seed(seed)
            
        # 确保prompt为列表
        if isinstance(prompt, str):
            prompt = [prompt]
        if isinstance(negative_prompt, str):
            negative_prompt = [negative_prompt]

        # 获取batch_size
        batch_size = len(prompt) * num_images_per_prompt

        # 处理文本输入
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt"
        )
        text_embeddings = self.text_encoder(
            text_inputs.input_ids.to(self.device)
        )[0]

        # 处理负向提示词
        uncond_input = self.tokenizer(
            negative_prompt,
            padding="max_length",
            max_length=text_inputs.input_ids.shape[-1],
            truncation=True,
            return_tensors="pt"
        )
        uncond_embeddings = self.text_encoder(
            uncond_input.input_ids.to(self.device)
        )[0]

        # 复制embeddings以匹配batch_size
        text_embeddings = text_embeddings.repeat(num_images_per_prompt, 1, 1)
        uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 1, 1)

        # 合并条件和无条件embeddings
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        # 准备初始噪声
        latents = torch.randn(
            (batch_size, 4, height // 8, width // 8),
            device=self.device,
            dtype=self.torch_dtype
        )

        # 设置调度器
        self.scheduler.set_timesteps(num_inference_steps)
        latents = latents * self.scheduler.init_noise_sigma

        # 去噪循环
        for t in tqdm(self.scheduler.timesteps):
            # 扩展latents以进行分类器引导
            latent_model_input = torch.cat([latents] * 2)
            
            # 预测噪声残差
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=text_embeddings
            ).sample

            # 执行分类器引导
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

            # 计算上一个潜变量
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        # 解码潜变量
        latents = 1 / 0.18215 * latents
        images = self.vae.decode(latents).sample
        images = (images / 2 + 0.5).clamp(0, 1)
        images = images.cpu().permute(0, 2, 3, 1).numpy()

        # 转换为PIL图像
        pil_images = []
        for image_array in images:
            image = Image.fromarray((image_array * 255).round().astype("uint8"))
            pil_images.append(image)

        # 保存图像
        if output_path:
            os.makedirs(output_path, exist_ok=True)
            for i, image in enumerate(pil_images):
                image.save(os.path.join(output_path, f"generated_{i}.png"))

        return pil_images

    def encode_prompt(self, prompt: str) -> torch.Tensor:
        """编码提示词"""
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt"
        )
        return self.text_encoder(text_inputs.input_ids.to(self.device))[0]

# 使用示例
if __name__ == "__main__":
    # 初始化推理器
    ldm = CustomLDMInference(
        checkpoint_dir="./ldm_model/checkpoint-10000",  # 训练好的模型检查点路径
        scheduler_type="dpm++",
        device="cuda",
        torch_dtype=torch.float16
    )
    
    # 生成图像
    images = ldm.generate(
        prompt="一只可爱的猫咪在草地上玩耍",
        negative_prompt="模糊, 低质量",
        num_inference_steps=30,
        guidance_scale=7.5,
        num_images_per_prompt=4,
        seed=42,
        output_path="./outputs"
    )
    
    print(f"Successfully generated {len(images)} images")




#不同使用方举例

# 初始化推理器
ldm = CustomLDMInference(
    checkpoint_dir="path/to/your/checkpoint",  # 训练检查点路径
    base_model_path="runwayml/stable-diffusion-v1-5",  # 基础模型
    scheduler_type="dpm++"
)

# 单张图像生成
images = ldm.generate(
    prompt="你的提示词",
    negative_prompt="要避免的内容",
    num_inference_steps=30,
    seed=42
)

# 批量生成
images = ldm.generate(
    prompt=["提示词1", "提示词2"],
    num_images_per_prompt=2,
    output_path="./outputs"
)

注意事项:

  1. 确保checkpoint_dir指向正确的训练检查点目录
  2. 检查是否有足够的GPU显存
  3. 可以根据需要调整生成参数(步数、guidance scale等)
  4. 对于批量生成,注意内存使用情况

这个实现专门针对自定义训练模型优化,可以直接使用前面训练代码得到的模型进行推理。


与你互动:留下你的脑洞与想法! 🚀

到这里,你是不是已经对潜在扩散模型(LDM)的原理和应用有了更深刻的了解?不管你是AI新手还是技术大拿,都希望这篇文章给你带来了启发!如果你觉得我们的内容有趣又实用,千万别忘了点赞👍、收藏📂,以免下次找不到!

我们还想听听你的想法,在评论区聊聊吧:

  1. 你觉得LDM最吸引你的地方是什么?它真的颠覆了你对图像生成的想象吗?
  2. 如果你能手动输入一个提示词让AI生成图像,你最想生成什么?(随便脑洞大开,比如“火星上的煎饼早餐”也完全OK 🤯)
  3. 看完这篇文章后还有啥技术上的疑惑?各种小白至硬核问题,我们都乐于解答!

AI发展正在快速改变世界,现在正是和它一起探索无限可能性的绝佳时机!我们期待在评论区和你互动,一起畅想技术与创意的未来世界~让留言区的“扩散”开始吧!🎉


原文地址:https://blog.csdn.net/duyuan6949/article/details/145117713

免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!