自学内容网 自学内容网

【论文笔记】Compact Language Models via Pruning and Knowledge Distillation

Abstract

目前,不同规模和尺寸的语言模型(LLM)通常是通过从头开始训练每个变体来生成的,这在计算上非常密集。
本文探讨了是否可以通过修剪现有的语言模型,然后仅使用原始训练数据的一小部分(❤️%)进行重新训练,作为替代全面重新训练的可行方案。

本文开发了一套实用有效的LLM压缩最佳实践,结合了深度、宽度、注意力和多层感知机(MLP)修剪与基于知识蒸馏的重新训练;这些最佳实践通过对每个维度的修剪策略、维度组合方法、蒸馏策略和搜索技术的详细实证探索得出。
使用这一方法将Nemotron-4系列LLM的规模压缩了2-4倍,并将其在多种语言建模任务中的表现与相似规模的模型进行了比较。通过本文的方法,从已预训练的15B模型派生出8B和4B模型所需的训练数据量比从头开始训练少了最多40倍;这使得整个模型系列(15B、8B和4B)的计算成本节省了1.8倍。Minitron模型在MMLU分数上较从头训练提高了最多16%,与Mistral 7B、Gemma 7B和Llama-3 8B等其他社区模型表现相当,并且在与文献中的最先进压缩技术对比中表现优异。

1 Introduction

大语言模型(Large Language Models,LLMs)现在主导了现实世界的自然语言处理,并在理解困难语境方面表现出了出色的熟练程度。
但从头开始训练多个百亿参数模型是极其消费时间、数据和资源。
本文提出问题:我们可否训练一个大的模型,并通过权重剪枝和再训练的组合,在只使用原始训练数据的一小部分的情况下,从中获得相对于从无到有的训练来说更小、更准确的模型?

权重剪枝(Weight Pruning)是一种强大的、众所周知的减少模型大小的技术。本文聚焦于结构化剪枝,其中一次性从模型权重中移除一块非零元素;结构化剪枝技术的例子包括神经元剪枝、注意力头剪枝、卷积滤波器剪枝和深度剪枝。
尽管文献中有许多关于结构化剪枝的研究,但对于最终用户来说,通常不清楚应该选择哪种技术、何时使用以及如何将它们结合起来以始终获得较好的剪枝模型。剪枝通常还伴随着一定程度的再训练以恢复准确度。在现代LLMs中,这一阶段的成本非常高,通常需要大量精心整理的数据。
目前没有关于结构化剪枝的研究探讨如何利用数据高效的再训练技术,如知识蒸馏,来减少再训练的成本。

2 Pruning Methodology

![[Pasted image 20241208134112.png]]

图2:迭代剪枝与蒸馏方法概览:本文在一个预训练的大型语言模型(LLM)上1)评估神经元的重要性,对其排序,2)剪除最不重要的神经元,3)将原始模型的知识蒸馏到剪枝后的模型中。在下一轮压缩中,原始模型会被蒸馏后的模型取代,从而逐步训练出一系列更小的LLM。

如图2所示,首先通过计算每一层、每个神经元、注意力头和嵌入维度的重要性来启动剪枝过程,然后对这些重要性分数进行排序以生成相应的重要性排名。本节将详细说明如何为每个剪枝维度计算排名,并如何利用这些排名获得一个剪枝后的模型。

2.1 Background and Notation

先从一些正式定义开始。多层感知机(MLP)层包含两个线性层,中间插入一个非线性激活函数:
MLP ( X ) = δ ( X ⋅ W 1 T ) ⋅ W 2 \text{MLP}(X)=\delta(X\cdot W_1^T)\cdot W_2 MLP(X)=δ(XW1T)W2
其中 X X X表示输入, W 1 , W 2 W_1,W_2 W1,W2表示MLP层中两个相关联的权重矩阵, W 1 , W 2 ∈ R d hidden × d model W_1,W_2\in\mathbb{R}^{d_\text{hidden}\times d_\text{model}} W1,W2Rdhidden×dmodel,其中 d model , d hidden d_\text{model},d_\text{hidden} dmodel,dhidden分别是嵌入维度和MLP隐藏维度。 δ ( ⋅ ) \delta(\cdot) δ()表示非线性的激活函数。

多头注意力(Multi-Head Attention, MHA)操作针对输入 X X X的定义如下:
MHA ( X ) = Concat ( head 1 , ⋯   , head L ) ⋅ W O \text{MHA}(X)=\text{Concat}(\text{head}_1,\cdots,\text{head}_L)\cdot W^O MHA(X)=Concat(head1,,headL)WO
head i = Attn ( X W Q , i , X W K , i , X W V , i ) \text{head}_i=\text{Attn}(XW^{Q,i},XW^{K,i},XW^{V,i}) headi=Attn(XWQ,i,XWK,i,XWV,i)
其中 W Q , i , W K , i , W V , i ∈ R d head × d model W^{Q,i},W^{K,i},W^{V,i}\in\mathbb{R}^{d_\text{head}\times d_\text{model}} WQ,i,WK,i,WV,iRdhead×dmodel W O ∈ R L × d head × d models W^O\in\mathbb{R}^{L\times d_\text{head}\times d_\text{models}} WORL×dhead×dmodels d head d_\text{head} dhead表示单头注意力的大小, L L L是头的数量。

