第三章 CGAN

article/2025/11/5 20:29:05

写在前面:最近看了《GAN实战》,由于本人忘性大,所以仅是笔记而已,方便回忆,如果能帮助大家就更好了。

目录

代价函数

训练过程

生成器和鉴别器

混淆矩阵

CGAN生成手写数字

导入声明

模型输入维度

生成器

鉴别器

构建并编译GAN

GAN训练循环

显示合成图像

运行模型

结果​​​​​​​


​​​​​​​

目标识别模型学习图像中的模式以识别图像的内容

生成器学习合成这些模式(与目标识别模型过程相反)

 代价函数

J^{G}表示生成器的代价函数,J^{D}代表鉴别器的代价函数

训练参数(权重与偏置):\theta^{G}代表生成器,\theta^{D}代表鉴别器

因为GAN与传统的神经网络不同,它由两个网络构成,其代价函数依赖于两个网络的参数。也就是说生成器代价函数是J^{G}(\theta^{G},\theta^{D}),而鉴别器代价函数是 ​​​​​​​ J^{D}(\theta^{G},\theta^{D})。在训练过程中,每个网络只能调整自己的参数:生成器只能调整\theta^{G},鉴别器只能调整\theta^{D}

 训练过程

GAN训练过程可以用一个博弈过程来描述,而非优化。(博弈双方是GAN的两个网络)

生成器训练参数\theta^{G},使得 代价函数J^{G}(\theta^{G},\theta^{D})最小化。

同时,对应该网络参数\theta^{D}下的鉴别器的代价函数 J^{D}(\theta^{G},\theta^{D})最小化

生成器和鉴别器

生成器G接受随机噪声向量z并生成一个伪样本x*,G(z)=x*

鉴别器D的输入要么是真实样本x,要么是伪样本x*,输出一个介于0和1之间的值(输入是真实样本的概率)

混淆矩阵

鉴别器的分类可以用混淆矩阵来表示,分类结果如下:

(1)真阳性true positive——真实样本正确,分类为真D(x)\approx 1

(2)假阴性false negative——真实样本错误,分类为假D(x)\approx 0

(3)真阴性true negative——伪样本正确,分类为假D(x*)\approx 0

(4)假阳性false positive——伪样本错误,分类为真D(x*)\approx 1

 鉴别器试图最大化真养性和真阴性分类

生成器目标是最大化鉴别器假阳性分类(不关心对真实样本分类效果只关心伪样本分类)

CGAN生成手写数字

导入声明

import matplotlib.pyplot as plt
import numpy as npfrom keras.datasets import mnist
from keras.layers import Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam

模型输入维度

img_rows = 28
img_cols = 28
channels = 1# Input image dimensions
img_shape = (img_rows, img_cols, channels)# Size of the noise vector, used as input to the Generator
z_dim = 100

生成器

def build_generator(img_shape, z_dim):model = Sequential()# Fully connected layermodel.add(Dense(128, input_dim=z_dim))# Leaky ReLU activationmodel.add(LeakyReLU(alpha=0.01))# Output layer with tanh activationmodel.add(Dense(28 * 28 * 1, activation='tanh'))# Reshape the Generator output to image dimensionsmodel.add(Reshape(img_shape))return model

鉴别器

def build_discriminator(img_shape):model = Sequential()# Flatten the input imagemodel.add(Flatten(input_shape=img_shape))# Fully connected layermodel.add(Dense(128))# Leaky ReLU activationmodel.add(LeakyReLU(alpha=0.01))# Output layer with sigmoid activationmodel.add(Dense(1, activation='sigmoid'))return model

构建并编译GAN

def build_gan(generator, discriminator):model = Sequential()# Combined Generator -> Discriminator modelmodel.add(generator)model.add(discriminator)return model# Build and compile the Discriminator
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy',optimizer=Adam(),metrics=['accuracy'])# Build the Generator
generator = build_generator(img_shape, z_dim)# Keep Discriminator’s parameters constant for Generator training
discriminator.trainable = False# Build and compile GAN model with fixed Discriminator to train the Generator
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())

