Decoupled Knowledge Distillation论文阅读+代码解析

article/2025/8/6 10:48:16

本文来自2022年CVPR的文章,论文地址点这里

一. 介绍

知识蒸馏(KD)的通过最小化师生预测对数之间的KL-Divergence来传递知识(下图a)。目前大部分的研究注意力都被吸引到从中间层的深层特征中提取知识。与基于logit的精馏方法相比,特征精馏在各种任务中都具有优越的性能,因此对logit精馏的研究很少。然而,基于特征的方法的训练成本并不令人满意,因为在训练期间,为了提取深度特征,引入了额外的计算和存储使用(例如,网络模块和复杂的操作)。
Logit蒸馏需要边际计算和存储成本,但性能较差。直观地说,logit蒸馏应该达到与特征蒸馏相当的性能,因为logit比深层特征具有更高的语义级别。我们认为,由于未知的原因限制了logit蒸馏的潜力,导致其性能不理想。为了重振基于逻辑的方法,我们从深入研究KD机制开始这项工作。首先,我们将分类预测分为两个层次:(1)对目标类和所有非目标类的二元预测和(2)对每个非目标类的多类预测。在此基础上,我们将经典的KD损失重新表述为两部分,如下图b所示。一种是针对目标类的二元logit蒸馏,另一种是针对非目标类的多类logit蒸馏。为简化起见,我们将其分别命名为目标分类知识蒸馏(TCKD)和非目标分类知识蒸馏(NCKD)。重新构建知识蒸馏的损失计算使我们能够独立地研究这两个部分的影响。
在这里插入图片描述

二. 方法

2.1 重新定义KD

