自学内容网 自学内容网

手撕Diffusion系列 - 第三期 - Dataset

手撕Diffusion系列 - 第三期 - Dataset

DDPM 原理图

​ DDPM包括两个过程:前向过程(forward process)反向过程(reverse process),其中前向过程又称为扩散过程(diffusion process),如下图所示。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),其中反向过程可以用来生成图片。

在这里插入图片描述

DDPM 整体大概流程

​ 图中,由高斯随机噪声 x T x_T xT 生成原始图片 x 0 x_0 x0 为反向过程,反之为前向过程(噪音扩散)。

MNIST数据集介绍

MNIST(Modified National Institute of Standards and Technology)是一个经典的图像数据集,广泛用于训练和评估机器学习和计算机视觉模型,尤其是手写数字识别任务。该数据集包含了70000张灰度图像,其中60000张用于训练,10000张用于测试。每张图像是28x28像素,表示从0到9的手写数字。MNIST数据集的每个图像对应一个标签,表示图像中的数字类别。

主要特点

  1. 图像尺寸:每张图像大小为28x28像素,单通道灰度图。
  2. 图像标签:每个图像都有一个标签,表示图像所对应的数字(从0到9)。
  3. 数据分割:数据集分为训练集和测试集。训练集包含60000个图像,测试集包含10000个图像。
  4. 应用场景:MNIST数据集广泛用于测试图像分类算法、深度学习模型等,尤其是在手写数字的分类问题中。

数据预处理

在使用MNIST数据集时,常见的预处理步骤包括:

  • 调整图像大小:将图像调整为统一的尺寸,例如28x28像素。
  • 转换为Tensor格式:将图像转换为Tensor格式以便于神经网络处理,通常还会进行归一化(如将像素值缩放到[0, 1]范围)。
  • 数据增强:在某些应用中,可以对图像进行旋转、翻转、裁剪等操作来增加数据多样性。

Dataset 代码

Part1 引入相关库函数

# 该模块实现的是引用一个MINIST训练数据集,需要预处理

'''
# Part1 引入一些库函数
'''

import torch
from torch.utils import data # 对数据集获取后的一些操作
import torchvision # 数据集存储的地方
from torchvision import transforms # 对于数据的一些处理手段
from config import * # 一些参数

Part2 MNIST数据集的预处理以及获取

# 显示图像
import matplotlib.pyplot as plt

PiltoTensor_action=transforms.Compose([
    transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)), # 改变图像的大小
    transforms.ToTensor()] # 对图像转化为Tensor类型,归一化,并且把通道提前(H,W,C)->(C,H,W)
)

minist_train=torchvision.datasets.MNIST(root='dataset',train=True,transform=PiltoTensor_action,download=True)

minist_loader=data.DataLoader(dataset=minist_train,batch_size=2,shuffle=True) # (image,label)

Part3 测试

# 为了可以简单的按照pillow去展示自己的图像,那么需要一个把tensor转回为Pillow的操作。主要分为三步。
TenosrtoPil_action=transforms.Compose([
    transforms.Lambda(lambda t:t*255),
    transforms.Lambda(lambda t:t.type(torch.uint8)),
    transforms.ToPILImage()
])

if __name__ == '__main__':
    Tensor_imag=minist_train[0][0]
    print(Tensor_imag)
    plt.figure(figsize=(5,5)) # 绘制画布
    pil_imag=TenosrtoPil_action(Tensor_imag)
    plt.imshow(pil_imag) # 绘制图像,和plot,bar等函数一样用于绘制不同的图像
    plt.show() # 展示图像

参考

视频讲解:diffusion训练数据集_哔哩哔哩_bilibili

Diffusion原理:手撕Diffusion系列 - 第一期 - DDPM原理-CSDN博客


原文地址:https://blog.csdn.net/m0_62030579/article/details/145310437

免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!