GAN训练循环

losses = []
accuracies = []
iteration_checkpoints = []def train(iterations, batch_size, sample_interval):# Load the MNIST dataset(X_train, _), (_, _) = mnist.load_data()# Rescale [0, 255] grayscale pixel values to [-1, 1]X_train = X_train / 127.5 - 1.0X_train = np.expand_dims(X_train, axis=3)# Labels for real images: all onesreal = np.ones((batch_size, 1))# Labels for fake images: all zerosfake = np.zeros((batch_size, 1))for iteration in range(iterations):# -------------------------#  Train the Discriminator# -------------------------# Get a random batch of real imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Generate a batch of fake imagesz = np.random.normal(0, 1, (batch_size, 100))gen_imgs = generator.predict(z)# Train Discriminatord_loss_real = discriminator.train_on_batch(imgs, real)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)d_loss, accuracy = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  Train the Generator# ---------------------# Generate a batch of fake imagesz = np.random.normal(0, 1, (batch_size, 100))gen_imgs = generator.predict(z)# Train Generatorg_loss = gan.train_on_batch(z, real)if (iteration + 1) % sample_interval == 0:# Save losses and accuracies so they can be plotted after traininglosses.append((d_loss, g_loss))accuracies.append(100.0 * accuracy)iteration_checkpoints.append(iteration + 1)# Output training progressprint("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %(iteration + 1, d_loss, 100.0 * accuracy, g_loss))# Output a sample of generated imagesample_images(generator)

显示合成图像

def sample_images(generator, image_grid_rows=4, image_grid_columns=4):# Sample random noisez = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))# Generate images from random noisegen_imgs = generator.predict(z)# Rescale image pixel values to [0, 1]gen_imgs = 0.5 * gen_imgs + 0.5# Set image gridfig, axs = plt.subplots(image_grid_rows,image_grid_columns,figsize=(4, 4),sharey=True,sharex=True)cnt = 0for i in range(image_grid_rows):for j in range(image_grid_columns):# Output a grid of imagesaxs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1

运行模型

# Set hyperparameters
iterations = 20000
batch_size = 128
sample_interval = 1000# Train the GAN for the specified number of iterations
train(iterations, batch_size, sample_interval)

结果​​​​​​​

 

 虽然不是很完美,但是简单的双层生成器学会了生成逼真的数字


http://chatgpt.dhexx.cn/article/lm22imZm.shtml

相关文章

【pytorch】CGAN编程实现

CGAN介绍 由于原始GAN生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标类别不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于…

GAN,CGAN,DCGAN

GAN对抗生成网络 训练流程 图片以及训练过程来源 训练这样的两个模型的大方法就是单独交替迭代训练。 我们人为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,我们就已经默认真样本集所有的类标签都为1&#xf…

GAN论文阅读——CGAN

论文标题:Conditional Generative Adversarial Nets 论文链接:https://arxiv.org/pdf/1411.1784.pdf 参考资料:http://blog.csdn.net/solomon1558/article/details/52555083 一、CGAN的思想 在原始GAN学习笔记中,我们提到过&am…

PyTorch随笔 - 生成对抗网络的改进cGAN和LSGAN

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/129939225 本文介绍GAN的两个常见改进,cGAN和LSGAN,两者一般结合使用。 cGAN: Conditional Generative Adversa…

CGAN实现过程

本文目录 一、原理二、参数初始化1. G的输入2. D的输入3. 模型参数初始化4. 测试噪声 三、执行过程四、测试 本文用MNIST数据集进行训练,并用图解的方法展示了CGAN与GAN中输入的区别,帮助理解CGAN的运行过程 一、原理 如下图所示,我们在输入…

CGAN

CGAN 生成符合特定描述的输出, 如图:给定描述,生成相应内容图片 CGAN中的negetive情况包含两种,(正确的描述,不清晰的图片)和(不正确的描述,清晰的图片)&a…

