【模型压缩】(二)—— 剪枝

article/2025/9/15 13:24:14

一、概述

剪枝(Pruning)的一些概念:

  • 当提及神经网络的"参数"时,大多数情况指的是网络的学习型参数,也就是权重矩阵weights和偏置bias;
  • 现代网络的参数量大概在百万至数十亿之间,因此实际上在一个网络中也并不是所有权值都是重要的,剪枝的作用就是削减那些不重要权重矩阵的一种直接压缩模型的方式;
  • 对于一个已经训练好的模型,切断或删除某些连接,同时保证不对精度造成重大影响,这样得到的模型就是一个参数较少的剪枝模型;
  • 从生物学的角度来说,人类在成长过程中突触会减少,但思维能力反而更强了;
  • 和dropout的区别:dropout具有随机性,剪枝具有针对性;

下面看一下剪枝的实际操作图:

在这里插入图片描述

二、策略

剪枝主要有以下几种方法:

1、迭代式剪枝:训练权重——剪枝(根据阈值)——重新训练权重【最常用】

2、动态剪枝:剪枝和训练同时进行,在网络的优化目标中加入权重的稀疏正则项,使得网络训练时部分权重趋近于0;
3、对推理过程中单个目标剪枝;

总结:大多数的剪枝方法实际上是迭代的方式进行的,因为修剪后重新训练,可以让模型因修剪操作导致的精度下降恢复过来,然后在进行下一次修剪,直到达到精度下降的阈值,就不再修剪;

策略对比图:

在这里插入图片描述

从图中可以看出,单纯剪枝到50%精度就开始下降,剪枝后训练到80%精度才开始下降,迭代进行剪枝到90%精度才下降;

拓展:
实际上剪枝的大类分为几种:
1、非结构化剪枝:也就是上述介绍的将不重要的权重置为0;
2、结构化剪枝:将模型的一个完整结构剪除,比如channels、filters、layers;
3、自动化剪枝:NAS,需要大量的算力支持;

三、优缺点

优点:

  • 可以应用在训练期间或训练结束后;

  • 对于任意一个结构,可以自主控制推理时间/模型大小与准确率之间的平衡;

  • 可应用于卷积层和全连接层;

缺点:

  • 没有直接切换到一个更好的网络来的有效;

四、代码案例

首先需要明确,剪枝是需要对模型层做一定修改的;

本次代码是基于小模型LeNet进行剪枝实验;

1、对模型结构中的Liner层进行修改,添加mask这个变量(自定义MaskedLinear层)

class MaskedLinear(Module):def __init__(self, in_features, out_features, bias=True):super(MaskedLinear, self).__init__()self.in_features = in_featuresself.out_features = out_features# 将weight转换为可学习的变量self.weight = Parameter(torch.Tensor(out_features, in_features))# 初始化mask的值为1,并转换为可学习的变量self.mask = Parameter(torch.ones([out_features, in_features]), requires_grad=False)if bias:# 对bias进行初始化self.bias = Parameter(torch.Tensor(out_features))else:# 将bias设置为空self.register_parameter('bias', None)	self.reset_parameters()# 参数初始化   def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)# 前向传播(实际上也是使用标准的Liner层)def forward(self, input):# 其中的weight、mask都定义成可变的可学习变量return F.linear(input, self.weight * self.mask, self.bias)

LeNet的定义没有做任何修改,也就是几层全连接层,就不在这里进行代码展示了;

2、对模型每一层学习到的参数进行处理

for name, p in model.named_parameters():if 'mask' in name:continue# 模型参数tensor = p.data.cpu().numpy()# 梯度信息grad_tensor = p.grad.data.cpu().numpy()# 将参数的值为0的,梯度也更新为0grad_tensor = np.where(tensor == 0, 0, grad_tensor)p.grad.data = torch.from_numpy(grad_tensor).to(device)

3、统计每一层参数的非零数量,可用于展示剪枝的效果

def print_nonzeros(model):nonzero = total = 0for name, p in model.named_parameters():if 'mask' in name:continuetensor = p.data.cpu().numpy()# 用numpy中的函数统计tensor中非0值的数量nz_count = np.count_nonzero(tensor)total_params = np.prod(tensor.shape)nonzero += nz_counttotal += total_paramsprint(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}')print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x  ({100 * (total-nonzero) / total:6.2f}% pruned)')

4、实现剪枝的具体操作

# 参数s控制剪枝的力度
def prune_by_std(self, s=0.25):for name, module in self.named_modules():if name in ['fc1', 'fc2', 'fc3']:# 取weight值得标准差乘以sthreshold = np.std(module.weight.data.cpu().numpy()) * s# 打印每一层计算标准差阈值后得结果print(f'Pruning with threshold : {threshold} for layer {name}')# 得到阈值后进行剪枝module.prune(threshold)# 具体实现剪枝的函数
def prune(self, threshold):weight_dev = self.weight.device# mask就是一开始传入的参数,全为1mask_dev = self.mask.device# Convert Tensors to numpy and calculatetensor = self.weight.data.cpu().numpy()mask = self.mask.data.cpu().numpy()# 更新mask(小于阈值的时候为0,不小于阈值的还是为1)new_mask = np.where(abs(tensor) < threshold, 0, mask)# weight和新的mask进行矩阵相乘self.weight.data = torch.from_numpy(tensor * new_mask).to(weight_dev)# 更新对应的maskself.mask.data = torch.from_numpy(new_mask).to(mask_dev)         