最后,层归一化(LayerNorm)操作对输入 X X X的定义如下:
L N ( X ) = X − μ σ 2 + ϵ ⊙ γ + β LN(X)=\frac{X-\mu}{\sqrt{\sigma^2+\epsilon}}\odot\gamma+\beta LN(X)=σ2+ϵ Xμγ+β
其中 μ \mu μ σ 2 \sigma^2 σ2代表嵌入维度的均值和方差, ϵ \epsilon ϵ是保持数值稳定的很小的值, γ , β \gamma,\beta γ,β是可学习的参数。

2.2 Importance Analysis

估计神经网络各组成部分的重要性或敏感性(例如神经元、注意力头和层)是一个研究充分的领域。
在LLM的背景下,近期研究指出传统指标(如权重大小)在估计重要性方面效果不佳;相反,近期研究针对LLMs的结构化剪枝研究聚焦于梯度/泰勒展开、余弦相似度以及校准数据集上的困惑度等指标。

由于现代LLMs体积庞大,计算其梯度信息的内存和计算成本高,因此一项主要目标是在获取重要性信息时避免这一昂贵的步骤。
本文提出了一种完全基于激活的重要性估计策略,利用小型校准数据集(1024个样本)并仅通过前向传播,同时计算我们考虑的所有维度(深度、神经元、注意力头和嵌入通道)的敏感性信息。
下面描述如何针对每个维度实现这一策略。

Width:通过检查MHA、MLP和LayerNorm层产生的激活,分别计算每个注意力头、神经元和嵌入通道的重要性。为此,我们使用一个小型校准数据集D。正式地,计算通过激活计算注意力头、神经元和嵌入通道的重要性分数,具体如下:
F head ( i ) = ∑ B , S ∣ ∣ Attn ( X W Q , i , X W K , i , X W V , i ) ∣ ∣ 2 F_\text{head}^{(i)}=\sum_{B,S} ||\text{Attn}(XW^{Q,i},XW^{K,i},XW^{V,i})||_2 Fhead(i)=B,S∣∣Attn(XWQ,i,XWK,i,XWV,i)2
F neuron ( i ) = ∑ B , S X ( W 1 i ) T ,   F emb ( i ) = ∑ B , S L N ( X ) i F_\text{neuron}^{(i)}=\sum_{B,S} X(W_1^i)^T,\ F_\text{emb}^{(i)}=\sum_{B,S} LN(X)_i Fneuron(i)=B,SX(W1i)T, Femb(i)=B,SLN(X)i
其中, W 1 i W_1^i W1i代表权重矩阵 W 1 W_1 W1的第 i i i行。 ∑ B , S \sum_{B,S} B,S代表沿批次和序列维度的聚合。
对于聚合方法,实验表明直接求和不是最优,对聚合的方法进行了实验,见表13。

Depth (Layers):对于深度剪枝,使用两种指标评估每一层的重要性:

  • 困惑度(PPL)
  • 块重要性(BI)
    对于基于PPL的排名,我们简单地移除一个层,并计算其对该剪枝模型困惑度的影响;这作为该层的“重要性”或敏感性。BI通过计算层的输入和输出之间的余弦距离来估计层的敏感性。第 i i i的BI得分计算如下:
    BI i = 1 − E X , t X i , t T X i + 1 , t ∣ ∣ X i , t ∣ ∣ 2 ∣ ∣ X i + 1 , t ∣ ∣ 2 \text{BI}_i=1-\mathbb{E}_{X,t}\frac{X_{i,t}^T X_{i+1,t}}{||X_{i,t}||_2||X_{i+1,t}||_2} BIi=1EX,t∣∣Xi,t2∣∣Xi+1,t2Xi,tTXi+1,t
    其中 X i X_i Xi表示第 i i i层的输入, X i , t X_{i,t} Xi,t表示 X i X_i Xi的第 t t t层输入。所有层的BI都可以仅使用前向传播得到,这使其在速度上相较于基于PPL的重要性具有显著优势。此外,按照Gromov等人的做法,可以扩展BI,以同时估计多个连续层的重要性。

Iterative Importance:在这种设置下,本文在给定的维度或维度组合上,迭代地交替进行剪枝和重要性估计。
给定迭代次数 T T T,来源与目标的维度(层、头等) d s , d t d_s,d_t ds,dt,迭代地计算 d s − i ⋅ ( d s − d t T ) d_s-i\cdot (\frac{d_s-d_t}{T}) dsi(Tdsdt)维上的重要性,并将其剪枝到 d s − ( i + 1 ) ⋅ ( d s − d t T ) d_s-(i+1)\cdot(\frac{d_s-d_t}{T}) ds(i+1)(Tdsdt)维。

2.3 Obtaining a Pruned Model