定义。 对于一个第 t t t个类别的训练样本,分类的概率可以表示为 p = \mathbf{p}= p= [ p 1 , p 2 , … , p t , … , p C ] ∈ R 1 × C \left[p_1, p_2, \ldots, p_t, \ldots, p_C\right] \in \mathbb{R}^{1 \times C} [p1,p2,,pt,,pC]R1×C,其中 p i p_i pi表示为第 i i i个类, C C C表示为所有类别的数量。每一个 p \mathbf{p} p可以使用softmax函数进行计算:
p i = exp ⁡ ( z i ) ∑ j = 1 C exp ⁡ ( z j ) , (1) p_i=\frac{\exp \left(z_i\right)}{\sum_{j=1}^C \exp \left(z_j\right)}, \tag1 pi=j=1Cexp(zj)exp(zi),(1)
其中 z i z_i zi表示第 i i i个类的逻辑输出。
为了区分与目标类相关和不相关的部分,我们定义接下来的部分。 b = [ p t , p \ t ] ∈ R 1 × 2 \mathbf{b}=\left[p_t, p_{\backslash t}\right] \in \mathbb{R}^{1 \times 2} b=[pt,p\t]R1×2,计算过程如下:
p t = exp ⁡ ( z t ) ∑ j = 1 C exp ⁡ ( z j ) , p \ t = ∑ k = 1 , k ≠ t C exp ⁡ ( z k ) ∑ j = 1 C exp ⁡ ( z j ) . p_t=\frac{\exp \left(z_t\right)}{\sum_{j=1}^C \exp \left(z_j\right)}, p_{\backslash t}=\frac{\sum_{k=1, k \neq t}^C \exp \left(z_k\right)}{\sum_{j=1}^C \exp \left(z_j\right)} . pt=j=1Cexp(zj)exp(zt),p\t=j=1Cexp(zj)k=1,k=tCexp(zk).
同时,我们定义 p ^ = [ p ^ 1 , … , p ^ t − 1 , p ^ t + 1 , … , p ^ C ] ∈ \hat{\mathbf{p}}=\left[\hat{p}_1, \ldots, \hat{p}_{t-1}, \hat{p}_{t+1}, \ldots, \hat{p}_C\right] \in p^=[p^1,,p^t1,p^t+1,,p^C] R 1 × ( C − 1 ) \mathbb{R}^{1 \times(C-1)} R1×(C1) 表示为非目标类的概率分,其中对于每一个元素计算如下:
p ^ i = exp ⁡ ( z i ) ∑ j = 1 , j ≠ t C exp ⁡ ( z j ) . (2) \hat{p}_i=\frac{\exp \left(z_i\right)}{\sum_{j=1, j \neq t}^C \exp \left(z_j\right)} .\tag2 p^i=j=1,j=tCexp(zj)exp(zi).(2)
重新构建。 我们使用 T \mathcal{T} T以及 S \mathcal{S} S表示为教师和学生网络模型。那么,经典的知识蒸馏使用KL三度去计算损失如下:
K D = K L ( p T ∥ p S ) = p t T log ⁡ ( p t T p t S ) + ∑ i = 1 , i ≠ t C p i T log ⁡ ( p i T p i S ) . (3) \begin{aligned} \mathrm{KD} &=\mathrm{KL}\left(\mathbf{p}^{\mathcal{T}} \| \mathbf{p}^{\mathcal{S}}\right) \\ &=p_t^{\mathcal{T}} \log \left(\frac{p_t^{\mathcal{T}}}{p_t^S}\right)+\sum_{i=1, i \neq t}^C p_i^{\mathcal{T}} \log \left(\frac{p_i^{\mathcal{T}}}{p_i^{\mathcal{S}}}\right) .\tag3 \end{aligned} KD=KL(pTpS)=ptTlog(ptSptT)+i=1,i=tCpiTlog(piSpiT).(3)
接下来,我们使用式子(1)(2)带入到式子(3):
K D = p t T log ⁡ ( p t T p t S ) + ∑ i = 1 , i ≠ t C p \ t T p ^ i T log ⁡ ( p \ t T p ^ i T p \ \ S p ^ i S ) = p t T log ⁡ ( p t T p t S ) + ∑ i = 1 , i ≠ t C p \ t T p ^ i T ( log ⁡ ( p ^ i T p ^ i S ) + log ⁡ ( p \ t T p \ t S ) ) = p t T log ⁡ ( p t T p t S ) + ∑ i = 1 , i ≠ t C p \ t T p ^ i T log ⁡ ( p ^ i T p ^ i S ) + ∑ i = 1 , i ≠ t C p \ t T p ^ i T log ⁡ ( p \ t T p \ t S ) \begin{aligned} \mathrm{KD} &=p_t^{\mathcal{T}} \log \left(\frac{p_t^{\mathcal{T}}}{p_t^S}\right)+\sum_{i=1, i \neq t}^C p_{\backslash t}^{\mathcal{T}} \hat{p}_i^{\mathcal{T}} \log \left(\frac{p_{\backslash t}^{\mathcal{T}} \hat{p}_i^{\mathcal{T}}}{p_{\backslash \backslash}^{\mathcal{S}} \hat{p}_i^S}\right) \\ &=p_t^{\mathcal{T}} \log \left(\frac{p_t^{\mathcal{T}}}{p_t^S}\right)+\sum_{i=1, i \neq t}^C p_{\backslash t}^{\mathcal{T}} \hat{p}_i^{\mathcal{T}}\left(\log \left(\frac{\hat{p}_i^{\mathcal{T}}}{\hat{p}_i^{\mathcal{S}}}\right)+\log \left(\frac{p_{\backslash t}^{\mathcal{T}}}{p_{\backslash t}^S}\right)\right) \\ &=p_t^{\mathcal{T}} \log \left(\frac{p_t^{\mathcal{T}}}{p_t^S}\right)+\sum_{i=1, i \neq t}^C p_{\backslash t}^{\mathcal{T}} \hat{p}_i^{\mathcal{T}} \log \left(\frac{\hat{p}_i^{\mathcal{T}}}{\hat{p}_i^S}\right) \\ &+\sum_{i=1, i \neq t}^C p_{\backslash t}^{\mathcal{T}} \hat{p}_i^{\mathcal{T}} \log \left(\frac{p_{\backslash t}^{\mathcal{T}}}{p_{\backslash t}^S}\right) \end{aligned} KD=ptTlog(ptSptT)+i=1,i=tCp\tTp^iTlog(p\\Sp^iSp\tTp^iT)=ptTlog(ptSptT)+i=1,i=tCp\tTp^iT(log(p^iSp^iT)+log(p\tSp\tT))=ptTlog(ptSptT)+i=1,i=tCp\tTp^iTlog(p^iSp^iT)+i=1,i=tCp\tTp^iTlog(p\tSp\tT)
其中 p \ t T p^{\mathcal{T}}_{\backslash t} p\tT以及 p \ t S p^{\mathcal{S}}_{\backslash t} p\tS表示为类 i i i的不相关的部分,有:
∑ i = 1 , i ≠ t C p \ t T p ^ i T log ⁡ ( p \ t T p \ t S ) = p \ t T log ⁡ ( p \ t T p \ t S ) ∑ i = 1 , i ≠ t C p ^ i T = p \ t T log ⁡ ( p \ t T p \ t S ) \begin{aligned} \sum_{i=1, i \neq t}^C p_{\backslash t}^{\mathcal{T}} \hat{p}_i^{\mathcal{T}} \log \left(\frac{p_{\backslash t}^{\mathcal{T}}}{p_{\backslash t}^{\mathcal{S}}}\right) &=p_{\backslash t}^{\mathcal{T}} \log \left(\frac{p_{\backslash t}^{\mathcal{T}}}{p_{\backslash t}^{\mathcal{S}}}\right) \sum_{i=1, i \neq t}^C \hat{p}_i^{\mathcal{T}} \\ &=p_{\backslash t}^{\mathcal{T}} \log \left(\frac{p_{\backslash t}^{\mathcal{T}}}{p_{\backslash t}^{\mathcal{S}}}\right) \end{aligned} i=1,i=tCp\tTp^iTlog(p\tSp\tT)=p\tTlog(p\tSp\tT)i=1,i=tCp^iT=p\tTlog(p\tSp\tT)
因此,可以得到
K D = p t T log ⁡ ( p t T p t S ) + p \ t T ∑ i = 1 , i ≠ t C p ^ i T ( log ⁡ ( p ^ i T p ^ i S ) + log ⁡ ( p \ t T p \ t S ) ) = p t T log ⁡ ( p t T p t S ) + p \ t T log ⁡ ( p ⟨ t T p \ t S ) ⏟ K L ( b T ∥ b S ) + p \ t T ∑ i = 1 , i ≠ t C p ^ i T log ⁡ ( p ^ i T p ^ i S ) ⏟ K L ( p ^ T ∥ p ^ S ) . (4) \begin{aligned} \mathrm{KD} &=p_t^{\mathcal{T}} \log \left(\frac{p_t^{\mathcal{T}}}{p_t^S}\right)+p_{\backslash t}^{\mathcal{T}} \sum_{i=1, i \neq t}^C \hat{p}_i^{\mathcal{T}}\left(\log \left(\frac{\hat{p}_i^{\mathcal{T}}}{\hat{p}_i^S}\right)+\log \left(\frac{p_{\backslash t}^{\mathcal{T}}}{p_{\backslash t}^S}\right)\right) \\ &=\underbrace{p_t^{\mathcal{T}} \log \left(\frac{p_t^{\mathcal{T}}}{p_t^S}\right)+p_{\backslash t}^{\mathcal{T}} \log \left(\frac{p_{\langle t}^{\mathcal{T}}}{p_{\backslash t}^S}\right)}_{\mathrm{KL}\left(\mathbf{b}^{\mathcal{T}} \| \mathbf{b}^{\mathcal{S}}\right)}+p_{\backslash t}^{\mathcal{T}} \underbrace{\sum_{i=1, i \neq t}^C \hat{p}_i^{\mathcal{T}} \log \left(\frac{\hat{p}_i^{\mathcal{T}}}{\hat{p}_i^S}\right)}_{\mathrm{KL}\left(\hat{\mathbf{p}}^{\mathcal{T}} \| \hat{\mathbf{p}}^{\mathcal{S}}\right)} . \end{aligned} \tag{4} KD=ptTlog(ptSptT)+p\tTi=1,i=tCp^iT(log(p^iSp^iT)+log(p\tSp\tT))=KL(bTbS) ptTlog(ptSptT)+p\tTlog(p\tSptT)+p\tTKL(p^Tp^S) i=1,i=tCp^iTlog(p^iSp^iT).(4)
之后,我们将式子(4)改写为:
K D = K L ( b T ∥ b S ) + ( 1 − p t T ) K L ( p ^ T ∥ p ^ S ) (5) \mathrm{KD}=\mathrm{KL}\left(\mathbf{b}^{\mathcal{T}} \| \mathbf{b}^{\mathcal{S}}\right)+\left(1-p_t^{\mathcal{T}}\right) \mathrm{KL}\left(\hat{\mathbf{p}}^{\mathcal{T}} \| \hat{\mathbf{p}}^{\mathcal{S}}\right) \tag{5} KD=KL(bTbS)+(1ptT)KL(p^Tp^S)(5)
根据式子(5)我们可以或的两个部分: K L ( b T ∥ b S ) \mathrm{KL}\left(\mathbf{b}^{\mathcal{T}} \| \mathbf{b}^{\mathcal{S}}\right) KL(bTbS)表示为教师以及学生的目标类的相似程度。因此我们可以命名其为目标类的知识蒸馏(TCKD)。同时 K L ( p ^ T ∥ p ^ S ) \mathrm{KL}\left(\hat{\mathbf{p}}^{\mathcal{T}} \| \hat{\mathbf{p}}^{\mathcal{S}}\right) KL(p^Tp^S)表示为非目标类的学生模型和教师模型的相似程度。因此,我们可以进一步将式子(5)改写为:
K D = T C K D + ( 1 − p t T ) N C K D . (6) \mathrm{KD}=\mathrm{TCKD}+\left(1-p_t^{\mathcal{T}}\right) \mathrm{NCKD} . \tag6 KD=TCKD+(1ptT)NCKD.(6)

