自学内容网 自学内容网

【Pytorch】15.自定义验证照片测试自定义的CIFAR10网络

输入照片的处理

因为CIFAR10对输入照片要求的格式为tensor(1,3,32,32)而我们在网上找到的图片基本都不满足要求,所以我们需要对网络上找到的图片先进行处理

比如我们找到一个猫的照片
在这里插入图片描述

img_path = "../imgs/cat.jpeg"

img = Image.open(img_path)

# <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=358x312 at 0x1009DAA90>
# 这里我们可以看到默认的格式不是RGB格式的 ,而我们训练出的数据集只能处理三通道,所以我们需要对通道数由RGBA转化为RGB形式
print(img)

# 将图片转化为RGB格式
img = img.convert('RGB')

# <PIL.Image.Image image mode=RGB size=358x312 at 0x103002BE0>
print(img)

# 定义一个转化规则为transform,将图像转化为32x32像素,并且转化为tensor格式
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])
img_tensor = transform(img)

# torch.Size([3, 32, 32])
print(img_tensor.shape)

img_tensor = torch.reshape(img_tensor, (1, 3, 32, 32))
# torch.Size([1, 3, 32, 32])
print(img_tensor.shape)

我们需要上面一系列的操作才能将图片的地址转化为tensor[1,3,32,32]

导入训练好的神经网络

在上一节的训练中,我们已经成功获得了训练30轮的CIFAR10神经网络,我们需要将训练好的网络加载到当前文件中,具体可以看【Pytorch】12.网络模型的加载、修改与保存

model = CIFAR10Model()
model = torch.load('../models/cifar10_model30.pth', map_location='cpu')

output = model(img_tensor)

经过这个步骤,我们就可以得到当前图片在CIFAR10数据集中10分类的哪个概率最大了,然后我们通过output.argmax(1)来获取最大概率的下标,然后根据下标来对应数据集的元素

print(output.argmax(1))

dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=False,
                                       transform=torchvision.transforms.ToTensor())
print(dataset.classes[output.argmax(1)])

我们可以看到输出了cat
在这里插入图片描述

完整代码

import torch
import torchvision
from PIL import Image
from CIFAR10Model import *

# 图片目录
img_path = "../imgs/dog.png"

img = Image.open(img_path)

# <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=358x312 at 0x1009DAA90>
# 这里我们可以看到默认的格式不是RGB格式的 ,而我们训练出的数据集只能处理三通道,所以我们需要对通道数由RGBA转化为RGB形式
print(img)

# 将图片转化为RGB格式
img = img.convert('RGB')

# <PIL.Image.Image image mode=RGB size=358x312 at 0x103002BE0>
print(img)

# 定义一个转化规则为transform,将图像转化为32x32像素,并且转化为tensor格式
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])
img_tensor = transform(img)

# torch.Size([3, 32, 32])
print(img_tensor.shape)

img_tensor = torch.reshape(img_tensor, (1, 3, 32, 32))
# torch.Size([1, 3, 32, 32])
print(img_tensor.shape)

model = CIFAR10Model()
model = torch.load('../models/cifar10_model30.pth', map_location='cpu')

output = model(img_tensor)
# tensor([2])
print(output.argmax(1))

dataset = torchvision.datasets.CIFAR10(root='../dataset', train=False, download=False,
                                       transform=torchvision.transforms.ToTensor())
print(dataset.classes[output.argmax(1)])

原文地址:https://blog.csdn.net/Elephant_King/article/details/139032158

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!