使用Batch Normalization解决VAE训练中的后验坍塌(posterior collapse)问题

article/2025/10/19 15:10:20

前言

在训练VAE模型时,当我们使用过于过于强大的decoder时,尤其是自回归式的decoder比如LSTM时,存在一个非常大的问题就是,decoder倾向于不从latent variable z中学习,而是独立地重构数据,这个时候,同时伴随着的就是KL(p(z|x)‖q(z))倾向于0。先不考虑说,从损失函数的角度来说,这种情况不是模型的全局最优解这个问题(可以将其理解为一个局部最优解)。单单从VAE模型的意义上来说,这种情况也是我们不愿意看到的。VAE模型的最重要的一点就是其通过无监督的方法构建数据的编码向量(即隐变量z)的能力。而如果出现posterior collapse的情况,就意味着后验概率退化为与先验概率一致,即N(0,1)。此时encoder的输出近似于一个常数向量,不再能充分利用输入x的信息。decoder则变为了一个普通的language model,尽管它依然很强。

因此,不管从哪个方面来看,解决它都是必须要面临的课题,事实上,从2016年开始就有很多文章提出了不同的解决方案,这里重点介绍一下使用Batch Normalization来解决这个问题的思路。这篇文章全名A Batch Normalized Inference Network Keeps the KL Vanishing Away发表于2020年,还算是一篇比较新的文章。下面我们开始。

方法介绍

Expectation of the KL’s Distribution

首先基于隐变量空间为高维高斯分布的假设,对于一个mini-batch的数据来说,我们可以计算KL divergence的表达式如下:
在这里插入图片描述
其中b代表的是mini-batch的样本个数,n代表的隐变量z的维度。同时作者还假设对于每个不同的维度,其都遵循某个特定的分布,各个维度可以不同。

假设我们认为上述的样本均值可以近似等于总体期望,那么我们可以将上述的样本均值用期望来代替,又因为我们有如下基本等式
在这里插入图片描述
最终我们可以得到KL divergence的期望表达式如下。
在这里插入图片描述
上述不等式是因为e^x-x>=1恒成立。那么这么一来,我们就有关于KL divergence的一个lower bound。这个lower bound只与隐变量的维度n和μi的分布有关。

Normalizing Parameters of the Posterior

接下来,我们要考虑的问题就是如何来构建每个μi的分布,使其保证这个lower bound的值恒为正,也就间接保证了KL divergence不会变为0。这里用到的方法就是Batch Normalization。

我们熟知的Batch Normalization往往用在神经网络模型中,通过控制每个隐藏层的数据的分布使得训练更加平稳。

但是在这里我们使用它来转换μi的分布,将其控制在一个合理的范围内,从而保证lower bound的值为正。具体如下
在这里插入图片描述
其中μBi 和 σBi 分别表示通过mini-batch计算的 μi的均值和标准差。γ 和 β分别是scale和shift参数。通过合理地控制这两个参数,我们可以将lower bound近似地转换为如下式子。
在这里插入图片描述
下面是完整的算法流程。
在这里插入图片描述
在原文中还有涉及到对参数设置的进一步拓展,大家可以参考苏剑林老师的这篇博客

Torch 实现

在苏剑林老师的博客中,他用keras实现了文章中的关键内容,在这里,我用torch实现了一下,供大家参考。

import torch
import torch.nn as nn# reference paper:https://arxiv.org/abs/2004.12585
class BN_Layer(nn.Module):def __init__(self,dim_z,tau,mu=True):super(BN_Layer,self).__init__()self.dim_z=dim_zself.tau=torch.tensor(tau) # tau : float in range (0,1)self.theta=torch.tensor(0.5,requires_grad=True)self.gamma1=torch.sqrt(self.tau+(1-self.tau)*torch.sigmoid(self.theta)) # for muself.gamma2=torch.sqrt((1-self.tau)*torch.sigmoid((-1)*self.theta)) # for varself.bn=nn.BatchNorm1d(dim_z)self.bn.bias.requires_grad=Falseself.bn.weight.requires_grad=Trueif mu:with torch.no_grad():self.bn.weight.fill_(self.gamma1)else:with torch.no_grad():self.bn.weight.fill_(self.gamma2)def forward(self,x): # x:(batch_size,dim_z)x=self.bn(x)return x