2.2 TCKD以及NCKD的影响

这部分大概描写的是作者做的哪些实验去验证这两部分,这里我就不在去解释一次。简单来说,对于TCKD来说,它传递了样本难度的相关知识,也就是训练样本的难度越大,TCKD体现出来的效果越好。而NCKD则是逻辑蒸馏的主要挑战,可以发现当教师网络预测目标类越精准的时候,NCKD的系数反而越小,则导致其没有起到良好的训练作用,影响了良好的知识传递。

2.3 分解的知识蒸馏(DKD)

根据上面进行分析的,我们可以重新设置我们需要的知识蒸馏的超参数,如下:
D K D = α T C K D + β N C K D . (7) \mathrm{DKD}=\alpha \mathrm{TCKD}+\beta \mathrm{NCKD} .\tag7 DKD=αTCKD+βNCKD.(7)
具体的算法如下:
在这里插入图片描述

三. 代码解析

代码链接点这里

""" 
logits_student : 学生网络的逻辑输出
logits_teacher : 教师网络的逻辑输出
target :标签值
alpha、beta、temperature : 超参数
"""
def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):### 获得每个target值对应的掩码,从而获得p_tgt_mask = _get_gt_mask(logits_student, target)### 获得其他target对应的掩码,从而获得p_{\t}other_mask = _get_other_mask(logits_student, target)pred_student = F.softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)## 计算b^T以及b^Spred_student = cat_mask(pred_student, gt_mask, other_mask)pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)log_pred_student = torch.log(pred_student)tckd_loss = (F.kl_div(log_pred_student, pred_teacher, size_average=False)* (temperature**2)/ target.shape[0])pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1)log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1)nckd_loss = (F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False)* (temperature**2)/ target.shape[0])return alpha * tckd_loss + beta * nckd_loss

