【YOLO v4 相关理论】Normalization: BN、CBN、CmBN

article/2025/10/6 19:05:34

一、Batch Normalization

论文:https://arxiv.org/pdf/1502.03167.pdf
源码: link.

Batch Normalization是google团队在2015年论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出的。

个人认为这时一篇可以排进深度学习前十的一篇神作,目前大部分的流行算法、模型都会用到BN,它可以加快模型的收敛速度,训练使用BN的模型甚至比不使用BN的模型快10倍。而且更重要的是在一定程度缓解了深层网络中“梯度弥散(特征分布较散)”的问题。在代码中我们基本上会默认使用Conv + BN + activation function的组合,但是bn真正是如何运作的很少有提及。

先来从直观上看下怎么使用Batch Normalization:
在这里插入图片描述

1、背景知识

什么是特征Normaliztion(Scaling 归一化、标准化)?
数据的归一(normalization)是将数据按比例缩放,使之落入一个小的特定区间。
可分为线性函数归一化(Min-Max Scaling)和零均值归一化(Zero-Score Normalization)两种

线性函数归一化(Min-Max Scaling)

公式:

X ∗ = x − x m e a n x m a x − x m i n X^*=\frac{x-x_{mean}}{x_{max}-x_{min}} X=xmaxxminxxmean

其中,X为原始数据 ,Xmean 为原始数据均值 ,Xmax为原始数据的最大值 ,Xmin为原始数据的最小值

线性函数归一化(Min-Max Scaling)

公式:

z = x − μ σ z=\frac{x-\mu}{\sigma} z=σxμ

其中,μ为原始特征的均值、σ为原始特征的标准差(方差)。
它会将原始数据映射到均值为0、标准差为1的分布上(高斯分布/正态分布)

feature map 为什么要用Normaliztion(归一化)?
1、方便训练、提高训练速度
2、防止模型梯度爆炸
3、提升模型的精度

方便训练、提高收敛速度

在这里插入图片描述
如上图是我从李宏毅老师的BN讲解视频.中截取的一张图。左边的图表示没有做Normalization的输入数据,所以两个数据的值是相差较大的,假设我们这里 x 2 > > x 1 x2 >> x1 x2>>x1,那么经过 w x + b wx+b wx+b 后再通过激活函数得到预测值 a a a,通过预测值和真实值得到损失函数 LossL。

因为 x 2 > > x 1 x2 >> x1 x2>>x1, 所以W2对LossL的影响非常大,而W1对LossL的影响很小,画出 损失函数和两个权重W1 和 W2的图像如左下图(椭圆形等高线)。在W2方向上grad很大, 在W1上grad较小。那么在训练的时候,如果需要改变较大的话,就需要给W2方向一个较小的learning_rate,给W1方向一个较小的learning_rate,这对我们的训练来说肯定大大的增加了难度的。

同理,如果对数据数据做过Normalization(使所以数据满足均值为0,方差为1的分布)的话,那么 x 1 x1 x1 x 2 x2 x2差不多大,W1和W2对Loss的损失差不多大,那么就会产生右下图(原形等高线)。在W2和W1方向上grad差不多大,那么我们就可以只给一个learning_rate进行训练,这就大大降低了我们的训练难度。

防止模型梯度爆炸

在这里插入图片描述
如上图是均值为0,方差为1的标准正态分布图,由上图可知,64%的概率x其值落在[-1,1]的范围内;95%的概率x其值落在了[-2,2]的范围内。那么这有什么意义呢?我们都知道输入值在经过加权(wx+b)后,会经过激活函数(sigmoid 、tanh、relu等)激活,假设非线性函数是sigmoid,那么看下sigmoid(x)函数及其导数图形:
在这里插入图片描述
在没有经过Normalization前,95%的值落在了[-8,4]之间,从sigmoid函数图可以看出,在[-8, -2] 和 [2, 4]这很明显是梯度饱和区(在这个区域梯度几乎消失,非常难以训练,训练起来速度特别的慢)。而经过BN后,目前大部分Activation的值落入非线性函数的线性区内,其对应的导数远离导数饱和区,这样来加速训练收敛过程,防止梯度爆炸。

提升模型的精度

每个维度的量纲其实已经等价了,每个维度都服从均值为0,、方差为1的正态分布,在计算距离的时候,每个维度都是去量纲化的,避免了不同量纲的选取对距离计算产生的巨大影响。

为什么要将数据Batch后再送进模型?
1、Batch之后可以将一个Batch的数据放到一个矩阵中,使用GPU进行矩阵运算,加速运算
2、Batch在处理时,我们要尽可能的大,用一个Batch的均值和方差作为对整个数据集均值和方差的估计

