PyTorch框架——基于深度学习LYT-Net神经网络AI低光图像增强系统源码
第一步:LYT-Net介绍
本文介绍了LYT-Net,即轻量级YUV Transformer 网络,作为一种新的低光图像增强方法。所提出的架构与传统的基于Retinex的模型不同,它利用YUV颜色空间对亮度(Y)和色度(U和V)的自然分离,简化了在图像中分离光和颜色信息的复杂任务。通过利用 Transformer 捕捉长距离依赖关系的优势,LYT-Net在保持降低模型复杂性的同时,确保了对图像的全面上下文理解。通过采用一种新颖的混合损失函数,LYT-Net在低光图像增强数据集上取得了最先进的结果,同时其体积比其他方法小得多。
LYT-Net采用了YUV色彩空间,这对LLIE来说尤其有利,因为它能将亮度(Y)和色度(U和V)明确分离。通过使用这个色彩空间,作者可以专门针对能在低光条件下提高图像可见性和细节的增强,而不会对颜色信息产生不利影响。由于人眼对亮度的变化更为敏感,因此专注于Y通道可以带来更自然、感知上更吸引人的增强效果。
作者的工作主要贡献可以概括为:
LYT-Net,一个轻量级模型,采用YUV颜色空间进行针对性增强。它在去噪后的亮度层和色度层上使用多头自注意力机制,旨在在处理过程的最后阶段实现更好的融合。
设计了一个混合损失函数,它在模型的高效训练中扮演了关键角色,并对模型的增强能力有显著贡献。
通过定量和定性的实验,LYT-Net在LOL数据集上与现有技术水平(SOTA)方法相比,已显示出强大的性能。
第二步:LYT-Net网络结构
作者展示了LYT-Net的整体架构。如图所示,该模型主要包括一个主要的YUV分解部分,以将色度与亮度分离,之后是几层及可分离的块,如多头自注意力(MHSA)块、多阶段挤压与激活融合(MSEF)块和通道去噪(CWD)块。作者采用双路径方法,将色度和亮度视为独立实体,以帮助模型更好地理解在光照调整和损坏恢复之间的差异。
该模型以RGB格式处理输入图像并将其转换为YUV。每个通道都通过一系列卷积层、池化操作以及MHSA机制单独增强。亮度通道经过卷积和池化提取特征,之后通过MHSA模块进行增强。色度通道和通过CWD块处理以降低噪声同时保留细节。增强后的色度通道被重新组合并通过MSEF块处理。最终,色度与亮度被连接起来,并通过最后一组卷积层生成输出,得到高质量的增强图像。
第三步:模型代码展示
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class LayerNormalization(nn.Module):
def __init__(self, dim):
super(LayerNormalization, self).__init__()
self.norm = nn.LayerNorm(dim)
def forward(self, x):
# Rearrange the tensor for LayerNorm (B, C, H, W) to (B, H, W, C)
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
# Rearrange back to (B, C, H, W)
return x.permute(0, 3, 1, 2)
class SEBlock(nn.Module):
def __init__(self, input_channels, reduction_ratio=16):
super(SEBlock, self).__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(input_channels, input_channels // reduction_ratio)
self.fc2 = nn.Linear(input_channels // reduction_ratio, input_channels)
self._init_weights()
def forward(self, x):
batch_size, num_channels, _, _ = x.size()
y = self.pool(x).reshape(batch_size, num_channels)
y = F.relu(self.fc1(y))
y = torch.tanh(self.fc2(y))
y = y.reshape(batch_size, num_channels, 1, 1)
return x * y
def _init_weights(self):
init.kaiming_uniform_(self.fc1.weight, a=0, mode='fan_in', nonlinearity='relu')
init.kaiming_uniform_(self.fc2.weight, a=0, mode='fan_in', nonlinearity='relu')
init.constant_(self.fc1.bias, 0)
init.constant_(self.fc2.bias, 0)
class MSEFBlock(nn.Module):
def __init__(self, filters):
super(MSEFBlock, self).__init__()
self.layer_norm = LayerNormalization(filters)
self.depthwise_conv = nn.Conv2d(filters, filters, kernel_size=3, padding=1, groups=filters)
self.se_attn = SEBlock(filters)
self._init_weights()
def forward(self, x):
x_norm = self.layer_norm(x)
x1 = self.depthwise_conv(x_norm)
x2 = self.se_attn(x_norm)
x_fused = x1 * x2
x_out = x_fused + x
return x_out
def _init_weights(self):
init.kaiming_uniform_(self.depthwise_conv.weight, a=0, mode='fan_in', nonlinearity='relu')
init.constant_(self.depthwise_conv.bias, 0)
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super(MultiHeadSelfAttention, self).__init__()
self.embed_size = embed_size
self.num_heads = num_heads
assert embed_size % num_heads == 0
self.head_dim = embed_size // num_heads
self.query_dense = nn.Linear(embed_size, embed_size)
self.key_dense = nn.Linear(embed_size, embed_size)
self.value_dense = nn.Linear(embed_size, embed_size)
self.combine_heads = nn.Linear(embed_size, embed_size)
self._init_weights()
def split_heads(self, x, batch_size):
x = x.reshape(batch_size, -1, self.num_heads, self.head_dim)
return x.permute(0, 2, 1, 3)
def forward(self, x):
batch_size, _, height, width = x.size()
x = x.reshape(batch_size, height * width, -1)
query = self.split_heads(self.query_dense(x), batch_size)
key = self.split_heads(self.key_dense(x), batch_size)
value = self.split_heads(self.value_dense(x), batch_size)
attention_weights = F.softmax(torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5), dim=-1)
attention = torch.matmul(attention_weights, value)
attention = attention.permute(0, 2, 1, 3).contiguous().reshape(batch_size, -1, self.embed_size)
output = self.combine_heads(attention)
return output.reshape(batch_size, height, width, self.embed_size).permute(0, 3, 1, 2)
def _init_weights(self):
init.xavier_uniform_(self.query_dense.weight)
init.xavier_uniform_(self.key_dense.weight)
init.xavier_uniform_(self.value_dense.weight)
init.xavier_uniform_(self.combine_heads.weight)
init.constant_(self.query_dense.bias, 0)
init.constant_(self.key_dense.bias, 0)
init.constant_(self.value_dense.bias, 0)
init.constant_(self.combine_heads.bias, 0)
class Denoiser(nn.Module):
def __init__(self, num_filters, kernel_size=3, activation='relu'):
super(Denoiser, self).__init__()
self.conv1 = nn.Conv2d(1, num_filters, kernel_size=kernel_size, padding=1)
self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=kernel_size, stride=2, padding=1)
self.conv3 = nn.Conv2d(num_filters, num_filters, kernel_size=kernel_size, stride=2, padding=1)
self.conv4 = nn.Conv2d(num_filters, num_filters, kernel_size=kernel_size, stride=2, padding=1)
self.bottleneck = MultiHeadSelfAttention(embed_size=num_filters, num_heads=4)
self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
self.up3 = nn.Upsample(scale_factor=2, mode='nearest')
self.up4 = nn.Upsample(scale_factor=2, mode='nearest')
self.output_layer = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=1)
self.res_layer = nn.Conv2d(num_filters, 1, kernel_size=kernel_size, padding=1)
self.activation = getattr(F, activation)
self._init_weights()
def forward(self, x):
x1 = self.activation(self.conv1(x))
x2 = self.activation(self.conv2(x1))
x3 = self.activation(self.conv3(x2))
x4 = self.activation(self.conv4(x3))
x = self.bottleneck(x4)
x = self.up4(x)
x = self.up3(x + x3)
x = self.up2(x + x2)
x = x + x1
x = self.res_layer(x)
return torch.tanh(self.output_layer(x + x))
def _init_weights(self):
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.output_layer, self.res_layer]:
init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='relu')
if layer.bias is not None:
init.constant_(layer.bias, 0)
class LYT(nn.Module):
def __init__(self, filters=32):
super(LYT, self).__init__()
self.process_y = self._create_processing_layers(filters)
self.process_cb = self._create_processing_layers(filters)
self.process_cr = self._create_processing_layers(filters)
self.denoiser_cb = Denoiser(filters // 2)
self.denoiser_cr = Denoiser(filters // 2)
self.lum_pool = nn.MaxPool2d(8)
self.lum_mhsa = MultiHeadSelfAttention(embed_size=filters, num_heads=4)
self.lum_up = nn.Upsample(scale_factor=8, mode='nearest')
self.lum_conv = nn.Conv2d(filters, filters, kernel_size=1, padding=0)
self.ref_conv = nn.Conv2d(filters * 2, filters, kernel_size=1, padding=0)
self.msef = MSEFBlock(filters)
self.recombine = nn.Conv2d(filters * 2, filters, kernel_size=3, padding=1)
self.final_adjustments = nn.Conv2d(filters, 3, kernel_size=3, padding=1)
self._init_weights()
def _create_processing_layers(self, filters):
return nn.Sequential(
nn.Conv2d(1, filters, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def _rgb_to_ycbcr(self, image):
r, g, b = image[:, 0, :, :], image[:, 1, :, :], image[:, 2, :, :]
y = 0.299 * r + 0.587 * g + 0.114 * b
u = -0.14713 * r - 0.28886 * g + 0.436 * b + 0.5
v = 0.615 * r - 0.51499 * g - 0.10001 * b + 0.5
yuv = torch.stack((y, u, v), dim=1)
return yuv
def forward(self, inputs):
ycbcr = self._rgb_to_ycbcr(inputs)
y, cb, cr = torch.split(ycbcr, 1, dim=1)
cb = self.denoiser_cb(cb) + cb
cr = self.denoiser_cr(cr) + cr
y_processed = self.process_y(y)
cb_processed = self.process_cb(cb)
cr_processed = self.process_cr(cr)
ref = torch.cat([cb_processed, cr_processed], dim=1)
lum = y_processed
lum_1 = self.lum_pool(lum)
lum_1 = self.lum_mhsa(lum_1)
lum_1 = self.lum_up(lum_1)
lum = lum + lum_1
ref = self.ref_conv(ref)
shortcut = ref
ref = ref + 0.2 * self.lum_conv(lum)
ref = self.msef(ref)
ref = ref + shortcut
recombined = self.recombine(torch.cat([ref, lum], dim=1))
output = self.final_adjustments(recombined)
return torch.sigmoid(output)
def _init_weights(self):
for module in self.children():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
init.kaiming_uniform_(module.weight, a=0, mode='fan_in', nonlinearity='relu')
if module.bias is not None:
init.constant_(module.bias, 0)
第四步:运行
第五步:整个工程的内容
项目完整文件下载请见演示与介绍视频的简介处给出:➷➷➷
PyTorch框架——基于深度学习LYT-Net神经网络AI低光图像增强系统源码_哔哩哔哩_bilibili
原文地址:https://blog.csdn.net/m0_59023219/article/details/144783085
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!