Yolov8 目标检测剪枝学习记录
最近在进行YOLOv8系列的轻量化,目前在网络结构方面的优化已经接近极限了,所以想要学习一下模型剪枝是否能够进一步优化模型的性能
这里主要参考了torch-pruning的基本使用,v8模型剪枝,Jetson nano部署剪枝YOLOv8
下面只是记录一个简单流程,用于后续使用在自己的任务和网络中,数据不作为参考
首先训练一个base模型用于参考
- 环境:Ultralytics YOLOv8.2.18 🚀 Python-3.10.14 torch-2.4.0 CUDA:0 (NVIDIA H100 PCIe, 81008MiB)
- 训练代码
参考网上或者自己写一个能训练即可,为了方便我将通用的记录下来,实测可用来自代码来源
from ultralytics import YOLO
import os
root = os.getcwd()
## 配置文件路径
name_yaml = os.path.join(root, "ultralytics/datasets/VOC.yaml")
name_pretrain = os.path.join(root, "yolov8s.pt")
## 原始训练路径
path_train = os.path.join(root, "runs/detect/VOC")
name_train = os.path.join(path_train, "weights/last.pt")
## 约束训练路径、剪枝模型文件
path_constraint_train = os.path.join(root, "runs/detect/VOC_Constraint")
name_prune_before = os.path.join(path_constraint_train, "weights/last.pt")
name_prune_after = os.path.join(path_constraint_train, "weights/last_prune.pt")
## 微调路径
path_fineturn = os.path.join(root, "runs/detect/VOC_finetune")
def else_api():
path_data = ""
path_result = ""
model = YOLO(name_pretrain)
metrics = model.val() # evaluate model performance on the validation set
model.export(format='onnx', opset=11, simplify=True, dynamic=False, imgsz=640)
model.predict(path_data, device="0", save=True, show=False, save_txt=True, imgsz=[288,480], save_conf=True, name=path_result, iou=0.5) # 这里的imgsz为高宽
def step1_train():
model = YOLO(name_pretrain)
model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_train) # train the model
## 2024.3.4添加【amp=False】
def step2_Constraint_train():
model = YOLO(name_train)
model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, amp=False, workers=16, save_period=1,name=path_constraint_train) # train the model
def step3_pruning():
from LL_pruning import do_pruning
do_pruning(name_prune_before, name_prune_after)
def step4_finetune():
model = YOLO(name_prune_after) # load a pretrained model (recommended for training)
model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_fineturn) # train the model
step1_train()
# step2_Constraint_train()
# step3_pruning()
# step4_finetune()
第一步,step1_train()
- 即训练一个base模型,用于最后性能好坏的重要参考
第二步,step2_Constraint_train()
训练之前在ultralytics\engine\trainer.py添加bn的L1正则,使得bn参数在训练时变得稀疏
- 通过对参数的绝对值进行惩罚,使得一些不重要的权重变为零,从而实现模型的稀疏化和简化
# Backward
self.scaler.scale(self.loss).backward()
## add new code=============================duj
## add l1 regulation for step2_Constraint_train
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
self.optimizer_step()
last_opt_step = ni
- 个人理解的稀疏化作用
- 通过对 gamma 和 beta 添加 L1 正则化,可以促使某些通道的 BN 权重变得非常小,甚至为零。这意味着在剪枝时,可以将这些通道从模型中移除
- 通过稀疏化 BN 层并剪除不重要的通道,剩下的通道会更有效地利用计算资源,减少无用计算。
第三步,step3_pruning()剪枝操作
LL_pruning.py
from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
import os
class PRUNE():
def __init__(self) -> None:
self.threshold = None
def get_threshold(self, model, factor=0.8):
ws = []
bs = []
for name, m in model.named_modules():
if isinstance(m, torch.nn.BatchNorm2d):
w = m.weight.abs().detach()
b = m.bias.abs().detach()
ws.append(w)
bs.append(b)
print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
print()
# keep
ws = torch.cat(ws)
self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
def prune_conv(self, conv1: Conv, conv2: Conv):
## a. 根据BN中的参数,获取需要保留的index================
gamma = conv1.bn.weight.data.detach()
beta = conv1.bn.bias.data.detach()
keep_idxs = []
local_threshold = self.threshold
while len(keep_idxs) < 8: ## 若剩余卷积核<8, 则降低阈值重新筛选
keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
local_threshold = local_threshold * 0.5
n = len(keep_idxs)
# n = max(int(len(idxs) * 0.8), p)
print(n / len(gamma) * 100)
# scale = len(idxs) / n
## b. 利用index对BN进行剪枝============================
conv1.bn.weight.data = gamma[keep_idxs]
conv1.bn.bias.data = beta[keep_idxs]
conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
conv1.bn.num_features = n
conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
conv1.conv.out_channels = n
## c. 利用index对conv1进行剪枝=========================
if conv1.conv.bias is not None:
conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
## d. 利用index对conv2进行剪枝=========================
if not isinstance(conv2, list):
conv2 = [conv2]
for item in conv2:
if item is None: continue
if isinstance(item, Conv):
conv = item.conv
else:
conv = item
conv.in_channels = n
conv.weight.data = conv.weight.data[:, keep_idxs]
def prune(self, m1, m2):
if isinstance(m1, C2f): # C2f as a top conv
m1 = m1.cv2
if not isinstance(m2, list): # m2 is just one module
m2 = [m2]
for i, item in enumerate(m2):
if isinstance(item, C2f) or isinstance(item, SPPF):
m2[i] = item.cv1
self.prune_conv(m1, m2)
def do_pruning(modelpath, savepath):
pruning = PRUNE()
### 0. 加载模型
yolo = YOLO(modelpath) # build a new model from scratch
pruning.get_threshold(yolo.model, 0.8) # 获取剪枝时bn参数的阈值,这里的0.8为剪枝率。
### 1. 剪枝c2f 中的Bottleneck
for name, m in yolo.model.named_modules():
if isinstance(m, Bottleneck):
pruning.prune_conv(m.cv1, m.cv2)
### 2. 指定剪枝不同模块之间的卷积核
seq = yolo.model.model
for i in [3,5,7,8]:
pruning.prune(seq[i], seq[i+1])
### 3. 对检测头进行剪枝
# 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)
# 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1]
# 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2]
detect:Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
pruning.prune(last_input, [colast, cv2[0], cv3[0]])
pruning.prune(cv2[0], cv2[1])
pruning.prune(cv2[1], cv2[2])
pruning.prune(cv3[0], cv3[1])
pruning.prune(cv3[1], cv3[2])
### 4. 模型梯度设置与保存
for name, p in yolo.model.named_parameters():
p.requires_grad = True
yolo.val()
torch.save(yolo.ckpt, savepath)
yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))
yolo.export(format="onnx")
## 重新load模型,修改保存命名,用以比较剪枝前后的onnx的大小
yolo = YOLO(modelpath) # build a new model from scratch
yolo.export(format="onnx")
if __name__ == "__main__":
modelpath = "runs/detect1/14_Constraint/weights/last.pt"
savepath = "runs/detect1/14_Constraint/weights/last_prune.pt"
do_pruning(modelpath, savepath)
- 如下图可用看到剪枝前后还是有区别的,参数量减少很多,网络性能将不可用,需要微调恢复精度
- 查看剪枝前后模型大小
du -sh ./runs/detect/VOC_Constraint/weights/last*
(yolov8n模型)
微调
该部分内容我也存在一些疑问,例如很多博主让ultralytics\engine\trainer.py添加加载模型代码,经过我8.2版本测试代码添加是完全失效的,因为setup_model在执行if isinstance(self.model, torch.nn.Module)就已经return了。
def setup_model(self):
"""Load/create/download model for any task."""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
- 例如ultralytics\engine\trainer.py
- v8…x添加代码:548行 参考这里
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)
# duj add code to finetune
self.model = weights
return ckpt
- 如果是v8.0.x 参考这里
在看到这篇中的修改1启发
- v8.2.x上面我不确定是哪个版本需要添加的,但是我实测都不起作用
- 我尝试在
ultralytics\engine\model.py
添加如下代码加载模型成功
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
if not args.get("resume"): # manually set model only if not resuming
# self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
# self.model = self.trainer.model
# dujiang edit
self.trainer.model = self.model.train()
if SETTINGS["hub"] is True and not self.session:
- 这里就是确保自己加载的是剪枝后的模型,但是不同版本好像不同,后续在探究原因。。。
- 这里有个小插曲,我在使用自己模型稀疏训练后剪枝发现(步骤2)发现BN层全没了,这里后面我将别人的稀疏训练的v8s模型拿来进行剪枝就没问题
- 可能是v8n的问题,也可能是我训练的问题,这里先不做深究继续查看剪枝是否成功且微调加载成功后能否恢复精度
- 此时多次尝试我基本确定微调加载的是我剪枝后的模型,接下来就是等待训练结果是否参数量正确了。
总结
总的来说跑通整个流程了,接下来尝试在自己的任务和数据上面进行剪枝,看看更换了模型结构又会有哪些坑等着我
原文地址:https://blog.csdn.net/qq_40938217/article/details/145166612
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!