我们刚刚有说让feature map满足某一分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说要计算出整个训练集的feature map然后在进行标准化处理,对于一个大型的数据集明显是不可能的,所以论文中说的是Batch Normalization,也就是我们计算一个Batch数据的feature map然后在进行标准化(batch越大越接近整个数据集的分布,效果越好)。

什么是Internal Covariate Shift(内部协变量偏移)?
内部协变量偏移指的是当前面的一些层(参数)发生很小的变化,会对后面的层造成很大的影响。后面的层需要不断的适应前面层的变化,导致非常难以训练。
这个问题使用BN后可以得以改善

为什么要对模型的每一层的输出都使用Normalization?
我们在上面讲了Normalization(数据要满足分布规律)的好处,虽然对于输出数据进行了Normalization,但是对于Conv2而言输入的feature map就不一定满足某一分布规律了。所以我们这里会对每一层在进行加权计算之后,都进行Normalization,最后再送入激活函数。
注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律
而且这样也可以减轻上面的Internal Covariate Shift(内部协变量偏移)现象。

2、训练和推理

在这里插入图片描述
BN可以作为神经网络的一层,放在激活函数(如Relu)之前。
上图是原论文截取的一张图,描述的是训练的步骤(对每一个mini-batch):

  1. 求出一个mini-batch的均值mean
  2. 求出一个mini-batch的方差/标准差 variance
  3. 使用求得的均值和方差对该批次的训练数据做归一化,获得0-1分布。其中ε是为了避免除数为0时所使用的微小正数
  4. 尺度变换和偏移:将 x i x_i xi乘以 γ \gamma γ调整数值大小,再加上β增加偏移后得到 y i y_i yi,这里的 γ \gamma γ是尺度因子,β是平移因子。这一步是BN的精髓,由于归一化后的 x i x_i xi 基本会被限制在正态分布下,使得网络的表达能力下降。为解决该问题,我们引入两个新的参数: γ \gamma γ 和 β。 γ \gamma γ和β是在训练时网络自己学习得到的。

那么测试时又该怎么用BN呢?

测试阶段不需要每一步都计算出均值和方差,我们会选出训练时具有代表性的均值和方差带入公式。而这个代表性的就是指的是训练集中计算出的所有均值和方差的平均,因为我们在每一个mini-batch计算均值和方差的时候都会保存好相应的均值和方差的,所以可以很方便的计算出。之后计算BN还是和训练时的公式一样,这里不再赘述。

3、计算示例

在这里插入图片描述

  1. 这里的 u 1 u_1 u1是对整个batch的channel1的所有数据而言的,同理也可计算出整个batch的channel2的所有数据的均值 u 2 u_2 u2,再组合成 u u u
  2. 利用均值和方差公式计算出方差 σ 2 \sigma^2 σ2
  3. 对mini-batch的每一个channel的每一个元素,利用计算的均值 u u u和方差 σ 2 \sigma^2 σ2带入BN公式,就可求出对应位置的值。

4、代码实现

import randomimport torch.nn as nn
import torchdef BN(feature, mean, var):feature_shape = feature.shape   # (2, 2, 2, 2) = (batch_size, C, H, W)for i in range(feature_shape[1]):   # feature_shape[1] = 2 = C: channel# [batch, channel, height, width]feature_t = feature[:, i, :, :]mean_t = feature_t.mean()  # 求出整个channel的mean# 训练:总体标准差std_t1 = feature_t.std()   # 求出整个channel的std# 测试:样本标准差std_t2 = feature_t.std(ddof=1)# bn   对第i个channel的每一个元素  进行norm  初始伽马=1 贝塔=0feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / std_t1# update calculating mean and var  记录下mean和var用于测试集用# 训练时使用总体标准差   测试时使用样本标准差# 0.1为momentummean[i] = mean[i] * (1-0.1) + mean_t * 0.1var[i] = var[i] * (1-0.1) + (std_t2 ** 2) * 0.1return feature, mean, varif __name__ == '__main__':random.seed(1)# 随机生成一个batch为2,channel为2,height=width=2的特征向量# [batch, channel, height, width]feature = torch.randn(2, 2, 2, 2)print("=============feature================")print(feature)# 初始化统计均值和方差mean = [0.0, 0.0]variance = [1.0, 1.0]# print(feature1.numpy())# # 注意要使用copy()深拷贝feature_bn, mean_bn, variance_bn = BN(feature.numpy().copy(), mean, variance)print("================feature_bn_myself================")print(feature_bn)print("================mean================")print(mean_bn)print("================variance================")print(variance_bn)#bn = nn.BatchNorm2d(2)output = bn(feature)print("================feature_bn_pytorch================")print(output)

