YOLO11改进|注意力机制篇|引入轴向注意力Axial Attention
目录
一、【Axial Attention】注意力机制
1.1【Axial Attention】注意力介绍
下图是【Axial Attention】的结构图,让我们简单分析一下运行过程和优势
处理过程:
- 按通道进行处理:
- 图片展示了对图像不同颜色通道(如绿色 G 和蓝色 B)的分离处理。这一步表示每个通道单独执行处理流程,确保每个通道的特征得到充分学习。
- 按行和按列采样:
- 处理过程中,首先针对每一行执行一次采样操作,接着对每一列进行采样。采样操作的目的是从较大的图像中提取关键的特征位置,从而减少计算量,同时保留重要的空间特征。
- 移位与填充:
- 采样后,图像块会被向下移动一个像素,并在顶部进行填充操作。这样做的目的是通过移位对局部特征进行整合,使得每个像素的上下文信息得到更广泛的捕捉。
- 最终拼接输出:
- 经过逐行、逐列和位置采样处理后,图像块会被重组,最后得到输出。这种处理方式不仅增强了图像的特征表示,还大幅降低了计算成本。
优势: - 计算效率高:
- 通过分块采样和局部处理,这种方法显著减少了每次操作需要处理的像素数量,从而加快了计算速度。特别是在处理大图像时,这种方法能够更好地利用计算资源。
- 分辨率增强:
- 移位与填充操作使得图像的局部上下文信息被有效捕获,有助于增强图像细节处理。每个像素点不仅能够考虑自身,还能够结合其上下文信息进行处理,从而提高图像的分辨率表现。
- 灵活的通道处理:
- 不同颜色通道的分离处理允许模型对每个通道的特征进行独立学习。这种方法使得模型能够更加专注于各个通道的特定信息,从而提升模型对颜色和纹理细节的捕捉能力。
- 自适应性强:
- 通过对不同图像区域的逐步采样和移位操作,这种方法具有较强的自适应性,适合处理不同分辨率和大小的图像。此外,采样过程中能够根据实际需求动态调整采样频率,进一步提升处理的精度与效率。
1.2【Axial Attention】核心代码
import torch
from torch import nn
from operator import itemgetter
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng=False, set_rng=False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# once multi-GPU is confirmed working, refactor and send PR back to source
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, f_args={}, g_args={}):
x1, x2 = torch.chunk(x, 2, dim=1)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=1)
def backward_pass(self, y, dy, f_args={}, g_args={}):
y1, y2 = torch.chunk(y, 2, dim=1)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=1)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=1)
dx = torch.cat([dx1, dx2], dim=1)
return x, dx
class IrreversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = f
self.g = g
def forward(self, x, f_args, g_args):
x1, x2 = torch.chunk(x, 2, dim=1)
y1 = x1 + self.f(x2, **f_args)
y2 = x2 + self.g(y1, **g_args)
return torch.cat([y1, y2], dim=1)
class _ReversibleFunction(Function):
@staticmethod
def forward(ctx, x, blocks, kwargs):
ctx.kwargs = kwargs
for block in blocks:
x = block(x, **kwargs)
ctx.y = x.detach()
ctx.blocks = blocks
return x
@staticmethod
def backward(ctx, dy):
y = ctx.y
kwargs = ctx.kwargs
for block in ctx.blocks[::-1]:
y, dy = block.backward_pass(y, dy, **kwargs)
return dy, None, None
class ReversibleSequence(nn.Module):
def __init__(
self,
blocks,
):
super().__init__()
self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
def forward(self, x, arg_route=(True, True), **kwargs):
f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)
block_kwargs = {"f_args": f_args, "g_args": g_args}
x = torch.cat((x, x), dim=1)
x = _ReversibleFunction.apply(x, self.blocks, block_kwargs)
return torch.stack(x.chunk(2, dim=1)).mean(dim=0)
# helper functions
def exists(val):
return val is not None
def map_el_ind(arr, ind):
return list(map(itemgetter(ind), arr))
def sort_and_return_indices(arr):
indices = [ind for ind in range(len(arr))]
arr = zip(arr, indices)
arr = sorted(arr)
return map_el_ind(arr, 0), map_el_ind(arr, 1)
# calculates the permutation to bring the input tensor to something attend-able
# also calculates the inverse permutation to bring the tensor back to its original shape
def calculate_permutations(num_dimensions, emb_dim):
total_dimensions = num_dimensions + 2
emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions)
axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]
permutations = []
for axial_dim in axial_dims:
last_two_dims = [axial_dim, emb_dim]
dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
permutation = [*dims_rest, *last_two_dims]
permutations.append(permutation)
return permutations
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class Sequential(nn.Module):
def __init__(self, blocks):
super().__init__()
self.blocks = blocks
def forward(self, x):
for f, g in self.blocks:
x = x + f(x)
x = x + g(x)
return x
class PermuteToFrom(nn.Module):
def __init__(self, permutation, fn):
super().__init__()
self.fn = fn
_, inv_permutation = sort_and_return_indices(permutation)
self.permutation = permutation
self.inv_permutation = inv_permutation
def forward(self, x, **kwargs):
axial = x.permute(*self.permutation).contiguous()
shape = axial.shape
*_, t, d = shape
# merge all but axial dimension
axial = axial.reshape(-1, t, d)
# attention
axial = self.fn(axial, **kwargs)
# restore to original shape and permutation
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
return axial
# axial pos emb
class AxialPositionalEmbedding(nn.Module):
def __init__(self, dim, shape, emb_dim_index=1):
super().__init__()
parameters = []
total_dimensions = len(shape) + 2
ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
self.num_axials = len(shape)
for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
shape = [1] * total_dimensions
shape[emb_dim_index] = dim
shape[axial_dim_index] = axial_dim
parameter = nn.Parameter(torch.randn(*shape))
setattr(self, f"param_{i}", parameter)
def forward(self, x):
for i in range(self.num_axials):
x = x + getattr(self, f"param_{i}")
return x
# attention
class SelfAttention(nn.Module):
def __init__(self, dim, heads, dim_heads=None):
super().__init__()
self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
dim_hidden = self.dim_heads * heads
self.heads = heads
self.to_q = nn.Linear(dim, dim_hidden, bias=False)
self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias=False)
self.to_out = nn.Linear(dim_hidden, dim)
def forward(self, x, kv=None):
kv = x if kv is None else kv
q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))
b, t, d, h, e = *q.shape, self.heads, self.dim_heads
merge_heads = (
lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
)
q, k, v = map(merge_heads, (q, k, v))
dots = torch.einsum("bie,bje->bij", q, k) * (e**-0.5)
dots = dots.softmax(dim=-1)
out = torch.einsum("bij,bje->bie", dots, v)
out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
out = self.to_out(out)
return out
# axial attention class
class AxialAttention(nn.Module):
def __init__(
self,
dim,
num_dimensions=2,
heads=8,
dim_heads=None,
dim_index=-1,
sum_axial_out=True,
):
assert (
dim % heads
) == 0, "hidden dimension must be divisible by number of heads"
super().__init__()
self.dim = dim
self.total_dimensions = num_dimensions + 2
self.dim_index = (
dim_index if dim_index > 0 else (dim_index + self.total_dimensions)
)
attentions = []
for permutation in calculate_permutations(num_dimensions, dim_index):
attentions.append(
PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads))
)
self.axial_attentions = nn.ModuleList(attentions)
self.sum_axial_out = sum_axial_out
def forward(self, x):
assert (
len(x.shape) == self.total_dimensions
), "input tensor does not have the correct number of dimensions"
assert (
x.shape[self.dim_index] == self.dim
), "input tensor does not have the correct input dimension"
if self.sum_axial_out:
return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))
out = x
for axial_attn in self.axial_attentions:
out = axial_attn(out)
return out
# axial image transformer
class AxialImageTransformer(nn.Module):
def __init__(
self,
dim,
depth,
heads=8,
dim_heads=None,
dim_index=1,
reversible=True,
axial_pos_emb_shape=None,
):
super().__init__()
permutations = calculate_permutations(2, dim_index)
get_ff = lambda: nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, dim * 4, 3, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(dim * 4, dim, 3, padding=1),
)
self.pos_emb = (
AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index)
if exists(axial_pos_emb_shape)
else nn.Identity()
)
layers = nn.ModuleList([])
for _ in range(depth):
attn_functions = nn.ModuleList(
[
PermuteToFrom(
permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads))
)
for permutation in permutations
]
)
conv_functions = nn.ModuleList([get_ff(), get_ff()])
layers.append(attn_functions)
layers.append(conv_functions)
execute_type = ReversibleSequence if reversible else Sequential
self.layers = execute_type(layers)
def forward(self, x):
x = self.pos_emb(x)
return self.layers(x)
if __name__ == "__main__":
input = torch.rand(3, 64, 32, 32).cuda()
model = AxialImageTransformer(dim=64, depth=12, reversible=True).cuda()
output = model(input)
print(input.size(), output.size())
二、添加【Axial Attention】注意力机制
2.1STEP1
首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个Axial_Attention.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示
2.2STEP2
在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示
2.3STEP3
找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加
2.4STEP4
定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】
三、yaml文件与运行
3.1yaml文件
以下是添加【Axial Attention】注意力机制在小目标检测层中中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
# YOLO11n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128,3,2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256,3,2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512,3,2]] # 5-P4/16
- [-1, 2, C3k2, [512, True]]
- [-1, 1, Conv, [1024,3,2]] # 7-P5/32
- [-1, 2, C3k2, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 2, C2PSA, [1024]] # 10
# YOLO11n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, C3k2, [512, False]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
- [-1,1,AxialImageTransformer,[1]]
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
- [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)
以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准
3.2运行成功截图
OK 以上就是添加【Axial Attention】注意力机制的全部过程了,后续将持续更新尽情期待
原文地址:https://blog.csdn.net/A1983Z/article/details/142844140
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!