说明:这里进行剪枝后,模型的精度会有下降,需要进行重新训练;

重新训练直接用原来的优化器参数训练即可,此时置为0的weight也不再参与梯度优化;

五、结果展示

剪枝前,经过了100个epoch:

在这里插入图片描述

此时精度为95.23%,wight参数全部不为0;

经过剪枝后:

在这里插入图片描述

此时可以看出,精度下降到85.08%,但weight的数值缩小了接近22倍,大大减少了参数量;

剪枝后再重新训练100个epoch:

在这里插入图片描述

此时精度又回到了97%,甚至比剪枝前还高,并且压缩度也保持不变;

总结

剪枝的操作总结下来分为几步:

模型的训练 —— 修改要剪枝的层(添加同weight维度的mask) —— 进行剪枝后推理 —— 根据剪枝后的权重重新训练
下图给到了剪枝的一个建议:
在这里插入图片描述

个人理解:剪枝本质就是忽略那些低于阈值的参数,从而减少参数量,使得模型得到压缩;

实际上在每一种结构中都可以用到剪枝,弊端就是工作量较大,需要针对不同层进行修改,并且还要重新训练,如果剪枝的力度过大,可能导致和剪枝前精度相差过大;


http://chatgpt.dhexx.cn/article/qtxGLaGW.shtml

相关文章

环形队列的基本运算算法-数据结构教程

环形队列的基本概念 如图&#xff0c;其实它就是一个队列&#xff0c;就是有点难理解而已&#xff0c;它避免了普通队列的缺点&#xff0c;一样有队列头&#xff0c;队列尾&#xff0c;一样是先进先出的原则。我们采用顺时针的方式来对队列进行排序。 队列头(front) :允许进行删…

一道亚马逊算法面试题的情景分析

阅读博客的朋友可以观看视频&#xff1a; http://study.163.com/course/courseMain.htm?courseId1002942008 我们聚焦于一道亚马逊的算法面试题&#xff0c;通过分析该题&#xff0c;复盘它的解题情景&#xff0c;我们可以初步体会到算法面试的应对步骤&#xff0c;并从中窥…

LeetCode刷题笔记 标准模板库巧解算法题 优先队列

优先队列简介 ​ 优先队列&#xff08;priority queue&#xff09;可以在 O(1) 时间内获得最大值&#xff0c;并且可以在 O(log n) 时间内取出最大值或插入任意值。 ​ 优先队列常常用堆&#xff08;heap&#xff09;来实现。堆是一个完全二叉树&#xff0c;其每个节点的值总…

Python数据结构与算法(3.4)——队列相关应用与习题

Python数据结构与算法(3.4)——队列相关应用与习题 0. 学习目标1. 使用两个栈实现一个队列2. 使用两个队列实现一个栈3. 栈中元素连续性判断4. 重新排列队列中元素顺序5. 反转队列中前 m 个元素的顺序相关链接0. 学习目标 我们已经学习了队列的相关概念以及其实现,同时也了…

第十七章 优先队列优化Dijkstra算法

第十七章 优先队列优化Dijkstra算法 一、普通dijkstra算法的缺陷1、选出最小距离的过程&#xff1a;2、松弛所有点的过程&#xff1a; 二、如何优化1、代码模板&#xff08;1&#xff09;问题&#xff1a;&#xff08;2&#xff09;模板&#xff1a; 2、详细解读 三、优化分析1…

【自顶向下模块化编程】C语言实现多级反馈队列调度算法

自顶向下-多级反馈队列 多级反馈队列算法算法原理算法描述题目摘要 自顶向下模块化设计整体框架具体实现GeneratorSchedulerExecutor 整体代码实现 总结及心得总结心得 多级反馈队列算法 多级反馈队列调度算法是一种CPU处理机调度算法&#xff0c;UNIX操作系统采取的便是这种调…

[算法] 栈和队列

欢迎来到老胡的算法解题思路&#xff0c;本文章主要使用的语言为java&#xff0c;使用的题型为力扣算法题&#xff0c;基于这一篇文章&#xff0c;我将为你介绍栈和队列的基础知识和栈和队列的题型&#xff0c;喜欢的朋友可以关注一下&#xff0c;下次更新不迷路&#xff01; 目…

十道经典面试算法真题详解

前言 分享一下 腾讯常考的十道算法题&#xff08;真题&#xff09;。在金三银四&#xff0c;希望对大家有帮助呀。 重排链表 最长递增子序列 环形链表 反转链表 最长回文子串 全排列 LRU 缓存 合并K个升序链表 无重复字符的最长子串 删除链表的倒数第 N 个结点 1. …