输入:
在这里插入图片描述
计算的均值和方差:
在这里插入图片描述
自己写的BN输出:
在这里插入图片描述
调用官方的BN输出:
在这里插入图片描述

5、BN的优点总结

  1. 调参简单多了,对于权重初始化要求没那么高
  2. 起到了正则化的效果,可以不再使用Dropout,也可以不再使用L2正则化
  3. 可以使用大的学习率而没有任何副作用,大大的加速了训练
  4. 一定程度缓解了深层网络中“梯度弥散(特征分布较散)”的问题
  5. 改善了Internal Covariate Shift(内部协变量偏移)现象
  6. 甚至可以提升模型精度。

总而言之,经过这么简单的变换,带来的好处多得很,这也是为何现在BN这么快流行起来的原因。

6、使用BN的注意事项

  1. 训练时要将traning参数设置为True,在验证时将trainning参数设置为False。在pytorch中可通过创建模型的model.train()和model.eval()方法控制。
  2. batch size尽可能设置大点,设置小后表现可能很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。
  3. 建议将bn层放在卷积层(Conv)和激活层(例如Relu)之间,且卷积层不要使用偏置bias,因为没有用,具体推理过程如下
    在这里插入图片描述

二、CBN

论文:https://arxiv.org/abs/2002.05712.
源码:CBN.py.

2.1、背景

从上节BN的学习我们可以知道BN有很多很多的优点,比如:

  1. 对权重初始化的要求没那么高了
  2. 可以使用更大的学习率进行训练,加大了训练的速度
  3. 一定程度上缓解了梯度消失的问题
  4. 解决了内部协变量偏移的现象
  5. 还具有一定的正则化的作用,可以不再使用DropOut
  6. 甚至还可以提升模型的精度

但是,BN有一个致命的缺陷,那就是我们在设计BN的时候有一个前提条件就是当batch_size足够大的时候,用mini-batch算出的BN参数( μ \mu μ σ \sigma σ)来近似等于整个数据集的BN参数。但是当batch_size较小的时候,BN的效果会很差。如下图1的BN线,随着batch_size的减小,BN的表现骤减。
在这里插入图片描述
针对这个问题,很多学者从空间角度做了很多的尝试,比如LN、IN、GN等,但是这些方法都是针对不同的任务的,不具备一定的普适性。所以CBN就改变了思路,希望从时间维度尝试解决这个问题:batch_size太小,本质上还是数据太少不足以近似整个训练集的BN参数,那就通过计算前几个iteration计算好的BN参数( μ \mu μ σ \sigma σ),一起来计算这次iter的BN参数。

问题1:这种用前几个iteration计算好的BN参数( μ \mu μ σ \sigma σ)来计算这次iter的BN参数的方法会有一个问题:过去的BN参数是由过去的网络参数计算出来feature map再计算得到的,而本轮迭代时计算BN参数时我的参数其实已经过时了,如图1的 Native CBN,直接用以前的网络参数来计算以前的BN参数效果并不好?

为了解决这个问题,我们引入了泰勒公式。因为由于梯度下降的机制,模型再训练过程中相近的iteration所对应的模型参数的变化是平滑的,所有我们可以用泰勒公式来估算以前的网络参数。

2.2、泰勒公式表达

回忆BN:

x ^ t , i ( θ t ) = x t , i ( θ t ) − μ t ( θ t ) σ ( θ t ) 2 + ε ( 1 ) \hat{x}_{t,i}(\theta_t)=\frac{x_{t,i}(\theta_t)-\mu_t(\theta_t)}{\sqrt{\sigma(\theta_t)^2+\varepsilon}}\qquad (1) x^t,i(θt)=σ(θt)2+ε xt,i(θt)μt(θt)(1) μ t ( θ t ) = 1 m ∑ i = 1 m x t , i ( θ t ) ( 2 ) \mu_t(\theta_t)=\frac{1}{m}\sum_{i=1}^m {x}_{t,i}(\theta_t)\qquad (2) μt(θt)=m1i=1mxt,i(θt)(2) σ ( θ t ) = 1 m ∑ i = 1 m ( x t , i ( θ t ) − μ t ( θ t ) ) 2 = ν t ( θ t ) − μ t ( θ t ) 2 ( 3 ) \sigma(\theta_t)=\sqrt{\frac{1}{m}\sum_{i=1}^m({x}_{t,i}(\theta_t)-\mu_t(\theta_t))^2} = \sqrt{\nu_t(\theta_t)-\mu_t(\theta_t)^2}\qquad (3) σ(θt)=m1i=1m(xt,i(θt)μt(θt))2 =νt(θt)μt(θt)2 (3)
ν t ( θ t ) = 1 m ∑ i = 1 m x t , i ( θ t ) 2 ( 4 ) \nu_t(\theta_t)=\frac{1}{m}\sum_{i=1}^m x_{t,i}(\theta_t)^2 \qquad (4) νt(θt)=m1i=1mxt,i(θt)2(4)
y t , i ( θ t ) = γ x ^ t , i ( θ t ) + β ( 5 ) y_{t,i}(\theta_t)=\gamma \hat{x}_{t,i}(\theta_t)+\beta \qquad (5) yt,i(θt)=γx^t,i(θt)+β(5)
其中:
θ t \theta_t θt表示第 t t t 个mini-batch 的网络参数;
x t , i ( θ t ) x_{t,i}(\theta_t) xt,i(θt)表示第 t t t个mini-batch中第i个样本经网络得到的feature map;
x ^ t , i ( θ t ) \hat{x}_{t,i}(\theta_t) x^t,i(θt)表示feature map中第i个样本经BN后得到的新样本的feature map(均值为0, 方差为1);
μ t ( θ t ) \mu_t(\theta_t) μt(θt) σ ( θ t ) \sigma(\theta_t) σ(θt)表示当前mini-batch计算出来的均值和方差 ε \varepsilon ε为防0系数;
γ \gamma γ β \beta β是BN需要学习的参数;
m表示mini-batch中有m个样本

