12.06 深度学习-预训练
# 使用更深的神经网络 经典神经网络
import torch
import cv2
from torchvision.models import resnet18,ResNet18_Weights
from torch import optim,nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
def demo1():
data_train=CIFAR10(root="assets",download=False,train=True,transform=transforms.Compose([transforms.ToTensor()]))
# # 获取权重
# weight=ResNet18_Weights.IMAGENET1K_V1 # 1000分类的权重文件
# net1=resnet18(weights=weight) #设置这个模型的权重
# # 把权重保存了 这里不能直接训练 因为这个net1的fc还不是 10输出
# torch.save(net1.state_dict(),"assets/model_pre.pt")
# return
# 获取模型
net1=resnet18(weights=None)
# 获取fc 的输入特征数 迁移学习 是网络结构有变化的 如果没有变化就是继续训练 就不是迁移学习
in_features=net1.fc.in_features
# 可以去改模型的层次结构 根据自己的数据来 被改的层次都要重新进行训练 不仅是fc了 被改的层次不能冻结 而且权重参数也要删掉
net1.fc=nn.Linear(in_features=in_features,out_features=10,bias=True)
net1.conv1=nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), bias=False)
# 加载预训练权重
state_dict=torch.load("assets/model_pre.pt",weights_only=True)
# 线性层 的预训练权重不需要
state_dict.pop("fc.weight")
state_dict.pop("fc.bias")
state_dict.pop("conv1.weight")
# state_dict.pop("conv1.bias")
# 更新模型的权重参数
# net1.load_state_dict(state_dict) # 会少keys 不能用
my_weight=net1.state_dict()
my_weight.update(state_dict)
net1.load_state_dict(my_weight)
# 冻结层的使用 现在这个模型net1 只有fc需要进行训练 其他层都训练好了 给其他层冻结了只训练fc层 model.parameters()返回model的每一层权重和偏置的tensor 的迭代器 可以遍历它 named_parameters多返回一个名字
# 把要冻结的层的权重和偏置的tensor的requires_grad=True 全设为Flase 在把model.parameters()给优化器 要过滤掉 requires_grad=Flase的
for name,param in net1.named_parameters():
param.requires_grad=False
for name,param in net1.named_parameters():
if name =="fc.weight" or name =="fc.bias" :
param.requires_grad=True
for name,param in net1.named_parameters():
if name =="conv1.weight" :
param.requires_grad=True
# 过滤掉 requires_grad=Flase的 权重参数
true_weight=filter(lambda p:p.requires_grad,net1.parameters())
dataLoader1=DataLoader(data_train,batch_size=16,shuffle=True)
# 循环轮次
epochs=2
# 优化器
optim1=optim.Adam(true_weight,lr=0.01)
# 损失函数
loss_func=nn.CrossEntropyLoss()
# 开始训练
for i in range(epochs):
for x_train,y_train in dataLoader1:
# 前向传播
y_pre=net1(x_train)
# 损失
loss=loss_func(y_pre,y_train)
# 清空梯度
optim1.zero_grad()
# 反向
loss.backward()
# 梯度更新
optim1.step()
torch.save(net1.state_dict(),"assets/model3.pt")
# 预训练 先用一组数据 对模型进行训练 然后在 把这个模型拿出来继续训练
# resnet18 有一个1000分类的预训练数据 这个数据拿过来改 把resnet18模型的线性层改为10分类 然后再把1000分类的预训练数据初始化给这个模型 进行再训练
# 需要注意的是 先初始化一个这个1000分类的模型 然后保存他的权重
# # # 获取权重
# weight=ResNet18_Weights.IMAGENET1K_V1
# net1=resnet18(weights=weight)
# # 把权重保存了 这里不能直接训练 因为这个net1的fc还不是 10输出
# torch.save(net1.state_dict(),"assets/model_pre.pt")
print("完成")
# 在初始化另一个来改fc
pass
def demo2(): # 用训练的模型 对图片进行分类
# 获得模型
net1=resnet18(weights=None)
in_features=net1.fc.in_features
net1.fc=nn.Linear(in_features=in_features,out_features=10,bias=True)
# 加载模型数据
net1.load_state_dict(torch.load("assets/model3.pt",weights_only=True))
# 加载图片数据 训练数据 是一个二维的 数组 RGB
img=cv2.imread("assets/qw.jpg")
img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img=cv2.resize(img,(32,32))
# 转为tensor
img=torch.tensor(img,dtype=torch.float32)
# 换维度
img=img.permute(2,0,1)
# 升一个维度
img=img.unsqueeze(0)
# print(img.shape)
# 推理
net1.eval()
with torch.no_grad():
res=net1(img)
func=nn.Softmax()
res=func(res)
print(res)
print(torch.argmax(res,dim=1))
pass
def demo3():
# 获得模型
net1=resnet18(weights=None)
in_features=net1.fc.in_features
net1.fc=nn.Linear(in_features=in_features,out_features=10,bias=True)
# 加载模型数据
net1.load_state_dict(torch.load("assets/model3.pt",weights_only=True))
data_test=CIFAR10(root="assets",download=False,train=False,transform=transforms.Compose([transforms.ToTensor()]))
data_loader1=DataLoader(data_test,shuffle=True,batch_size=32)
acc=0
i=0
for x_test,y_test in data_loader1:
# 推理
net1.eval()
with torch.no_grad():
res=net1(x_test)
func=nn.Softmax()
res=func(res)
res=torch.argmax(res,dim=1)
acc+=sum(res==y_test)/len(y_test)
i+=1
print(acc/i)
pass
if __name__=="__main__":
demo1()
# demo2()
# demo3()
pass
原文地址:https://blog.csdn.net/2401_86807530/article/details/144298090
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!