对抗生成网络GAN系列——GAN原理及手写数字生成小案例

article/2025/9/24 17:47:36

 

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊往期回顾:目标检测系列——开山之作RCNN原理详解    目标检测系列——Fast R-CNN原理详解   目标检测系列——Faster R-CNN原理详解

🍊近期目标:拥有10000粉丝
🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

文章目录

    • 对抗生成网络GAN系列——GAN原理及手写数字生成小案例
    • 写在前面
    • GAN简介
    • 生成对抗网络✨✨✨
      • GAN损失函数🧅🧅🧅
      • GAN流程🧅🧅🧅
      • GAN实现效果🧅🧅🧅
    • 使用GAN生成手写数字小demo✨✨✨
    • 论文下载地址
    • 参考链接

 

本节已录制视频教学,链接如下:生成对抗网络GAN原理详解篇🌱🌱🌱

对抗生成网络GAN系列——GAN原理及手写数字生成小案例

写在前面

  其实关于GAN的讲解我早就做过一期,点击☞☞☞了解详情🌱🌱🌱由于最近会用到GAN的一些知识,自己又对GAN进行了一些整理,有了一些新的认识,便写了这篇文章。那么这篇文章和早期的文章有什么区别呢?首先,早期的文章只是对GAN做了一个大概的认识,而这篇文章会贴合论文较为详细的讲解GAN网络;其次,这次我准备写一个GAN系列,介绍一些经典的GAN网络,所以这篇文章和后面打算写的文章关联性更强。【注:我觉得大家可以先去读一下我之前的文章,对于损失函数部分我通过一个例子来讲解,还是很好理解的,文章也很短,能让大家快速对GAN有一个感性的认识】

  准备好了嘛,下面就正式发车了。🚖🚖🚖

 

GAN简介

  这里先来简单的介绍一下GAN,其完整的名称为Generative Adversarial Nets (生成对抗网络) 。其实这个起名还有个小故事,我简要的说一下,大家随便听听,就当放松了。当时作者Goodfellow 对于这篇文章其实是有好几个备选名字的,后来一个中国人说GAN(干)在中国有一种对抗的意思,作者一听,直接拍案选择了这个名称。🍋🍋🍋

  接下来让我们看看论文中对GAN的解释,如下图所示:

image-20220713213916675

  我简单的来翻译一下,其大致意思是说:在我们提出的对抗生成网络中,有一个生成模型,也有一个对抗模型,它们互相对抗,互相促进。文中也举了个小例子,生成模型可以被认为是一个假币伪造团队,试图生产假币并使用,而判别器类似于警察,试图发现假币。这就是一个互相博弈的过程,生成模型不断的产生伪造水平高的假币,而判别器不断提高警察识别假币水平,直至两者达到一个平衡。这个平衡是指什么呢?即判别器对于生成模型产生的假币辨别的成功率大致为50%,即很难辨别真假。

 
 

生成对抗网络✨✨✨

GAN损失函数🧅🧅🧅

  这部分我们主要结合生成对抗网络的损失函数来介绍网络的整个流程,首先呢,我们需要对一些字母做一些解释。如下:

Z Z Z随机噪声
P z ( Z ) P_z(Z) Pz(Z)随机噪声Z服从的概率分布
G ( Z ) G(Z) G(Z)生成器:输入为噪声Z,输出为假图像
P g P_g Pg生成器生成的假图像服从的概率分布
P d a t a P_{data} Pdata真实数据服从的概率分布
D ( X ) D(X) D(X)判别器:输入为图像,输出为该图像为真实图像的概率,概率在[0,1]之间

​  对上述字母有一定的了解后,下面就可以给出生成对抗网络的损失函数了,如下图所示:

image-20220714095708540

图片来自B站同济子豪兄
 

​  乍一看这个公式你应该是懵逼的,下面就跟着我的思路来分解分解上述公式。首先这个公式应该有两部分,一部分为给定G,找到使V最大化的D;另一部分为给定D,找到使V最小化的G。

  我们先来看第一部分,即给定G,找到使V最大化的D。如下图所示:【注:我们为什么想要找到使V最大化的D,是因为使V最大化的D会使判别器的效果最好】

