unet20241110
import torch import torch.nn as nn import torch.nn.functional as F import math import copy from einops import rearrange class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class single_conv(nn.Module): def __init__(self, in_ch, out_ch): super(single_conv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 5, padding=2), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class up(nn.Module): def __init__(self, in_ch): super(up, self).__init__() self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, 2, stride=2) self.conv_concat = nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1) self.relu = nn.ReLU(inplace=True) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) x = torch.cat((x1, x2), dim=1) x = self.conv_concat(x) x = self.relu(x) return x # diff import torch import torch.nn as nn class EMA(nn.Module): def __init__(self, channels, c2=None, factor=32): super(EMA, self).__init__() self.groups = factor assert channels // self.groups > 0 self.softmax = nn.Softmax(-1) self.agp = nn.AdaptiveAvgPool2d((1, 1)) self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None)) self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups) self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=5, stride=1, padding=2) self.conv1x1_3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1) self.conv1x1_7 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=7, stride=1, padding=3) self.conv1x1_c = nn.Conv2d(channels // self.groups*3, channels // self.groups, kernel_size=5, stride=1, padding=2) self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=5, stride=1, padding=2) self.conv3x3_3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1) self.conv3x3_7 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=7, stride=1, padding=3) self.conv3x3_c = nn.Conv2d(channels // self.groups*3, channels // self.groups, kernel_size=5, stride=1, padding=2) def forward(self, x): b, c, h, w = x.size() group_x = x.reshape(b * self.groups, -1, h, w) # b*g,c//g,h,w x_h = self.pool_h(group_x) x_w = self.pool_w(group_x).permute(0, 1, 3, 2) hw_5 = self.conv1x1(torch.cat([x_h, x_w], dim=2)) hw_3 = self.conv1x1_3(torch.cat([x_h, x_w], dim=2)) hw_7 = self.conv1x1_7(torch.cat([x_h, x_w], dim=2)) hw = torch.cat([hw_5, hw_3, hw_7], dim=1) hw = self.conv1x1_c(hw) x_h, x_w = torch.split(hw, [h, w], dim=2) x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid()) x2_5 = self.conv3x3(group_x) x2_3 = self.conv3x3_3(group_x) x2_7 = self.conv3x3_7(group_x) x2 = torch.cat([x2_5, x2_3, x2_7], dim=1) x2 = self.conv3x3_c(x2) x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1)) x12 = x2.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1)) x22 = x1.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w) return (group_x * weights.sigmoid()).reshape(b, c, h, w) # diff class down(nn.Module): def __init__(self, in_channels, dilation_rates=(1, 2, 3)): super(down, self).__init__() self.conv = nn.ModuleList([ nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=d, dilation=d, bias=False), nn.ReLU(inplace=True) ) for d in dilation_rates ]) self.conv_2 = nn.Conv2d(in_channels * 3, in_channels, 3, 1, 1) self.relu = nn.ReLU(inplace=True) def forward(self, x): outputs = [conv(x) for conv in self.conv] out = torch.cat(outputs, dim=1) out = self.conv_2(out) out = self.relu(out) return out class outconv(nn.Module): def __init__(self, in_ch, out_ch): super(outconv, self).__init__() self.conv = nn.Conv2d(in_ch, out_ch, 1) def forward(self, x): x = self.conv(x) return x class adjust_net(nn.Module): def __init__(self, out_channels=64, middle_channels=32): super(adjust_net, self).__init__() self.model = nn.Sequential( nn.Conv2d(2, middle_channels, 3, padding=1), nn.ReLU(inplace=True), nn.AvgPool2d(2), nn.Conv2d(middle_channels, middle_channels * 2, 3, padding=1), nn.ReLU(inplace=True), nn.AvgPool2d(2), nn.Conv2d(middle_channels * 2, middle_channels * 4, 3, padding=1), nn.ReLU(inplace=True), nn.AvgPool2d(2), nn.Conv2d(middle_channels * 4, out_channels * 2, 1, padding=0) ) def forward(self, x): out = self.model(x) out = F.adaptive_avg_pool2d(out, (1, 1)) out1 = out[:, :out.shape[1] // 2] out2 = out[:, out.shape[1] // 2:] return out1, out2 # The architecture of U-Net refers to "Toward Convolutional Blind Denoising of Real Photographs", # official MATLAB implementation: https://github.com/GuoShi28/CBDNet. # unofficial PyTorch implementation: https://github.com/IDKiro/CBDNet-pytorch/tree/master. # We improved it by adding time step embedding and EMM module, while removing the noise estimation network. class UNet(nn.Module): def __init__(self, in_channels=2, out_channels=1): super(UNet, self).__init__() dim = 32 self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) ) self.inc = nn.Sequential( single_conv(in_channels, 64), single_conv(64, 64) ) self.down1 = down(64) self.mlp1 = nn.Sequential( nn.GELU(), nn.Linear(dim, 64) ) self.adjust1 = adjust_net(64) self.conv1 = nn.Sequential( single_conv(64, 128), single_conv(128, 128), single_conv(128, 128) ) self.down2 = down(128) self.mlp2 = nn.Sequential( nn.GELU(), nn.Linear(dim, 128) ) self.adjust2 = adjust_net(128) self.conv2 = nn.Sequential( single_conv(128, 256), single_conv(256, 256), single_conv(256, 256), single_conv(256, 256), single_conv(256, 256), single_conv(256, 256) ) self.up1 = up(256) self.mlp3 = nn.Sequential( nn.GELU(), nn.Linear(dim, 128) ) self.adjust3 = adjust_net(128) self.conv3 = nn.Sequential( single_conv(128, 128), single_conv(128, 128), single_conv(128, 128) ) self.up2 = up(128) self.mlp4 = nn.Sequential( nn.GELU(), nn.Linear(dim, 64) ) self.adjust4 = adjust_net(64) self.conv4 = nn.Sequential( single_conv(64, 64), single_conv(64, 64) ) self.outc = outconv(64, out_channels) self.ema1 = EMA(channels=128) self.ema2 = EMA(channels=256) self.ema3 = EMA(channels=128) self.ema4 = EMA(channels=64) def forward(self, x, t, x_adjust, adjust): inx = self.inc(x) time_emb = self.time_mlp(t) down1 = self.down1(inx) condition1 = self.mlp1(time_emb) b, c = condition1.shape condition1 = rearrange(condition1, 'b c -> b c 1 1') if adjust: gamma1, beta1 = self.adjust1(x_adjust) down1 = down1 + gamma1 * condition1 + beta1 else: down1 = down1 + condition1 conv1 = self.conv1(down1) # diff: 通过EMA模块处理conv1的输出 conv1 = self.ema1(conv1) down2 = self.down2(conv1) condition2 = self.mlp2(time_emb) b, c = condition2.shape condition2 = rearrange(condition2, 'b c -> b c 1 1') if adjust: gamma2, beta2 = self.adjust2(x_adjust) down2 = down2 + gamma2 * condition2 + beta2 else: down2 = down2 + condition2 conv2 = self.conv2(down2) # diff: 通过EMA模块处理conv2的输出 conv2 = self.ema2(conv2) up1 = self.up1(conv2, conv1) condition3 = self.mlp3(time_emb) b, c = condition3.shape condition3 = rearrange(condition3, 'b c -> b c 1 1') if adjust: gamma3, beta3 = self.adjust3(x_adjust) up1 = up1 + gamma3 * condition3 + beta3 else: up1 = up1 + condition3 conv3 = self.conv3(up1) conv3 = self.ema3(conv3) up2 = self.up2(conv3, inx) condition4 = self.mlp4(time_emb) b, c = condition4.shape condition4 = rearrange(condition4, 'b c -> b c 1 1') if adjust: gamma4, beta4 = self.adjust4(x_adjust) up2 = up2 + gamma4 * condition4 + beta4 else: up2 = up2 + condition4 conv4 = self.conv4(up2) conv4 = self.ema4(conv4) out = self.outc(conv4) return out class Network(nn.Module): def __init__(self, in_channels=3, out_channels=1, context=True): super(Network, self).__init__() self.unet = UNet(in_channels=in_channels, out_channels=out_channels) self.context = context def forward(self, x, t, y, x_end, adjust=True): if self.context: x_middle = x[:, 1].unsqueeze(1) else: x_middle = x x_adjust = torch.cat((y, x_end), dim=1) out = self.unet(x, t, x_adjust, adjust=adjust) + x_middle return out # WeightNet of the one-shot learning framework class WeightNet(nn.Module): def __init__(self, weight_num=10): super(WeightNet, self).__init__() init = torch.ones([1, weight_num, 1, 1]) / weight_num self.weights = nn.Parameter(init) def forward(self, x): weights = F.softmax(self.weights, 1) out = weights * x out = out.sum(dim=1, keepdim=True) return out, weights # Main function to test the model def main(): # Set random seed for reproducibility torch.manual_seed(0) # Instantiate the model model = Network(in_channels=2, out_channels=1, context=True) # Set the model to evaluation mode (since we're debugging dimensions) model.eval() # Define input dimensions batch_size = 2 num_channels = 2 # Input channels for x height = 64 width = 64 # Create sample inputs x = torch.randn(batch_size, num_channels, height, width) t = torch.randint(0, 1000, (batch_size,)) y = torch.randn(batch_size, 1, height, width) x_end = torch.randn(batch_size, 1, height, width) # Pass the inputs through the model with torch.no_grad(): output = model(x, t, y, x_end, adjust=True) # Print the output shape print(f"Final output shape: {output.shape}") if __name__ == "__main__": main()
原文地址:https://blog.csdn.net/yyfhq/article/details/143658285
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!