四. 总结

其实本文的想法很简单,但却从数学的角度分析了逻辑知识蒸馏效果不如特征知识蒸馏的原因,并且设置了详细的实验去验证了分解后的知识蒸馏。


http://chatgpt.dhexx.cn/article/Q7ZvSOnb.shtml

相关文章

令牌桶算法

一 算法 令牌桶算法和漏桶算法不同的是,有时后端能够处理一定的突发情况,只是为了系统稳定,一般不会让请求超过正常情况的60%,给容灾留有余地。但漏桶算法中后端处理速度是固定的,对于短时的突发情况,后端…

动态分区分配算法(1、首次适应算法 2、最佳适应算法 3、最坏适应算法 4、邻近适应算法)

文章目录 前言知识总览1、首次适应算法2、最佳适应算法3、最坏适应算法4、邻近适应算法知识回顾与重要考点 前言 此篇文章是我在B站学习时所做的笔记,大部分图片都是课件老师的PPT,方便复习用。此篇文章仅供学习参考。 提示:以下是本篇文章…

《算法4》读书笔记(一)

写在前面:配套网站algs4.cs.princeton.edu,可以把这个网站作为编程的时候的参考资料。这本书比较实用(某瓣评分9.3),但没有动态规划部分,作为两三年没怎么碰过算法和数据结构的菜狗,看了《图解算…

