自学内容网 自学内容网

模型减肥秘籍:模型压缩技术 知识蒸馏

教程链接:模型减肥秘籍:模型压缩技术-课程详情 | Datawhale

知识蒸馏:让AI模型更轻更快

在人工智能快速发展的今天,我们经常需要在资源受限的设备(如手机、IoT设备)上运行AI模型。但这些设备的计算能力和内存都很有限,无法直接运行庞大的AI模型。这就带来了一个重要问题:如何将大模型的能力迁移到小设备上?知识蒸馏(Knowledge Distillation)就是解决这个问题的重要技术之一。

什么是知识蒸馏?

知识蒸馏可以形象地理解为"教师教学生"的过程。大模型(教师模型)将自己学到的"知识"传授给小模型(学生模型),帮助小模型在保持较小体积的同时,获得接近大模型的性能。

这里的"知识"主要包括:

  • 模型的输出概率分布(软标签)
  • 模型中间层的特征
  • 注意力图等信息

知识蒸馏的核心概念

1. 软标签与硬标签

  • 硬标签:传统的分类标签,比如[0,1,0]表示第二类
  • 软标签:模型输出的概率分布,比如[0.1,0.8,0.1],包含更丰富的信息

2. 温度参数

温度参数用于调节概率分布的"软硬程度":

  • 温度越高,分布越平滑
  • 温度越低,分布越接近硬标签
  • 合适的温度可以帮助学生模型更好地学习

下面是一个例子:当输入一张马的图片时,对于未调整温度(默认为1)的 Softmax 输出,正标签的概率接近 1,而负标签的概率接近 0。这种尖锐的分布对学生模型不够友好,因为它只提供了关于正确答案的信息,而忽略了错误答案的信息。即驴比汽车更像马,识别为驴的概率应该大于识别为汽车的概率。而通过温度调整后, 最后得到一个相对平滑的概率分布, 称为 “软标签” (Soft Label)。

知识蒸馏的不同方式

1. 基于输出的蒸馏

直接匹配教师模型和学生模型的输出概率分布。

2. 基于中间层特征的蒸馏

匹配模型中间层的特征,让学生模型学习教师模型的"思考过程"。

3. 基于中间层注意力图的蒸馏

传递模型的注意力机制,帮助学生模型知道"该关注什么"。

4.基于中间层权重的蒸馏

5.基于中间层稀疏模式的蒸馏

6.基于中间相关信息的蒸馏

创新的蒸馏方法

1. 自蒸馏

模型自己当老师,通过多次迭代提升性能,不需要额外的教师模型。

2. 在线蒸馏

教师模型和学生模型同时训练,相互学习,提高效率。

3.结合在线蒸馏和自蒸馏

实际应用场景

知识蒸馏在多个领域都有成功应用:

1. 目标检测

不仅传递分类知识,还包括物体定位信息。

2. 语义分割

通过像素级、成对和整体三个层面的蒸馏提升性能。

3. 生成对抗网络(GAN)

结合蒸馏、重构和对抗性损失实现模型压缩。

4. 自然语言处理

特别强调注意力机制的传递,提升文本处理能力。

网络增强:另一种思路

除了传统的知识蒸馏,网络增强(NetAug)提供了一个新视角:

  • 不是简化大模型,而是增强小模型
  • 将小模型嵌入到大模型中学习
  • 通过多重监督提升性能

代码实践

主要包含:

KD知识蒸馏        DKD解耦知识蒸馏

其区别主要集中在损失函数的不同。

现有的知识蒸馏方法主要关注于中间层的深度特征蒸馏,而对logit蒸馏的重要性认识不足。[DKD]()重新定义了传统的知识蒸馏损失函数,将其分解为目标类知识蒸馏(TCKD)和非目标类知识蒸馏(NCKD)。

- 目标类知识蒸馏(TCKD):关注于目标类的知识传递。

- 非目标类知识蒸馏(NCKD):关注于非目标类之间的知识传递。

# kd_loss
def loss(logits_student, logits_teacher, temperature):
    log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
    loss_kd *= temperature**2
    return loss_kd
import torch
import torch.nn as nn
import torch.nn.functional as F


def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):
    # 使用 _get_gt_mask 和 _get_other_mask 函数创建掩码,分别用于标识真实标签和其他类别。这使得损失计算可以选择性地关注特定类别。
    gt_mask = _get_gt_mask(logits_student, target)
    other_mask = _get_other_mask(logits_student, target)
    pred_student = F.softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    # 使用 cat_mask 函数将掩码应用于学生和教师的预测,得到只关注特定类别的输出。
    pred_student = cat_mask(pred_student, gt_mask, other_mask)
    pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
    log_pred_student = torch.log(pred_student)
    # 计算针对真实标签的 KL 散度损失(tckd_loss),并进行温度缩放
    tckd_loss = (
        F.kl_div(log_pred_student, pred_teacher, size_average=False)
        * (temperature**2)
        / target.shape[0]
    )
    # 计算针对其他类别的 KL 散度损失(nckd_loss),通过从 logits 中减去一个大的值(1000.0)来忽略真实标签的影响。
    pred_teacher_part2 = F.softmax(
        logits_teacher / temperature - 1000.0 * gt_mask, dim=1
    )
    log_pred_student_part2 = F.log_softmax(
        logits_student / temperature - 1000.0 * gt_mask, dim=1
    )
    nckd_loss = (
        F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False)
        * (temperature**2)
        / target.shape[0]
    )
    # 原论文中这里加入了一个 WarmUP
    return alpha * tckd_loss + beta * nckd_loss


def _get_gt_mask(logits, target):
    # 生成一个与 logits 形状相同的全零张量,并在真实标签对应的位置设置为 1,最终返回一个布尔掩码。这个掩码用于在损失计算中关注真实类别。
    target = target.reshape(-1)
    mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
    return mask


def _get_other_mask(logits, target):
    # 生成一个与 logits 形状相同的全一张量,并在真实标签对应的位置设置为 0,最终返回一个布尔掩码。这个掩码用于在损失计算中关注其他类别。
    target = target.reshape(-1)
    mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
    return mask


def cat_mask(t, mask1, mask2):
    # 将输入张量 t 与两个掩码结合,计算出只关注特定类别的输出。
    # 由于 mask1 只保留真实类别的概率,因此这个求和操作给出了每个样本的真实类别的总概率。
    t1 = (t * mask1).sum(dim=1, keepdims=True)
    t2 = (t * mask2).sum(1, keepdims=True)
    rt = torch.cat([t1, t2], dim=1)
    return rt

完整代码:


原文地址:https://blog.csdn.net/TianxiaZhu824/article/details/144016028

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!