【深度学习2】基于Pytorch的WGAN理论和代码解析

article/2025/10/6 14:55:40

目录

1 原始GAN存在问题

2 WGAN原理

3 代码理解

GitHub源码


参考文章:令人拍案叫绝的Wasserstein GAN - 知乎 (zhihu.com)

1 原始GAN存在问题

实际训练中,GAN存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。这与GAN的机制有关。GAN最终达到对抗的纳什均衡只是一个理想状态,而现实情况中得到的结果都是中间状态(伪平衡)。

大部分的情况是,随着训练的次数越多判别器D的效果越好,会导致一直可以将生成器G的输出与真实样本区分开。这是因为生成器G是从低维空间向高维空间(复杂的样本空间)映射,其生成的样本分布空间Pg难以充满整个真实样本的分布空间Pr。即两个分布完全没有重叠的部分,或者它们重叠的部分可以忽略,这样就使得判别器D总会将它们分开。

在原始GAN的训练中,判别器训练得太好,会使生成器梯度消失,生成器loss降不下去;判别器训练得不好,会使生成器梯度不准,四处乱跑。只有判别器训练到中间状态最佳,但是这个尺度很难把握,没有一个收敛判断的依据。甚至在同一轮训练的前后不同阶段,这个状态出现的时段都不一样,是个完全不可控的情况。

引入Kullback–Leibler divergence(简称KL散度)和Jensen-Shannon divergence(简称JS散度)这两个重要的相似度衡量指标,后面的主角之一Wasserstein距离,就是要来吊打它们两个的。所以接下来介绍这两个重要的配角——KL散度和JS散度:

根据原始GAN定义的判别器loss,我们可以得到最优判别器的形式;而在最优判别器的下,我们可以把原始GAN定义的生成器loss等价变换为最小化真实分布与生成分布之间的JS散度。我们越训练判别器,它就越接近最优,最小化生成器的loss也就会越近似于最小化和之间的JS散度。

第一种loss形式无论跟是远在天边,还是近在眼前,只要它们俩没有一点重叠或者重叠部分可忽略,JS散度就固定是常数,而这对于梯度下降方法意味着——梯度为0!此时对于最优判别器来说,生成器肯定是得不到一丁点梯度信息的;即使对于接近最优的判别器来说,生成器也有很大机会面临梯度消失的问题。

而第二种loss形式存在两个严重的问题。第一是它同时要最小化生成分布与真实分布的KL散度,却又要最大化两者的JS散度,一个要拉近,一个却要推远!这在直观上非常荒谬,在数值上则会导致梯度不稳定,这是后面那个JS散度项的毛病。

第一部分小结:在原始GAN的(近似)最优判别器下,第一种生成器loss面临梯度消失问题,第二种生成器loss面临优化目标荒谬、梯度不稳定、对多样性与准确性惩罚不平衡导致mode collapse这几个问题。

原始GAN问题的根源可以归结为两点,一是等价优化的距离衡量(KL散度、JS散度)不合理,二是生成器随机初始化后的生成分布很难与真实分布有不可忽略的重叠。

2 WGAN原理

WGan(Wasserstein Gan),Wasserstein是指Wasserstein距离,又叫Earth-Mover(EM)推土机距离,定义如下:

Wasserstein距离相比KL散度、JS散度的优越性在于,即便两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。WGAN本作通过简单的例子展示了这一点。考虑如下二维空间中的两个分布和,在线段AB上均匀分布,在线段CD上均匀分布,通过控制参数可以控制着两个分布的距离远近。

WGan的思想是将生成的模拟样本分布Pg与原始样本分布Pr组合起来,当成所有可能的联合分布的集合。然后可以从中采样得到真实样本与模拟样本,并能够计算二者的距离,还可以算出距离的期望值。

KL散度和JS散度是突变的,要么最大要么最小,Wasserstein距离却是平滑的,如果我们要用梯度下降法优化这个参数,前两者根本提供不了梯度,Wasserstein距离却可以。类似地,在高维空间中如果两个分布不重叠或者重叠部分可忽略,则KL和JS既反映不了远近,也提供不了梯度,但是Wasserstein却可以提供有意义的梯度

使用W-GAN网络进行图像生成时,网络将整个图像视为一种属性,其目的就是学习图像整个属性的数据分布,因而将生成图像分布Pg拟合为真实图像分布Pr是合理可行的。若期望的生成分布Pg不是当前的真实图像分布Pr,那么网络具体的收敛方向将会不可控,会出现训练失败的情况。

这样就可以通过训练,让网络在所有可能的联合分布中对这个期望值取下界的方向优化,也就是将两个分布的集合拉到一起。这样原来的判别式就不再是判别真伪的功能了,而是计算两个分布集合距离的功能。所以将其称为评论器更加合适,同样,最后一层的sigmoid也需要去掉了。

