Batch Normalization详解以及pytorch实验

article/2025/10/6 20:18:35

Batch Normalization是google团队在2015年论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出的。通过该方法能够加速网络的收敛并提升准确率。在网上虽然已经有很多相关文章,但基本都是摆上论文中的公式泛泛而谈,bn真正是如何运作的很少有提及。本文主要分为以下几个部分:

(1)BN的原理

(2)使用pytorch验证本文的观点

(3)使用BN需要注意的地方(BN没用好就是个坑)

1.Batch Normalization原理

我们在图像预处理过程中通常会对图像进行标准化处理,这样能够加速网络的收敛,如下图所示,对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言输入的feature map就不一定满足某一分布规律了(注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律)。而我们Batch Normalization的目的就是使我们的feature map满足均值为0,方差为1的分布规律。

看到这里应该还是蒙的,不要慌,喝口水,慢慢来。下面是从原论文中截取的原话,注意标黄的部分:

“对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。”  假设我们输入的x是RGB三通道的彩色图像,那么这里的d就是输入图像的channels即d=3,x=(x^{(1)}, x^{(2)}, x^{(3)}),其中x^{(1)}就代表我们的R通道所对应的特征矩阵,依此类推。标准化处理也就是分别对我们的R通道,G通道,B通道进行处理。上面的公式不用看,原文提供了更加详细的计算公式:

我们刚刚有说让feature map满足某一分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说要计算出整个训练集的feature map然后在进行标准化处理,对于一个大型的数据集明显是不可能的,所以论文中说的是Batch Normalization,也就是我们计算一个Batch数据的feature map然后在进行标准化(batch越大越接近整个数据集的分布,效果越好)。我们根据上图的公式可以知道\mu _{\ss }代表着我们计算的feature map每个维度(channel)的均值,注意\mu _{\ss }是一个向量不是一个值\mu _{\ss }向量的每一个元素代表着一个维度(channel)的均值。\sigma_{\ss }^{2}代表着我们计算的feature map每个维度(channel)的方差,注意\sigma_{\ss }^{2}是一个向量不是一个值\sigma_{\ss }^{2}向量的每一个元素代表着一个维度(channel)的方差,然后根据\mu _{\ss }\sigma_{\ss }^{2}计算标准化处理后得到的值。下图给出了一个计算均值\mu _{\ss }和方差\sigma_{\ss }^{2}的示例:

上图展示了一个batch size为2(两张图片)的Batch Normalization的计算过程,假设feature1、feature2分别是由image1、image2经过一系列卷积池化后得到的特征矩阵,feature的channel为2,那么x^{(1)}代表该batch的所有feature的channel1的数据,同理x^{^{(2)}}代表该batch的所有feature的channel2的数据。然后分别计算x^{(1)}x^{^{(2)}}的均值与方差,得到我们的\mu _{\ss }\sigma_{\ss }^{2}两个向量。然后在根据标准差计算公式分别计算每个channel的值(公式中的\epsilon是一个很小的常量,防止分母为零的情况)。在我们训练网络的过程中,我们是通过一个batch一个batch的数据进行训练的,但是我们在预测过程中通常都是输入一张图片进行预测,此时batch size为1,如果在通过上述方法计算均值和方差就没有意义了。所以我们在训练过程中要去不断的计算每个batch的均值和方差,并使用移动平均(moving average)的方法记录统计的均值和方差,在训练完后我们可以近似认为所统计的均值和方差就等于整个训练集的均值和方差。然后在我们验证以及预测过程中,就使用统计得到的均值和方差进行标准化处理

细心的同学会发现,在原论文公式中不是还有\gamma\beta两个参数吗?是的,\gamma是用来调整数值分布的方差大小,\beta是用来调节数值均值的位置。这两个参数是在反向传播过程中学习得到的,\gamma的默认值是1,\beta的默认值是0。

2.使用pytorch进行试验