队列相关习题

1.已知循环队列存储在一维数组A0…n-1]中&#xff0c;且队列非空时front和rear分别指向队头元素和队尾元素。若初始时队列为空&#xff0c;且要求第一个进入队列的元素存储在A[0]处&#xff0c;则初始时front和rear的值分别是( )。 A.0&#xff0c;0 B. 0&#xff0c;n-1 C. n-…

java算法面试题_Java算法面试题汇总

原标题&#xff1a;Java算法面试题汇总 1. 字符串 如果IDE没有代码自动补全功能&#xff0c;所以你应该记住下面的这些方法。 toCharArray() // 获得字符串对应的char数组 Arrays.sort() // 数组排序 Arrays.toString(char[] a) // 数组转成字符串 charAt(int x) // 获得某个索…

详解单调队列算法

前言 嘿!彩蛋!感觉有帮助就三连呗! 如果你对这篇文章可感兴趣,可以点击「【访客必读 - 指引页】一文囊括主页内所有高质量博客」,查看完整博客分类与对应链接。 在上一篇文章中,我们介绍了「单调栈」这一最常考察的线性数据结构。而今天我们将继续沿着这个思路,介绍另…

栈和队列相关经典算法题总结(数据结构+C语言)

我们这里针对栈和队列的一些经典算法题做详细讲解: 1.括号匹配问题. 2.用队列实现栈. 3.用栈实现队列. 4.设计循环队列. 一.详细讲解如下: 1.括号匹配问题.(如下图) 给定一个只包括 (&#xff0c;)&#xff0c;{&#xff0c;}&#xff0c;[&#xff0c;] 的字符串 s &am…

qt使用消息队列服务器,qt代码实现消息队列通信

qt代码实现消息队列通信 内容精选 换一换 HBase 1.X版本在RPC流程中&#xff0c;多个数据通信线程会争抢同一个缓存Buffer队列&#xff0c;代码以lock重入锁实现线程安全&#xff0c;锁抢占严重&#xff0c;导致HBase不能充分发挥CPU多核的能力。HBase 1.X版本的RPC通信机制中B…

消息队列MQ常见面试题

面试官在面试候选人时&#xff0c;如果发现候选人的简历中写了在项目中使用了 MQ 技术&#xff08;如 Kafka、RabbitMQ、RocketMQ&#xff09;&#xff0c;基本都会抛出一个问题&#xff1a;在使用 MQ 的时候&#xff0c;怎么确保消息 100% 不丢失&#xff1f; 这个问题在实际…

RabbitMQ消息队列常见面试题总结

1、什么是消息队列&#xff1a; 1.1、消息队列的优点&#xff1a; &#xff08;1&#xff09;解耦&#xff1a;将系统按照不同的业务功能拆分出来&#xff0c;消息生产者只管把消息发布到 MQ 中而不用管谁来取&#xff0c;消息消费者只管从 MQ 中取消息而不管是谁发布的。消息…

【消息队列】面试题及答案整理

消息队列面试题 为什么要使用消息队列/消息队列的应用场景使用了消息队列会有什么缺点如何保证消息队列是高可用的RocketMQ是如何保证消息队列是高可用的 如何保证消息不被重复消费/如何保证消息消费的幂等性如何保证消费的可靠性传输RocketMQ如何保证消费的可靠性传输RabbitMQ…

JAVA——快速排序(详细)

JAVA快速排序的实现 快速排序由于排序效率在同为O(N*logN)的几种排序方法中效率较高&#xff0c;因此经常被采用&#xff0c;再加上快速排序思想----分治法也确实实用&#xff0c;因此很多软件公司的笔试面试&#xff0c;包括像腾讯&#xff0c;微软等知名IT公司都喜欢考这个&…

快速排序算法(java实现)

基本思想 快速排序是一种采用分治法解决问题的一个典型应用&#xff0c;也是冒泡排序的一种改进。它的基本思想是&#xff0c;通过一轮排序将待排记录分割成独立的两部分&#xff0c;其中一部分均比另一部分小&#xff0c;则可分别对这两部分继续进行排序&#xff0c;已达到整…

java快速排序(含快速排序代码)

目录 一&#xff1a;快速排序思想 二&#xff1a;快速排序代码&#xff08;pivot一定时先和arrays【r】先比较&#xff09; 三&#xff1a;结果 一&#xff1a;快速排序思想 假设我们现在对“6 1 2 7 9 3 4 5 10 8”这个10个数进行排序。首先在这个序列中随便找一个数作为基准…

快速排序 Java 实现

概念 快速排序&#xff08;Quicksort&#xff09;是对冒泡排序的一种改进。 参考: [数据结构与算法(Kotlin语言)]1.冒泡排序&#xff08;Bubble Sort&#xff09; 快速排序是C.R.A.Hoare于1962年提出的一种划分交换排序。它采用了一种分治的策略&#xff0c;通常称其为分治法(…