图2概述了如何获得剪枝模型。对于给定的架构配置,首先根据计算得到的重要性对每个维度的元素进行排名,并直接对相应的权重矩阵进行剪枝。对于神经元和注意力头剪枝,分别剪裁MLP和MHA层的权重。对于嵌入通道,剪裁MLP、MHA和LayerNorm层中权重矩阵的嵌入维度。
在剪枝注意力头时,将剪枝掉的头部的残差信息重新加入到剩余的头部中,目的是保留来自剪枝头部的相关知识提高了模型的准确性。

具体地,给定 L L L个原始注意力头 head 1 , head 2 , ⋯   , head L \text{head}_1,\text{head}_2,\cdots,\text{head}_L head1,head2,,headL,要被剪枝成 K K K个头,每个头都具有形式(对于第 i i i个头): head i + ( head i − head 2 K − i + 1 ) \text{head}_i+(\text{head}_i-\text{head}_{2K-i+1}) headi+(headihead2Ki+1) i ∈ [ K − ( L − K ) , K ] i\in[K-(L-K),K] i[K(LK),K]

Lightweight Neural Architecture Search:图3概述寻找最佳架构配置的搜索策略。
![[Pasted image 20241208172541.png]]

图3:神经架构搜索算法概述。本文在多个维度上进行搜索:层数、注意力头数量、MLP和嵌入维度,以得到一组符合参数预算的可行架构。RT指重训练。

给定一个搜索空间和参数预算(图的左侧),枚举了所有满足参数预算的可行架构。在这一阶段,虽然可以通过遗传搜索和/或贝叶斯优化等策略进一步缩小搜索空间大小,但我们发现,坚持常用的神经元、头和嵌入维度,以及合理窄的目标参数范围(不到10亿),足以获得易于处理的解集(候选人不足20人)。然后对可行候选者进行轻量级再训练(本工作中的1.8 B令牌)。在图9中显示,这个再训练阶段稳定了相对排名,并帮助我们找到一个更准确的候选人来进一步训练。注意到参数有效的微调技术,例如LoRA也可以应用在这个阶段;将对此类技术的探索留给未来的工作。

3 Retraining

本文使用"再训练"一词来指修剪后的精度恢复过程,探索了两种再训练策略:(1)常规训练,利用真实标签;(2)使用未剪枝模型(教师)的监督进行知识蒸馏。
未压缩和剪枝后的模型分别对应教师模型和学生模型。

给定token x i x_i xi,LLM的概率分布可以计算为:
p ( x i , τ ) = exp ⁡ ( x i τ ) ∑ j = 1 ∣ V ∣ exp ⁡ ( x j τ ) p(x_i,\tau)=\frac{\exp(\frac{x_i}{\tau})}{\sum_{j=1}^{|V|}\exp(\frac{x_j}{\tau})} p(xi,τ)=j=1Vexp(τxj)exp(τxi)
其中 τ \tau τ是softmax温度, ∣ V ∣ |V| V是词汇库(vocabulary)的大小。Logit-based KD loss可以表示为:
L logits = 1 l ∑ k = 1 l Loss ( p t k ( x , τ ) , p s k ( x , τ ) ) L_\text{logits}=\frac{1}{l}\sum_{k=1}^l \text{Loss}(p_t^k(x,\tau),p_s^k(x,\tau)) Llogits=l1k=1lLoss(ptk(x,τ),psk(x,τ))
其中 p t k ( x , τ ) p_t^k(x,\tau) ptk(x,τ) p s k ( x , τ ) p_s^k(x,\tau) psk(x,τ)表示第 k k k个token下教师和学生的分布, l l l代表序列长度。

![[Pasted image 20241208175604.png]]

图4:蒸馏概述。一个具有N层的学生模型通过蒸馏自一个具有M层的教师模型。学生通过最小化嵌入输出损失、logit损失以及映射到学生块S和教师块T之间的Transformer编码器特定损失的组合来学习。

在蒸馏过程中,探索了各种损失函数,并结合了Transformer模型中的多个中间状态和映射作为损失组件,以及它们各自的权衡。图4展示了这一过程。基于中间状态的KD损失,跨越一系列Transformer特定的隐藏状态表示为:
L is = 1 l ∑ k ∈ H ∑ i = 1 l Loss k ( h t k i , h s k i ) L_\text{is}=\frac{1}{l}\sum_{k\in H}\sum_{i=1}^l \text{Loss}_k(h_t^{ki},h_s^{ki}) Lis=l1kHi=1lLossk(htki,hski)
其中 h t k i h_t^{ki} htki h s k i h_s^{ki} hski代表第 i i i个token的第 k k k个隐藏层, l l l代表序列长度, H H H代表被选择的中间状态集合。
学生和教师隐藏状态之间的差异通过在蒸馏过程中学习一个共享的线性变换来处理,以将学生的隐藏状态提升到教师隐藏状态的维度。所使用的隐藏状态总是经过LayerNorm处理后的状态。

总损失为 L = L CLM + L logits + α × L is L=L_\text{CLM}+L_\text{logits}+\alpha\times L_\text{is} L=LCLM+Llogits+α×Lis,其中 L CLM L_\text{CLM} LCLM是与真值标签的交叉熵损失, α \alpha α是权重。


原文地址:https://blog.csdn.net/xhyu61/article/details/144330597

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!