ADMM算法在神经网络模型剪枝方面的应用

article/2025/9/22 11:13:21

文章目录

    • 前言
    • 1. 交替方向乘子法
    • 2. 论文中的表述
    • 3. 对论文中的公式进行推导
    • 4. 代码流程
    • 5. 主要函数实现
    • 6. dense vs. prune(finetune)
    • 结束语

前言

  本篇博客记录一下自己根据对论文 GRIM: A General, Real-Time Deep Learning Inference Framework for Mobile Devices based on Fine-Grained Structured Weight Sparsity 中提到的ADMM算法的理解,给出了ADMM算法的推导过程,并在文章的末尾提供了实现的代码。

1. 交替方向乘子法

  交替方向乘子法(Alternating Direction Method of Multipliers, ADMM)作为一种求解优化问题的计算框架,适用于求解凸优化问题。ADMM算法的思想根源可以追溯到20世纪50年代,在20世纪八九十年代中期存在大量的文章分析这种方法的性质,但是当时ADMM主要用于解决偏微分方程问题。1970年由 R. GlowinskiD. Gabay 等提出的一种适用于可分离凸优化的简单有效方法,并在统计机器学习、数据挖掘和计算机视觉等领域中得到了广泛应用。ADMM算法主要解决带有等式约束的关于两个变量的目标函数的最小化问题,可以看作在增广拉朗格朗日算法基础上发展的算法,混合了对偶上升算法(Dual Ascent)的可分解性和乘子法(Method of Multipliers)的算法优越的收敛性。相对于乘子法,ADMM算法最大的优势在于其能够充分利用目标函数的可分解性,对目标函数中的多变量进行交替优化。在解决大规模问题上,利用ADMM算法可以将原问题的目标函数等价地分解成若干个可求解的子问题,然后并行求解每一个子问题,最后协调子问题的解得到原问题的全局解。1

  优化问题