image-20220714103147254

  首先看第①部分,因为判别器此时的输入为 X X X,是真实数据, E X ∼ P d a t a [ l o g D ( X ) ] E_{X \sim P_{data}}[logD(X)] EXPdata[logD(X)] 值越大表示判别器认为输入X为真实数据的概率越大,也即表示判别器能力越强,因此这项的输出越大对判别器来说越好。接着来看第②部分,注意此时判别器的输入为 G ( Z ) G(Z) G(Z),即输入为假图像,那么此时对于 D ( G ( Z ) ) D(G(Z)) D(G(Z))来说这个值越小,表示判别器判定假图像为真实数据的概率越小,同样表示判别器能力越强。需要注意的是第二项为 l o g ( 1 − D ( G ( Z ) ) ) log(1-D(G(Z))) log(1D(G(Z))) 的期望,当判别器越强时, D ( G ( Z ) ) D(G(Z)) D(G(Z)) 值越小,而 E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( Z ) ) ) ] E_{z \sim p_z(z)}[log(1-D(G(Z)))] Ezpz(z)[log(1D(G(Z)))] 越大。【注:部分①和部分②要想使给定G时,判别器的效果最好,都需要最大化V,即给定G,找到最大化V的D会使判别器的效果最好。】为方便大家理解,画出 l o g ( 1 − D ( G ( Z ) ) ) log(1-D(G(Z))) log(1D(G(Z))) 的函数图像如下:

image-20220714110115933

  接着我们来看第二部分,即给定D,找到使V最小化的G。如下图所示:【注:我们为什么想要找到使V最小化的G,是因为使V最小化的G会使生成器的效果最好】

image-20220714110850880

  同样的,先来看第①部分,由于这次我们是固定了D,而①只和D有关,因此这部分是常量,其对最小化V是没有任何影响的,可以舍去。那么我们就来看看第②部分,此时判别器的输入同样是 G ( Z ) G(Z) G(Z),为假图像。不同的是现在我们期待的是生成器的效果好,即尽可能的瞒过判别器,也即期望 D ( G ( Z ) ) D(G(Z)) D(G(Z)) 尽可能大。 D ( G ( Z ) ) D(G(Z)) D(G(Z))越大就表示判别器判定假图像为真实数据的概率越大,也就表示生成器生成的图像效果好,可以很成功的骗过判别器。同样的 D ( G ( Z ) ) D(G(Z)) D(G(Z)) 值越大, E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( Z ) ) ) ] E_{z \sim p_z(z)}[log(1-D(G(Z)))] Ezpz(z)[log(1D(G(Z)))] 就越小,因此给定D,找到最小化V的G会使生成器的效果最好。

 

GAN流程🧅🧅🧅

  论文中在给出损失函数后,又给了一个图例来解释GAN的过程,用原文的话来说就是一个不怎么正式,却更具教学意义的解释。(See Figure 1 for a less formal, more pedagogical explanation of the approach )

图片来自B站同济子豪兄
 

​  其实上图中的文字标注已经将这个过程解释的相当详细了,我再来简单的复述一遍。首先图中黑点表示真实图像的分布,绿点表示生成图像的概率分布,蓝点表示判别器预测 X X X为真实数据的概率。在(a)时,黑点和绿点的分布相差较大,判别器能大致辨别真实图像和生成图像,但分辨效果不好。【在黑点集中区域蓝点的值普遍较高,表示预测黑点为真实图像的概率较大;同理,在绿点集中区域蓝点的值普遍较低,表示预测绿点为真实图像的概率较小。但蓝点存在一定的波动,效果不是很好。】 从(a)到(b)经过了判别器的训练,这会导致什么结果呢,从图(b)中可以发现,此时蓝点表现的更加稳定,在黑点集中处预测概率大,在绿点集中处预测概率小,也就是说此时的判别器已经能很好的分辨什么是真实图像,什么是生成的假图像了。接下来从(b)到(c)经过了生成器的训练,这会导致什么结果呢,从图(c)中可以发现,此时绿点逐渐像黑点靠近,即生成的图像更加真实了,而此时蓝点没有变化,这就会导致现在判别器对真实图像和生成图像的辨别难度变大了。这样不断的训练判别器和生成器,最后变成图(d),即真实图像分部和生成器生成图像分布完全一致,判别器预测概率恒为0.5,也即此时判别器完全无法区分真实图像很生成图像了。🌾🌾🌾