核心思想:原始GAN的D的loss都是真实样本和1作交叉熵,模拟样本和0作交叉熵;G的loss是模拟样本和1作交叉熵。WGan的loss就是将真实样本和模拟样本形成联合分布,采样后给二者作差,D的目的是二者越大越好,G的目的是二者越小越好。

尽可能取到最大,此时L就会近似真实分布与生成分布之间的Wasserstein距离(忽略常数倍数)。注意原始GAN的判别器做的是真假二分类任务,所以最后一层是sigmoid,但是现在WGAN中的判别器做的是近似拟合Wasserstein距离,属于回归任务,所以要把最后一层的sigmoid拿掉。

接下来生成器要近似地最小化Wasserstein距离,可以最小化L,由于Wasserstein距离的优良性质,我们不需要担心生成器梯度消失的问题。再考虑到的第一项与生成器无关,就得到了WGAN的两个loss。

公式15是公式17的反,可以指示训练进程,其数值越小,表示真实分布与生成分布的Wasserstein距离越小,GAN训练得越好。

3 代码理解

WGAN与原始GAN第一种形式相比,只改了四点:

  • 判别器最后一层去掉sigmoid
  • 生成器和判别器的loss不取log
  • 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
  • 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

GAN代码的修改部分,具体区别可参照主页GAN代码进行对照:

  • 因为变成回归任务,所以去掉Sigmoid函数。

  •  删除Loss function。

  •  不要用基于动量的优化算法(包括momentum和Adam),优化器改为RMSprop

  • 保证fθ(x)fθ(x)满足K-Lipschitz条件,《Wasserstein GAN》做了一个简单地处理,因为判别器是由神经网络构成的,因此对每层的线性算子中参数做了一个夹逼,限制其取值范围,就可以实现。如上面代码的这个部分。clamp函数用于取上下限
# Clip weights of discriminator
for p in discriminator.parameters():p.data.clamp_(-opt.clip_value, opt.clip_value)
  • 修改loss的计算方法:
loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
loss_G = -torch.mean(discriminator(gen_imgs))

GitHub源码

import argparse
import os
import numpy as np
import math
import sysimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.00005, help="learning rate")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)img_shape = (opt.channels, opt.img_size, opt.img_size)cuda = True if torch.cuda.is_available() else Falseclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.shape[0], *img_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),)def forward(self, img):img_flat = img.view(img.shape[0], -1)validity = self.model(img_flat)return validity# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr)Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# ----------
#  Training
# ----------batches_done = 0
for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):# Configure inputreal_imgs = Variable(imgs.type(Tensor))# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))# Generate a batch of imagesfake_imgs = generator(z).detach()# Adversarial lossloss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))loss_D.backward()optimizer_D.step()# Clip weights of discriminatorfor p in discriminator.parameters():p.data.clamp_(-opt.clip_value, opt.clip_value)# Train the generator every n_critic iterationsif i % opt.n_critic == 0:# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()# Generate a batch of imagesgen_imgs = generator(z)# Adversarial lossloss_G = -torch.mean(discriminator(gen_imgs))loss_G.backward()optimizer_G.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item()))if batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)batches_done += 1


http://chatgpt.dhexx.cn/article/JgJfFcRt.shtml

相关文章

YOLOv3:Darknet代码解析(四)结构更改与训练