m i n i m i z e f ( x ) + g ( z ) s u b j e c t t o A x + B z = c minimize\ f(x)+g(z) \\ subject\ to\ Ax+Bz=c minimize f(x)+g(z)subject to Ax+Bz=c  其中, x ∈ R n , z ∈ R m , A ∈ R p × n , B ∈ R p × m , c ∈ R p x \in R^n,z \in R^m,A \in R^{p \times n},B \in R^{p \times m},c \in R^p xRn,zRm,ARp×n,BRp×m,cRp,构造拉格朗日函数为
L p ( x , z , λ ) = f ( x ) + g ( z ) + λ T ( A x + B z − c ) L_p(x,z,\lambda )=f(x)+g(z)+\lambda ^{T}(Ax+Bz-c) Lp(x,z,λ)=f(x)+g(z)+λT(Ax+Bzc)  其增广拉格朗日函数(augmented Lagrangian function)
L p ( x , z , λ ) = f ( x ) + g ( z ) + λ T ( A x + B z − c ) + ρ 2 ∣ ∣ A x + B z − c ∣ ∣ 2 L_p(x,z,\lambda )=f(x)+g(z)+\lambda ^{T}(Ax+Bz-c)+ \frac {\rho} {2}||Ax+Bz-c||^{2} Lp(x,z,λ)=f(x)+g(z)+λT(Ax+Bzc)+2ρAx+Bzc2  对偶上升法迭代更新
( x k + 1 , z k + 1 ) = a r g m i n x , z L p ( x , z , λ k ) λ k + 1 = λ k + ρ ( A x k + 1 + B z k + 1 − c ) (x^{k+1},z^{k+1})=\underset {x,z} {argmin\ } L_p(x,z,\lambda ^k) \\ \lambda ^{k+1}=\lambda ^k+\rho (Ax^{k+1}+Bz^{k+1}-c) (xk+1,zk+1)=x,zargmin Lp(x,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1c)  交替方向乘子法则是在 ( x , z ) (x,z) (x,z)一起迭代的基础上将 x , z x,z x,z分别固定单独交替迭代,即
x k + 1 = a r g m i n x L p ( x , z k , λ k ) z k + 1 = a r g m i n z L p ( x k + 1 , z , λ k ) λ k + 1 = λ k + ρ ( A x k + 1 + B z k + 1 − c ) x^{k+1}=\underset {x} {argmin\ }L_p(x,z^k,\lambda ^k) \\ z^{k+1}=\underset {z} {argmin\ }L_p(x^{k+1},z,\lambda ^k) \\ \lambda ^{k+1}=\lambda ^k+\rho (Ax^{k+1}+Bz^{k+1}-c) xk+1=xargmin Lp(x,zk,λk)zk+1=zargmin Lp(xk+1,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1c)  交替方向乘子的另一种等价形式,将残差定义为 r k = A x k + B z k − c r^k=Ax^k+Bz^k-c rk=Axk+Bzkc,同时定义 u k = 1 ρ λ k u^k=\frac {1} {\rho} \lambda ^k uk=ρ1λk作为缩放的对偶变量(dual variable),有
( λ k ) T r k + ρ 2 ∣ ∣ r k ∣ ∣ 2 = ρ 2 ∣ ∣ r k + u k ∣ ∣ 2 − ρ 2 ∣ ∣ u k ∣ ∣ 2 (\lambda ^k)^Tr^k+\frac {\rho} {2} ||r^k||^2=\frac {\rho} {2}||r^k+u^k||^2-\frac {\rho} {2}||u^k||^2 (λk)Trk+2ρrk2=2ρrk+uk22ρuk2  改写 ADMM 的迭代过程
x k + 1 = a r g m i n x { f ( x ) + ρ 2 ∣ ∣ A x + B z k − c + u k ∣ ∣ 2 } z k + 1 = a r g m i n z { g ( z ) + ρ 2 ∣ ∣ A x k + 1 + B z − c + u k ∣ ∣ 2 } u k + 1 = u k + A x k + 1 + B z k + 1 − c x^{k+1} =\underset {x} {argmin\ }\bigg\{f(x)+\frac {\rho} {2}||Ax+Bz^k-c+u^k||^2\bigg\} \\[5pt] z^{k+1}=\underset {z} {argmin\ }\bigg\{g(z)+\frac {\rho} {2}||Ax^{k+1}+Bz-c+u^k||^2\bigg\} \\[5pt] u^{k+1}=u^k+Ax^{k+1}+Bz^{k+1}-c xk+1=xargmin {f(x)+2ρAx+Bzkc+uk2}zk+1=zargmin {g(z)+2ρAxk+1+Bzc+uk2}uk+1=uk+Axk+1+Bzk+1c

2. 论文中的表述

在这里插入图片描述
在这里插入图片描述

3. 对论文中的公式进行推导

  为便于推导公式,将论文中的进行简化,参数W和b简记为W,此时的优化问题变为
