模型压缩一-知识蒸馏

article/2025/9/2 16:11:43

一、知识蒸馏简介

        知识蒸馏是模型压缩方法中的一个大类,是一种基于“教师-学生网络(teacher-student-network)思想”的训练方法, 其主要思想是拟合教师模型(teacher-model)的泛化性等(如输出概率、中间层特征、激活边界等),而不是一个简简单单的0-1类别标签。

       这一技术的理论来自于2015年Hinton发表的一篇论文: Distilling the Knowledge in a Neural Network。知识蒸馏,英文名为Knowledge Distillation, 简称KD,顾名思义,就是将已经训练好的模型包含的知识(“Knowledge”),蒸馏(“Distill”)提取到另一个模型里面去(通常是简单的模型、学生模型)。

        知识蒸馏也可以看成是迁移学习的特例,工业界应用得比较广泛的是将BERT模型蒸馏到较少层的transformer, 或者LSTM、CNN等普通模型。BERT模型由于其强大的特征抽取能力,在很多NLP任务上能够达到soft-state-art的效果。 尽管如此,BERT还是有着超参数量大、占用空间大、占用计算资源大、推理时间长等缺点,即便是大公司等也不能随心所欲地使用。 因此,一个简单的想法便是通过BERT等获取一个简单、但性能更好地轻量级算法模型,知识蒸馏无疑是一种有效的方法。 巨大的BERT在很多业务场景下的线上inference都存在很大的性能瓶颈,于是就有了知识蒸馏的用武之地。

        常见的使用方式是离线fintune BERT模型,训练一个离线指标明显优于小模型的模型, 然后用fintue好的BERT模型作为指导蒸馏一个小的模型,也可以看做是一个muti-task的训练任务。 最后上线用小模型即可,从而获得性能和效果双赢的局面。

 

 

二、知识蒸馏开山之作

        知识蒸馏的开山之作是Hinton于2015年提出的论文: Distilling the Knowledge in a Neural Network,旨在把一个大模型或者多个模型ensemble学到的知识迁移到另一个轻量级单模型上,方便部署。简单的说就是用新的小模型去学习大模型的预测结果,改变一下目标函数。

        为什么蒸馏可以work?好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的目标是让student学习到teacher的泛化能力,理论上得到的结果会比单纯拟合训练数据的student要好。另外,对于分类任务,如果soft targets的熵比hard targets高,那显然student会学习到更多的信息。

        其soft targets学习(重点是温度)的主要公式(即softmax前除以一个整数T【一般在1-100间】)是:

                           

                          

                          

        注意:      

                1. 很简单的方法,即预测标签概率 除以 T,真实标签概率 除以 T,最终loss结果需要 乘以 T的平方。

                2. 自己实验中,更适用于半监督学习,以及均衡的、简单的分类问题等。

 

三、知识蒸馏思维导图

        模型蒸馏的开山之作,其基本思路是让student-model学习teacher-model的输出概率分布,其主要是对输出目标的学习(如分类类别的概率分布)。

       之后,陆续出现从teacher-model的中间层学习其特征feature的论文,主要是图像吧,这样可以学习到更加丰富的信息。

       接着,出现了从teachea-model的激活边界进行学习的论文。

       再然后,出现了student-model自学的论文,主要还是对数据集的操作,batch-size内数据的综合,数据增强等。

       最后一个,是利用对抗神经网络的经验,将student模型看成生成模型,teachea-model看成判断模型......

       借用别人的一张思维导图说明:

 

四、BERT模型蒸馏到简单网络、Tranformers

        BERT蒸馏一般有两种方式,一种是BERT蒸馏到较少网络层的transformers,另外一种是BERT蒸馏到TextCNN、TextRNN等简单网络。

        这里我们介绍一篇BERT蒸馏到BILSTM的论文:

                Distilling Task-Specific Knowledge from BERT into Simple Neural Networks

                论文的关键点在于:

                       1. 适用于少量标注、大批量无标签的数据集,也就是冷启动问题;

                       2. 同知识蒸馏的开山之作,该论文还是两阶段蒸馏;

                       3. 损失函数使用的是常规loss + 均方误差loss;

                       4. 使用[MASK]、N-GRAM等数据增强技术等。

          再介绍一篇BERT蒸馏到少层的Transformer的论文:

                 TINYBERT: DISTILLING BERT FOR NATURAL LANGUAGE UNDERSTANDING

                 论文的关键点在于:

                         1、还是两阶段模型蒸馏,teacher-model到student-model层数有一个映射;

                         2、基于注意力的蒸馏和基于隐状态的蒸馏、基于Embedding的蒸馏;

                         3、MSE-loss损失函数,3个loss相加

            自己实验效果:

                    1. 某任务二分类(有提升):

                            最终 (21 epoch) : mean-P    R       F

                            原始text-cnn    0.961     0.960   0.96

                            蒸馏text-cnn    0.97      0.97    0.97

                            原始Bert-base   0.98      0.98    0.98

                      2. 某任务11分类(长尾问题,效果不好)

                           最终 (21 epoch) :mean-P    R       F

                           原始text-cnn   0.901     0.872    0.879

                           蒸馏text-cnn   0.713     0.766    0.738

                      3. 某任务90分类(长尾问题,效果不好)

                           最终(21 epoch) mean-P    R       F

                           原始text-cnn   0.819     0.8      0.809

                           蒸馏text-cnn   0.696     0.593    0.64  (样本最多的类别不能拟合)

                           原生Bert-base 0.85      0.842    0.846

                     总结:BERT蒸馏到Text-CNN,Bi-LSTM等简单模型,并不适合完全有监督分类的任务,

                                最好还是BERT蒸馏到较少层数的transforerm比较好。

 

