自学内容网 自学内容网

【知识蒸馏】deeplabv3 logit-based 知识蒸馏实战,对剪枝的模型进行蒸馏训练

本文将对【模型剪枝】基于DepGraph(依赖图)完成复杂模型的一键剪枝 文章中剪枝的模型进行蒸馏训练

一、逻辑蒸馏步骤

  • 加载教师模型
  • 定义蒸馏loss
  • 计算蒸馏loss
  • 正常训练

二、代码

1、加载教师模型

教师模型使用未进行剪枝,并且已经训练好的原始模型。

teacher_model = torch.load('./logs/before_prune.pth', map_location=device)

2、定义蒸馏loss

分割和分类的loss,都是用的softmax。

import torch.nn.functional as F
import torch.nn as nn
# 蒸馏温度
Tempature = 2
def KD_loss(teacher_pred, student_pred):
    t_p = F.softmax(teacher_pred / Tempature, dim=1)
    s_p = F.log_softmax(student_pred / Tempature, dim=1)
    return nn.KLDivLoss(reduction='mean')(s_p, t_p) * (Tempature ** 2)

3、 计算蒸馏loss

teacher_outputs = t_model(imgs)
# 蒸馏loss
soft_loss = KD_loss(teacher_outputs, outputs)
# 总loss = 蒸馏loss*alpha + 原学生模型loss*(1-alpha)
alpha = 0.9
all_loss = loss * (1 - alpha) + soft_loss * alpha

4、正常训练

all_loss.backward()

用剪枝前训练好的模型对剪枝后模型进行蒸馏训练,训练后测试效果如下:
在这里插入图片描述


原文地址:https://blog.csdn.net/m0_51579041/article/details/139095067

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!