参考

A Batch Normalized Inference Network Keeps the KL Vanishing Away
变分自编码器(五):VAE + BN = 更好的VAE


http://chatgpt.dhexx.cn/article/W5sStXTR.shtml

相关文章

【论文模型讲解】Learning to Select Knowledge for Response Generation in Dialog Systems(PostKS模型)

文章目录 前言背景Posterior Knowledge Selection 模型(PostKS)1. 对话编码器&知识编码器(Utterance Encoder&Knowledge Encoder)2. 知识管理器(Knowledge Manager)3. 解码器4. 损失函数 Q & A1. 先验知识模块相关问题 前言 论文网址&#…

引导方法深度补全系列—晚期融合模型—1—《Dense depth posterior (ddp) from single image and sparse range》文章细读

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 创新点 实施细节 对比sparse_RGBD tips 方法详解 损失函数 优缺点 总结 创新点 1.提出了基于贝叶斯理论的两步法网络做深度补全 文章概述 提出了两步法,实际…

扩散模型(Diffusion model)代码详细解读

扩散模型代码详细解读 代码地址:denoising-diffusion-pytorch/denoising_diffusion_pytorch.py at main lucidrains/denoising-diffusion-pytorch (github.com) 前向过程和后向过程的代码都在GaussianDiffusion​这个类中。​ 有问题可以一起讨论! …

DIFFUSION POSTERIOR SAMPLING FOR GENERALNOISY INVERSE PROBLEMS (Paper reading)

DIFFUSION POSTERIOR SAMPLING FOR GENERALNOISY INVERSE PROBLEMS Hyungjin Chung, Kim Jae Chul Graduate School of AI, ICLR 2023 spotlight, Cited:10, Code, Paper. 目录子 DIFFUSION POSTERIOR SAMPLING FOR GENERALNOISY INVERSE PROBLEMS1. 前言2. 整体思想3. 方法实…

Exploring and Distilling Posterior and Prior Knowledge for Radiology Report Generation

Exploring and Distilling Posterior and Prior Knowledge for Radiology Report Generation(探索和提炼后验和先验知识的放射学报告生成) 先验与后验目前的放射学报告生成的局限性Paper的贡献模型详解模型输入模型主要部分 先验与后验 在阅读这篇Paper…

国际会议poster: 海报制作流程 格式介绍

1 流程 word制作, 转pdf, 打印 2 模板 UCLhttps://wenku.baidu.com/view/034bcb7e4a7302768f99392a.html 3 CYBER2019格式要求 海报尺寸:A1尺寸23.4英寸(59.4厘米)宽,33.1英寸(84.1厘米)高。 请注意,以A4尺寸列印已递交的整张纸作为海报是不可…

概率基础 · 联合概率 边缘概率 prior posterior likelihood

概率基础 联合概率 边缘概率 prior posterior likelihood 联合概率 (Joint Probability)边缘概率(margin probability)贝叶斯定理(Bayes Theorem)prior,posterior,likelihood:概率与似然的区别…

Stochastic Image Denoising By Sampling from the Posterior Distribution (Paper reading)

Stochastic Image Denoising By Sampling from the Posterior Distribution (Paper reading) Bahjat Kawar, Haifa(Israel), ICCV Workshop2021, Cited:22, Code:无, Paper. 目录子 Stochastic Image Denoising By Sampling from the Posterior Distribution (Paper reading)1…

GAN论文精读 P2GAN: Posterior Promoted GAN 用鉴别器产生的后验分布来提升生成器

