自学内容网 自学内容网

【Datawhale组队学习】模型减肥秘籍:模型压缩技术2——模型剪枝

概念理解

模型剪枝就像是给一个庞大的模型瘦身,让它变得更轻、更快、更高效。剪枝可以让模型变小,计算量减少,从而提高运行速度。这就好比一辆车,越轻跑得越快,消耗的燃料也越少。对于一些需要实时响应的应用,比如手机上的语音助手或自动驾驶,模型运行得越快,效果就越好。小的模型用更少的计算资源,意味着它会更省电。对于一些依靠电池供电的设备来说,模型剪枝可以延长电池的使用时间,这对用户体验非常重要。

关于剪枝,MIT 6.5940课程里有个很形象的比喻:
在这里插入图片描述
刚出生的婴儿每个神经元大约有2500个突触。婴儿的大脑正在快速发育并建立基本的神经连接,以适应新世界中的各种感官输入和体验。而在2-4岁阶段,突触数量急剧增加,达到每个神经元大约15000个突触的峰值。这个阶段是儿童大脑发育最迅速的时期,大量的突触被创建,以支持孩子学习各种技能和理解复杂的外界环境。这种大量突触的形成让孩子能够迅速吸收新知识,就像一个“学习海绵”一样。但随着个体进入青春期和成年阶段,大脑开始经历“突触修剪”(synaptic pruning)。不常用的突触被移除,而常用的突触被强化,使得神经网络更高效。到成年时,突触数量减少到大约7000个每个神经元,但这种精简后的神经连接更加精准和有用。
这与深度学习中的模型剪枝很相似。大脑在早期会创建大量连接,然后通过修剪过程来优化这些连接,使得思维和学习更高效。同样,深度学习模型也可以通过剪枝去除不重要的部分,从而提升运行效率并保持性能。

工业界应用

在这里插入图片描述
剪枝技术在工业界非常实用,利用硬件来支持稀疏性可以显著提高深度神经网络的效率。
EIE(Efficient Inference Engine) 和 ESE(Efficient Speech Recognition Engine):这些硬件架构旨在加速剪枝后的稀疏神经网络的推理。它们通过压缩神经网络模型并优化存储和计算路径来减少延迟和能耗。这些方法使用稀疏格式对权重进行编码,并在计算过程中减少不必要的操作。
剪枝过程从Dense Neural Network(稠密神经网络)开始,模型最早包含完整着的参数和连接。通过剪枝逐步移除不重要的权重或神经元,降低模型的复杂性。然后通过剪枝和微调,恢复或优化模型性能,可以在保留模型准确性的同时显著减少计算资源的消耗。
剪枝可以将模型的复杂性降低5倍到50倍,同时对准确性影响最小。这种压缩和加速对于实际应用非常关键,尤其是在需要高效推理的场景中,如嵌入式设备、物联网、和边缘计算等。

代码实践心得

使用 torch.nn.utils.prune 模块进行模型剪枝,并解释了几种不同的剪枝策略,包括局部剪枝、全局剪枝和自定义剪枝。

  • 随机结构化剪枝:
import torch.nn.utils.prune as prune

module = model.conv1  
prune.random_structured(module, name="weight", amount=2, dim=0)

module = model.conv1:选择第一个卷积层 conv1 进行剪枝。
prune.random_structured():
name=“weight”:指明要剪枝的是 conv1 的 weight 参数,而非 bias。
amount=2:剪掉两个通道的权重。
dim=0:在通道维度进行剪枝,剪掉整个卷积核。

  • 范数结构化剪枝
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

prune.ln_structured():
amount=0.5:剪掉 50% 的通道。
n=2:使用 L2 范数来衡量每个通道的权重大小,并剪掉范数较小的通道。
dim=0:在通道维度上进行剪枝。

  • 随机非结构化剪枝
prune.random_unstructured(module, name="bias", amount=1)

prune.random_unstructured():对 bias 参数进行随机非结构化剪枝。
amount=1:随机剪掉一个偏置参数。

  • 永久化剪枝
prune.remove(module, 'weight')

prune.remove():将剪枝操作永久化。原始权重 weight_orig 被替换成 weight,掩码也被移除,剪枝变成永久效果。

参考文献

  1. https://www.dropbox.com/scl/fi/2oxmtvoeccyuw47yfambb/lec03.pdf?rlkey=3ykm0g21ibsoqn7xnw43v7aaw&e=1&dl=0
  2. https://www.datawhale.cn/learn/content/68/960

原文地址:https://blog.csdn.net/weixin_46319888/article/details/143809696

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!