​  接下来论文中给出了训练GAN网络的伪代码,如下图所示:

image-20220714160200976

图片来自B站同济子豪兄
 

  如果我前文的描述你都听懂了的话,其实这个过程就没什么好说的了,就是对判别器和生成器不断的迭代更新。需要注意的有两点,第一是在训练过程中,我们是训练K次判别器,训练一次生成器;第二是在训练生成器过程中,我们的损失函数没有了 1 m ∑ i = 1 m log ⁡ D ( x ( i ) ) \frac{1}{m}\sum\limits_{i = 1}^m {\log D({x^{(i)}})} m1i=1mlogD(x(i)) 这一项,这个我在GAN损失函数这节有提到,因为训练生成器G时固定了判别器D,该项是定值,可省略。【注:这里的 1 m ∑ i = 1 m log ⁡ D ( x ( i ) ) \frac{1}{m}\sum\limits_{i = 1}^m {\log D({x^{(i)}})} m1i=1mlogD(x(i)) E X ∼ P d a t a [ l o g D ( X ) ] E_{X \sim P_{data}}[logD(X)] EXPdata[logD(X)] 完全一样,只是一个是用均值表示,一个用期望表示。】

 

GAN实现效果🧅🧅🧅

  论文中给出了GAN的一些实现效果的图片,如下图所示:

image-20220714162630505

​  上面四个图中,注意黄框框住的并不是GAN生成的图片,它们表示与GAN生成图片最相似的原始真实图片。而GAN生成的图片为黄框左侧第一张图片,可以看出,GAN生成的效果还是挺好的。

 

使用GAN生成手写数字小demo✨✨✨

  上文算是把原理讲述清楚了,若你还不明白,慢慢的阅读每句话,加入自己的思考,或许会有不一样的收获。那么这节我讲来讲讲通过GAN网络生成手写数字的小demo,通过这部分你会了解搭建GAN网络的基本流程。下面就让我们一起来学学吧!!!🌻🌻🌻【注:其实大致的流程和一般分类网络的搭建是类似的,相关分类网络的搭建流程可参考我的这篇博文】

  首先训练一个模型肯定少不了数据集,我们通过一下代码获取torch自带的MNIST数据集,代码如下:

#MNIST数据集获取
dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.Resize(28),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.5], [0.5]),]))

  之后我们通过DataLoader方法加载数据集,代码如下:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

  这样数据就准备好了,下面就来构建我们的模型,分为生成器(Generator)和判别器(Discriminator)。【注:由于这期算是入门GAN,所以模型搭建只采用了全连接层】

生成器模型搭建:

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, 128),torch.nn.BatchNorm1d(128),torch.nn.GELU(),nn.Linear(128, 256),torch.nn.BatchNorm1d(256),torch.nn.GELU(),nn.Linear(256, 512),torch.nn.BatchNorm1d(512),torch.nn.GELU(),nn.Linear(512, 1024),torch.nn.BatchNorm1d(1024),torch.nn.GELU(),nn.Linear(1024, np.prod(image_size, dtype=np.int32)),nn.Sigmoid(),)def forward(self, z):# shape of z: [batchsize, latent_dim]output = self.model(z)image = output.reshape(z.shape[0], *image_size)return image

判别器模型搭建:

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(np.prod(image_size, dtype=np.int32), 512),torch.nn.GELU(),nn.Linear(512, 256),torch.nn.GELU(),nn.Linear(256, 128),torch.nn.GELU(),nn.Linear(128, 64),torch.nn.GELU(),nn.Linear(64, 32),torch.nn.GELU(),nn.Linear(32, 1),nn.Sigmoid(),)def forward(self, image):# shape of image: [batchsize, 1, 28, 28]prob = self.model(image.reshape(image.shape[0], -1))return prob

  模型搭建好后,我们会对损失函数、优化器等参数进行设置:

g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
loss_fn = nn.BCELoss()

  需要注意,这里采用的是BCELOSS损失函数,这个函数其实就对应着我们GAN理论部分的损失函数,这里想了解更多的话可以参考这个视频:BCE损失函数说明 🌲🌲🌲

​  这些设置好后,我们就来训练我们的GAN网络了,相关代码如下:这一部分我还是建议大家看一下这个视频,讲解的比较清楚。【可直接空降到41分钟】🥗🥗🥗

num_epoch = 200
for epoch in range(num_epoch):for i, mini_batch in enumerate(dataloader):gt_images, _ = mini_batchz = torch.randn(batch_size, latent_dim)pred_images = generator(z)g_optimizer.zero_grad()g_loss = loss_fn(discriminator(pred_images), labels_one)g_loss.backward()g_optimizer.step()d_optimizer.zero_grad()real_loss = loss_fn(discriminator(gt_images), labels_one)fake_loss = loss_fn(discriminator(pred_images.detach()), labels_zero)d_loss = (real_loss + fake_loss)# 观察real_loss与fake_loss,同时下降同时达到最小值,并且差不多大,说明D已经稳定了d_loss.backward()d_optimizer.step()

  最后,我来展示一下训练结果吧!!!我是在服务器上进行训练的,所以还是比较快的。先来看一下初始的图,都是一些随机的噪声,如下图所示:

image-20220715112632561

​ 再来看训练一段时间的结果,发现效果还是蛮不错滴🏵🏵🏵

image-20220715112438713

 
 

论文下载地址

论文下载地址🌱🌱🌱

 
 

参考链接

生成对抗网络GAN开山之作论文精读 🍁🍁🍁

原始GAN论文详解 🍁🍁🍁

GAN原理讲解与PyTorch手写逐行讲解 🍁🍁🍁

 
 

如若文章对你有所帮助,那就🛴🛴🛴

咻咻咻咻~~duang~~点个赞呗


http://chatgpt.dhexx.cn/article/AviPaJQA.shtml

相关文章

GAN——对抗生成网络

GAN的基本思想 作为现在最火的深度学习模型之一,GAN全称对抗生成网络,顾名思义是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的。它使用两个神经网络,将一个神经网络与另一个神经网络进行对抗。 基本思想:&…

一文读懂对抗生成网络的3种模型

https://www.toutiao.com/i6635851641293636109/ 2018-12-17 14:53:28 基于对抗生成网络技术的在线工具edges2cats, 可以为简笔画涂色 前言 在GAN系列课程中分别讲解了对抗生成网络的三种模型,从Goodfellow最初提出的原始的对抗生成网络,到…

对抗生成网络(GAN)详解

目录 前言 目标函数 原理 训练 给定生成器,训练判别器 给定判别器,训练生成器 总结 前言 之前的生成模型侧重于将分布函数构造出来,然后使用最大似然函数去更新这个分布函数的参数,从而优化分布函数,但是这种方法…

对抗生成网络(GAN)简介及生成数字实战

一、简介 生成对抗网络(Generative Adversarial Netword,简称GAN),是一种生成式机器学习模型,该方法由伊恩古德费洛等人于2014年提出,曾被称为“机器学习这二十年来最酷的想法”,可以用来创造虚…

对抗生成网络(Generative Adversarial Net)

好久没有更新博客了,但似乎我每次更新博客的时候都这么说(泪)。最近对生活有了一些新的体会,工作上面,新的环境总算是适应了,知道了如何摆正工作和生活之间的关系,如何能在有效率工作的同时还能…

【PaddleOCR-det-finetune】一:基于PPOCRv3的det检测模型finetune训练

文章目录 基本流程详细步骤打标签,构建自己的数据集下载PPOCRv3训练模型修改超参数,训练自己数据集启动训练导出模型 测试 相关参考手册在PaddleOCR项目工程中的位置: det模型训练和微调:PaddleOCR\doc\doc_ch\PPOCRv3_det_train.…

模型微调(Finetune)

参考:https://zhuanlan.zhihu.com/p/35890660 ppt下载地址:https://github.com/jiangzhubo/What-is-Fine-tuning 一.什么是模型微调 给定预训练模型(Pre_trained model),基于模型进行微调(Fine Tune)。相…

fine-tuning

