【Keras-CGAN】MNIST / CIFAR-10

article/2025/11/5 19:21:43

在这里插入图片描述

本博客是 One Day One GAN [DAY 3] 的 learning notes!用 CGAN 来做 MNIST 图片的生成!

参考 【Keras-MLP-GAN】MNIST


文章目录

  • 1 CGAN(Conditional Generative Adversarial Nets)
  • 2 CGAN for MNIST
    • 2.1 导入必要的库
    • 2.2 搭建 generator
    • 2.3 搭建 discriminator
    • 2.4 compile 模型,对学习过程进行配置
    • 2.5 保存生成的图片
    • 2.6 训练
    • 2.7 结果展示

1 CGAN(Conditional Generative Adversarial Nets)

condition的意思是就是条件
原始 GAN
在这里插入图片描述
如果我们已知输入的 ground truth 的 label 信息,那么我们便可以在这个基础上结合条件概率的公式得到 CGAN 的目标函数:
在这里插入图片描述

如下图所示
在这里插入图片描述

2 CGAN for MNIST

2.1 导入必要的库

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adamimport matplotlib.pyplot as pltimport numpy as np

2.2 搭建 generator

通过,embedding,把 label 嵌入到 100 维,然后和噪声 z multiply,做为模型的输入,原来的 GAN 只是把 z 作为输入

# build_generator
model = Sequential()model.add(Dense(256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod((28,28,1)), activation='tanh'))
model.add(Reshape((28,28,1)))model.summary()noise = Input(shape=(100,)) # input 100,这里写成100不加逗号不行哟
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, 100)(label))#class, z dimensionmodel_input = multiply([noise, label_embedding]) # 把 label 和 noise embedding 在一起,作为 model 的输入
print(model_input.shape)img = model(model_input) # output (28,28,1)generator = Model([noise,label], img)

和原始 GAN 不同的地方是,多了 label_embedding,也即把 noise 和 label 信息 embedding,关于 embedding,可以参考 深度学习中Keras中的Embedding层的理解与使用。

  • input_dim 为 classes,也即 10
  • output_dim 为要嵌入的向量空间的大小,这里是 100
  • input_length 这里为 1 (也即 0-9)中的一种

output

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 256)               25856     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 256)               1024      
_________________________________________________________________
dense_2 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048      
_________________________________________________________________
dense_3 (Dense)              (None, 1024)              525312    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_4 (Dense)              (None, 784)               803600    
_________________________________________________________________
reshape_1 (Reshape)          (None, 28, 28, 1)         0         
=================================================================
Total params: 1,493,520
Trainable params: 1,489,936
Non-trainable params: 3,584
_________________________________________________________________
(?, 100)

2.3 搭建 discriminator

通过 embedding,把 label 嵌入到 28281 维,然后和图片 multiply,做为模型的输入,原来的 GAN 只是把图片作为输入

# build_discriminator
model = Sequential()model.add(Flatten(input_shape=(28,28,1)))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))model.add(Dense(1, activation='sigmoid'))
model.summary()img = Input(shape=(28,28,1)) # 输入 (28,28,1)
label = Input(shape=(1,), dtype='int32')label_embedding = Flatten()(Embedding(10, np.prod((28,28,1)))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])validity = model(model_input) # 把 label 和 G(z) embedding 在一起,作为 model 的输入
discriminator = Model([img, label], validity)

output

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 512)               262656    
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_7 (Dense)              (None, 512)               262656    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_8 (Dense)              (None, 1)                 513       
=================================================================
Total params: 927,745
Trainable params: 927,745
Non-trainable params: 0
_________________________________________________________________

2.4 compile 模型,对学习过程进行配置

这里训练 GAN 分为两个过程

  • 训练 discriminator,图片由固定 generator 产生
  • 训练 generator,联合 discriminator 和 generator,但是 discriminator 的梯度不更新,所以 discriminator 固定住了