使用泰勒公式近似之前iter的均值和方差:

假设现在是第 t t t 次迭代,假如要算之前的第 ( t − τ ) (t-\tau) (tτ) 次迭代的均值和方差
但是之前迭代计算的均值和方差都是用之前的网络参数( θ t − τ \theta_{t-\tau} θtτ)计算得到的 => μ t ( θ t ) \mu_t(\theta_t) μt(θt) ν t ( θ t ) \nu_t(\theta_t) νt(θt)
因为我们又发现连续几次迭代的网络参数的变化是平滑的,所以根据泰勒公式展开式可以估算上述两个参数
μ t − τ ( θ t ) = μ t − τ ( θ t − τ ) + ∂ μ t − τ ( θ t − τ ) ∂ θ t − τ ( θ t − θ t − τ ) + O ( ∣ ∣ θ t − θ t − τ ∣ ∣ 2 ) ( 6 ) \mu_{t-\tau(\theta_t)}= \mu_{t-\tau}(\theta_{t-\tau})+ \frac{\partial \mu_{t-\tau}(\theta_{t-\tau})}{\partial \theta_{t-\tau}} (\theta_t-\theta_{t-\tau}) +O(|| \theta_t-\theta_{t-\tau} ||^2) \qquad (6) μtτ(θt)=μtτ(θtτ)+θtτμtτ(θtτ)(θtθtτ)+O(θtθtτ2)(6)
ν t − τ ( θ t ) = ν t − τ ( θ t − τ ) + ∂ ν t − τ ( θ t − τ ) ∂ θ t − τ ( θ t − θ t − τ ) + O ( ∣ ∣ θ t − θ t − τ ∣ ∣ 2 ) ( 7 ) \nu_{t-\tau}(\theta_t)= \nu_{t-\tau}(\theta_{t-\tau})+\frac{\partial \nu_{t-\tau}(\theta_{t-\tau})}{\partial \theta_{t-\tau}} (\theta_t-\theta_{t-\tau}) +O(|| \theta_t-\theta_{t-\tau} ||^2) \qquad (7) νtτ(θt)=νtτ(θtτ)+θtτνtτ(θtτ)(θtθtτ)+O(θtθtτ2)(7)
其中 ∂ μ t − τ ( θ t − τ ) ∂ θ t − τ \frac{\partial \mu_{t-\tau}(\theta_{t-\tau})}{\partial \theta_{t-\tau}} θtτμtτ(θtτ) ∂ ν t − τ ( θ t − τ ) ∂ θ t − τ \frac{\partial \nu_{t-\tau}(\theta_{t-\tau})}{\partial \theta_{t-\tau}} θtτνtτ(θtτ)为第 ( t − τ ) (t-\tau) (tτ)次迭代的BN参数对第 ( t − τ ) (t-\tau) (tτ)次迭代的网络参数的偏导数
O ( ∣ ∣ θ t − θ t − τ ∣ ∣ 2 ) O(|| \theta_t-\theta_{t-\tau} ||^2) O(θtθtτ2)表示泰勒展开式的高阶项,当 ( θ t − θ t − τ ) (\theta_t-\theta_{t-\tau}) θtθtτ较小时,高阶项可以忽略不计
但是要精确计算出 ∂ μ − τ ( θ t − τ ) ∂ θ t − τ \frac{\partial \mu- \tau(\theta_{t-\tau})}{\partial \theta_{t-\tau}} θtτμτ(θtτ) ∂ ν − τ ( θ t − τ ) ∂ θ t − τ \frac{\partial \nu- \tau(\theta_{t-\tau})}{\partial \theta_{t-\tau}} θtτντ(θtτ)的计算量会很大,因为 μ t − τ l ( θ t − τ ) \mu^l_{t-\tau}(\theta_{t-\tau}) μtτl(θtτ) ν t − τ l ( θ t − τ ) \nu^l_{t-\tau}(\theta_{t-\tau}) νtτl(θtτ)会依赖之前所有层的网络权重(要算l层就要先算l层之前的所有层)
实际上,我们通过实验发现,当 r < = l r<=l r<=l ∂ μ t l ( θ t ) θ t r \frac{\partial \mu^l_t(\theta_{t})}{ \theta^r_{t}} θtrμtl(θt) ∂ ν t l ( θ t ) θ t r \frac{\partial \nu^l_t(\theta_{t})}{ \theta^r_{t}} θtrνtl(θt) 会减少的很快
在这里插入图片描述
所以,我们在求 ∂ μ t l ( θ t ) θ t r \frac{\partial \mu^l_t(\theta_{t})}{ \theta^r_{t}} θtrμtl(θt) ∂ ν t l ( θ t ) θ t r \frac{\partial \nu^l_t(\theta_{t})}{ \theta^r_{t}} θtrνtl(θt)时,我们直接忽略 l 层之前的层对 l 层的影响
最终,上面泰勒公式可以近似为:
μ t − τ l ( θ t ) ≈ μ t − τ ( θ t − τ ) l + ∂ μ t − τ l ( θ t − τ ) ∂ θ t − τ l ( θ t l − θ t − τ l ) ( 8 ) {\mu^l_{t-\tau}(\theta_t) \approx \mu^l_{t-\tau(\theta_{t-\tau})}+ \frac{\partial \mu^l_{t-\tau}(\theta_{t-\tau})}{\partial \theta^l_{t-\tau}} (\theta^l_t-\theta^l_{t-\tau}) \qquad (8)} μtτl(θt)μtτ(θtτ)l+θtτlμtτl(θtτ)(θtlθtτl)(8)
ν t − τ l ( θ t ) ≈ ν t − τ l ( θ t − τ ) + ∂ ν t − τ l ( θ t − τ ) ∂ θ t − τ l ( θ t l − θ t − τ l ) ( 9 ) { \nu^l_{t-\tau}(\theta_t) \approx \nu^l_{t-\tau}(\theta_{t-\tau})+\frac{\partial \nu^l_{t-\tau}(\theta_{t-\tau})}{\partial \theta^l_{t-\tau}} (\theta^l_t-\theta^l_{t-\tau}) \qquad (9) } νtτl(θt)νtτl(θtτ)+θtτlνtτl(θtτ)(θtlθtτl)(9)