CGAN 简介与代码实战

1.介绍 原始GAN(GAN 简介与代码实战_天竺街潜水的八角的博客-CSDN博客)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约…

CGAN原理分析

1、CGAN原理分析 1.1 网络结构 CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息y,实现条件生成模型。CGAN原文中作者…

CGAN理论讲解及代码实现

目录 1.原始GAN的缺点 2.CGAN中心思想 3.原始GAN和CGAN的区别 4.CGAN代码实现 5.运行结果 6.CGAN缺陷 1.原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。 针对原始…

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信…

解决关于Navicat破解安装过程中出现“rsa public key not find”

解决关于Navicat破解安装过程中出现“rsa public key not find” 问题描述解决办法 问题描述 出现“rsa public key not find”的输出框 解决办法 首先先安装Navicat。安装后先不要运行 打开 然后点击 如果出现 则是对的 如果出现这个: 那就请你找到在本地的na…

navicat安装与激活

原文网址:https://www.jianshu.com/p/5f693b4c9468?mTypeGroup 一、Navicat Premium 12下载 Navicat Premium 12是一套数据库开发管理工具,支持连接 MySQL、Oracle等多种数据库,可以快速轻松地创建、管理和维护数据库。 Navicat Premium 12简…

Mac上安装 Navicat

1.下载安装包 Mac版 Navicat Premium 12 v12.0.23.0 官网下载地址: 英文64位 http://download.navicat.com/download/navicat120_premium_en.dmg 中文简体64位 http://download.navicat.com/download/navicat120_premium_cs.dmg 中文简体安装包:链接:h…

Navicat Premium 12.1.21 最新版激活工具及方法

At The Beginning ****** Sincerely regards to the author of the original work ******* 本帖持续更新 Last updated at 21st Aug 2019 Steps navicat_premium原版安装包 官网下载地址:https://www.navicat.com.cn/download/navicat-premium 注册工具下载 git…

Navicat Premium 12.1.16.0安装与激活

一、Navicat Premium 12下载 Navicat Premium 12简体中文下载; 提取码:cgv4 二、Navicat Premium 12安装 双击安装,点击下一步: 同意协议,点击下一步: 选择安装位置(可默认)&…

Navicat Premium安装和激活

前言 Navicat Premium这个软件是非常的好用,这个软件中包含mysql,SQL Server等等的数据库,受到广大编程爱好者的欢迎,废话不多说,下面就直接进入主题,马上就是Navicat Premium安装和激活的环节。 (1&…

Navicat Premium 12破解激活

下载Navicat Premium 12并安装; 蓝奏云下载:Navicat Premium 12注册机 重要提示:该注册机来源于DeltaFoX。一般来说,由于注册机会修改.exe文件或.dll文件,加壳并且没有数字签名,所以杀毒软件会报毒。如需…

Navicat v15

特别注意: 1.断网,否则在安装过程中会失败2.关闭防火墙及杀毒软件   3.选择对应版本:mysql版就选择mysql 出现如下情况: 就卸载,删除注册表,重新安装,出现rsa public key not find的错误 以及 generate first a serial 错误都也是如此删除注册表的办法打开文件&am…

关于Navicat 数据库一直激活不成功的解决方法

首先激活时一直出现 rsa public key not found,说明获取不到激活码,此时就需要检查 - 在Patch的时候是不是没成功 使用破解软件如果出现说已经patch过了的时候赶紧卸载重装!! - 在激活的时候是不是没有断开网络 解决办法&#…

Navicat安装激活

有条件的同学麻烦不要使用下面的激活步骤,仅供个人学习使用 。。。。。。 。。。。。。 。。。。。。 。。。。。。 一、去官网下载最新Navicat软件https://www.navicat.com.cn/download/navicat-premium 二、去下载激活脚本https://github.com/DoubleLabyrinth/nav…