自学内容网 自学内容网

第34周:生成对抗网络(GAN)入门

目录

前言

一、理论基础

1.1 生成器

1.2 判别器

1.3 基本原理

二、前期准备工作

2.1 定义超参数

2.2 下载数据

2.3 配置数据

三、定义模型

3.1 定义鉴别器

3.2 定义生成器

四、训练模型

4.1 创建实例

4.2 训练模型

4.3 保存模型

总结


前言

说在前面

本周任务:基础任务——了解什么是生成对抗网络,生成对抗网络结果是怎么样的,学习本文代码并跑通代码;进阶任务——调用训练好的模型生成新图像

我的环境:Python3.6、Pycharm2020、TensorFlow2.4.0


一、理论基础

        生成对抗网络(Generative Adversarial Networks, GAN)是近年来深度学习领域的一个热点方向,GAN并不指代某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。GAN由两个分别被称为生成器和判别器的神经网络组成。其中生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;判别器的输入则是真实样本或人工样本,其目的是将人工样本与真实样本尽可能地区分出来。生成器和判别器交替运行,相互博弈,各自的能力都得到提升。理想情况下,经过足够次数的博弈之后,判别器无法判断给定样本的真实性,即对于所有样本1都输出50%真、50%假的判断。此时,生成器输出的人工样本以及逼真到使判别器无法分辨真假,停止博弈。这样就可以得到一个具有“伪造”真实样本能力的生成器。

1.1 生成器

        GANs中,生成器G选取随机噪声z作为输入,通过生成器的不断拟合,最终输出一个和真实样本尺寸相同,分布相似的伪造样本G(z)。生成器的本质是一个使用生成式方法的模型,他对数据的分布假设和分布参数进行学习,然后根据学习到的模型重新采样出新的样本。

        从数学上来说,生成式方法对于给定的真实数据,首先需要对数据的显式变量或隐含变量做分布假设,然后再将真实数据输入到模型中对变量、参数进行训练;最后得到一个学习后的近似分布,这个分布可以用来生成新的数据。从机器学习的角度来说,模型不会去做分布假设,而是通过不断地学习真实数据,对模型进行修正,最后也可以得到一个学习后的模型来做样本生成任务,这种方法不同于数学方法,学习的过程对人类的理解较不直观。

1.2 判别器

        GANs中,判别器D对于输入的样本x,输出一个[0,1]之间的概率数值D(x)。x可能是来自于原始数据集中的真实样本x,也可能是来自于生成器G的人工样本G(z)。通常约定,概率值D(x)越接近于1就代表此样本为真实样本的可能性更大;反之概率值越小则此样本为伪造样本的可能性越大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的不是判定输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管在生成器还是判别器中,样本的类别信息都没有用到,也表明GAN是一个无监督学习的过程。

1.3 基本原理

        GAN是博弈论和机器学习相结合的产物,于2014年lan Goodfellow的论文中问世,一经问世即火爆足以看出人们对于这种算法的认可和狂热的研究热忱。想要更详细的了解GAN,就要知道它是怎么来的,以及这种算法出现的意义是什么。研究者最初想要通过计算机完成自动生成数据的功能,例如通过训练某种算法模型,让某模型学习过一些苹果的图片后能够自动生成苹果的图片,具备这些功能的算法即任务具有生成功能。但是GAN不是第一个生成算法,而是以往的生成算法在衡量生成图片和真实图片的差距时,采用均方误差作为损失函数,但是研究者发现有时均方误差一样的两张生成图片效果却截然不同,鉴于此不足lan Goodfellow提出来GAN。

      那么GAN是如何完成生成图片这项功能的呢,如上图所示,GAN是由两个模型组成的:生成模型G和判别模型D。首先第一代生成模型1G的输入是随机噪声z,然后生成模型会生成一张初级照片,训练一代判别器模型1D令其进行二分类操作,将生成的图片判别为0,而真实图片判别为1;为了期满一代鉴别器,于是一代生成模型开始优化,然后它进阶成了二代,当它生成的数据成功欺瞒1D时,鉴别模型也会优化更新,进而升级为2D,按照同样的过程也会不断更新出N代的G和D。

