简单搭建卷积神经网络实现手写数字10分类
搭建卷积神经网络实现手写数字10分类
1.思路流程
1.导入minest数据集
2.对数据进行预处理
3.构建卷积神经网络模型
4.训练模型,评估模型
5.用模型进行训练预测
一.导入minest数据集
MNIST--->raw--->test-->(0,1,2...) 10个文件夹
MNIST--->raw--->train-->(0,1,2...) 10个文件夹
共60000张图片.可自己去网上下载
二.对数据进行预处理
----读取图片,将图片先转为张量。
img=cv2.imread(path)
----将图片进行归一化,即将像素值标准化到0-1之间
img_tensor=util.transforms_train(img)
----裁剪,翻转等,实现数据增强。
数据增强:通过对原始图像进行旋转、翻转等操作,可以增加数据的多样性。这有助于模型学习到更具泛化性的特征,减少对特定方向或位置的依赖,从而提高模型的鲁棒性和准确性
transforms_train=transforms.Compose([ # transforms.CenterCrop(10), # transforms.PILToTensor(), transforms.ToTensor(),#归一化,转tensor transforms.Resize((28,28)), transforms.RandomVerticalFlip() ])
ps:为什么要归一化
-
消除量纲影响:不同图像的像素值范围可能差异很大。归一化可以将像素值范围统一到一个特定的区间,例如 [0, 1] 或 [-1, 1],消除不同图像之间因像素值范围差异带来的影响,使模型更关注图像的特征和结构,而不是像素值的绝对大小。
-
提高训练稳定性:有助于优化算法的收敛性和稳定性。如果像素值范围较大且分布不均匀,可能导致梯度计算不稳定,从而影响模型的训练效率和效果。
-
缓解过拟合:一定程度上可以减少数据中的噪声和异常值对模型的影响,降低模型对某些特定像素值的过度依赖,从而提高模型的泛化能力,减少过拟合的风险。
三.构建卷积神经网络模型
常见卷积神经网络(CNN),主要由卷积,池化,全连接组成。卷积核在输入图像上滑动,通过卷积运算提取局部特征。卷积核在整个图像上重复使用,大大减少了模型的参数数量,降低了计算复杂度,同时也增强了模型对平移不变性的鲁棒性。池化层对特征进行压缩,提取主要特征,减少噪声和冗余信息。
x=torch.randn(2,3,28,28)
用x表示初始图形的信息。为了简单理解,简单表述。其中
2--->两张图片
3--->图片的通道数是3个,即 RGB
28,28--->图片的宽高是28px 28px
采用以上的神经网络conv为卷积操作,maxpool为池化。Linear为全连接。relu为激活函数。
进入全连接层时需要将展平。torch.Size([2, 16, 5, 5])--->torch.Size([2, 400])
x=torch.flatten(x,1)
因为全连接是只进行的线性的变化。所以要把每张图片的维数参数降为1。
使用print(summary(net, x))可查看网络的层次结构。其中-1就表示自己算,是多少张图片就是多少
输入的的是x=torch.randn(2,3,28,28),最终输出的是(2,10)
四.训练模型,评估模型
需要初始化之前的数据和网络,然后选择合适的优化器和损失函数,学习率和加载图片的批次去训练模型。使用loss_avg和accurary来评估模型的性能。对于pytorch来说优化器可以实现自动梯度清0,自动更新参数。我们需要主要的是就实现其中的维度的转化。loss越小越接近真实值。其中计算精度的方法使用one-hot编码。其中0表示[0,0,0,0,0,0,0,0,0,0],1表示[0,1,0,0,0,0,0,0,0,0],2表示[0,0,1,0,0,0,0,0,0,0].。。。其他依次类推。我们把用网络得出的参数,类似[0.1,0.2,0.1,0.5,0,0,0,0,0,0](数据我随便写的),然后用Python的argmax去处最大值的索引与one-hot真实值的索引相比,如果相等就是正确的结果。
----本次实验使用的是MSE损失函数
----lr(学习率)设为0.01
----使用的优化器Adam ,其实其他优化器你也可以随便试试。
Adam 算法的主要优点包括:
-
自适应学习率:能够为每个参数自适应地调整学习率。
-
偏差校正:在初始阶段对梯度估计进行校正,加速初期的学习速率。
-
适应性强:在很多不同的模型和数据集上都表现出良好的性能。
-
实现简单,计算高效,对内存需求少。
使用tensorboard进行可视化
五.用模型进行训练预测
需要读取之前训练好的模型,然后用这个模型来实现预测一个自己手写的图片
# 加载整个模型 loaded_model = torch.load('whole_model.pth') # 保存模型参数 torch.save(loaded_model.state_dict(),'model_params.pth')
代码附上:
dataset.py
import glob import os.path import cv2 import torch import util class DataAndLabel: def __init__(self,path='D:\\0MNIST\\raw',is_train=True): super().__init__() # 拼接路径 #data里面是path,label clas='train' if is_train==True else 'test' path=os.path.join(path,clas) paths=glob.glob(os.path.join(path,'*','*')) # print(paths) # print(path) self.data=[] for path in paths: label=int(path.split('\\')[-2]) self.data.append((path,label)) def __getitem__(self, idx): #返回一个tensor,one-hot path,label =self.data[idx] img=cv2.imread(path) # cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) img_tensor=util.transforms_train(img) one_hot=torch.zeros(10) one_hot[label]=1 return img_tensor,one_hot def __len__(self): return len(self.data) # if __name__ == '__main__': # data=DataAndLabel() # print(data[0]) # print()
lenet5.py
import torch import torch.nn as nn from torchkeras import summary class Net(nn.Module): def __init__(self): super().__init__() self.conv1=nn.Conv2d(3,6,5,1) self.maxpool1=nn.MaxPool2d(2) self.conv2=nn.Conv2d(6,16,3,1) self.maxpool2=nn.MaxPool2d(2) self.layer1=nn.Linear(16*5*5,10) self.layer2=nn.Linear(10,10) self.relu=nn.Softmax() def forward(self,x): x=self.conv1(x) x=self.relu(x) x=self.maxpool1(x) x=self.conv2(x) x=self.relu(x) x=self.maxpool2(x) # print(x.shape) x=torch.flatten(x,1) # print(x.shape) x=self.layer1(x) x=self.layer2(x) return x if __name__ == '__main__': x=torch.randn(2,3,28,28) net=Net() out=net(x) # print(out.shape) # print(summary(net, x))
train_and_test
import torch import tqdm from torch.utils.data import Dataset, DataLoader from torch.utils.tensorboard import SummaryWriter from lenet5 import Net import torch.nn as nn from dataset import DataAndLabel class TrainAndTest(Dataset): def __init__(self): super().__init__() # self.writer=SummaryWriter("logs") net=Net() self.net=net self.loss=nn.MSELoss() self.opt = torch.optim.Adam(net.parameters(), lr=0.1) self.train_data=DataAndLabel(is_train=True) self.test_data=DataAndLabel(is_train=False) self.train_loader=DataLoader(self.train_data,batch_size=100,shuffle=False) self.test_loader=DataLoader(self.test_data,batch_size=100,shuffle=False) # 拿到数据,网络 def train(self,epoch): loss_sum = 0 accurary_sum = 0 for img_tensor, label in tqdm.tqdm(self.train_loader, desc='train...', total=len(self.train_loader)): out = self.net(img_tensor) loss = self.loss(out, label) self.opt.zero_grad() loss.backward() self.opt.step() loss_sum += loss.item() accurary_sum += torch.mean( torch.eq(torch.argmax(label, dim=1), torch.argmax(out, dim=1)).to(torch.float32)).item() loss_avg = loss_sum / len(self.train_loader) accurary_avg = accurary_sum / len(self.train_loader) print(f'train---->loss_avg={round(loss_avg, 3)},accurary_avg={round(accurary_avg, 3)}') # self.writer.add_scalars('loss',{'loss_avg':loss_avg},epoch) def train1(self): sum_loss = 0 sum_acc = 0 for img_tensors, targets in tqdm.tqdm(self.train_loader, desc="train...", total=len(self.train_loader)): out = self.net(img_tensors) loss = self.loss(out, targets) self.opt.zero_grad() loss.backward() self.opt.step() sum_loss += loss.item() pred_cls = torch.argmax(out, dim=1) target_cls = torch.argmax(targets, dim=1) accuracy =torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32)) sum_acc += accuracy.item() avg_loss = sum_loss / len(self.train_loader) avg_acc = sum_acc / len(self.train_loader) print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}') def run(self): for epoch in range(10): self.train1() # self.test(epoch) if __name__ == '__main__': tt=TrainAndTest() tt.run()
util.py
from torchvision import transforms transforms_train=transforms.Compose([ # transforms.CenterCrop(10), # transforms.PILToTensor(), transforms.ToTensor(),#归一化,转tensor transforms.Resize((28,28)), transforms.RandomVerticalFlip() ]) transforms_test=transforms.Compose([ transforms.ToTensor(), # 归一化,转tensor transforms.Resize((28, 28)), ])
原文地址:https://blog.csdn.net/m0_53291740/article/details/140425056
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!