五、TextBrewer

5.1 概述:

        TextBrewer科大讯飞开源的一个基于PyTorch的、为实现NLP中的**知识蒸馏**任务而设计的工具包,融合并改进了NLP和CV中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架,用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。

5.2 主要特点:

**TextBrewer** 为NLP中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架,主要特点有:

* 模型无关:适用于多种模型结构(主要面向**Transfomer**结构)

* 方便灵活:可自由组合多种蒸馏方法;可方便增加自定义损失等模块

* 非侵入式:无需对教师与学生模型本身结构进行修改

* 支持典型的NLP任务:文本分类、阅读理解、序列标注等

5.3 支持的知识蒸馏技术:

* 软标签与硬标签混合训练

* 动态损失权重调整与蒸馏温度调整

* 多种蒸馏损失函数: hidden states MSE, attention-based loss, neuron selectivity transfer

* 任意构建中间层特征匹配方案

* 多教师知识蒸馏

 

5.4 自己实验:

某任务90分类数据集,ERNIE【一次推理0.1s】蒸馏到3层transofrmers【一次推理0.02s】

Ernie转Transformer(T3) 蒸馏实验(teacher-epochs-21, student-epochs-48)

lr    batch_size     loss     T    dropout  Mean-P    R       F

ernie:

5e-5    16         bce     1      0.1    0.856    0.871   0.863

蒸馏:

5e-5    16        0.9-kl    1     0.1     0.818    0.818    0.818  (3 - layer)

原生transformers-3:

5e-5    16         bce     1      0.1    0.741    0.768    0.754   (3 - layer)

原始text-cnn  

1e-3     16         bce                      0.819     0.8      0.809

 

总结:TextBrewer还可以,BERT蒸馏到较少层数的transforerm对于简单的全监督分类还是有一定效果的。

 

 

希望对你有所帮助!

 

 

 

 

 

 

 

 


http://chatgpt.dhexx.cn/article/24ICirw7.shtml

相关文章

推荐系统之GBDT+LR