你以为你都懂了?不一定哦。刚刚说了在我们训练过程中,均值\mu _{\ss }和方差\sigma_{\ss }^{2}是通过计算当前批次数据得到的记为为\mu _{now}\sigma _{now}^{2},而我们的验证以及预测过程中所使用的均值方差是一个统计量记为\mu _{statistic}\sigma _{statistic}^{2}\mu _{statistic}\sigma _{statistic}^{2}的具体更新策略如下,其中momentum默认取0.1:

\large \mu _{statistic+1}=(1-momentum)*\mu _{statistic}+momentum*\mu _{now}

\large \sigma _{statistic+1}^{2}=(1-momentum)*\sigma _{statistic}^{2}+momentum*\sigma _{now}^{2}

这里要注意一下,在pytorch中对当前批次feature进行bn处理时所使用的\large \sigma _{now}^{2}总体标准差,计算公式如下:

\bg_white \large \sigma _{now}^{2}=\frac{1}{m}\sum_{i=1}^{m}(x_{i}-\mu _{now})^{2}

在更新统计量\large \sigma _{statistic}^{2}时采用的\large \sigma _{now}^{2}样本标准差,计算公式如下:

\bg_white \large \sigma _{now}^{2}=\frac{1}{m-1}\sum_{i=1}^{m}(x_{i}-\mu _{now})^{2}

下面是我使用pytorch做的测试,代码如下:

(1)bn_process函数是自定义的bn处理方法验证是否和使用官方bn处理方法结果一致。在bn_process中计算输入batch数据的每个维度(这里的维度是channel维度)的均值和标准差(标准差等于方差开平方),然后通过计算得到的均值和总体标准差对feature每个维度进行标准化,然后使用均值和样本标准差更新统计均值和标准差。

(2)初始化统计均值是一个元素为0的向量,元素个数等于channel深度;初始化统计方差是一个元素为1的向量,元素个数等于channel深度,初始化\gamma=1,\beta=0。

import numpy as np
import torch.nn as nn
import torchdef bn_process(feature, mean, var):feature_shape = feature.shapefor i in range(feature_shape[1]):# [batch, channel, height, width]feature_t = feature[:, i, :, :]mean_t = feature_t.mean()# 总体标准差std_t1 = feature_t.std()# 样本标准差std_t2 = feature_t.std(ddof=1)# bn process# 这里记得加上eps和pytorch保持一致feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2 + 1e-5)# update calculating mean and varmean[i] = mean[i] * 0.9 + mean_t * 0.1var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1print(feature)# 随机生成一个batch为2,channel为2,height=width=2的特征向量
# [batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
# 初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
# print(feature1.numpy())# 注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)bn = nn.BatchNorm2d(2, eps=1e-5)
output = bn(feature1)
print(output)

首先我在最后设置了一个断点进行调试,查看下官方bn对feature处理后得到的统计均值和方差。我们可以发现官方提供的bn的running_mean和running_var和我们自己计算的calculate_mean和calculate_var是一模一样的(只是精度不同)。

然后我们打印出通过自定义bn_process函数得到的输出以及使用官方bn处理得到输出,明显结果是一样的(只是精度不同):

3.使用BN时需要注意的问题

(1)训练时要将traning参数设置为True,在验证时将trainning参数设置为False。在pytorch中可通过创建模型的model.train()和model.eval()方法控制。

(2)batch size尽可能设置大点,设置小后表现可能很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。

(3)建议将bn层放在卷积层(Conv)和激活层(例如Relu)之间,且卷积层不要使用偏置bias,因为没有用,参考下图推理,即使使用了偏置bias求出的结果也是一样的\bg_white \large y_{i}^{b}=y_{i}

最后给出李宏毅老师关于batch normalization的视频讲解:

李宏毅深度学习(2017)_哔哩哔哩_bilibili


http://chatgpt.dhexx.cn/article/Z4wGDZk3.shtml