背景:我们需要降低YOLOv2-tiny的参数量和存储量,以便硬件实现。 目的:更改YOLO结构,去掉后面的两层卷积层,降低参数量和运算量。 相关文章: YOLOv3:Darknet代码解析(一&#xff0…

StrongSORT(deepsort强化版)浅实战+代码解析

1.实战部分 1.1 具体操作 其实和之前的deepsort没差 到github上下载Yolov5_StrongSORT_OSNet下载对应的yolov5去替代原文件中yolov5下载yolov5权重(可以自动下载)和ReID权重(可能要科学上网)放到weight里面 ReID权重有点神秘&a…

对比学习 ——simsiam 代码解析。:

目录 1 : 事先准备 。 2 : 代码阅读。 2.1: 数据读取 2.2: 模型载入 3 训练过程: 4 测试过程: 5 :线性验证 6 : 用自己数据集进行对比学习。 第一: 改数据集 &#x…

AutoWare 代码解析

Auto Ware 代码解析系列(1) Auto Ware 是日本名古屋大学的开源无人车项目,下图为ros仿真环境下的各个节点的关系图: 代码库地址为:https://github.com/CPFL/Autoware 上面有较为详细的仿真环境配置信息,建…

FAST-LIO2代码解析(六)

0. 简介 上一节我们将while内部的IKD-Tree部分全部讲完,下面将是最后一部分,关于后端优化更新的部分。 1. 迭代更新 最后一块主要做的就是,拿当前帧与IKD-Tree建立的地图算出的残差,然后去计算更新自己的位置,并将更…

DeblurGAN-V2源代码解析

DeblurGAN-V2源代码解析(pytorch) DeblurGAN-V2是DeblurGAN的改进版,主要解决的是去图像运动模糊的问题,相比于DeblurGAN而言有速度更快,效果更好的优点。 论文:https://arxiv.org/pdf/1908.03826.pdf 代码…

mmsegmentation模型生成代码解析

前言 疫情在家办公,新Team这边习惯用MMLab开发网络,正好趁这段时间理解一下商汤大佬们的框架。我之前其实网络开发的比较少,主要是学习用的,而且开发网络基本是靠手写或者copy,用这种架构开发我是十分赞成的&#xff…

PX4代码解析(1)

前言 做pixhawk飞控有一段时间了,但在学习过程中遇到许多困难,目前网上找不到比较完整的PX4学习笔记,我打算结合自己理解,写写自己对PX4源码的理解,不一定对,只是希望与各位大佬交流交流,同时梳…

PX4代码解析(2)

前言 在大致了解PX4代码架构后,我们需要了解PX4的通信机制。在PX4代码架构中,每通信总线主要分为两个部分,一是内部通信总线uORB,即PX4内部进程通信采用的协议,例如PX4内部姿态控制需要获取飞行器姿态,而飞行器姿态是…

Teams Bot App 代码解析

上一篇文章我们讲了如何使用 teams toolkit 来快速弄一个 teams bot,可以看到 toolkit 给我们提供了极大的方便性,让开发人员可以更好的把重心放在 coding 上,而不是各种配置上。 那我们这篇文章主要接着上篇,来解析一下 teams b…

代码分析(一)

2021SCSDUSC 分析前言 对于APIJSON的代码分析首先就是,看一下该项目的作用以及如何进行,看一下原来不部署这个项目的正常流程: 再来看一下部署上APIJSON后项目的流程走向: 接下来开始按照这个流程对相应的代码进行分析。 Abst…

Linux命令之lsusb

一、lsusb命令用于显示本机的USB设备列表,以及USB设备的详细信息。 二、lsusb命令显示的USB设备信息来自“/proc/bus/usb”目录下的对应文件。 三、Linux从/var/lib/usbutils/usb.ids识别USB设备的详细信息。 语法格式 lsusb [参数] 常用参数: -v显…

Linux命令-磁盘管理-lsusb

1 需求 2 语法 C:\>adb shell lsusb --help Toybox 0.8.4-android multicall binary: https://landley.net/toybox (see toybox --help)usage: lsusbList USB hosts/devices. 3 示例 adb shell lsusb 4 参考资料

嵌入式debian没有lsusb命令解决

问题 -bash: lsusb: command not found 解决

linux之lsusb命令和cd -命令使用总结

1、lsusb命令介绍 使用 lsusb 来列出 USB 设备和它的属性,lsusb 会显示驱动和内部连接到你系统的设备。直接在控制台输入 lsusb 即可 2、lsusb简单使用 在控制台输入 lsusb 效果如下 系统中同时使用了 USB 2.0 root hub 驱动和 USB 3.0 root hub 驱动。 bus 002 指明设备…

LSB

知识点 LSB即为最低有效位(Least Significant Bit,lsb),这里百度了一下:图片中的图像像素一般是由RGB三原色(红绿蓝)组成,每一种颜色占用8位,取值范围为0x00~0xFF&#…

lsusb命令-在系统中显示有关USB设备信息

在 中我们使用lsusb 列出USB设备及其属性,lsusb用于显示系统中的USB总线及其连接的设备信息。下面介绍如何安装并使用。 系统环境 7 安装usbutils 默认Centos7系统中没有lsusb ,我们需要安装usbutils安装包,才能使用lsusb: […

LSF-bsub命令

文章目录 一、LSF(load sharing facility)二、bsub命令三、 常用命令3.1 bhosts3.2 bqueues3.3 bjobs3.4 bkill3.5 bhist3.6 busers 一、LSF(load sharing facility) 分布资源管理的工具,用来调度、监视、分析联网计算机的负载。 目的:通过集中监控和调…

Linux下的lsusb命令详解

lsusb命令详解 参考: 1、https://zhuanlan.zhihu.com/p/142403866 2、https://blog.csdn.net/phmatthaus/article/details/124198879 简介 ​USB,是英文Universal Serial Bus(通用串行总线)的缩写,是一个外部总线标…

详解 lsusb命令

USB设备检测的一般过程 USB设备检测也是通过/proc目录下的USB文件系统进行的。为了使一个USB设备能够正常工作,必须要现在系统中插入USB桥接器模块。在检测开始时,一般要先检测是否存在/proc/bus/usb目录,若不存在则尝试插入USB桥接模块。 现…