《Posterior Promoted GAN with Distribution Discriminator for Unsupervised Image Synthesis》是大连理工学者发表的文章,被2021年CVPR收录。 文章地址:https://ieeexplore.ieee.org/document/9578672 本篇文章是阅读这篇论文的精读笔记。 一、原文…

先验、后验与似然

在学习SLAM 14讲第六章时,看到三个概念,有些不太了解,查阅资料后有了一些自己的理解。 三个概念存在于贝叶斯公式中 表示先验概率Prior,表示后验概率posterior,表示似然likelihood 上式可以写为 下面分别对三个概念进…

Prior 、Posterior 和 Likelihood 的理解与几种表达方式

Prior 、Posterior 和 Likelihood 的理解与几种表达方式 (下载图片可以看大图。)

Windows作为NTP同步时间的服务器时的设置

1.先关闭Windows系统自带的防火墙; 2. 在桌面上右击“计算机”, 选择“管理”, 然后选择“服务”。 具体如图所示 2. 选中“Windows Time”,设置为开启,这样就可以将“Windows Time”这一个服务打开。 3. “开始”--》“运…

NTP时钟服务器推荐-国内时间服务器顶尖设备

电子钟时间服务器在物联网应用中起到了关键的作用,它能够为各种智能设备提供准确的时间参考,确保设备之间的协同工作和数据的准确传输。无论是智能家居、智能工厂还是智慧城市,电子钟时间服务器都是不可或缺的一部分。 一、产品卖点 时间服…

NTP同步时间失败。Linux作为客户端,Windows作为NTP时钟源服务端。

使用windows作NTP时钟源,NTP同步时间失败 【关 键 词】:NTP,时钟源,windows时钟源,同步时间失败 【故障类型】:操作维护->其他 【适用版本】:Linux 【问题描述】:windows做时钟…

如何在windows10 搭建 NTP 时间服务器

windows本身是可以作为NTP时间同步服务器使用的,本文介绍一下如何在win10上配置NTP时间同步服务器。 如何在windows10 搭建 NTP 时间服务器 工具/原料 系统版本:win10版本 [10.0.17134.706] 方法/步骤 使用组合键WIN R 启动运行窗口,在…

Linux服务器NTP客户端时钟同步配置方法

前提说明:配置客户端NTP时候,必须要有一台时钟服务器,可以是服务器搭建的,也可以是购买的时钟设备。我这里使用临时的时钟服务器IP地址10.10.4.100 步骤如下: 1 首先在客户端服务器中ping一下时钟的IP地址是否网络可通…

NTP时钟服务器(PTP服务器)无法同步的排查方法

NTP时钟服务器(PTP服务器)无法同步的排查方法 NTP时钟服务器(PTP服务器)无法同步的排查方法 NTP系统是典型的C-S模型,一般将整个系统分为服务器,网络和客户端三个区域,因NTP时间服务器一般在出厂时已经测试,并设置为可使用&#…

NTP时间服务器同步时钟系统安装汇总分享

在现代科技发展的背景下,各种设备的时间同步变得越来越重要。同步时钟管理系统的应用可以让多个设备在时间上保持一致,提高工作效率和安全性,为各个行业的发展提供了重要的支持。 一、同步时钟系统介绍 同步时钟管理系统的应用范围非常广泛&…

关于NTP时间服务器

NTP(Network Time Protocol) 网络时间协议,工作在UDP的123端口上。是用来使计算机时间同步化的一种协议,它可以使计算机对其服务器或时钟源(如石英钟,GPS等等)做同步化,它可以提供高精准度的时间校正(局域网…

R语言产生对角阵、次对角阵等矩阵及矩阵运算

R语言产生各种类型的矩阵及矩阵运算 R语言产生一般的矩阵R语言产生单位阵R语言产生次对角阵R语言矩阵的常见运算 R语言产生一般的矩阵 # 依行排列,产生3行5列的矩阵 A matrix(c(1:15),3,5,byrowT)R语言产生单位阵 #产生对角线元素为1的6x6的单位阵 A diag(6) #产…