生成对抗网络(二)CGAN

article/2025/11/5 19:48:09

一、简介

        之前介绍了生成式对抗网络(GAN),关于GAN的变种比较多,我打算将几种常见的GAN做一个总结,也算是激励自己学习,分享自己的一些看法和见解。

        之前提到的GAN是最基本的模型,我们的输入是随机噪声,输出的是对应的图像,但是我们没法控制生成图像的类型。比如,我要生成一张数字0的图片,但是GAN生成的图片却是数字0-9的图片,针对这个问题,Conditional Generative Adversarial Nets被提了出来,在原有GAN的基础上,添加了类别信息以便让模型生成特定的图片。这里的条件(conditional),就是这个额外的类别信息。

二、原理

         由于在GAN的生成器和判别器中都加入了额外的类别信息,模型的目标优化函数也发生了变化。

         生成器的输入变为噪音变量P_z(z)和类别信息y, 判别器的输入为图片数据x和类别信息y, 目标函数如下:

               \underset{G}{min}\, \underset{D}{max}\, V(D,G)=E_{x\sim p_{data} (x)}[\log D(x|y)] +E_{z\sim P_{z}(z)}[\log (1-D(G(z|y)))]

         就是在GAN的目标函数上添加了y这一类别变量,x变为了条件分布。

         模型的结构图如下,

         

        GAN的结构与这个类似,生成器部分和判别器部分是分开的两个子网络,单独进行训练。类别信息y是通过embedding层嵌入的。

        具体的实现可以看看代码:

        生成器:

    def build_generator(self):model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))model.summary()noise = Input(shape=(self.latent_dim,))label = Input(shape=(1,), dtype='int32')label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))model_input = multiply([noise, label_embedding])img = model(model_input)return Model([noise, label], img)

          标签是通过嵌入层实现的,embedding层可以将类别标签转换为对应的向量表示,在此生成器中,类别有10个(0-9),对应embedding中的input_dim, 输出维度和噪音数据是相同的,之后,再利用multiply层将两者逐项做乘积,这便是生成器的输入。

           判别器:

    def build_discriminator(self):model = Sequential()model.add(Dense(512, input_dim=np.prod(self.img_shape)))model.add(LeakyReLU(alpha=0.2))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.4))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.4))model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=self.img_shape)label = Input(shape=(1,), dtype='int32')label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))flat_img = Flatten()(img)model_input = multiply([flat_img, label_embedding])validity = model(model_input)return Model([img, label], validity)

            判别器的输入和生成器是一样的,输出是对应的图片的类别。

            训练:

            训练采用的mnist数据集,训练时需要将图片数据和对应的标签输入模型。

            生成器和判决器作为一个整体进行训练的时候,判别器是不训练的,这时只训练生成器;当判决器作为一个单独的模型时,判决器会得到训练。二者的训练是交替进行的。

            具体的代码可以参考github

三、效果

       最后跑出来的效果还是很不错的,我在台式机上跑的,用的是1050ti的显卡,训练速度还比较快,一共20000轮,大概10分钟左右跑完。

       这是最后的训练效果:

       

       可以与前一篇博客里面的内容进行比较,与原始的GAN相比,效果要好一些,但是还是不是很清晰。一方面,mnist提供的图片像素较低,另一方面,我们采用的是全连接神经网络,对于图片的处理效果并不是很好。

       要生成更加清晰地图片,可以利用DCGAN,这也是我接下来要做的工作。


http://chatgpt.dhexx.cn/article/FOkAEe5U.shtml

相关文章

读CGAN文章

提出CGAN是因为非条件的生成模型中,对生成的内容控制,实际上只要保证真实性就可以了;采用CGAN的话,我们会增加一些额外的信息去控制数据生成的过程,例如一些类别的标签,例如数字图片数据集中,可…

CGAN论文解读:Conditional Generative Adversarial Nets

