自学内容网 自学内容网

【大模型】通俗解读变分自编码器VAE

目录

写在前面

一、VAE结构

二、损失函数

三、代码实现

1.训练代码

2.推理生成图片

3.插值编辑图片

四、总结


写在前面

        论文地址:https://arxiv.org/abs/1312.6114

        大模型已经有了突破性的进展,图文的生成质量都越来越高,可控性也越来越强。很多阅读大模型源码的小伙伴会发现,大部分大模型,尤其是CV模型都会用到一个子模型:变分自编码器(VAE),这篇文章就以图像生成为例介绍一下VAE,并且解释它问什么天生适用于图像生成。配合代码尽量做到通俗易懂。

        变分自编码器(VAE)是一种生成模型,旨在通过学习数据的潜在表示(Latent)来生成新样本。VAE 的训练目标是最大化变分下界,这意味着在学习潜在空间时,保持生成样本与真实数据的相似性,并尽量让潜在变量的分布接近标准正态分布。这样一来,模型就能有效地生成多样化的新图像。

        上面那段话似乎不容易理解,我用白话解释一遍。VAE 的最大作用是尽量简单的生成“能看的”图片。现在达到的效果是输入一段标准高斯分布的Latent,就能生成自然连贯的图像。而且生成的图像有如下三个特点:        

1.这个图像是全新的(也许跟某些训练数据相似);

2.通过编辑Latent可以一定程度上控制生成图像中的内容;

3.Latent空间中的结构化使得生成的图像自然且连贯,也就是说输入虽然是随机的,但输出是“能看的”,不是无意义的图像。

一、VAE结构

        VAE由如下三块组成:

        1.编码器(Encoder):输入数据通过编码器转换为潜在空间的分布。编码器通常由几层神经网络组成,输出潜在变量的均值和方差(其实是对数方差)。

        2.重参数化层(Reparameterize):从编码器输出的均值和方差中进行重参数化采样,生成潜在变量。这一过程使得模型能够在训练时进行反向传播。

        3.解码器(Decoder):解码器接收潜在变量并将其转换回原始数据的分布。解码器同样由神经网络组成,目的是重构输入数据。

        可以看到和AE相比,VAE的结构差别主要集中在编码器和潜在空间的处理。编码器有两个输出均值和方差(其实是对数方差);中间的重参数化层根据均值和方差重采样得到Latent,我们一般管他叫做z。

        下面我们使用MNIST数据集模拟一个VAE的结构,编码器和解码器使用最简单的全连接,Hidden维度400,Latent维度20,batch_size=128。

        可以看到,编码器的输出是两个128x20的特征图,用于重参数化;重参数化的输出是128x20,也就是每一个点都根据对应的均值和方差采样得来。

二、损失函数

        (VAE)的损失函数主要由两部分组成:

        1.重构损失(Reconstruction Loss):衡量模型生成的样本与原始输入之间的差异,通常使用均方误差(MSE)或二元交叉熵(Binary Cross-Entropy)作为度量。这部分确保生成的样本尽量忠实于输入数据。

        2.KL散度(Kullback-Leibler Divergence):衡量编码器输出的潜在分布与先验分布(通常是标准正态分布)之间的差异。目标是使得 q(z|x)逼近标准正态分布N(0,1),使得采样变得更加合理。

        重构损失没什么可说的,下面给出KL散度的公式:

D_{KL}(q(z|x)||p(z))=-0.5\cdot (1+log(\sigma ^2)-\mu ^2-\sigma ^2)

        KL散度代码实现:在代码实现的时候编码器的输出其实是均值mu和对数方差log_var,这一点在上图也能看出来:

KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        其中log_var 对应对数方差log(\sigma^2),使用对数方差的形式可以保证数值稳定性、避免负值以及计算便利性,这种做法在许多深度学习模型中都得到了广泛应用,尤其是在处理概率分布时。;mu 是均值\mu\sigma^2=exp(log(\sigma^2)),在代码中就是log_var.exp()。

        KL散度会在下一篇文章详细介绍,这里到此为止。

三、代码实现

1.训练代码

        下面是训练的全部代码,很简单,没什么可说的,重点是重参数化层和损失函数中的KL散度。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # 输出均值和对数方差
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 输出为 [0, 1]
        )

    def encode(self, x):
        """
        编码器
        :param x:
        :return:
        """
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)
        return mu, log_var

    @staticmethod
    def reparameterize(mu, log_var):
        """
        重参数化
        :param mu:
        :param log_var:
        :return:
        """
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        """
        解码器
        :param z:
        :return:
        """
        return self.decoder(z)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var


def loss_function(recon_x, x, mu, log_var):
    """
    重构损失和 KL 散度
    :param recon_x:
    :param x:
    :param mu:
    :param log_var:
    :return:
    """
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD


def train(model, train_loader, optimizer, epoch):
    """
    训练模型
    :param model:
    :param train_loader:
    :param optimizer:
    :param epoch:
    :return:
    """
    model.train()
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(data.size(0), -1)  # 展平输入
        optimizer.zero_grad()
        recon_batch, mu, log_var = model(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]: Loss: {loss.item()}')


# 超参数
input_dim = 28 * 28  # MNIST
hidden_dim = 400
latent_dim = 20
batch_size = 128
learning_rate = 1e-3
num_epochs = 200

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # 展平
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 初始化模型和优化器
model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(1, num_epochs + 1):
    train(model, train_loader, optimizer, epoch)
    if epoch % 20 == 0:
        # 保存模型
        torch.save(model.state_dict(), 'model_data/vae_mnist_{}.pth'.format(epoch))