m i n i m i z e f ( W i ) + ∑ i = 1 N g ( Z i ) s u b j e c t t o W i = Z i , i = 1 , 2 , . . . , N minimize\ f(W_i)+\sum_{i=1}^{N} g(Z_i) \\[4pt] subject\ to\ W_i=Z_i, i=1,2,...,N minimize f(Wi)+i=1Ng(Zi)subject to Wi=Zi,i=1,2,...,N  构造拉格朗日函数为
L p ( w , z , λ ) = f ( w ) + ∑ g ( z ) + λ T ( w − z ) L_p(w,z,\lambda )=f(w)+\sum g(z)+\lambda ^{T}(w-z) Lp(w,z,λ)=f(w)+g(z)+λT(wz)  其增广拉格朗日函数为
L p ( w , z , λ ) = f ( w ) + ∑ g ( z ) + λ T ( w − z ) + ∑ ρ 2 ∣ ∣ w − z ∣ ∣ 2 L_p(w,z,\lambda )=f(w)+\sum g(z)+\lambda ^{T}(w-z)+ \sum \frac {\rho} {2}||w-z||^{2} Lp(w,z,λ)=f(w)+g(z)+λT(wz)+2ρwz2  交替方向乘子法:在(x, z)一起迭代的基础上将 x, z 分别固定,单独交替迭代,即
w k + 1 = a r g m i n w L p ( w , z k , λ k ) z k + 1 = a r g m i n z L p ( w k + 1 , z , λ k ) λ k + 1 = λ k + ∑ ρ ( w − z ) w^{k+1}=\underset {w} {argmin\ }L_p(w,z^k,\lambda ^k) \\[4pt] z^{k+1}=\underset {z} {argmin\ }L_p(w^{k+1},z,\lambda ^k) \\[4pt] \lambda ^{k+1}=\lambda ^k+\sum \rho (w-z) wk+1=wargmin Lp(w,zk,λk)zk+1=zargmin Lp(wk+1,z,λk)λk+1=λk+ρ(wz)  定义一个对偶变量
u k = 1 ρ λ k u^k=\frac {1} {\rho} \lambda ^k uk=ρ1λk  改写ADMM的迭代过程
w k + 1 = a r g m i n w { f ( w ) + ∑ ρ 2 ∣ ∣ w − z k + u k ∣ ∣ 2 } z k + 1 = a r g m i n z { ∑ g ( z ) + ∑ ρ 2 ∣ ∣ w k + 1 − z + u k ∣ ∣ 2 } u k + 1 = u k + w k + 1 − z k + 1 w^{k+1} =\underset {w} {argmin\ }\bigg\{f(w)+\sum \frac {\rho} {2}||w-z^k+u^k||^2\bigg\} \\[5pt] z^{k+1}=\underset {z} {argmin\ }\bigg\{\sum g(z)+\sum \frac {\rho} {2}||w^{k+1}-z+u^k||^2\bigg\} \\[5pt] u^{k+1}=u^k+w^{k+1}-z^{k+1} wk+1=wargmin {f(w)+2ρwzk+uk2}zk+1=zargmin {g(z)+2ρwk+1z+uk2}uk+1=uk+wk+1zk+1

4. 代码流程

# 初始化参数Z和U
Z, U = initialize_Z_and_U(model)# 训练model,并更新X,Z,U,损失函数为admm loss
for epoch in range(epochs):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = admm_loss(model, Z, U, output, target)loss.backward()optimizer.step()W = update_W(model)Z = update_Z(W, U, percent)U = update_U(U, W, Z)# 对weight进行剪枝,返回 mask
mask = apply_prune(model, percent)# 对剪枝后的model进行finetune
finetune(model, mask, train_loader, test_loader, optimizer)

5. 主要函数实现

def admm_loss(args, device, model, Z, U, output, target):idx = 0loss = F.nll_loss(output, target)for name, param in model.named_parameters():if name.split('.')[-1] == "weight":u = U[idx].to(device)z = Z[idx].to(device)# 这里就是推导出来的admm的表达式loss += args.rho / 2 * (param - z + u).norm()return lossdef update_W(model):W = ()for name, param in model.named_parameters():if name.split('.')[-1] == "weight":W += (param.detach().cpu().clone(),)return Wdef update_Z(W, U, args):new_Z = ()idx = 0for w, u in zip(W, U):z = w + upcen = np.percentile(abs(z), 100*args.percent[idx])under_threshold = abs(z) < pcen# percent剪枝率,小于percent分位数的置为0z.data[under_threshold] = 0new_Z += (z,)idx += 1return new_Zdef update_U(U, W, Z):new_U = ()for u, w, z in zip(U, W, Z):new_u = u + w - znew_U += (new_u,)return new_Udef prune_weight(weight, device, percent):# to work with admm, we calculate percentile based on all elements instead of nonzero elements.weight_numpy = weight.detach().cpu().numpy()pcen = np.percentile(abs(weight_numpy), 100*percent)under_threshold = abs(weight_numpy) < pcen# 非结构化剪枝weight_numpy[under_threshold] = 0mask = torch.Tensor(abs(weight_numpy) >= pcen).to(device)return mask

6. dense vs. prune(finetune)

在这里插入图片描述

结束语

  对论文中算法的推导仅限于自己的理解,可能还存在一些问题,欢迎来评论区交流哦^_^

参考教程


  1. 《分布式机器学习:交替方向乘子法在机器学习中的应用》---- 雷大江著 ↩︎