optimizer = Adam(0.0002, 0.5)# discriminator
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# The combined model  (stacked generator and discriminator)
noise = Input(shape=(100,))
label = Input(shape=(1,))
img = generator([noise,label])# For the combined model we will only train the generator
validity = discriminator([img,label])
discriminator.trainable = False# Trains the generator to fool the discriminator
combined = Model([noise,label], validity)
combined.summary()
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

2.5 保存生成的图片

def sample_images(epoch):r, c = 2, 5noise = np.random.normal(0, 1, (r * c, 100))sampled_labels = np.arange(0, 10).reshape(-1, 1)gen_imgs = generator.predict([noise, sampled_labels])# Rescale images 0 - 1gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()

2.6 训练

这里的 epoch 理解为 iteration

batch_size = 32
sample_interval = 200# Load the dataset
(X_train, y_train), (_, _) = mnist.load_data() # (60000,28,28)
# Rescale -1 to 1
X_train = X_train / 127.5 - 1. # tanh 的结果是 -1~1,所以这里 0-1 归一化后减1
X_train = np.expand_dims(X_train, axis=3)  # (60000,28,28,1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))for epoch in range(20001):# ---------------------#  Train Discriminator# ---------------------# Select a random batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size) # 0-60000 中随机抽  #imgs = X_train[idx]imgs, labels = X_train[idx], y_train[idx]noise = np.random.normal(0, 1, (batch_size, 100))# 生成标准的高斯分布噪声# Generate a batch of new imagesgen_imgs = generator.predict([noise,labels])# Train the discriminatord_loss_real = discriminator.train_on_batch([imgs, labels], valid) #真实数据对应标签1d_loss_fake = discriminator.train_on_batch([gen_imgs,labels], fake) #生成的数据对应标签0d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  Train Generator# ---------------------#noise = np.random.normal(0, 1, (batch_size, 100))sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)# Train the generator (to have the discriminator label samples as valid)g_loss = combined.train_on_batch([noise, sampled_labels], valid)# Plot the progressif epoch % 500==0:print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# If at save interval => save generated image samplesif epoch % sample_interval == 0:sample_images(epoch)

output

0 [D loss: 0.692575, acc.: 32.81%] [G loss: 0.680311]
200 [D loss: 0.442047, acc.: 76.56%] [G loss: 5.113770]
400 [D loss: 0.332470, acc.: 85.94%] [G loss: 2.495651]
……
19600 [D loss: 0.644090, acc.: 57.81%] [G loss: 0.867414]
19800 [D loss: 0.682952, acc.: 54.69%] [G loss: 0.818742]
20000 [D loss: 0.662673, acc.: 60.94%] [G loss: 0.831777]

2.7 结果展示

0 iteration在这里插入图片描述
200 iteration
在这里插入图片描述
400 iteration
在这里插入图片描述

19600 iteration
在这里插入图片描述
19800 iteration
在这里插入图片描述
20000 iteration
在这里插入图片描述


http://chatgpt.dhexx.cn/article/jJWTwSf0.shtml

相关文章

CGAN及代码实现

前言 本文主要介绍CGAN及其代码实现阅读本文之前,建议先阅读GAN(生成对抗网络)本文基于一次课程实验,代码仅上传了需要补充部分 CGAN 全称: C o n d i t i o n a l G e n e r a t i v e A d v e r s a r i a l N e t w o r k Conditional …

生成对抗网络(二)CGAN

一、简介 之前介绍了生成式对抗网络(GAN),关于GAN的变种比较多,我打算将几种常见的GAN做一个总结,也算是激励自己学习,分享自己的一些看法和见解。 之前提到的GAN是最基本的模型,我们的输入是随机噪声,输出…

读CGAN文章

提出CGAN是因为非条件的生成模型中,对生成的内容控制,实际上只要保证真实性就可以了;采用CGAN的话,我们会增加一些额外的信息去控制数据生成的过程,例如一些类别的标签,例如数字图片数据集中,可…

CGAN论文解读:Conditional Generative Adversarial Nets