《算法4》深入理解红黑树

红黑树是一种性能非常优秀的数据结构,关键在于它能保证最坏的性能也是对数的,主要是因为它是一种平衡的树,所以也叫平衡查找树。要理解红黑树,最好先看看我的上一篇博客《算法4》符号表以及二叉查找树,了解二叉查找树以…

【算法4总结】第四章:图

目录备份 第四章:图 概述 图可以根据是否有向和带权分成以下四种: 无向图 (无向不带权)有向图 (有向不带权)加权无向图(无向带权)加权有向图(有向带权) …

算法4(一、递归学习)

每次用递归都感觉有点难,这个趁着恶补基础知识的时候,专门看了一遍递归,算法4的。 1.1 递归介绍 方法可以调用自己,例如:下面给出了bin_search的二分查找的一种实现。(算法4中使用的是Java,但…

【算法4总结】第一章:基础

目录备份 第一章:基础 我认为这一章主要介绍的是如何使用工具。 一共五节,前两节主要是对 Java 语法的回顾,第三节则是三个数据结构,背包,队列和栈的API讲解。 而第四节是讲解的是如何分析算法。第五节则是针对具体…

SQL修改语句

如果我们要修改数据库中表的数据&#xff0c;这个时候我们就要使用到UPDATE语句。 UPDATE语句的基本语法是&#xff1a; UPDATE <表名> SET 字段1值1, 字段2值2, ... WHERE ...; 例如&#xff0c;我们想更新employees表id100的记录的last_name和salary这两个字段&…

【数据库】SQL语句之修改语句(INSERT,UPDATE,DELETE)

1.INSERT INSERT INTO <表名> (字段1, 字段2, ...) VALUES (值1, 值2, ...); 例如&#xff1a; 一次插入一个 INSERT INTO students (class_id, name, gender, score) VALUES (2, 小明, M, 80);一次插入多条 INSERT INTO students (class_id, name, gender, score) VA…

SQL Server修改数据

本篇主要讲解的是SQL Server 中修改数据的几种语句&#xff1a; INSERT语句INSERT INTO SELECT语句UPDATE语句DELETE语句 一&#xff1a;INSERT语句 INSERT语句向表中添加新行&#xff0c;以下是INSERT语句的最基本形式&#xff1a; 首先&#xff1a;table_name指定要插入的…

使用SQL语句修改表数据

使用SQL语句修改表数据 文章目录 使用SQL语句修改表数据利用INSERT语句输入数据利用UPDATE语句更新表数据利用DELETE语句删除表中数据利用Truncate Table语句删除表中数据 利用INSERT语句输入数据 INSERT语句的基本语法格式如下&#xff1a; 上述格式主要参数说明如下&#xf…

初探POC编写

文章目录 前言什么是POC什么是 ExpPOC注意事项尝试编写第一个POCpikachu sql盲注poc参考 前言 想锻炼一下编程能力&#xff0c;师兄说以后很重要的&#xff0c;最好学好一点 但是我又想学习安全相关的&#xff0c;那就来练练poc吧 什么是POC PoC(全称: Proof of Concept), …

POC_MeterSphere-RCE

MeterSphere-RCE 漏洞详情影响范围指纹- fingerPOC-YAML《飞致云MeterSphere开源测试平台远程代码执行漏洞》 漏洞详情 MeterSphere一站式开源持续测试平台存在的远程代码执行漏洞。由于自定义插件功能处存在缺陷,未经身份验证的攻击者可利用该漏洞在目标系统上远程执行任意…

【POC---概念验证】

文章目录 前言一、Proof of Concept是什么&#xff1f;验证内容 PoC测试工作准备前提PoC测试工作参与者PoC测试工作准备文档PoC测试工作第一阶段 工作启动第二阶段 产品宣讲及现场集中测试第三阶段 技术测评第四阶段 间歇性测试工作 第五阶段 商务验证第六阶段 背书归档、分析总…

POC_Jenkins

简介 Jenkins是一个开源软件项目,是基于Java开发的一种持续集成工具,用于监控持续重复的工作,旨在提供一个开放易用的软件平台,使软件项目可以进行持续集成,Jenkins用Java语言编写,可在Tomcat等流行的servlet容器中运行,也可独立运行。通常与版本管理工具(SCM)、构建工…

网络安全中的POC、EXP、Payload、ShellCode_网络安全payload是什么意思

什么是 POC、EXP、Payload&#xff1f; POC&#xff1a;概念证明&#xff0c;即概念验证&#xff08;英语&#xff1a;Proof of concept&#xff0c;简称POC&#xff09;是对某些想法的一个较短而不完整的实现&#xff0c;以证明其可行性&#xff0c;示范其原理&#xff0c;其…

Zendstudio 9.0.2 安装Aptana3 并且配置 jQuery

原文地址:http://my.oschina.net/feek/blog/64517 aptana-javascript-jquery.ruble文件夹下载地址: http://dl.dbank.com/c04bfgbchz 一直在用zenstudio9&#xff0c;有时候又需要用到jquery等插件辅助制作前台效果&#xff0c;想安装个Aptana3插件&#xff0c;但是查了好多网…

elfk

1. 2. ​​​​​​​ 3. 4. 5.

c/c++读写文件(c方法篇)

读写文件操作 C语言方法 参考博客&#xff1a;函数基本介绍&#xff1a; https://www.cnblogs.com/lidabo/p/6813354.html 自己写了个demo测试一下fopen、feek、fwrite、fread函数的使用&#xff0c;代码如下&#xff1a; void OperationFile_C() {FILE* fp NULL;fp fopen(…

EFLFK

目录 zookeeper概述 zookeeper 定义 Zookeeper工作机制 Zookeeper特点 Zookeeper数据结构 Zookeeper 应用场景 Zookeeper选举机制 第一次启动选举机制 非第一次启动选举机制 部署 Zookeeper 集群 准备 3 台服务器做 Zookeeper 集群 安装前准备 关闭防火墙 安装 J…