自学内容网 自学内容网

常用的损失函数pytorch实现

梯度损失: 

先利用sobel算子计算梯度,然后计算计算出来梯度的一范数。

实现:
定义Sobelxy类。

class Sobelxy(nn.Module):
    def __init__(self):
        super(Sobelxy, self).__init__()
        kernelx = [[-1, 0, 1],
                  [-2,0 , 2],
                  [-1, 0, 1]]
        kernely = [[1, 2, 1],
                  [0,0 , 0],
                  [-1, -2, -1]]
        kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
        kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
        self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
        self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
    def forward(self,x):
        sobelx=F.conv2d(x, self.weightx, padding=1)
        sobely=F.conv2d(x, self.weighty, padding=1)
        return torch.abs(sobelx)+torch.abs(sobely)

实例化类:

sobelconv = Sobelxy()

计算损失:

L_T_grad = sobelconv(L)
gradient_loss = beta * torch.sum(torch.abs(L_T_grad))

备注:计算的是L的梯度损失。

vgg感知损失:

模型的init中:

self.vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(opts.device)

在训练的迭代中调用:

vgg_loss = torch.norm(model.vgg(R) - model.vgg(hat_R), p=1)

备注:计算的是用预训练好的vgg提取的R和hat_R的高级特征损失

结构损失:SSIM

调用封装好的SSIM基本都会失败,所以要自己写。
SSIM类:

class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.val_range = val_range

        # Assume 1 channel for SSIM
        self.channel = 1
        self.window = create_window(window_size)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel

        return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)


类的实例化:

ssim_metric = SSIM(window_size=11, size_average=True, val_range=None)
ssim_value = ssim_metric(R, hat_R)

备注:算的是R和hat_R的结构损失。
 


原文地址:https://blog.csdn.net/qq_46012097/article/details/143697169

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!