论文链接:Conditional Generative Adversarial Nets 代码解读:Keras-CGAN_MNIST 代码解读 目录 一、前言 二、相关工作 三、网络结构 CGAN NETS 四、实验结果 4.1 单模态 (mnist实验) 4.2 多模态(自动为图片打…

第三章 CGAN

写在前面:最近看了《GAN实战》,由于本人忘性大,所以仅是笔记而已,方便回忆,如果能帮助大家就更好了。 目录 代价函数 训练过程 生成器和鉴别器 混淆矩阵 CGAN生成手写数字 导入声明 模型输入维度 生成器 鉴别…

【pytorch】CGAN编程实现

CGAN介绍 由于原始GAN生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标类别不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于…

GAN,CGAN,DCGAN

GAN对抗生成网络 训练流程 图片以及训练过程来源 训练这样的两个模型的大方法就是单独交替迭代训练。 我们人为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,我们就已经默认真样本集所有的类标签都为1&#xf…

GAN论文阅读——CGAN

论文标题:Conditional Generative Adversarial Nets 论文链接:https://arxiv.org/pdf/1411.1784.pdf 参考资料:http://blog.csdn.net/solomon1558/article/details/52555083 一、CGAN的思想 在原始GAN学习笔记中,我们提到过&am…

PyTorch随笔 - 生成对抗网络的改进cGAN和LSGAN

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/129939225 本文介绍GAN的两个常见改进,cGAN和LSGAN,两者一般结合使用。 cGAN: Conditional Generative Adversa…

CGAN实现过程

本文目录 一、原理二、参数初始化1. G的输入2. D的输入3. 模型参数初始化4. 测试噪声 三、执行过程四、测试 本文用MNIST数据集进行训练,并用图解的方法展示了CGAN与GAN中输入的区别,帮助理解CGAN的运行过程 一、原理 如下图所示,我们在输入…

CGAN

CGAN 生成符合特定描述的输出, 如图:给定描述,生成相应内容图片 CGAN中的negetive情况包含两种,(正确的描述,不清晰的图片)和(不正确的描述,清晰的图片)&a…

CGAN 简介与代码实战

1.介绍 原始GAN(GAN 简介与代码实战_天竺街潜水的八角的博客-CSDN博客)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约…

CGAN原理分析

1、CGAN原理分析 1.1 网络结构 CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息y,实现条件生成模型。CGAN原文中作者…

CGAN理论讲解及代码实现

目录 1.原始GAN的缺点 2.CGAN中心思想 3.原始GAN和CGAN的区别 4.CGAN代码实现 5.运行结果 6.CGAN缺陷 1.原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。 针对原始…

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信…

解决关于Navicat破解安装过程中出现“rsa public key not find”

解决关于Navicat破解安装过程中出现“rsa public key not find” 问题描述解决办法 问题描述 出现“rsa public key not find”的输出框 解决办法 首先先安装Navicat。安装后先不要运行 打开 然后点击 如果出现 则是对的 如果出现这个: 那就请你找到在本地的na…

navicat安装与激活

原文网址:https://www.jianshu.com/p/5f693b4c9468?mTypeGroup 一、Navicat Premium 12下载 Navicat Premium 12是一套数据库开发管理工具,支持连接 MySQL、Oracle等多种数据库,可以快速轻松地创建、管理和维护数据库。 Navicat Premium 12简…

Mac上安装 Navicat

1.下载安装包 Mac版 Navicat Premium 12 v12.0.23.0 官网下载地址: 英文64位 http://download.navicat.com/download/navicat120_premium_en.dmg 中文简体64位 http://download.navicat.com/download/navicat120_premium_cs.dmg 中文简体安装包:链接:h…

Navicat Premium 12.1.21 最新版激活工具及方法

At The Beginning ****** Sincerely regards to the author of the original work ******* 本帖持续更新 Last updated at 21st Aug 2019 Steps navicat_premium原版安装包 官网下载地址:https://www.navicat.com.cn/download/navicat-premium 注册工具下载 git…

Navicat Premium 12.1.16.0安装与激活

一、Navicat Premium 12下载 Navicat Premium 12简体中文下载; 提取码:cgv4 二、Navicat Premium 12安装 双击安装,点击下一步: 同意协议,点击下一步: 选择安装位置(可默认)&…