二、前期准备工作

2.1 定义超参数

  • n_epochs:这个参数决定了模型训练的总轮数,轮数越多,模型有更多机会学习数据中的模型,但也可能导致过拟合。
  • batch_size:批次大小影响模型每次更新时使用的数据量。较小批次可能导致训练过程波动较大,但可能有助于模型逃离局部最小值;较大的批次则可能使训练更稳定,但需要更多的内存空间。
  • lr:学习率控制着模型权重更新的步长,学习率过大可能导致模型在最优解附近震荡甚至发散;学习率过小则可能导致模型收敛速度缓慢或陷入局部最小值。
  • b1和b2:这两个参数是Adam优化器的一部分,分别控制一阶矩(梯度的指数移动平均)和二阶矩(梯度平方的指数移动平均)的指数衰减率。它们影响模型更新的稳定性和收敛速度。
  • n_cpu:这个参数指定了用于数据加载的CPU数量,可以影响数据预处理和加载的速度,从而影响训练的效率。
  • latent_dim:随机向量的维度,它影响生成器生成图像的多样性和质量,维度过低可能导致生成图像缺乏多样性,而维度过高可能导致模型难以训练。
  • img_size:图像的大小直接影响模型的感受野和所需计算资源,图像尺寸越大,模型可能需要更多的计算资源和更长的训练时间。
  • channels:图像的通道数,对于彩色图像通常是3(RGB),对于灰度图像是1,通道数影响模型处理的信息量。
  • sample_interval:保存生成图像的间隔,这个参数决定了我们在训练过程中多久保存一次生成的图像,用于监控生成图像的质量。
  • cuda:是否使用GPU进行计算,使用GPU可以显著加速模型的训练过程,因为GPU在并行处理大量计算时更为高效。

代码如下:

import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch

# 一、前期准备工作
# 1.1 定义超参数
# 创建文件夹
os.makedirs("./images", exist_ok=True)  # 记录训练过程的图片效果
os.makedirs("./save", exist_ok=True)  # 训练完成时模型保存的位置
os.makedirs("./datasets/minist", exist_ok=True)  # 下载数据集存放的位置

# 超参数配置
n_epochs = 50
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 2
latent_dim = 100
img_size = 28
channels = 1
sample_interval = 500

# 图像的尺寸:(1,28,28), 和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)
#print(img_area)
# 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)

2.2 下载数据

代码如下:

# 1.2下载数据
mnist = datasets.MNIST(
    root='./datasets', train=True, download=True, transform=transforms.Compose(
        [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)

2.3 配置数据

代码如下:

# 1.3 配置数据
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

三、定义模型

3.1 定义鉴别器

这段代码定义了一个名为Discriminator的类,它继承自nn.Module。这个类是一个判别器模型,用于判断输入图像是否为真实图像。下面是对代码中每一行的详细解释:

  • class Discriminator(nn.Module):名为Discriminator的类,它继承自nn.Module。nn.Module是Pytorch中的一个基类,用于构建神经网络模型。
  • def __init__(self):定义类的构造函数,用于初始化模型的参数和层。
  • super(Discriminator,self).__init__():调用父类nn.Module的构造函数,以确保正确地初始化模型。
  • self.model = nn.Sequential():创建一个nn.Sequential对象,它是一个容器,用于按顺序堆叠的多个神经网络层。 
  • nn.Linear(img_area,512):添加一个线性层,输入大小为img_area(图像区域的像素数),输出大小为512。这个层用于将输入图像展平并映射到一个新的特征空间。
  • nn.LeakyReLu(0.2, inplace=True):添加一个Leaky ReLu激活函数,其负斜率为0.2,inplace=True表示在原始数据上进行操作,以节省内存。
  • nn.Linear(512,256):添加一个线性层,输入大小为512,输出大小为256.这个层用于进一步将特征映射到更小的特征空间。
  • nn.LeakyReLu(0.2, inplace=True):再添加一个Leaky ReLu激活函数。
  • nn.Linear(512,256):添加一个线性层,输入大小为256,输出大小为1。这个层用于将特征映射到一个标量值,用于表示输入图像的真实性。
  • nn.Sigmoid():添加一个Sigmoid激活函数,将输出值限制在0到1之间,这可以解释为输入图像为真实图像的概率。
  • def forward(self, img):定义模型的前向传播函数,用于计算输入图像的输出。
  • img_flat=img.view(img.size(0),-1):将输入图像img展平为一个一维向量,img.size(0)表示批量大小,-1表示自动计算剩余维度的大小。
  • validity = self.model(img_flat):将展平后的图像传递给之前定义的nn.Sequential模型,得到一个表示图像真实性的标量值。
  • return validity:返回计算得到的图像真实性值。

代码如下:

# 二、定义模型
# 2.1 定义鉴别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        #print("img_flat:",img_flat.size())
        validity = self.model(img_flat)
        return validity

3.2 定义生成器

代码如下:

# 2.2 定义鉴别器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 模型中间块儿
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # prod():返回给定轴上的数组元素的乘积:1*28*28=784
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, img_area),
            nn.Tanh()
        )

    def forward(self, z):
        imgs = self.model(z)
        imgs = imgs.view(imgs.size(0), *img_shape)
        return imgs