2.推理生成图片

        下面是推理代码,理论上一个训练好的解码器,只需要标准高斯分布的随机噪声作为输入即可。我们来试一下,只使用解码器,输入是标准高斯分布的采样数据,输出是数字图片。

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 输出为 [0, 1]
        )

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        pass


def ran_demo():
    with torch.no_grad():
        z = torch.randn(64, latent_dim).to(device)  # 随机采样
        sample = model.decode(z).cpu()

    # 绘制生成的样本
    fig, axes = plt.subplots(8, 8, figsize=(8, 8))
    for i in range(64):
        axes[i // 8, i % 8].imshow(sample[i].view(28, 28), cmap='gray')
        axes[i // 8, i % 8].axis('off')
    plt.show()



if __name__ == '__main__':
    # 超参数
    input_dim = 28 * 28  # MNIST
    hidden_dim = 400
    latent_dim = 20
    # hidden_dim = 1024
    # latent_dim = 128
    batch_size = 128
    learning_rate = 1e-3
    num_epochs = 500
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 初始化模型和优化器
    model = VAE()

    # 加载模型并生成图像
    model.load_state_dict(torch.load('model_data/vae_mnist_1000.pth', map_location=torch.device('cpu')))
    # model.load_state_dict(torch.load('model_data/vae_mnist_200.pth', map_location=torch.device('cpu')))
    model.eval()

    # 随机输入
    ran_demo()

         输出结果如下:大部分是能看出来的数字的。毕竟只是一个简单的demo,就不要在意细节了。(#^.^#)

3.插值编辑图片

        下面玩一个有意思的,既然不同的Latent分布控制着不同的图像特征,那么我们试试把一个数字的Latent通过插值慢慢混入另一个数字的Latent,看看会发生什么。我们在数字6的Latent中慢慢混入7.

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # 输出均值和对数方差
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 输出为 [0, 1]
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        print(eps)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var


def interpolate_demo(from_num, to_num):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(-1))  # 展平
    ])
    dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

    def interpolate(z1, z2, num_steps=10):
        return [(1 - alpha) * z1 + alpha * z2 for alpha in np.linspace(0, 1, num_steps)]

    # 找到数字“1”和“7”的潜在向量
    def get_latent_vector(digit):
        model.eval()
        with torch.no_grad():
            for data, labels in data_loader:
                if labels[0] == digit:
                    data = data.to(device)
                    mu, log_var = model.encode(data.view(-1, input_dim))
                    return mu.mean(0).cpu().numpy()  # 返回均值作为潜在向量
    # 获取两个数字的向量
    latent_1 = get_latent_vector(from_num)
    latent_7 = get_latent_vector(to_num)
    # 计算插值向量
    interpolated_latents = interpolate(latent_1, latent_7)
    # 使用解码器生成图像
    with torch.no_grad():
        generated_images = [model.decode(torch.tensor(latent).float().to(device)).view(28, 28).cpu().numpy() for latent
                            in interpolated_latents]

    # 可视化生成的图像
    fig, axs = plt.subplots(1, len(generated_images), figsize=(15, 3))
    for i, img in enumerate(generated_images):
        axs[i].imshow(img, cmap='gray')
        axs[i].axis('off')
    plt.show()


if __name__ == '__main__':
    # 超参数
    input_dim = 28 * 28  # MNIST
    hidden_dim = 400
    latent_dim = 20
    # hidden_dim = 1024
    # latent_dim = 128
    batch_size = 128
    learning_rate = 1e-3
    num_epochs = 500
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 初始化模型和优化器
    model = VAE()

    # 加载模型并生成图像
    model.load_state_dict(torch.load('model_data/vae_mnist_200.pth', map_location=torch.device('cpu')))
    model.eval()

    # 插值demo
    interpolate_demo(6, 7)

         可以看到数字6慢慢变成了数字7,中间的几张图既有6的特征又有7的特征。通过控制Latent确实可以控制输出图像的特征。那是不是也可以把一个人的脸慢慢变成另一个人的脸呢,我感觉可以试试。

四、总结

        1.与AE模型相比,VAE主要有两处修改:

        (1)编码器输出均值和方差(对数方差),经过重参数化层重采样后得到Latent,再进行解码;

        (2)损失函数加入了KL散度,衡量编码器输出的Latent分布与先验分布(通常是标准正态分布)之间的差异,同时起到正则化的目的,使码器输出的Latent分布尽量符合标准高斯分布。

        2.为什么VAE适合用在生成任务?

        (1)容易生成的“能看的”图像:解码器只需接受标准高斯分布的采样数据就能生成自然连贯的图像,这意味着我们不再为生成的图像过于抽象而烦恼;

        (2)生成图像的属性可以编辑:图像的各种属性特征都蕴含在Latent里,只要找到方法对齐并组合这些特征,我们就能控制输出图像的内容,比如:长着牛头的企鹅。这就是为什么当今很多生成模型吧VAE作为一个模块来使用,同时还需要配合其它模型来完成特定的生成任务,这点今天不做过多讨论。

        总之VAE极大推动了生成任务,是很有研究价值的,小伙伴们快玩起来吧。

        VAE就介绍到这,关注不迷路(*^__^*) 

  关注订阅号了解更多精品文章

交流探讨请加微信


原文地址:https://blog.csdn.net/xian0710830114/article/details/142487069

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!