http://chatgpt.dhexx.cn/article/LgKtKilt.shtml

相关文章

ADMM算法及其放缩形式,在压缩快照成像重建的图像重建论文中的公式推导

刚接触图像恢复问题&#xff0c;简单记录一下最近看论文刚弄懂的一个小问题 ADMM算法 ADMM&#xff08;交替方向乘子法&#xff09;是一种解决变量可分离凸优化问题的简单算法&#xff0c;具有求解速度快&#xff0c;收敛性能好的特点。ADMM可以将原问题转换为几个子问题&…

开漏输出和推挽输出的区别?

《《《《正文》》》》》 《理解三极管的原理》 如下图&#xff0c;以NPN三极管为例&#xff1a; 它是一种电流控制型元器件&#xff0c;即基极B的输出输出电流可以实现对元器件的控制。所以可以进一步理解为基极B为控制端&#xff0c;集电极C为输入端&#xff0c;发射极E为输出…

推挽输出和开漏输出

推挽输出&#xff08;push-pull&#xff09;&#xff1a; 推挽输出&#xff08;push-pull&#xff09;&#xff1a; 推挽输出&#xff0c;正如字面上的意思&#xff0c;有“推”&#xff0c;也有“挽”&#xff0c;推挽输出电路运用两个MOS管构成&#xff0c;上面为P-MOS&…

推挽输出和开漏输出详解

序言&#xff1a; 平时&#xff0c;写程序的时候总遇IO口模式的端口配置。但是从来没有仔细研究过具体到底是什么含义。作为一名嵌入式工程师应该是不合格的&#xff0c;现在把端口定义重新梳理一下。 一、NPN和PNP区别 NPN 是用 B→E 的电流&#xff08;IB&#xff0…

STM32的推挽输出和开漏输出

文章目录 前言一、推挽输出二、开漏输出三、区别和适应场景总结前言 本篇文章将带大家了解STM32的推挽输出和开漏输出,并且学习这两个的区别,学习分别在什么时候使用这两个不同的输出方式。 在 STM32 微控制器中,GPIO(General Purpose Input/Output)模块是一个通用的输入…

什么是GPIO的推挽输出和开漏输出

数字芯片GPIO一般分为推挽输出和开漏输出 数字芯片GPIO一般是推挽输出&#xff08;PUSH-PULL&#xff09;&#xff0c;其内部结构如下&#xff1a; 当上面的MOS管导通时&#xff0c;GPIO输出高电平1&#xff0c;称为“推” 当下面MOS管导通时&#xff0c;GPIO输出低电平0&…

浅谈开漏输出和推挽输出的理解

理解电路元件特性 在理解这两种输出之前我们需要对三极管这种电路元器件进行理解&#xff0c;三极管都包括三个部分&#xff0c;基极&#xff08;base&#xff09;、集电极(Collector)以及发射极(Emitter)。他们负责不同的功能&#xff0c; 1.基极主要负责控制电流导通与否 2.…

开漏输出和推挽输出的差别

&#xff27;&#xff30;&#xff29;&#xff2f;内部仅有以上三种组合形式 而当上面任意两种形式组合时则 一、推挽输出 高低电平两两组合则形成了推挽输出的模式。 优点&#xff1a;能输出高低电平、且高低电平都有驱动能力 缺点&#xff1a;不能实现线与的功能&#xff…

终于搞清楚开漏输出和推挽输出这个鬼东西

先说下推挽输出&#xff0c;简单的说&#xff0c;就是想输出高电平&#xff0c;就输出高电平&#xff0c;想输出低电平就输出低电平。 推挽电路上面是NPN三极管&#xff0c;下面是PNP三极管&#xff0c;请注意输入端和输出端的波形。 下面是输入波形 当输入为正时&#xff0c;上…

推挽输出和开漏输出有什么不同?

推挽输出和开漏输出有什么不同&#xff1f; 推挽输出&#xff08;Push-Pull Output&#xff09;开漏输出&#xff08;Open Drain Output&#xff09;两者比较 首先介绍一下什么是推挽输出和开漏输出。 推挽输出&#xff08;Push-Pull Output&#xff09; 推挽输出结构是由两个…