四、训练模型

4.1 创建实例

代码如下:

# 三、训练模型
# 3.1 创建实例
# 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()

criterion = torch.nn.BCELoss()  # loss函数,二分类的交叉熵
# 定义优化函数(优化函数的学习率为0.0003)
# betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()

4.2 训练模型

代码如下:

# 3.2 训练模型
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # =========训练判别器=========
        # view():相当于numpy中的reshape,重新定义矩阵形状,相当于reshape(128,784),原来是(128,1,28,28)
        imgs = imgs.view(imgs.size(0), -1)
        # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
        real_img = Variable(imgs).cuda()
        #print("imgs:",imgs.size())
        #print("real_imgs:", real_img.size())
        # 定义真实图片label为1,假的图片为0
        real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()
        # ----------------------------------------
        # Train Discriminator
        # 分为两部分:1、真的图像判别为真,假的图像判别为家
        # ----------------------------------------
        real_out = discriminator(real_img)  # 将真实图片放入到判别器中
        loss_real_D = criterion(real_out, real_label)  # 得到真实图片的loss
        real_scores = real_out  # 得到真实图片的判别值,输出的值越接近1越好
        # 计算假的图片的损失
        # !!!detach():从当前计算图中分离下来避免梯度传递到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out, fake_label)
        fake_scores = fake_out
        # 损失函数和优化
        loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失
        optimizer_D.zero_grad()  # 在反向传播之前,先将梯度归0
        loss_D.backward()
        optimizer_D.step()
        # ----------------------------------------
        # Train Generator
        # 原理:目的是希望生成的假的图片被判别器判断为真的图片
        # 在此过程中将判别器固定,将假的图片传入判别器的结果与真实的label对应
        # 下面的反向传播更新的参数是生成网络里面的参数
        # 这样可以通过更新G的参数来训练网络,使得生成的图片让判别器以为是真的以达到对抗目的
        # ----------------------------------------
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
        fake_img = generator(z)
        output = discriminator(fake_img)
        # 损失函数和优化
        loss_G = criterion(output, real_label)
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # 打印训练过程中的日志
        # item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if (i + 1) % 300 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(),
                   fake_scores.data.mean())
            )

        # 保存训练过程中都的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)

训练过程如下:

[Epoch 0/50] [Batch 299/938] [D loss: 1.012438] [G loss: 1.301834] [D real: 0.699438] [D fake: 0.474318]
[Epoch 0/50] [Batch 599/938] [D loss: 1.008257] [G loss: 0.819791] [D real: 0.527896] [D fake: 0.264144]
[Epoch 0/50] [Batch 899/938] [D loss: 1.153856] [G loss: 1.398770] [D real: 0.655207] [D fake: 0.486927]
[Epoch 1/50] [Batch 299/938] [D loss: 0.967525] [G loss: 1.812901] [D real: 0.731250] [D fake: 0.448682]
[Epoch 1/50] [Batch 599/938] [D loss: 1.224676] [G loss: 0.697846] [D real: 0.430522] [D fake: 0.137575]
[Epoch 1/50] [Batch 899/938] [D loss: 0.903827] [G loss: 2.397436] [D real: 0.797271] [D fake: 0.469886]
[Epoch 2/50] [Batch 299/938] [D loss: 0.867260] [G loss: 1.746334] [D real: 0.717585] [D fake: 0.352441]
[Epoch 2/50] [Batch 599/938] [D loss: 0.988614] [G loss: 1.898680] [D real: 0.717136] [D fake: 0.422032]
[Epoch 2/50] [Batch 899/938] [D loss: 0.857187] [G loss: 2.397475] [D real: 0.823763] [D fake: 0.459199]
[Epoch 3/50] [Batch 299/938] [D loss: 1.252721] [G loss: 3.494035] [D real: 0.878717] [D fake: 0.645408]
[Epoch 3/50] [Batch 599/938] [D loss: 1.053324] [G loss: 3.200003] [D real: 0.906355] [D fake: 0.586196]
[Epoch 3/50] [Batch 899/938] [D loss: 0.888583] [G loss: 1.750376] [D real: 0.685471] [D fake: 0.263957]
[Epoch 4/50] [Batch 299/938] [D loss: 0.626337] [G loss: 1.632535] [D real: 0.868374] [D fake: 0.339547]
[Epoch 4/50] [Batch 599/938] [D loss: 0.760465] [G loss: 3.926685] [D real: 0.896429] [D fake: 0.438739]
[Epoch 4/50] [Batch 899/938] [D loss: 1.112041] [G loss: 3.600408] [D real: 0.895314] [D fake: 0.588828]
[Epoch 5/50] [Batch 299/938] [D loss: 0.760230] [G loss: 3.093684] [D real: 0.819037] [D fake: 0.373181]
[Epoch 5/50] [Batch 599/938] [D loss: 0.655095] [G loss: 4.203346] [D real: 0.936321] [D fake: 0.412457]
[Epoch 5/50] [Batch 899/938] [D loss: 0.868438] [G loss: 1.858977] [D real: 0.775516] [D fake: 0.376645]
... ...
[Epoch 47/50] [Batch 299/938] [D loss: 0.633114] [G loss: 1.890121] [D real: 0.769088] [D fake: 0.172500]
[Epoch 47/50] [Batch 599/938] [D loss: 0.605113] [G loss: 2.451130] [D real: 0.830360] [D fake: 0.274797]
[Epoch 47/50] [Batch 899/938] [D loss: 0.738247] [G loss: 2.078587] [D real: 0.817741] [D fake: 0.342395]
[Epoch 48/50] [Batch 299/938] [D loss: 0.791022] [G loss: 1.390561] [D real: 0.742090] [D fake: 0.247390]
[Epoch 48/50] [Batch 599/938] [D loss: 0.643103] [G loss: 2.504039] [D real: 0.859238] [D fake: 0.314199]
[Epoch 48/50] [Batch 899/938] [D loss: 0.853483] [G loss: 1.335287] [D real: 0.684199] [D fake: 0.110085]
[Epoch 49/50] [Batch 299/938] [D loss: 0.718962] [G loss: 2.062187] [D real: 0.786376] [D fake: 0.297801]
[Epoch 49/50] [Batch 599/938] [D loss: 0.704778] [G loss: 2.398799] [D real: 0.837675] [D fake: 0.317781]
[Epoch 49/50] [Batch 899/938] [D loss: 0.450702] [G loss: 1.835477] [D real: 0.787361] [D fake: 0.083240]

训练过程中生成的图片(部分截取)

➡️➡️

4.3 保存模型

代码如下:

# 3.3 保存模型
torch.save(generator.state_dict(), './save/generator.pth')
torch.save(discriminator.state_dict(), './save/discriminator.pth')

总结

本周了解了一下对抗生成网络的定义和概念,同时熟悉了生成器和鉴别器的搭建以及GAN的训练过程。跑通了利用GAN进行Mnist数据集中手写字图片的生成。


原文地址:https://blog.csdn.net/weixin_46620278/article/details/144281799

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!