前言 前面讲过的FM与FFM模型虽然增强了模型的交叉能力,但是不管怎样都只能做二阶的交叉,如果想要继续加大特征交叉的维度,那就会出大计算爆炸的情况。所以Facebook提出了梯度提升树(GBDT)逻辑回归(LR&…

使用Keras进行单模型多标签分类

原文:https://www.pyimagesearch.com/2018/05/07/multi-label-classification-with-keras/ 作者:Adrian Rosebrock 时间:2018年5月7日 源码:https://pan.baidu.com/s/1x7waggprAHQDjalkA-ctvg (wa61) 译者&…

LR模型常见问题小议

 LR模型常见问题小议 标签: LR机器学习 2016-01-10 23:33 671人阅读 评论(0) 收藏 举报 本文章已收录于: 分类: 机器学习(10) 作者同类文章 X 版权声明:本文为博主原创文章&…

信用评分卡(A卡) 基于LR模型的数据处理及建模过程

数据来自:魔镜杯风控算法大赛(拍拍贷)。有关数据的具体描述可以看比赛页面。 0. 数据集的关键字段及描述: Master:每一行代表一个样本(一笔成功成交借款),每个样本包含200多个各类…

机器分类---LR分类+模型评估

文章目录 数据集ROC曲线与AUC理论知识曲线理解实例计算 代码 更详细的数据集介绍(有图形分析,应该比较好理解) https://blog.csdn.net/weixin_42567027/article/details/107416002 数据集 数据集有三个类别,每个类别有50个样本。…

python机器学习算法(赵志勇)学习笔记( Logistic Regression,LR模型)

Logistic Regression(逻辑回归) 分类算法是典型的监督学习,分类算法通过对训练样本的学习,得到从样本特征到样本的标签之间的映射关系,也被称为假设函数,之后可利用该假设函数对新数据进行分类。 通过训练数据中的正负样本,学习样本特征到样本标签之间的假设函数,Log…

推荐系统实战中LR模型训练(二)

背景: 上一篇推荐系统实战中LR模型训练(一) 中完成了LR模型训练的代码部分。本文中将详细讲解数据准备部分,即将文本数据数值化为稀疏矩阵的形式。 文本数据: 稀疏矩阵: 实现过程: 文本数据格…

机器学习 | LR逻辑回归模型

逻辑回归(Logistic Regression,简称LR)名为“回归”却是用来分类工作、在线性数据上表现优异的分类器。 视频教程:第07讲:逻辑回归是线性分类器的佼佼者 LR是数据挖掘领域常用的一种分类模型,常用于解决二分类问题,例如垃圾邮件判定、经济预测、疾病诊断(通过年龄、性…

推荐系统实战中LR模型训练(一)

背景: 在“批量导入数据到Redis” 中已经介绍了将得到的itema item1:score1,item2:score2…批量导入到Redis数据库中。本文的工作是运用机器学习LR技术,抽取相应的特征,进行点击率的估计。 点击率(Click-Through-Rate, CTR&#…

Prometheus TSDB存储原理

Python微信订餐小程序课程视频 https://blog.csdn.net/m0_56069948/article/details/122285951 Python实战量化交易理财系统 https://blog.csdn.net/m0_56069948/article/details/122285941 Prometheus 包含一个存储在本地磁盘的时间序列数据库,同时也支持与远程…

数据库必知必会:TiDB(8)TiDB 数据库 SQL 执行流程

数据库必知必会:TiDB(8)TiDB 数据库 SQL 执行流程 数据库 SQL 执行流程DML语句读流程概述SQL的Parse与CompileSQL的Execute DML语句写流程概述执行 DDL语句流程概要执行 知识点回顾 数据库 SQL 执行流程 在TiDB中三个重要组件: …

时不我待,TSDB崛起正当时

近期有小伙伴问Jesse,为什么你们要在现在这个时点做TSDB,这是个好时点吗?我认为这是个挺好的问题,因为再强的个人也比不上一个团队,再牛的团队也需要顺势而为。我们其实一直在深度思考“Why Now”的问题,因…

时间序列数据库TSDB排名

DB-Engines 中时序列数据库排名 我们先来看一下DB-Engines中关于时序列数据库的排名,这是当前(2016年2月的)排名情况: 下面,我们就按照这个排名的顺序,简单介绍一下这些时序列数据库中的一些。下面要介绍的…

TiDB Server

目录 TiDB Server架构 Online DDL GC 缓存管理 热点小表缓存 例题 TiDB Server架构 Protocol Layer:负责处理客户端的连接 Parse,Compile:负责SQL语句的解析与编译,并生成执行计划 Executor,DistSQL&#xff0…

Prometheus TSDB

TSDB 概述: Head: 数据库的内存部分 Block: 磁盘上持久块,是不变的 WAL: 预写日志系统 M-map: 磁盘及内存映射 粉红色框是传入的样品,样品先进入Head中存留一会,然后到磁盘、内存映射中(蓝色框)。然后当内…

TiDB体系结构之TiDB Server

TiDB体系结构之TiDB Server TiDB ServerTiDB Server主要组成模块SQL语句的解析和编译行数据与KV的转化SQL读写相关模块在线DDL相关模块TiDB的垃圾回收TiDB Server的缓存 TiDB Server TiDB Server的主要作用如下: 处理客户端连接SQL语句的解析和编译关系型数据与KV…

TSDB助力风电监控

各位小伙伴大家好,本期Jesse想再来跟大家聊聊TSDB的应用场景,在此也感谢尹晨所著的《时序数据库在风电监控系统中的应用》一文,其为我们探究TSDB在风电系统中的应用提供了重要的帮助。 本文仅代表个人观点,如有偏颇之处&#xff…

dbt-tidb 1.2.0 尝鲜

作者: shiyuhang0 原文来源: https://tidb.net/blog/1f56ab48 本文假设你对 dbt 有一定了解。如果是第一次接触 dbt,建议先阅读 官方文档 或 当 TiDB 遇见 dbt 本文中的示例基于官方维护的 jaffle_shop 项目。关于此项目的细节介绍&a…

为啥用 时序数据库 TSDB

前言 其实我之前是不太了解时序数据库以及它相关的机制的,只是大概知晓它的用途。但因为公司的业务需求,我意外参与并主导了公司内部开源时序数据库influxdb的引擎改造,所以我也就顺理成章的成为时序数据库“从业者”。 造飞机的人需要时刻…

Prometheus 学习之——本地存储 TSDB

Prometheus 学习之——本地存储 TSDB 文章目录 Prometheus 学习之——本地存储 TSDB前言一、TSDB 核心概念二、详细介绍1.block1)chunks2)index3)tombstone4)meta.json 2.WAL 总结 前言 Prometheus 是 CNCF 收录的第二个项目&…