2.3、CBN细节

Cross-Iteration Batch Normalization细节:

上面利用之前的参数估计出当前参数下 l 层在 ( t − τ ) (t-\tau) (tτ)次迭代的参数值,利用这些估计值可以计算出当前迭代时的BN参数( μ \mu μ ν \nu ν):
μ ˉ t , k l ( θ t ) = 1 k ∑ τ = 0 k − 1 μ t − τ l ( θ t ) ( 10 ) {\bar{\mu}^l_{t,k} (\theta_t) = \frac{1}{k}\sum_{\tau=0}^{k-1}\mu^l_{t-\tau}(\theta_t) } \qquad (10) μˉt,kl(θt)=k1τ=0k1μtτl(θt)(10)
ν ˉ t , k l ( θ t ) = 1 k ∑ τ = 0 k − 1 m a x [ ν t − τ l ( θ t ) , μ t − τ l ( θ t ) 2 ] ( 11 ) {\bar \nu^l_{t,k}(\theta_t) = \frac{1}{k}\sum_{\tau=0}^{k-1}max[\nu^l_{t-\tau}(\theta_t), \mu^l_{t-\tau}(\theta_t)^2]}\qquad (11) νˉt,kl(θt)=k1τ=0k1max[νtτl(θt),μtτl(θt)2](11)
σ ˉ t , k l ( θ t ) = ν ˉ t , k l ( θ t ) − μ ˉ t , k l ( θ t ) 2 ( 12 ) \bar\sigma^l_{t,k}(\theta_t)= \sqrt{\bar\nu^l_{t,k}(\theta_t)-\bar\mu^l_{t,k}(\theta_t)^2} \qquad (12) σˉt,kl(θt)=νˉt,kl(θt)μˉt,kl(θt)2 (12)

