自学内容网 自学内容网

加载完整pytorch .pt模型可能出现的问题

我们在保留训练好的模型model时,可以保留model的参数,也可以直接将完整model全部保留下来。两种方式如下所示:
torch.save(model.state_dict(),PATH)
torch.save(model,PATH)
在加载完整模型的时候,很可能出错,特别是当我们对模型里面数据流经的途径进行了更改的时候。加载的模型会按照最新的数据流定义进行计算,说多无益,举例子如下:

def Net(nn.Module):
 def __init__(self,class_num ):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 100, kernel_size=1, bias=False)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.linear(100,class_num)
 def forward(self,x):
        x = self.conv1(x)
        x = self.gap(x)
        x = self.classifier(x)
        return x

假如我们按照上述的模型定义了模型并且训练好的模型,将其保留为如下形式:

model = Net(10)
训练过程省略
torch.save(model,'zuiniubidemoxing.pt')

现在我们兴致勃勃的昭告天下,我们产生了最牛逼的模型,然后想要在别的地方使用它。但是此处要注意,别以为保存了完整结构的模型就能直接加载使用,不会出错了。例如按照下面这样加载:

model = torch.load("zuiniubidemoxing.pt")

如果我们不对Net()进行更改,那么没问题,但是当我们更改了Net()的数据流时,就很有可能发生问题。例如更改了如下的代码:

def Net(nn.Module):
 def __init__(self,class_num ):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 100, kernel_size=1, bias=False)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.linear(100,class_num)
 def forward(self,x):
        x = self.conv1(x)
        x = self.gap(x)
        x = self.classifier(x)+0.05
        return x

那么现在再load zuiniubidemoxing.pt 使用时就会出问题,所以要注意,一定要保证数据流的形式没问题,要不然直接加载使用也会出问题。


原文地址:https://blog.csdn.net/t20134297/article/details/139094418

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!