相关文章

Java参数校验validation和validator区别

Java参数校验validation和validator区别 1. 参数校验概述2. validation与validator区别3. validation注解说明4. validator注解说明5. 日期格式化说明6. 实现验证6.1 引入依赖6.2 代码实现6.3 实现验证 1. 参数校验概述 常见的业务开发中无可避免的会进行请求参数校验&#xf…

hibernate-validator

validator 简介各种注解好处 validator.validate方法业务逻辑代码中检查传入的参数时为传入的参数类型中各个属性添加注解NotNull、NotBlank、NotEmpty间的区别 简介 validator,翻译过来,就是“验证器”的意思。它是一种注解式参数校验,包名…

validator自定义校验注解及使用

validator自定义校验注解及使用 官方文档&#xff1a;https://docs.jboss.org/hibernate/validator/8.0/reference/en-US/html_single/#validator-customconstraints 用到依赖: <!--validator的依赖如果项目使用的springBoot的依赖可以不用再引入 hibernate-validator 因为…

spring之Validator

初步认识 spring数据验证核心类&#xff1a;①&#xff1a;Validator ②&#xff1a;Errors,两者之间的纽带是Validator中定义的validate方法。 public interface Validator {// 限定Validator的职责&#xff0c;不可能所有的校验全部交给一个Validator来做boolean supports(…

Hibernate-Validator的学习

Hibernate-Validator的学习 此教程基于黑马程序员Java品达通用权限项目&#xff0c;哔哩哔哩链接&#xff1a;https://www.bilibili.com/video/BV1tw411f79E?p49 1.hibernate-validator介绍 早期的网站&#xff0c;用户输入一个邮箱地址&#xff0c;需要将邮箱地址发送到服…

Hibernate Validator源码解析

一、引言 问题&#xff1a;在代码编写的过程中&#xff0c;数据值的校验在JavaEE三层架构&#xff08;展示层、业务层、数据访问层&#xff09;均有涉及&#xff0c;各层的校验需求又是不尽相同的&#xff0c;因此往往会伴随着代码冗余&#xff0c;重复的校验逻辑出现在三层代…

Hibernate Validator简介

亲爱的小伙伴们我来填坑啦&#xff0c;java中优雅的参数校验方法中的校验的实现原理。 1.前言 验证数据是发生在所有应用程序层&#xff08;从表示层到持久层&#xff09;的常见任务。通常在每一层中实现相同的验证逻辑&#xff0c;这既耗时又容易出错。为了避免重复这些验证&…

bootstrapValidator验证最大值最小值范围限制

录入该值的最大值与最小值 bootstrapValidator进行效验&#xff0c;使最小值不可大于最大值&#xff0c;最大值不可小于最小值 刚开始的验证还是没事的&#xff0c;符合正常的验证规则 再把不符合规则的最大值改变&#xff0c;现在最小值已经比最大值小了&#xff0c;但是最大…

class-validator中文教程

官方文档&#xff1a; https://www.npmjs.com/package/class-validator class-validator可以说是一个简化验证的依赖库 &#xff08;采用注释的方式进行校验&#xff09; 但是缺少中文文档和过程&#xff0c;以自己的理解和对官网文档的阅读进行整理输出。 它的好兄弟class-t…

Hibernate Validator 总结大全

背景 代码开发过程中&#xff0c;参数的有效性校验是一项很繁琐的工作&#xff0c; 如果参数简单&#xff0c;就那么几个参数&#xff0c;直接通过ifelse可以搞定&#xff0c;如果参数太多&#xff0c;比如一个大对象有100多个字段作为入参&#xff0c;你如何校验呢&#xff1…

java使用validator进行校验

不管是html页面表单提交的对象数据还是和第三方公司进行接口对接&#xff0c;都需要对接收到的数据进行校验&#xff08;非空、长度、格式等等&#xff09;。如果使用if一个个进行校验&#xff08;字段非常多&#xff09;&#xff0c;这是让人崩溃的过程。幸好jdk或hibernate都…