其中式10:计算 i t e r a t i o n [ t − τ , t ] iteration[t-\tau,t] iteration[tτt]轮迭代均值的平均;
式11:在有效统计中 ν t − τ l ( θ t ) ≥ μ t − τ l ( θ t ) 2 \nu^l_{t-\tau}(\theta_t) \geq \mu^l_{t-\tau}(\theta_t)^2 νtτl(θt)μtτl(θt)2是一直满足的,但是利用泰勒展开式估算就不一定满足了,不过在代码中是默认过滤掉不满足的情况的,论文中称这样可以获取信息更有意义。
最后,CBN更新featute map方法同CN:
x ^ t , i l ( θ t ) = x t , i l ( θ t ) − u ˉ t , k l ( θ t ) σ ˉ t , k l ( θ t ) 2 + ϵ ( 13 ) \hat{x}^l_{t,i}(\theta_t)=\frac{x^l_{t,i}(\theta_t)-\bar{u}^l_{t,k}(\theta_t)}{\sqrt{\bar{\sigma}^l_{t,k}(\theta_t)^2 + \epsilon}} \qquad (13) x^t,il(θt)=σˉt,kl(θt)2+ϵ xt,il(θt)uˉt,kl(θt)(13)
在这里插入图片描述

同时作者指出CBN操作不会引入比较大的内存开销,训练速度不会影响很多,会慢一点点。

2.4、代码实现

class CBatchNorm2d(nn.Module):def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,track_running_stats=True,buffer_num=0, rho=1.0,burnin=0, two_stage=True,FROZEN=False, out_p=False):super(CBatchNorm2d, self).__init__()self.num_features = num_featuresself.eps = epsself.momentum = momentumself.affine = affineself.track_running_stats = track_running_statsself.buffer_num = buffer_numself.max_buffer_num = buffer_numself.rho = rhoself.burnin = burninself.two_stage = two_stageself.FROZEN = FROZENself.out_p = out_pself.iter_count = 0self.pre_mu = []self.pre_meanx2 = []  # mean(x^2)self.pre_dmudw = []self.pre_dmeanx2dw = []self.pre_weight = []self.ones = torch.ones(self.num_features)if self.affine:self.weight = Parameter(torch.Tensor(num_features))self.bias = Parameter(torch.Tensor(num_features))else:self.register_parameter('weight', None)self.register_parameter('bias', None)if self.track_running_stats:self.register_buffer('running_mean', torch.zeros(num_features))self.register_buffer('running_var', torch.ones(num_features))else:self.register_parameter('running_mean', None)self.register_parameter('running_var', None)self.reset_parameters()def reset_parameters(self):if self.track_running_stats:self.running_mean.zero_()self.running_var.fill_(1)if self.affine:self.weight.data.uniform_()self.bias.data.zero_()def _check_input_dim(self, input):if input.dim() != 4:raise ValueError('expected 4D input (got {}D input)'.format(input.dim()))def _update_buffer_num(self):if self.two_stage:if self.iter_count > self.burnin:self.buffer_num = self.max_buffer_numelse:self.buffer_num = 0else:self.buffer_num = int(self.max_buffer_num * min(self.iter_count / self.burnin, 1.0))def forward(self, input, weight):# deal with wight and grad of self.pre_dxdw!self._check_input_dim(input)y = input.transpose(0, 1)return_shape = y.shapey = y.contiguous().view(input.size(1), -1)# burninif self.training and self.burnin > 0:self.iter_count += 1self._update_buffer_num()if self.buffer_num > 0 and self.training and input.requires_grad:  # some layers are frozen!# cal current batch mu and sigmacur_mu = y.mean(dim=1)cur_meanx2 = torch.pow(y, 2).mean(dim=1)cur_sigma2 = y.var(dim=1)# cal dmu/dw dsigma2/dwdmudw = torch.autograd.grad(cur_mu, weight, self.ones, retain_graph=True)[0]dmeanx2dw = torch.autograd.grad(cur_meanx2, weight, self.ones, retain_graph=True)[0]# update cur_mu and cur_sigma2 with presmu_all = torch.stack([cur_mu, ] + [tmp_mu + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) fortmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)])meanx2_all = torch.stack([cur_meanx2, ] + [tmp_meanx2 + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) fortmp_meanx2, tmp_d, tmp_w inzip(self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)])sigma2_all = meanx2_all - torch.pow(mu_all, 2)# with considering countre_mu_all = mu_all.clone()re_meanx2_all = meanx2_all.clone()re_mu_all[sigma2_all < 0] = 0re_meanx2_all[sigma2_all < 0] = 0count = (sigma2_all >= 0).sum(dim=0).float()mu = re_mu_all.sum(dim=0) / countsigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)self.pre_mu = [cur_mu.detach(), ] + self.pre_mu[:(self.buffer_num - 1)]self.pre_meanx2 = [cur_meanx2.detach(), ] + self.pre_meanx2[:(self.buffer_num - 1)]self.pre_dmudw = [dmudw.detach(), ] + self.pre_dmudw[:(self.buffer_num - 1)]self.pre_dmeanx2dw = [dmeanx2dw.detach(), ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]tmp_weight = torch.zeros_like(weight.data)tmp_weight.copy_(weight.data)self.pre_weight = [tmp_weight.detach(), ] + self.pre_weight[:(self.buffer_num - 1)]else:x = ymu = x.mean(dim=1)cur_mu = musigma2 = x.var(dim=1)cur_sigma2 = sigma2if not self.training or self.FROZEN:y = y - self.running_mean.view(-1, 1)# TODO: outside **0.5?if self.out_p:y = y / (self.running_var.view(-1, 1) + self.eps) ** .5else:y = y / (self.running_var.view(-1, 1) ** .5 + self.eps)else:if self.track_running_stats is True:with torch.no_grad():self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * cur_muself.running_var = (1 - self.momentum) * self.running_var + self.momentum * cur_sigma2y = y - mu.view(-1, 1)# TODO: outside **0.5?if self.out_p:y = y / (sigma2.view(-1, 1) + self.eps) ** .5else:y = y / (sigma2.view(-1, 1) ** .5 + self.eps)y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)return y.view(return_shape).transpose(0, 1)def extra_repr(self):return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \'buffer={max_buffer_num}, burnin={burnin}, ' \'track_running_stats={track_running_stats}'.format(**self.__dict__)

