model_MTS2ONet
import torch import torch.nn as nn import torch.nn.functional as F from torchgan.layers import SpectralNorm2d from ssim import msssim from vggloss_1 import VGGLoss import torchvision.models as models import numpy as np import torch from torch.autograd import Variable import matplotlib.pyplot as plt import argparse import lpips from CBAM import CBAM import lpips class SelfAttention(nn.Module): #diff: 添加自注意力模块 def __init__(self, in_dim): super(SelfAttention, self).__init__() self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1) self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1) self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self, x): m_batchsize, C, width, height = x.size() proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B x N x C' proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B x C' x N energy = torch.bmm(proj_query, proj_key) # B x N x N attention = self.softmax(energy) # B x N x N proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B x C x N out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # B x C x N out = out.view(m_batchsize, C, width, height) out = self.gamma * out + x return out class h_sigmoid(nn.Module): def __init__(self, inplace=True): super(h_sigmoid, self).__init__() self.relu = nn.ReLU6(inplace=inplace) def forward(self, x): return self.relu(x + 3) / 6 class h_swish(nn.Module): def __init__(self, inplace=True): super(h_swish, self).__init__() self.sigmoid = h_sigmoid(inplace=inplace) def forward(self, x): return x * self.sigmoid(x) class CoordAtt(nn.Module): def __init__(self, inp, oup, reduction=32): super(CoordAtt, self).__init__() self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None)) mip = max(8, inp // reduction) self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) self.bn1 = nn.BatchNorm2d(mip) self.act = h_swish() self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) def forward(self, x): identity = x n, c, h, w = x.size() x_h = self.pool_h(x) x_w = self.pool_w(x).permute(0, 1, 3, 2) y = torch.cat([x_h, x_w], dim=2) y = self.conv1(y) y = self.bn1(y) y = self.act(y) x_h, x_w = torch.split(y, [h, w], dim=2) x_w = x_w.permute(0, 1, 3, 2) a_h = self.conv_h(x_h).sigmoid() a_w = self.conv_w(x_w).sigmoid() out = identity * a_w * a_h return out class OutConv(nn.Sequential): def __init__(self, in_channels, num_classes): super(OutConv, self).__init__( nn.Conv2d(in_channels, num_classes, kernel_size=1), nn.Tanh() ) class inConv(nn.Sequential): def __init__(self, in_channels, out_channels): super(inConv, self).__init__(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.InstanceNorm2d(out_channels)) class Sub_Res_down(nn.Module): def __init__(self, in_channels, out_channels): super(Sub_Res_down, self).__init__() self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False), nn.InstanceNorm2d(out_channels), nn.Mish(inplace=True), nn.Dropout(0.1)) self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False), nn.InstanceNorm2d(out_channels), nn.Dropout(0.1)) # self.cbam = CBAM(out_channels, 8, 7) self.cbam = CoordAtt(out_channels, out_channels) self.relu = nn.Mish(inplace=True) self.shortcut = nn.Sequential() if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.InstanceNorm2d(out_channels)) self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): residual = x out = self.conv1(x) out = self.conv2(out) out = self.cbam(out) out += self.shortcut(residual) out = self.relu(out) out = self.maxpool(out) return out class Sub_Res_up(nn.Module): def __init__(self, in_channels, out_channels): super(Sub_Res_up, self).__init__() self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False), nn.InstanceNorm2d(out_channels), nn.Mish(inplace=True), nn.Dropout(0.1)) self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False), nn.InstanceNorm2d(out_channels), nn.Dropout(0.1)) # self.cbam = CBAM(out_channels, 8, 7) self.cbam = CoordAtt(out_channels, out_channels) self.relu = nn.Mish(inplace=True) self.shortcut = nn.Sequential() if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.InstanceNorm2d(out_channels)) self.ConvT = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) def forward(self, x): x = self.ConvT(x) residual = x out = self.conv1(x) out = self.conv2(out) out = self.cbam(out) out += self.shortcut(residual) out = self.relu(out) return out class ResNetBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ResNetBlock, self).__init__() self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False), nn.GroupNorm(32,out_channels), nn.Mish(inplace=True), nn.Dropout(0.1)) self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False), nn.Dropout(0.1)) # self.cbam = CBAM(out_channels, 8, 7) self.cbam = CoordAtt(out_channels, out_channels) self.relu = nn.Mish(inplace=True) self.shortcut = nn.Sequential() if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.GroupNorm(32, out_channels) ) def forward(self, x): residual = x out = self.conv1(x) out = self.conv2(out) out = self.cbam(out) out += self.shortcut(residual) out = self.relu(out) return out class Gen(nn.Module): def __init__(self, in_channels=8, out_channels=4): super(Gen, self).__init__() self.down_1 = Sub_Res_down(2, 64) self.down_2 = Sub_Res_down(64, 128) self.up_1 = Sub_Res_up(128, 64) self.up_2 = Sub_Res_up(64, 32) self.OutConv = OutConv(32, 4) self.OutConv_1 = OutConv(64,4) self.InConv = inConv(2,32) # encoder self.conv1 = ResNetBlock(in_channels, 64) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = ResNetBlock(64, 128) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv3 = ResNetBlock(128, 256) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv4 = ResNetBlock(256, 512) self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # center self.center = ResNetBlock(512, 1024) # decoder self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) self.conv_decode4 = ResNetBlock(1024, 512) self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.conv_decode3 = ResNetBlock(512, 256) self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.conv_decode2 = ResNetBlock(256, 128) self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.conv_decode1 = ResNetBlock(128, 64) self.attn = SelfAttention(512) # diff: 添加自注意力模块 def forward(self, a, b, c): x_0 = c - a x_1 = self.down_1(x_0) x_2 = self.down_2(x_1) x_3 = self.up_1(x_2) x_4 = self.up_2(x_3) x_r = self.InConv(x_0) x_5 = x_r+x_4 x_6 = self.OutConv(x_5) y = torch.cat([x_6, b], dim=1) # encoder conv1 = self.conv1(y) pool1 = self.pool1(conv1) conv2 = self.conv2(pool1) pool2 = self.pool2(conv2) conv3 = self.conv3(pool2) pool3 = self.pool3(conv3) conv4 = self.conv4(pool3) pool4 = self.pool4(conv4) attn_output = self.attn(pool4) # center center = self.center(attn_output) # decoder up4 = self.up4(center) concat4 = torch.cat([up4, conv4], dim=1) conv_decode4 = self.conv_decode4(concat4) up3 = self.up3(conv_decode4) concat3 = torch.cat([up3, conv3], dim=1) conv_decode3 = self.conv_decode3(concat3) up2 = self.up2(conv_decode3) concat2 = torch.cat([up2, conv2], dim=1) conv_decode2 = self.conv_decode2(concat2) up1 = self.up1(conv_decode2) concat1 = torch.cat([up1, conv1], dim=1) conv_decode1 = self.conv_decode1(concat1) # output output = self.OutConv_1(conv_decode1) return x_6, output class ReconstructionLoss(nn.Module): def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, g=1.0): super(ReconstructionLoss, self).__init__() self.alpha = alpha self.beta = beta self.gamma = gamma self.vggloss = VGGLoss(4) def forward(self, prediction, target): loss = (self.alpha * (self.vggloss(prediction, target)) + self.gamma * (1.0 - torch.mean(F.cosine_similarity(prediction, target, 1))) + self.beta * (1.0 - msssim(prediction, target, normalize=True))) return loss class ResidulBlockWithSpectralNorm_1(nn.Module): def __init__(self, in_channels, out_channels): super(ResidulBlockWithSpectralNorm_1, self).__init__() self.residual = nn.Sequential( nn.BatchNorm2d(in_channels), nn.Mish(), SpectralNorm2d(nn.Conv2d(in_channels, in_channels, 4, 2, 1)), nn.BatchNorm2d(in_channels), nn.Mish(), SpectralNorm2d(nn.Conv2d(in_channels, out_channels, 1)) ) self.transform = SpectralNorm2d(nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)) def forward(self, inputs): return self.transform(inputs) + self.residual(inputs) class ResidulBlockWithSpectralNorm_2(nn.Module): def __init__(self, in_channels, out_channels): super(ResidulBlockWithSpectralNorm_2, self).__init__() self.residual = nn.Sequential( nn.BatchNorm2d(in_channels), nn.Mish(), SpectralNorm2d(nn.Conv2d(in_channels, in_channels, 4, 1, 1)), nn.BatchNorm2d(in_channels), nn.Mish(), SpectralNorm2d(nn.Conv2d(in_channels, out_channels, 1)), ) self.transform = SpectralNorm2d(nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=1, padding=1)) def forward(self, inputs): return self.transform(inputs) + self.residual(inputs) class Discriminator(nn.Sequential): def __init__(self, channels): modules = [] for i in range(1, (len(channels)-1)): modules.append(ResidulBlockWithSpectralNorm_1(channels[i - 1], channels[i])) modules.append(nn.Sequential(ResidulBlockWithSpectralNorm_2(channels[-2], channels[-1]), ResidulBlockWithSpectralNorm_2(channels[-1], 1), nn.Sigmoid())) super(Discriminator, self).__init__(*modules) def forward(self, inputs): prediction = super(Discriminator, self).forward(inputs) # return prediction.view(-1, 1).squeeze(1) return prediction class MSDiscriminator(nn.Module): def __init__(self): super(MSDiscriminator, self).__init__() self.d1 = Discriminator((12, 64, 128, 256,512)) self.d2 = Discriminator((12, 128, 256,512)) self.d3 = Discriminator((12, 256,512)) def forward(self, inputs): l1 = self.d1(inputs) l2 = self.d2(F.interpolate(inputs, scale_factor=0.5)) l3 = self.d3(F.interpolate(inputs, scale_factor=0.25)) L = l1+l2+l3 # return torch.mean(torch.stack((l1, l2, l3))) return L
原文地址:https://blog.csdn.net/yyfhq/article/details/145191661
免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!