自学内容网 自学内容网

model_MTS2ONet

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchgan.layers import SpectralNorm2d

from ssim import msssim

from vggloss_1 import VGGLoss
import torchvision.models as models

import numpy as np
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import argparse
import lpips

from CBAM import CBAM
import lpips


class SelfAttention(nn.Module):  #diff: 添加自注意力模块
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)  # B x N x C'
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)  # B x C' x N
        energy = torch.bmm(proj_query, proj_key)  # B x N x N
        attention = self.softmax(energy)  # B x N x N
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)  # B x C x N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # B x C x N
        out = out.view(m_batchsize, C, width, height)

        out = self.gamma * out + x
        return out


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x

        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out





class OutConv(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__(
            nn.Conv2d(in_channels, num_classes, kernel_size=1),
            nn.Tanh()
        )

class inConv(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(inConv, self).__init__(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                                     nn.InstanceNorm2d(out_channels))


class Sub_Res_down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Sub_Res_down, self).__init__()
        self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Mish(inplace=True),
                                   nn.Dropout(0.1))
        
        self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Dropout(0.1))
        # self.cbam = CBAM(out_channels, 8, 7)
        self.cbam = CoordAtt(out_channels, out_channels)
        self.relu = nn.Mish(inplace=True)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.InstanceNorm2d(out_channels))

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        residual = x

        out = self.conv1(x)

        out = self.conv2(out)

        out = self.cbam(out)

        out += self.shortcut(residual)
        out = self.relu(out)
        out = self.maxpool(out)

        return out

class Sub_Res_up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Sub_Res_up, self).__init__()
        self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Mish(inplace=True),
                                   nn.Dropout(0.1))
        self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Dropout(0.1))

        # self.cbam = CBAM(out_channels, 8, 7)
        self.cbam = CoordAtt(out_channels, out_channels)
        self.relu = nn.Mish(inplace=True)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.InstanceNorm2d(out_channels))

        self.ConvT = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):

        x = self.ConvT(x)

        residual = x

        out = self.conv1(x)

        out = self.conv2(out)

        out = self.cbam(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False),
                                   nn.GroupNorm(32,out_channels),
                                   nn.Mish(inplace=True),
                                   nn.Dropout(0.1))
        self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.Dropout(0.1))
        # self.cbam = CBAM(out_channels, 8, 7)
        self.cbam = CoordAtt(out_channels, out_channels)
        self.relu = nn.Mish(inplace=True)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.GroupNorm(32, out_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)

        out = self.conv2(out)

        out = self.cbam(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out






class Gen(nn.Module):
    def __init__(self, in_channels=8, out_channels=4):
        super(Gen, self).__init__()

        self.down_1 = Sub_Res_down(2, 64)
        self.down_2 = Sub_Res_down(64, 128)
        self.up_1 = Sub_Res_up(128, 64)
        self.up_2 = Sub_Res_up(64, 32)

        self.OutConv = OutConv(32, 4)
        self.OutConv_1 = OutConv(64,4)
        self.InConv = inConv(2,32)



        # encoder
        self.conv1 = ResNetBlock(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = ResNetBlock(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = ResNetBlock(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = ResNetBlock(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # center
        self.center = ResNetBlock(512, 1024)

        # decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv_decode4 = ResNetBlock(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv_decode3 = ResNetBlock(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv_decode2 = ResNetBlock(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv_decode1 = ResNetBlock(128, 64)

        self.attn = SelfAttention(512)  # diff: 添加自注意力模块


    def forward(self, a, b, c):
        x_0 = c - a
        x_1 = self.down_1(x_0)
        x_2 = self.down_2(x_1)
        x_3 = self.up_1(x_2)
        x_4 = self.up_2(x_3)
        x_r = self.InConv(x_0)
        x_5 = x_r+x_4
        x_6 = self.OutConv(x_5)

        y = torch.cat([x_6, b], dim=1)






        # encoder
        conv1 = self.conv1(y)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)

        attn_output = self.attn(pool4)
        
        # center
        center = self.center(attn_output)

        # decoder
        up4 = self.up4(center)
        concat4 = torch.cat([up4, conv4], dim=1)
        conv_decode4 = self.conv_decode4(concat4)
        up3 = self.up3(conv_decode4)
        concat3 = torch.cat([up3, conv3], dim=1)
        conv_decode3 = self.conv_decode3(concat3)
        up2 = self.up2(conv_decode3)
        concat2 = torch.cat([up2, conv2], dim=1)
        conv_decode2 = self.conv_decode2(concat2)
        up1 = self.up1(conv_decode2)
        concat1 = torch.cat([up1, conv1], dim=1)
        conv_decode1 = self.conv_decode1(concat1)

        # output
        output = self.OutConv_1(conv_decode1)


        return x_6, output


class ReconstructionLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, g=1.0):
        super(ReconstructionLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.vggloss = VGGLoss(4)

    def forward(self, prediction, target):
        loss = (self.alpha * (self.vggloss(prediction, target)) +
                self.gamma * (1.0 - torch.mean(F.cosine_similarity(prediction, target, 1))) +
                self.beta * (1.0 - msssim(prediction, target, normalize=True)))
        return loss


class ResidulBlockWithSpectralNorm_1(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidulBlockWithSpectralNorm_1, self).__init__()
        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, in_channels, 4, 2, 1)),
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, out_channels, 1))
        )
        self.transform = SpectralNorm2d(nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1))

    def forward(self, inputs):
        return self.transform(inputs) + self.residual(inputs)

class ResidulBlockWithSpectralNorm_2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidulBlockWithSpectralNorm_2, self).__init__()
        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, in_channels, 4, 1, 1)),
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, out_channels, 1)),
        )
        self.transform = SpectralNorm2d(nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=1, padding=1))

    def forward(self, inputs):
        return self.transform(inputs) + self.residual(inputs)


class Discriminator(nn.Sequential):
    def __init__(self, channels):
        modules = []
        for i in range(1, (len(channels)-1)):
            modules.append(ResidulBlockWithSpectralNorm_1(channels[i - 1], channels[i]))

        modules.append(nn.Sequential(ResidulBlockWithSpectralNorm_2(channels[-2], channels[-1]),
                       ResidulBlockWithSpectralNorm_2(channels[-1], 1),
                       nn.Sigmoid()))

        super(Discriminator, self).__init__(*modules)


    def forward(self, inputs):
        prediction = super(Discriminator, self).forward(inputs)
        # return prediction.view(-1, 1).squeeze(1)

        return prediction



class MSDiscriminator(nn.Module):
    def __init__(self):
        super(MSDiscriminator, self).__init__()
        self.d1 = Discriminator((12, 64, 128, 256,512))
        self.d2 = Discriminator((12, 128, 256,512))
        self.d3 = Discriminator((12, 256,512))

    def forward(self, inputs):
        l1 = self.d1(inputs)
        l2 = self.d2(F.interpolate(inputs, scale_factor=0.5))
        l3 = self.d3(F.interpolate(inputs, scale_factor=0.25))
        L = l1+l2+l3
        # return torch.mean(torch.stack((l1, l2, l3)))
        return L

原文地址:https://blog.csdn.net/yyfhq/article/details/145191661

免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!