三、CmBN(待完善)

在这里插入图片描述

  1. BN:对当前mini-batch进行归一化
  2. CBN: 对当前以及当前往前数3个mini-batch的结果进行归一化
  3. CmBN: CmBN 在整个批次中使用Cross min-batch Normalization 收集统计数据,而非在单独的mini-batch中收集统计数据

Reference

  1. BN1.
  2. BN2.
  3. BN3.
  4. BN4.
  5. CBN1.
  6. CBN2.

http://chatgpt.dhexx.cn/article/IE7SwaaS.shtml

相关文章

Betaflight BN880 GPS 简单测试

Betaflight BN880 GPS 简单测试 1. 源由2. 窗台对比测试3. 开阔区域测试3.1 GPS安装位置3.1.1 BN880 GPS 机尾打印支架 安装位置3.1.2 BN880 GPS 机头固定 安装位置3.1.3 M8N GPS 机尾打印支架 安装位置 3.2 M8N模块历史记录3.3 BN880模块第一次&#xff08;机尾安装&#xff0…

BN(Batch Normalization):批量归一化

现在的神经网络通常都特别深&#xff0c;在输出层像输入层传播导数的过程中&#xff0c;梯度很容易被激活函数或是权重以指数级的规模缩小或放大&#xff0c;从而产生“梯度消失”或“梯度爆炸”的现象&#xff0c;造成训练速度下降和效果不理想。 随着训练的进行&#xff0c;…

通俗理解BN(Batch Normalization)

1. 深度学习流程简介 1&#xff09;一次性设置&#xff08;One time setup&#xff09; - 激活函数&#xff08;Activation functions&#xff09; ​ - 数据预处理&#xff08;Data Preprocessing&#xff09; ​ - 权重初始化&#xff08;Weight Initialization&#xff0…

为什么BN?batch normalization的原理及特点

1 什么是BN&#xff1f; 数据归一化方法&#xff0c;往往用在深度神经网络中激活层之前。其作用可以加快模型训练时的收敛速度&#xff0c;使得模型训练过程更加稳定&#xff0c;避免梯度爆炸或者梯度消失。并且起到一定的正则化作用&#xff0c;几乎代替了Dropout 2 原理 B…

【深度学习基础知识 - 07】BN的原理和作用

Batch Normalization也是深度学习中的一个高频词汇&#xff0c;这篇文章将会对其做一个简单介绍。 目录 1. BN的原理2. BN的作用3. BN层的可学习参数4. infer时BN的处理5. BN的具体计算步骤以及公式6. BN和L2参数权重正则化的区别 1. BN的原理 BN就是在激活函数接收输入之前对…

什么是BN(Batch Normalization)

什么是BN(Batch Normalization)&#xff1f; 在之前看的深度学习的期刊里&#xff0c;讲到了BN&#xff0c;故对BN做一个详细的了解。在网上查阅了许多资料&#xff0c;终于有一丝明白。 什么是BN&#xff1f; 2015年的论文《Batch Normalization: Accelerating Deep Networ…

深度学习—BN的理解(一)

0、问题 机器学习领域有个很重要的假设&#xff1a;IID独立同分布假设&#xff0c;就是假设训练数据和测试数据是满足相同分布的&#xff0c;这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。那BatchNorm的作用是什么呢&#xff1f;BatchNorm就是在深度神经…