微调(fine-tuning) 在平时的训练中,我们通常很难拿到大量的数据,并且由于大量的数据,如果一旦有调整,重新训练网络是十分复杂的,而且参数不好调整,数量也不够,所以我们可…

大模型的三大法宝:Finetune, Prompt Engineering, Reward

编者按:基于基础通用模型构建领域或企业特有模型是目前趋势。本文简明介绍了最大化挖掘语言模型潜力的三大法宝——Finetune, Prompt Engineering和RLHF——的基本概念,并指出了大模型微调面临的工具层面的挑战。 以下是译文,Enjoy! 作者 | B…

RCNN网络源码解读(Ⅲ) --- finetune训练过程

目录 0.回顾 1.finetune二分类代码解释(finetune.py) 1.1 load_data(定义获取数据的方法) 1.2 CustomFineTuneDataset类 1.3 custom_batch_sampler类( custom_batch_sampler.py) 1.4 训练train_mod…

FinSH

finSH介绍 FinSH 是 RT-Thread 的命令行组件,提供一套供用户在命令行调用的操作接口,主要用于调试或查看系统信息。它可以使用串口 / 以太网 / USB 等与 PC 机进行通信。 命令执行过程 功能: 支持鉴权,可在系统配置中选择打开/关闭。(TODO…

从统一视角看各类高效finetune方法

每天给你送来NLP技术干货! 来自:圆圆的算法笔记 随着预训练模型参数量越来越大,迁移学习的成本越来越高,parameter-efficient tuning成为一个热点研究方向。在以前我们在下游任务使用预训练大模型,一般需要finetune模型…

finetune

finetune的含义是获取预训练好的网络的部分结构和权重,与自己新增的网络部分一起训练。下面介绍几种finetune的方法。 完整代码:https://github.com/toyow/learn_tensorflow/tree/master/finetune 一,如何恢复预训练的网络 方法一&#xf…

11.2 模型finetune

一、Transform Learning 与 Model Finetune 二、pytorch中的Finetune 一、Transfer Learning 与 Model Finetune 1. 什么是Transfer Learning? 迁移学习是机器学习的一个分支,主要研究源域的知识如何应用到目标域当中。迁移学习是一个很大的概念。 怎么理解源域…

飞桨深度学习学院零基础深度学习7日入门-CV疫情特辑学习笔记(四)DAY03 车牌识别

本课分为理论和实战两个部分 理论:卷积神经网络 1.思考全连接神经网络的问题 一般来收机器学习模型实践分为三个步骤,(1)建立模型 (2)选择损失函数 (3)参数调整学习 1.1 模型结构不…

unity sdk(android)-友盟推送SDK接入

注意:一开始想接友盟Unity的SDk,但是导入后缺少各种jar,所以最后还是接了android的,demo文档齐全 官方文档:开发者中心 按照官方文档对接即可, 接入流程 1、项目中com.android.tools.build:gradle配置&…

友盟推送学习

一、首次使用U_Push 1、首先注册友盟账号,进入工作台,选择产品U_Push。 2、创建应用 3、在自己的项目中自动集成SDK 开发环境要求: Android Studio 3.0以上 Android minSdkVersion: 14 Cradle: 4.4以上 在根目录build.gradle中添加mav…

Android 学习之如何集成友盟推送

我是利用Android studio 新建一个空的Android项目。 步骤一 导入第三方库 1.切换Android项目状态为Project状态 2.在main文件下新建 jniLibs文件夹(用来导入PushSDK项目下lib文件中的so文件) 3.在libs文件夹下添加友盟PuskSDK中的 jar 文件&#xff…

用PaddlePaddle(飞浆)实现车牌识别

项目描述:本次实践是一个多分类任务,需要将照片中的每个字符分别进行识别,完成车牌的识别 实践平台:百度AI实训平台-AI Studio、PaddlePaddle1.8.0 动态图 数据集介绍(自己去网上下载车牌识别数据集) 数据…

深度学习(五) CNN卷积神经网络

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 CNN卷积神经网络 前言一、CNN是什么?二、为什么要使用CNN?三、CNN的结构1.图片的结构2.卷积层1.感受野(Receptive Field)2.卷积…