java validator_Spring中校验器(Validator)的深入讲解

前言 Spring框架的 validator 组件,是个辅助组件,在进行数据的完整性和有效性非常有用,通过定义一个某个验证器,即可在其它需要的地方,使用即可,非常通用。 应用在执行业务逻辑之前,必须通过校验保证接受到的输入数据是合法正确的,但很多时候同样的校验出现了多次,在不…

springboot使用hibernate validator校验

目录 一、参数校验二、hibernate validator校验demo三、hibernate的校验模式 1、普通模式&#xff08;默认是这个模式&#xff09;2、快速失败返回模式四、hibernate的两种校验 1、请求参数校验2、GET参数校验(RequestParam参数校验)3、model校验4、对象级联校验5、分组校验五…

Validator 使用总结

介绍 首先说下大家常用的hibernate-validator&#xff0c;它是对JSR-303/JSR-349标准的实现&#xff0c;然后spring为了给开发者提供便捷集成了hibernate-validator&#xff0c;默认在springmvc模块。 依赖 本文所介绍皆在springboot应用的基础上&#xff0c;首先加上web模块…

浅谈 Android Tombstone(墓碑日志)分析步骤

最近项目产品刚刚出货&#xff0c;客户退机、死机事件频发。日常解决bug中&#xff0c;少不了和墓碑日志打交道&#xff0c;截止今天之前&#xff0c;见到墓碑日志都是一脸懵逼&#xff0c;不知道怎么分析。最近又有了两个日志&#xff0c;硬着头皮看吧。之所以称之为浅谈&…

Android tombstone文件是如何生成的

本节内容我们聚焦到androidQ上&#xff0c;分析android中一个用于debug的功能&#xff0c;那就是tombstone&#xff0c;俗称“墓碑”。现实生活中墓碑一般是给死人准备的&#xff0c;而在android系统中“墓碑”则是给进程准备的。 为何Android要设计出这样一个东西呢&#xff…

【Android NDK 开发】NDK C/C++ 代码崩溃调试 - Tombstone 报错信息日志文件分析 ( 获取 tombstone_0X 崩溃日志信息 )

文章目录 一、崩溃信息描述二、手机命令行操作三、电脑命令行操作四、Tombstone 内容 Tombstone 报错信息日志文件被保存在了 /data/tombstones/ 目录下 , 先 ROOT 再说 , 没有 ROOT 权限无法访问该目录中的信息 ; 使用 Pixel 2 手机进行调试 , 其它 ROOT 后的手机也可以使用 …

Android tombstone 分析案例

Android tombstone 分析案例 tombstone文件内容1. 体系结构2. 发生Crash线程3. 原因4. 寄存器状态4.1 处理器工作模式下的寄存器4.2 未分组寄存器r0 – r74.3 分组寄存器r8 – r144.4 程序计数器pc(r15)4.5 程序状态寄存器4.6 ARM参数规则 5. 回溯栈6. 程序栈7. 寄存器地址附近…

RocksDB Tombstone 详解

目录 为什么会有墓碑&#xff1f; 使用场景 原理 描述 分段 查询 优化点 总结 为什么会有墓碑&#xff1f; 我们知道 TP 数据库一般选择 KV 引擎作为存储引擎&#xff0c;数据库的元数据和数据通过一定的编码规则变成 KV 对存储在存储引擎中&#xff0c;比如 CockroachD…

Tombstone 文件分析

Tombstone 文件分析 /* * 下面信息是dropbox负责添加的 **/ isPrevious: true Build: Rock/odin/odin:7.1.1/NMF26F/1500868195:user/dev-keys Hardware: msm8953 Revision: 0 Bootloader: unknown Radio: unknown Kernel: Linux version 3.18.31-perf-g34cb3d1 (smartcmhardc…