区分推挽输出和开漏输出

推挽输出:可以输出高,低电平,连接数字器件。 输出 0 时&#xff0c;N-MOS 导通&#xff0c;P-MOS 高阻&#xff0c;输出0。 输出 1 时&#xff0c;N-MOS 高阻&#xff0c;P-MOS 导通&#xff0c;输出1&#xff08;不需要外部上拉电路&#xff09;。 开漏输出:输出端相当于三…

GPIO之推挽输出和开漏输出

疑问 GPIO配置为输出时会有两种模式&#xff0c;一种叫推挽输出&#xff0c;一种叫开漏模式。那什么是推挽输出&#xff0c;什么又是开漏输出呢&#xff1f; 三种输出状态 如下图所示为将GPIO配置为输出时的内部示意图&#xff1a; 由上图可以看出&#xff0c;GPIO的输出状…

推挽输出和开漏输出区别

目录 1.推挽输出 2.开漏输出 1.推挽输出 当想输出高电平时&#xff0c;P-MOS导通&#xff0c;N-MOS截止&#xff0c;输出为电源电压VDD 当想输出低电平时&#xff0c;N-MOS导通&#xff0c;P-MOS截止&#xff0c;相当于引脚直接接地&#xff0c;输出低电平 2.开漏输出 开…

从硬件方面理解GPIO的开漏输出和推挽输出

最近在学STM32&#xff0c;看正点原子视频中对开漏输出和推挽输出的讲解视频时&#xff0c;发现原子哥对电路的讲解有一些错误&#xff0c;主要说关于MOS管的开关问题&#xff0c;查了一晚上资料&#xff0c;终于想明白了&#xff0c;特意发个文章分享一下。 这是STM32F4XX中文…

推挽输出与开漏输出

推挽输出 要理解推挽输出&#xff0c;首先要理解好三极管&#xff08;晶体管&#xff09;的原理。下面这种三极管有三个端口&#xff0c;分别是基极&#xff08;Base&#xff09;、集电极&#xff08;Collector&#xff09;和发射极&#xff08;Emitter&#xff09;。下图是NP…

开漏输出与推挽输出

一、开漏输出&#xff1a;集电极开路门(OC)与漏极开路门(OD)一般用于线与和电流驱动的场合&#xff0c;为开集(漏)输出结构。 1. 利用外部电路的驱动能力&#xff0c;减少IC内部的驱动。 2. 可以将多个开漏输出引脚连接在一起&#xff0c;通过一个上拉电阻上拉到VCC&#xff…

开漏输出和推挽输出总结(一看就懂)

推挽输出&#xff08;Push-Pull Output&#xff09; 推挽输出结构是由两个MOS或者三极管收到互补控制的信号控制&#xff0c;两个管子时钟一个在导通&#xff0c;一个在截止&#xff0c;如图1所示&#xff1a; 推挽输出的最大特点是可以真正能真正的输出高电平和低电平&…

开漏输出、推挽输出的区别

前言 background&#xff1a;测试相关设备引脚输出&#xff0c;使用示波器时发现部分引脚需外接上拉电阻至高电平才能在示波器观察到高阻态&#xff0c;为了深究其中原理&#xff0c;查阅了相关资料&#xff0c;发现知乎中有一篇对这两种输出描述得清晰易懂的文章&#xff0c;此…

开漏输出和推挽输出

开漏输出和推挽输出 概述模拟文件下载推挽输出线与开漏输出输出电压最后 概述 在STM32或者GD32中&#xff0c;普通的输出GPIO输出方式主要是开漏输出和推挽输出&#xff0c;下面我们开始讲解这2种模式的区别。需要样片的可以加群申请&#xff1a;615061293。 下图是GPIO内部的…

如何正确理解开漏输出和推挽输出

作者&#xff1a;知乎用户 链接&#xff1a;https://www.zhihu.com/question/28512432/answer/41217074 来源&#xff1a;知乎 著作权归作者所有。商业转载请联系作者获得授权&#xff0c;非商业转载请注明出处。 我觉得下面这个「网上资料」还是很不错的。 单片机I/O口推挽输出…