PyTorch深度学习快速入门教程【土堆】基础知识篇
Juptyer
版本:
- Python 3.9.19
- Pytorch 2.4.1
(pytorch0) C:\Users\25694>conda install nb_conda_kernels
(pytorch0) C:\Users\25694>jupyter notebook
使用conda环境的pytorch:
成功解决python.exe无法找到程序入口 无法定位程序输入点
shift+enter:运行这个代码块并跳转到下一个代码块
- 将环境写入Notebook的kernel中:
python -m ipykernel install --user --name 环境名称 --display-name "Python (环境名称)"
- 打开Jupyter notebook,新建Python文件,这时候你就能看见你的创建的环境
Python学习中的两大法宝函数
实战操作:
Python交互模式主要有两种:CPython用>>>作为提示符,而IPython用In [序号]:作为提示符。
如果你是>>>,那么可以回ana黑色窗口控制台输入conda install ipython来使其变成in
如果虚拟环境中没有安装 ipython包,那么默认就是>>>模式
如果当前显示IN[序号],想换回>>>,则在File->Setting中取消勾选下列的框,Apply->OK
下列的框在“构建、执行、部署”(我使用了Pycharm里面的汉化插件)→“控制台”→使用IPython
PyCharm及Jupyter使用及对比
PyTorch加载数据初认识
其中train是训练数据集,val是验证数据集
Dataset??
一般数据和对应label有两种形式:
- 如一个文件夹内存放多个同类的图片:文件夹的名称就是其label
- 数据和label存放在不同的文件夹内
Dataset类代码实战
通常使用pycharm的python console进行一些小的测试!可以方便地查看过程属性
(pytorch0) C:\Users\25694>pip install opencv-python
把hymenoptera_data数据集拷贝到项目文件夹中并重命名
绝对路径 ctrl+shift+c。但是注意windows下路径要使用两个斜杠,来表示转义
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset): # 创建class继承自Dataset
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
tran_dataset = ants_dataset + bees_dataset
img, label = bees_dataset[10]
img.show()
os.listdir() 是 Python 中 os 模块的一个函数,主要用于列出指定目录中的所有文件和子目录的名称。它返回一个包含该目录下所有条目(文件或文件夹)名称的列表(但不会递归子目录),但不会包含文件的完整路径,只返回名称(仅列出名称,不会指明是文件还是目录。如果需要判断某个条目是文件还是目录,可以结合 os.path.isfile() 和 os.path.isdir() 一起使用。)。参数:path:需要列出内容的目录路径。可以是相对路径或绝对路径。如果不传递 path 参数,则默认列出当前工作目录(即 .)。
Python中self用法详解
TensorBoard的使用
ctrl+点按SummaryWriter:
SummaryWriter:直接向log_dir文件夹写入事件文件,这个事件文件可以被TensorBoard解析。需要输入一个文件夹的名称,不输入的话默认文件夹为runs/CURRENT_DATETIME_HOSTNAME。
log_dir:tensorboard文件的存放路径 flush_secs:表示写入tensorboard文件的时间间隔
其他的参数当前并不重要,需要的话可以自己看看。
标量只有大小概念,没有方向的概念。通过一个具体的数值就能表达完整。比如:重量、温度、长度、提及、时间、热量等都数据标量。
安装tensorboard:
(pytorch0) C:\Users\25694>
conda list
conda search numpy
conda install numpy=1.23.1
conda install tensorboard
add_scalar()方法的使用(常用来绘制train/val loss)
def add_scalar(
self,
tag,
scalar_value,
global_step=None,
walltime=None,
new_style=False,
double_precision=False,
):
添加一个标量数据到 Summary 当中,需要参数
- tag:Data指定方式,类似于图表的title
- scalar_value:需要保存的数值(y轴)
- global_step:训练到多少步(x轴)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
# writer.add_imgae()
# y = 2x
for i in range(100):
writer.add_scalar("y=2x", 2 * i, i)
writer.close()
如何打开生成的事件文件:
注意路径!
(pytorch0) E:\PyCharmProjects\learn_torch>tensorboard --logdir=logs --port=6007
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.17.0 at http://localhost:6007/ (Press CTRL+C to quit)
这里不指定端口port 默认是6006
在writer中写入新事件,还有上个事件
解决方法:删除logs文件夹下的所有事件,重新运行程序,在terminal中按ctrl+c退出,再按上键打开端口
add_image()方法的使用(常用来观察训练结果)
准备:将练手数据集里的解压到项目目录下新建的data文件夹中
def add_image(self, tag, img_tensor, global_step=None):
- tag:对应图像的title
- img_tensor:图像的数据类型,只能是torch.Tensor、numpy.array、string/blobnaem
- global_step:训练步骤,int 类型
# 打开控制台,其位置就是项目文件夹所在的位置
# 故只需复制相对地址
image_path = "data/train/ants_image/0013035.jpg"
from PIL import Image
img = Image.open(image_path)
print(type(img))
PIL.格式不符合要求。
因此,利用opencv(numpy.array())读取图片,对PIL图片进行转换,活动numpy型图片数据
import numpy as np
img_array=np.array(img)
print(type(img_array)) # numpy.ndarray
在Python控制台输出图片类型:
从PIL到numpy,需要在add_image()中指定shape中每一个数字/维表示的含义
img_tensor默认的图片尺寸格式为(3,H,W),但是一般我们的图片格式为(H,W,3),因此需要对图片格式进行调整
通过print(img_array.shape)以查看img是否为C(通道)H(高度)W(宽度)的形式
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer = SummaryWriter("logs")
image_path = "data/train/ants_image/0013035.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
print(type(img_array))
print(img_array.shape) # (512, 768, 3) (H, W, C)(高度,宽度,通道)
writer.add_image("test", img_array, 1, dataformats="HWC")
# y = 2x
for i in range(100):
writer.add_scalar("y=5x", 5 * i, i)
writer.close()
在一个title下,通过滑块显示每一步的图形,可以直观地观察训练中给model提供了哪些数据,或者想对model进行测试时,可以看到每个阶段的输出结果
如果想要单独显示,重命名一下title即可,即 writer.add_image() 的第一个字符串类型的参数
Tensorforms的使用-主要是对图片进行变换
transforms是torchvision下的一个工具箱,用于格式转化,视觉处理工具,不用于文本
图片经过transforms工具的变换,得到我们想要的一个图像变换结果
解释:根据模具创造工具,使用具体工具根据说明进行输入和输出
按住ctrl,点击transforms
conda install torchvision
它里面有多个工具类:
- Compose类:结合不同的transforms
- ToTensor类:将PIL和numpy类型的图片转为Tensor(可用于训练)
- ToPILImage类:把一个图片转换成PIL Image
- Normalize类:归一化,标准化,用来对数据预处理
- Resize类:尺寸变换
- CenterCrop类:中心裁剪
- Regularize类:正则化,防止模型过拟合的技术
- RandomCrop:随机裁剪。
工具类都有__ call __()方法,具体作用看python中的 call()
在 Python 中,call 是一个特殊的方法,可以让对象像函数一样被调用。换句话说,如果一个类实现了 call 方法,那么它的实例就能像调用普通函数一样被调用。
例子:
class MyClass:
def __init__(self, value):
self.value = value
def __call__(self, x):
return self.value * x
# 创建类的实例
obj = MyClass(10)
# 调用实例,像调用函数一样
result = obj(5) # 等价于 obj.__call__(5)
print(result) # 输出 50
两个问题
python的用法 ——> tensor数据类型
通过 transforms.ToTensor去解决两个问题
- Transforms该如何使用
- Tensor数据类型与其他图片数据类型有什么区别?为什么需要Tensor数据类型
from PIL import Image
from torchvision import transforms
# 绝对路径 D:\PycharmProjects\pythonProject\pytorchlearn\data\train\ants_image\0013035.jpg
# 相对路径 data/train/ants_image/0013035.jpg
img_path="data/train/ants_image/0013035.jpg" #用相对路径,绝对路径里的\在Windows系统下会被当做转义符
# img_path_abs="C:\Users\11842\Desktop\Learn_torch\data\train\ants_image\0013035.jpg",双引号前加r表示转义
img = Image.open(img_path) #Image是Python中内置的图片的库
print(img) # PIL类型
问题一:
transforms 该如何使用(python)
从transforms中选择一个class,对它进行创建,对创建的对象传入图片,即可返回出结果
ToTensor将一个 PIL Image 或 numpy.ndarray 转换为 tensor的数据类型
# 1、Transforms该如何使用
tensor_trans = transforms.ToTensor() #从工具箱transforms里取出ToTensor类,返回tensor_trans对象
tensor_img = tensor_trans(img) #创建出tensor_trans后,传入其需要的参数,即可返回结果。返回一个tensor类型的图片
print(tensor_img)
ctrl+p提示函数参数
问题二:
为什么我们需要 Tensor 数据类型
在Python Console输入:
from PIL import Image
from torchvision import transforms
img_path= "data/train/ants_image/0013035.jpg"
img = Image.open(img_path)
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
Tensor 数据类型包装了反向神经网络所需要的一些理论基础的参数,如:_backward_hooks、_grad等(先转换成Tensor数据类型,再训练)
下载opencv:
python版本要和opencv版本相对应,否则安装的时候会报错
pip install opencv-python==3.4.11.45
两种读取图片的方式
- PIL Image
from PIL import Image
img_path = "xxx"
img = Image.open(img_path)
img.show()
- numpy.ndarray(通过opencv)
import cv2
cv_img=cv2.imread(img_path)
上节课以 numpy.array 类型为例,这节课使用 torch.Tensor 类型:
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
# python的用法 ——> tensor数据类型
# 通过 transforms.ToTensor去解决两个问题
# 1、Transforms该如何使用
# 2、Tensor数据类型与其他图片数据类型有什么区别?为什么需要Tensor数据类型
# 绝对路径 C:\Users\11842\Desktop\Learn_torch\data\train\ants_image\0013035.jpg
# 相对路径 data/train/ants_image/0013035.jpg
img_path="data/train/ants_image/0013035.jpg" #用相对路径,绝对路径里的\在Windows系统下会被当做转义符
# img_path_abs="C:\Users\11842\Desktop\Learn_torch\data\train\ants_image\0013035.jpg",双引号前加r表示转义
img = Image.open(img_path) #Image是Python中内置的图片的库
#print(img)
writer = SummaryWriter("logs")
# 1、Transforms该如何使用
tensor_trans = transforms.ToTensor() #从工具箱transforms里取出ToTensor类,返回tensor_trans对象
tensor_img = tensor_trans(img) #创建出tensor_trans后,传入其需要的参数,即可返回结果
#print(tensor_img)
writer.add_image("Tensor_img",tensor_img) # .add_image(tag, img_tensor, global_step)
# tag即名称
# img_tensor的类型为torch.Tensor/numpy.array/string/blobname
# global_step为int类型
writer.close()
常见的transforms
图片有不同的格式,打开方式也不同
图片格式 | 打开方式 |
---|---|
PIL | Image.open() ——Python自带的图片打开方式 |
tensor | ToTensor() |
narrays | cv.imread() ——Opencv |
Compose的使用
把不同的 transforms 结合在一起,后面接一个数组,里面是不同的transforms
Example:图片首先要经过中心裁剪,再转换成Tensor数据类型
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
Python中 call 的用法
class Person:
def __call__(self, name):
print("call"+name)
def hello(self,name):
print("hi"+name)
person = Person()
# 如果定义了call方法可以对象名(传入参数)来调用
person("zhangsan")
person.hello("lisi")
ToTensor的使用
把 PIL Image 或 numpy.ndarray 类型转换为 tensor 类型(TensorBoard 必须是 tensor 的数据类型)(运行前要先把之前的logs进行删除)
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
writer = SummaryWriter("logs")
img = Image.open("images/pytorch.png")
print(img) # 可以看到类型是PIL
# ToTensor的使用
trans_totensor = transforms.ToTensor() # 将类型转换为tensor
img_tensor = trans_totensor(img) # img变为tensor类型后,就可以放入TensorBoard当中
writer.add_image("ToTensor", img_tensor)
writer.close()
ToPILImage 的使用
把 tensor 数据类型或 ndarray 类型转换成 PIL Image
Normalize 的使用
用平均值/标准差归一化 tensor 类型的 image(输入)
图片RGB三个信道,将每个信道中的输入进行归一化
output[channel] = (input[channel] - mean[channel]) / std[channel]
设置 mean 和 std 都为0.5,则 output= 2*input -1。如果 input 图片像素值为0~1范围内,那么结果就是 -1 ~1之间
加入step值:
第一步
#Normalize的使用
print(img_tensor[0][0][0]) # 第0层第0行第0列
trans_norm = transforms.Normalize([4,6,7],[3,2,6]) # mean,std,因为图片是RGB三信道,故传入三个数
img_norm = trans_norm(img_tensor) # 输入的类型要是tensor
print(img_norm[0][0][0])
writer.add_image("Normalize",img_norm,1)#第一步
第二步
#Normalize的使用
print(img_tensor[0][0][0]) # 第0层第0行第0列
trans_norm = transforms.Normalize([2,6,7],[1,2,2]) # mean,std,因为图片是RGB三信道,故传入三个数
img_norm = trans_norm(img_tensor) # 输入的类型要是tensor
print(img_norm[0][0][0])
writer.add_image("Normalize",img_norm,2)#第二步
Resize 的使用
输入:PIL Image 将输入转变到给定尺寸
序列:(h,w)高度,宽度
一个整数:不改变高和宽的比例,只单纯改变最小边和最长边之间的大小关系。之前图里最小的边将会匹配这个数(等比缩放)
取消首字母匹配:
一般情况下,你需要输入R,才能提示出Resize
我们想设置,即便你输入的是r,也能提示出Resize,也就是忽略了大小写进行匹配提示
File—> Settings—> 搜索case—> Editor-General-Code Completion-去掉Match case前的√—>Apply—>OK
#Resize的使用
print(img.size) # 输入是PIL.Image
trans_resize = transforms.Resize((512,512))
#img:PIL --> resize --> img_resize:PIL
img_resize = trans_resize(img) #输出还是PIL Image
#img_resize:PIL --> totensor --> img_resize:tensor(同名,覆盖)
img_resize = trans_totensor(img_resize)
writer.add_image("Resize",img_resize,0)
print(img_resize)
Compose() 中的参数需要是一个列表,Python中列表的表示形式为[数据1,数据2,…]
在Compose中,数据需要是transforms类型,所以得到Compose([transforms参数1,transforms参数2,…])
#Compose的使用(将输出类型从PIL变为tensor类型,第二种方法)
trans_resize_2 = transforms.Resize(512) # 将图片短边缩放至512,长宽比保持不变
# PIL --> resize --> PIL --> totensor --> tensor
#compose()就是把两个参数功能整合,第一个参数是改变图像大小,第二个参数是转换类型,前者的输出类型与后者的输入类型必须匹配
trans_compose = transforms.Compose([trans_resize_2,trans_totensor])
img_resize_2 = trans_compose(img) # 输入需要是PIL Image
writer.add_image("Resize",img_resize_2,1)
RandomCrop的使用
随机裁剪,输入PIL Image
参数size:
- sequence:(h,w) 高,宽
- int:裁剪一个该整数×该整数的图像
(1)以 int 为例:
#RandomCrop()的使用
trans_random = transforms.RandomCrop(512)
trans_compose_2 = transforms.Compose([trans_random,trans_totensor])
for i in range(10): #裁剪10个
img_crop = trans_compose_2(img) # 输入需要是PIL Image
writer.add_image("RandomCrop",img_crop,i)
(2)以 sequence 为例:
#RandomCrop()的使用
trans_random = transforms.RandomCrop((200,500))
trans_compose_2 = transforms.Compose([trans_random,trans_totensor])
for i in range(10): #裁剪10个
img_crop = trans_compose_2(img)
writer.add_image("RandomCropHW",img_crop,i)
touchvision中的数据集使用
需要学习知识:
- 如何把数据集(多张图片)和 transforms 结合在一起。
- 标准数据集如何下载、查看、使用。
各个模块作用
(1)torchvision.datasets
如:COCO 目标检测、语义分割;MNIST 手写文字;CIFAR 物体识别
(2)torchvision.io
输入输出模块,不常用
(3)torchvision.models
提供一些比较常见的神经网络,有的已经预训练好,比较重要,后面会使用到,如分类模型、语义分割模型、目标检测、视频分类等
(4)torchvision.ops
torchvision提供的一些比较少见的特殊的操作,基本不常用
(5)torchvision.transforms
之前讲解过
(6)torchvision.utils
提供一些常用的小工具,如TensorBoard
本节主要讲解torchvision.datasets,以及它如何跟transforms联合使用
1.数据集如何下载
#如何使用torchvision提供的标准数据集
import torchvision
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True) #root使用相对路径,会在该.py所在位置创建一个叫dataset的文件夹,同时把数据保存进去。用Ctrl加P查看需要参数。
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
运行后,控制台中就会显示正在下载数据集
数据集下载过慢时:
获得下载链接后,把下载链接放到迅雷中,会首先下载压缩文件tar.gz,之后会对该压缩文件进行解压,里面会有相应的数据集。
采用迅雷下载完毕后,在PyCharm里新建directory,名字也叫dataset,再将下载好的压缩包复制进去,download依然为True,运行后,会自动解压该数据
注意dataset里面不要解压完,要放压缩包,然后在运行这个代码,不然他会重新下载
实际上首先下载的是下面这个压缩文件,然后会对其进行解压
2.数据集如何查看与使用
注意虽然运行代码时文件中还有上面两行,且download=True,但是会自动校验已下载,也就是说不会产生影响,所以可以放着不管
import torchvision
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
print(test_set[0]) # 查看测试集中的第一个数据,是一个元组:(img, target)
print(test_set.classes) # 列表
img,target = test_set[0]
print(img)
print(target) # 输出:3。输出为列表第几个类别。从0开始数,这里类别为cat列表第四个
print(test_set.classes[target]) # cat
img.show()
在 PyTorch 的 torchvision 库中,CIFAR10 数据集是一个继承自 torch.utils.data.Dataset 的类。这个类实现了 __ getitem__ 方法(定义了如何获取单个样本),因此当你访问 test_set[0] 时,它实际上调用的是 __ getitem__,并返回一个元组。
在 Python 中,元组(tuple) 是一种 不可变的 序列类型,它可以存储多个元素。元组与列表(list)非常相似,不同之处在于元组一旦创建就 不能修改,而列表是可变的。元组常用于存储一组不希望被修改的数据。元组中的元素可以是不同类型的(list也可以不同),比如整数、字符串、浮点数等。
元组的创建:元组使用圆括号 () 创建,元素之间用逗号 , 分隔。
# 创建一个包含多个元素的元组
my_tuple = (1, 2, 3, "apple", 3.14)
# 创建单个元素的元组时,需要加上一个逗号
single_element_tuple = (5,)
# 不加逗号,括号会被认为是表达式
not_a_tuple = (5)
元组的用途:元组常用于函数返回多个值的场景,因为可以一次性返回多个元素。
def get_position():
return (10, 20)
position = get_position()
print(position) # 输出: (10, 20)
函数返回值的类型
- 圆括号 ():表示元组。
- 方括号 []:表示列表。
这个 __ getitem __ 方法属于一个自定义数据集类,通常用在 PyTorch 或其他类似框架中,用来在索引处返回数据集中的样本(如图像和标签)。这是实现自定义数据集的重要方法之一。以下是代码的详细解析:
index: int:该方法接受一个整数 index,表示要获取的数据样本的索引。
返回类型 Tuple[Any, Any]:它返回一个 元组 (image, target),其中:
image 是图像数据,target 是与该图像对应的类别标签。
Any 在类型注解中表示可以是任意类型,通常 image 会是图像对象,而 target 是整数表示的类别标签。
img, target = self.data[index], self.targets[index]
self.data[index]:从数据集中提取索引为 index 的图像数据。
self.targets[index]:提取与该图像对应的目标标签。target 通常是一个表示类别的整数,类似于分类任务中图像的类别编号。
转换为 PIL 图像:
img = Image.fromarray(img)
Image.fromarray(img):将图像数据从数组格式(通常是 NumPy 数组)转换为 PIL 图像对象,这一步是为了保证返回的数据与其他图像处理流程(如数据增强)兼容。PIL 是 Python Imaging Library 的简称,常用于图像处理。
应用图像变换:
if self.transform is not None:
img = self.transform(img)
如果 self.transform 不为空,则将其应用于图像 img。这是常见的图像预处理步骤,transform 通常是数据增强操作,如旋转、裁剪、归一化等。
应用目标变换:
if self.target_transform is not None:
target = self.target_transform(target)
如果 self.target_transform 不为空,则将其应用于目标 target。目标变换通常用于修改标签的格式,比如将类别标签转换为独热编码,或进行其他处理。
返回值:
return img, target
最后返回一个元组 (img, target),其中 img 是经过可能的转换后的图像,target 是与图像对应的标签。
小结:
这个 __ getitem __ 方法的作用是:
根据给定的索引,从数据集中提取图像和标签。
将图像从数组格式转换为 PIL 图像,以与其他图像数据处理兼容。
根据需要对图像和标签应用预处理(transform 和 target_transform)。
返回处理后的图像和标签,作为一个元组 (image, target)。
3.CIFAR10数据集介绍
CIFAR10 数据集包含了6万张32×32像素的彩色图片,图片有10个类别,每个类别有6千张图像,其中有5万张图像为训练图片,1万张为测试图片。
如何把数据集(多张图片)和 transforms 结合在一起
CIFAR10数据集原始图片是PIL Image,如果要给pytorch使用,需要转为tensor数据类型(转成tensor后,就可以用tensorboard了)
transforms 更多地是用在 datasets 里 transform 的选项中
import torchvision
from torch.utils.tensorboard import SummaryWriter
#把dataset_transform运用到数据集中的每一张图片,都转为tensor数据类型
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True) #root使用相对路径,会在该.py所在位置创建一个叫dataset的文件夹,同时把数据保存进去
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
# print(test_set[0])
writer = SummaryWriter("logs")
#显示测试数据集中的前10张图片
for i in range(10):
img,target = test_set[i]
writer.add_image("test_set",img,i) # img已经转成了tensor类型
writer.close()
Dataloader的使用
Dataloader每次从dataset中取数据
参数介绍
参数如下(大部分有默认值,实际中只需要设置少量的参数即可):
- dataset:只有dataset没有默认值,只需要将之前自定义的dataset实例化,再放到dataloader中即可
- batch_size:每次抓牌抓几张
- shuffle:设置为True在每个 epoch 重新洗牌数据(默认值:False),但一般用True
- num_workers:加载数据时采用单个进程还是多个进程,多进程的话速度相对较快,默认为0(主进程加载)。Windows系统下该值>0会有问题(报错提示:BrokenPipeError)
- drop_last:100张牌每次取3张,最后会余下1张,这时剩下的这张牌是舍去还是不舍去。值为True代表舍去这张牌、不取出,False代表要取出该张牌
import torchvision
from torch.utils.data import DataLoader
#准备的测试数据集
test_data = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor)
test_loader = DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
#测试数据集中第一张图片及target
img,target = test_data[0]
print(img.shape)
print(target)
输出结果:
torch.Size([3, 32, 32]) #三通道,32×32大小
3 #类别为3
dataset
-
__ getitem() __:return img,target
dataloader(batch_size=4):从dataset中取4个数据 -
img0,target0 = dataset[0]
-
img1,target1 = dataset[1]
-
img2,target2 = dataset[2]
-
img3,target3 = dataset[3]
把 img 0-3 进行打包,记为imgs;target 0-3 进行打包,记为targets;作为dataloader中的返回
for data in test_loader:
imgs,targets = data
print(imgs.shape)
print(targets)
输出:
torch.Size([4, 3, 32, 32]) #4张图片,三通道,32×32
tensor([0, 4, 4, 8]) #4个target进行一个打包
数据是随机取的(断点debug一下,可以看到采样器sampler是随机采样的),所以两次的 target 0 并不一样
batch_size
对于打包的图片展示,使用的方法是add_images()方法,单张图片展示使用add_image()方法
# 用上节课torchvision提供的自定义的数据集
# CIFAR10原本是PIL Image,需要转换成tensor
import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())
# 加载测试集
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
#batch_size=4,意味着每次从test_data中取4个数据进行打包
writer = SummaryWriter("dataloader")
step=0
for data in test_loader:
imgs,targets = data #imgs是tensor数据类型
writer.add_images("test_data",imgs,step)
step=step+1
writer.close()
由于 drop_last 设置为 False,所以最后16张图片(没有凑齐64张)显示如下:
drop_last
若将 drop_last 设置为 True,最后16张图片(step 156)会被舍去,结果如图:
shuffle
一个 for data in test_loader 循环,就意味着打完一轮牌(抓完一轮数据),在下一轮再进行抓取时,第二次数据是否与第一次数据一样。值为True的话,会重新洗牌(一般都设置为True)
在外面再套一层 for epoch in range(2) 的循环
shuffle为False的话两轮取的图片顺序是一样的
# shuffle为True
for epoch in range(2):
step=0
for data in test_loader:
imgs,targets = data #imgs是tensor数据类型
writer.add_images("Epoch:{}".format(epoch),imgs,step)
step=step+1
原文地址:https://blog.csdn.net/m0_51448653/article/details/142286300
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!