论文链接:Conditional Generative Adversarial Nets 代码解读:Keras-CGAN_MNIST 代码解读 目录 一、前言 二、相关工作 三、网络结构 CGAN NETS 四、实验结果 4.1 单模态 (mnist实验) 4.2 多模态(自动为图片打…

第三章 CGAN

写在前面:最近看了《GAN实战》,由于本人忘性大,所以仅是笔记而已,方便回忆,如果能帮助大家就更好了。 目录 代价函数 训练过程 生成器和鉴别器 混淆矩阵 CGAN生成手写数字 导入声明 模型输入维度 生成器 鉴别…

【pytorch】CGAN编程实现

CGAN介绍 由于原始GAN生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标类别不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于…

GAN,CGAN,DCGAN

GAN对抗生成网络 训练流程 图片以及训练过程来源 训练这样的两个模型的大方法就是单独交替迭代训练。 我们人为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,我们就已经默认真样本集所有的类标签都为1&#xf…

GAN论文阅读——CGAN

论文标题:Conditional Generative Adversarial Nets 论文链接:https://arxiv.org/pdf/1411.1784.pdf 参考资料:http://blog.csdn.net/solomon1558/article/details/52555083 一、CGAN的思想 在原始GAN学习笔记中,我们提到过&am…

PyTorch随笔 - 生成对抗网络的改进cGAN和LSGAN

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/129939225 本文介绍GAN的两个常见改进,cGAN和LSGAN,两者一般结合使用。 cGAN: Conditional Generative Adversa…

CGAN实现过程

本文目录 一、原理二、参数初始化1. G的输入2. D的输入3. 模型参数初始化4. 测试噪声 三、执行过程四、测试 本文用MNIST数据集进行训练,并用图解的方法展示了CGAN与GAN中输入的区别,帮助理解CGAN的运行过程 一、原理 如下图所示,我们在输入…

CGAN

CGAN 生成符合特定描述的输出, 如图:给定描述,生成相应内容图片 CGAN中的negetive情况包含两种,(正确的描述,不清晰的图片)和(不正确的描述,清晰的图片)&a…

CGAN 简介与代码实战

1.介绍 原始GAN(GAN 简介与代码实战_天竺街潜水的八角的博客-CSDN博客)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约…

CGAN原理分析

1、CGAN原理分析 1.1 网络结构 CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息y,实现条件生成模型。CGAN原文中作者…

CGAN理论讲解及代码实现

目录 1.原始GAN的缺点 2.CGAN中心思想 3.原始GAN和CGAN的区别 4.CGAN代码实现 5.运行结果 6.CGAN缺陷 1.原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。 针对原始…

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信…

解决关于Navicat破解安装过程中出现“rsa public key not find”

解决关于Navicat破解安装过程中出现“rsa public key not find” 问题描述解决办法 问题描述 出现“rsa public key not find”的输出框 解决办法 首先先安装Navicat。安装后先不要运行 打开 然后点击 如果出现 则是对的 如果出现这个: 那就请你找到在本地的na…

navicat安装与激活

原文网址:https://www.jianshu.com/p/5f693b4c9468?mTypeGroup 一、Navicat Premium 12下载 Navicat Premium 12是一套数据库开发管理工具,支持连接 MySQL、Oracle等多种数据库,可以快速轻松地创建、管理和维护数据库。 Navicat Premium 12简…

Mac上安装 Navicat

1.下载安装包 Mac版 Navicat Premium 12 v12.0.23.0 官网下载地址: 英文64位 http://download.navicat.com/download/navicat120_premium_en.dmg 中文简体64位 http://download.navicat.com/download/navicat120_premium_cs.dmg 中文简体安装包:链接:h…

Navicat Premium 12.1.21 最新版激活工具及方法

At The Beginning ****** Sincerely regards to the author of the original work ******* 本帖持续更新 Last updated at 21st Aug 2019 Steps navicat_premium原版安装包 官网下载地址:https://www.navicat.com.cn/download/navicat-premium 注册工具下载 git…

Navicat Premium 12.1.16.0安装与激活

一、Navicat Premium 12下载 Navicat Premium 12简体中文下载; 提取码:cgv4 二、Navicat Premium 12安装 双击安装,点击下一步: 同意协议,点击下一步: 选择安装位置(可默认)&…

Navicat Premium安装和激活

前言 Navicat Premium这个软件是非常的好用,这个软件中包含mysql,SQL Server等等的数据库,受到广大编程爱好者的欢迎,废话不多说,下面就直接进入主题,马上就是Navicat Premium安装和激活的环节。 (1&…

Navicat Premium 12破解激活

下载Navicat Premium 12并安装; 蓝奏云下载:Navicat Premium 12注册机 重要提示:该注册机来源于DeltaFoX。一般来说,由于注册机会修改.exe文件或.dll文件,加壳并且没有数字签名,所以杀毒软件会报毒。如需…