Batch Normalization详解以及pytorch实验

Batch Normalization是google团队在2015年论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出的。通过该方法能够加速网络的收敛并提升准确率。在网上虽然已经有很多相关文章&#xff0c;但基本都是摆上论文中的公式泛…

Java参数校验validation和validator区别

Java参数校验validation和validator区别 1. 参数校验概述2. validation与validator区别3. validation注解说明4. validator注解说明5. 日期格式化说明6. 实现验证6.1 引入依赖6.2 代码实现6.3 实现验证 1. 参数校验概述 常见的业务开发中无可避免的会进行请求参数校验&#xf…

hibernate-validator

validator 简介各种注解好处 validator.validate方法业务逻辑代码中检查传入的参数时为传入的参数类型中各个属性添加注解NotNull、NotBlank、NotEmpty间的区别 简介 validator&#xff0c;翻译过来&#xff0c;就是“验证器”的意思。它是一种注解式参数校验&#xff0c;包名…

validator自定义校验注解及使用

validator自定义校验注解及使用 官方文档&#xff1a;https://docs.jboss.org/hibernate/validator/8.0/reference/en-US/html_single/#validator-customconstraints 用到依赖: <!--validator的依赖如果项目使用的springBoot的依赖可以不用再引入 hibernate-validator 因为…

spring之Validator

初步认识 spring数据验证核心类&#xff1a;①&#xff1a;Validator ②&#xff1a;Errors,两者之间的纽带是Validator中定义的validate方法。 public interface Validator {// 限定Validator的职责&#xff0c;不可能所有的校验全部交给一个Validator来做boolean supports(…

Hibernate-Validator的学习

Hibernate-Validator的学习 此教程基于黑马程序员Java品达通用权限项目&#xff0c;哔哩哔哩链接&#xff1a;https://www.bilibili.com/video/BV1tw411f79E?p49 1.hibernate-validator介绍 早期的网站&#xff0c;用户输入一个邮箱地址&#xff0c;需要将邮箱地址发送到服…

Hibernate Validator源码解析

一、引言 问题&#xff1a;在代码编写的过程中&#xff0c;数据值的校验在JavaEE三层架构&#xff08;展示层、业务层、数据访问层&#xff09;均有涉及&#xff0c;各层的校验需求又是不尽相同的&#xff0c;因此往往会伴随着代码冗余&#xff0c;重复的校验逻辑出现在三层代…

Hibernate Validator简介

亲爱的小伙伴们我来填坑啦&#xff0c;java中优雅的参数校验方法中的校验的实现原理。 1.前言 验证数据是发生在所有应用程序层&#xff08;从表示层到持久层&#xff09;的常见任务。通常在每一层中实现相同的验证逻辑&#xff0c;这既耗时又容易出错。为了避免重复这些验证&…

bootstrapValidator验证最大值最小值范围限制

录入该值的最大值与最小值 bootstrapValidator进行效验&#xff0c;使最小值不可大于最大值&#xff0c;最大值不可小于最小值 刚开始的验证还是没事的&#xff0c;符合正常的验证规则 再把不符合规则的最大值改变&#xff0c;现在最小值已经比最大值小了&#xff0c;但是最大…

class-validator中文教程

官方文档&#xff1a; https://www.npmjs.com/package/class-validator class-validator可以说是一个简化验证的依赖库 &#xff08;采用注释的方式进行校验&#xff09; 但是缺少中文文档和过程&#xff0c;以自己的理解和对官网文档的阅读进行整理输出。 它的好兄弟class-t…

Hibernate Validator 总结大全

背景 代码开发过程中&#xff0c;参数的有效性校验是一项很繁琐的工作&#xff0c; 如果参数简单&#xff0c;就那么几个参数&#xff0c;直接通过ifelse可以搞定&#xff0c;如果参数太多&#xff0c;比如一个大对象有100多个字段作为入参&#xff0c;你如何校验呢&#xff1…

java使用validator进行校验

不管是html页面表单提交的对象数据还是和第三方公司进行接口对接&#xff0c;都需要对接收到的数据进行校验&#xff08;非空、长度、格式等等&#xff09;。如果使用if一个个进行校验&#xff08;字段非常多&#xff09;&#xff0c;这是让人崩溃的过程。幸好jdk或hibernate都…

java validator_Spring中校验器(Validator)的深入讲解

前言 Spring框架的 validator 组件,是个辅助组件,在进行数据的完整性和有效性非常有用,通过定义一个某个验证器,即可在其它需要的地方,使用即可,非常通用。 应用在执行业务逻辑之前,必须通过校验保证接受到的输入数据是合法正确的,但很多时候同样的校验出现了多次,在不…