CGAN 简介与代码实战

article/2025/11/5 21:03:28

1.介绍
  原始GAN(GAN 简介与代码实战_天竺街潜水的八角的博客-CSDN博客)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约束,能生成我们想要的图片,这个时候,CGAN就横空出世了,更加详细的介绍参考论文:Conditional Generative Adversarial Nets
 

2.模型结构

 公式1是原始GAN的损失函数,公式2相对于公式1多了一个条件y,这个y可以是标签和图片中需要修复的部分(比如动物)等

 如果只看公式2,很难想象到,怎样才能把y当作条件来融入网络。看下图之后,我们很容易想到,条件y和待判别的图像被拼接(concat)起来就可以达到这个效果。 

 

3.模型特点

使用额外信息y对模型增加条件,可以指导数据生成过程

4.代码实现 keras

class CGAN():def __init__(self):# Input shapeself.img_rows = 28self.img_cols = 28self.channels = 1self.img_shape = (self.img_rows, self.img_cols, self.channels)self.num_classes = 10self.latent_dim = 100optimizer = Adam(0.0002, 0.5)# Build and compile the discriminatorself.discriminator = self.build_discriminator()self.discriminator.compile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise and the target label as input# and generates the corresponding digit of that labelnoise = Input(shape=(self.latent_dim,))label = Input(shape=(1,))img = self.generator([noise, label])# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated image as input and determines validity# and the label of that imagevalid = self.discriminator([img, label])# The combined model  (stacked generator and discriminator)# Trains generator to fool discriminatorself.combined = Model([noise, label], valid)self.combined.compile(loss=['binary_crossentropy'],optimizer=optimizer)def build_generator(self):model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))model.summary()noise = Input(shape=(self.latent_dim,))label = Input(shape=(1,), dtype='int32')label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))model_input = multiply([noise, label_embedding])img = model(model_input)return Model([noise, label], img)def build_discriminator(self):model = Sequential()model.add(Dense(512, input_dim=np.prod(self.img_shape)))model.add(LeakyReLU(alpha=0.2))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.4))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.4))model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=self.img_shape)label = Input(shape=(1,), dtype='int32')label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))flat_img = Flatten()(img)model_input = multiply([flat_img, label_embedding])validity = model(model_input)return Model([img, label], validity)def train(self, epochs, batch_size=128, sample_interval=50):# Load the dataset(X_train, y_train), (_, _) = mnist.load_data()# Configure inputX_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------#  Train Discriminator# ---------------------# Select a random half batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs, labels = X_train[idx], y_train[idx]# Sample noise as generator inputnoise = np.random.normal(0, 1, (batch_size, 100))# Generate a half batch of new imagesgen_imgs = self.generator.predict([noise, labels])# Train the discriminatord_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  Train Generator# ---------------------# Condition on labelssampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)# Train the generatorg_loss = self.combined.train_on_batch([noise, sampled_labels], valid)# Plot the progressprint ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# If at save interval => save generated image samplesif epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 2, 5noise = np.random.normal(0, 1, (r * c, 100))sampled_labels = np.arange(0, 10).reshape(-1, 1)gen_imgs = self.generator.predict([noise, sampled_labels])# Rescale images 0 - 1gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()


http://chatgpt.dhexx.cn/article/e5t7lcCG.shtml

相关文章

CGAN原理分析

1、CGAN原理分析 1.1 网络结构 CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息y,实现条件生成模型。CGAN原文中作者…

CGAN理论讲解及代码实现

目录 1.原始GAN的缺点 2.CGAN中心思想 3.原始GAN和CGAN的区别 4.CGAN代码实现 5.运行结果 6.CGAN缺陷 1.原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。 针对原始…

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信…

解决关于Navicat破解安装过程中出现“rsa public key not find”

解决关于Navicat破解安装过程中出现“rsa public key not find” 问题描述解决办法 问题描述 出现“rsa public key not find”的输出框 解决办法 首先先安装Navicat。安装后先不要运行 打开 然后点击 如果出现 则是对的 如果出现这个: 那就请你找到在本地的na…

navicat安装与激活

原文网址:https://www.jianshu.com/p/5f693b4c9468?mTypeGroup 一、Navicat Premium 12下载 Navicat Premium 12是一套数据库开发管理工具,支持连接 MySQL、Oracle等多种数据库,可以快速轻松地创建、管理和维护数据库。 Navicat Premium 12简…

Mac上安装 Navicat

1.下载安装包 Mac版 Navicat Premium 12 v12.0.23.0 官网下载地址: 英文64位 http://download.navicat.com/download/navicat120_premium_en.dmg 中文简体64位 http://download.navicat.com/download/navicat120_premium_cs.dmg 中文简体安装包:链接:h…

Navicat Premium 12.1.21 最新版激活工具及方法

At The Beginning ****** Sincerely regards to the author of the original work ******* 本帖持续更新 Last updated at 21st Aug 2019 Steps navicat_premium原版安装包 官网下载地址:https://www.navicat.com.cn/download/navicat-premium 注册工具下载 git…

Navicat Premium 12.1.16.0安装与激活

一、Navicat Premium 12下载 Navicat Premium 12简体中文下载; 提取码:cgv4 二、Navicat Premium 12安装 双击安装,点击下一步: 同意协议,点击下一步: 选择安装位置(可默认)&…

Navicat Premium安装和激活

前言 Navicat Premium这个软件是非常的好用,这个软件中包含mysql,SQL Server等等的数据库,受到广大编程爱好者的欢迎,废话不多说,下面就直接进入主题,马上就是Navicat Premium安装和激活的环节。 (1&…

Navicat Premium 12破解激活

下载Navicat Premium 12并安装; 蓝奏云下载:Navicat Premium 12注册机 重要提示:该注册机来源于DeltaFoX。一般来说,由于注册机会修改.exe文件或.dll文件,加壳并且没有数字签名,所以杀毒软件会报毒。如需…

Navicat v15

特别注意: 1.断网,否则在安装过程中会失败2.关闭防火墙及杀毒软件   3.选择对应版本:mysql版就选择mysql 出现如下情况: 就卸载,删除注册表,重新安装,出现rsa public key not find的错误 以及 generate first a serial 错误都也是如此删除注册表的办法打开文件&am…

关于Navicat 数据库一直激活不成功的解决方法

首先激活时一直出现 rsa public key not found,说明获取不到激活码,此时就需要检查 - 在Patch的时候是不是没成功 使用破解软件如果出现说已经patch过了的时候赶紧卸载重装!! - 在激活的时候是不是没有断开网络 解决办法&#…

Navicat安装激活

有条件的同学麻烦不要使用下面的激活步骤,仅供个人学习使用 。。。。。。 。。。。。。 。。。。。。 。。。。。。 一、去官网下载最新Navicat软件https://www.navicat.com.cn/download/navicat-premium 二、去下载激活脚本https://github.com/DoubleLabyrinth/nav…

Navicat 12.1 Macos 激活指南

Navicat 12.1 Navicat从版本11开始使用,一直在macos上表现稳定,速度还快,操作也简单,比Mysql workbench好用多了, workbench总是会发生程序崩溃,修改数据还要点Apply键。对开发来说很不好用. 以下是整个过…

激活navicat提示rsa public key not find的问题

操作顺序先不打开Navicat,注机patch,然后再开Navicat注册 卸载原来的navicat重新安装再次点击patch选择路径就行了 还不行就记得,右键激活工具以管理员权限打开激活再次patch选择navicat的安装好的navicat.exe文件即可

navicat premiun 12激活

注册机: https://download.csdn.net/download/qq_31967985/10545930 步骤: 以管理员身份运行此注册机: 运行注册机 打开注册机后,1) Patch勾选Backup、Host和Navicat v12,然后点击Patch按钮: 默认勾选 …

Navicat安装激活时提示激活失败: 激活次数达到上限,90010003

Navicat安装激活时提示激活失败: 激活次数达到上限,90010003 问题: 概述: 激活失败。原因可能是由于已达到激活次数上限。请检查你是否已在卸载或重新安装 Navicat前取消激活许可证密钥。 90010003解决方法: 通过查看C:\WINDOW…

解决Navicat激活、注册时候出现No All Pattern Found的问题

用Navicat Keygen Patch v5.6.0.exe注册激活Navicat15时,出现No All Pattern Found的错误,具体原因是navicat注册表问题,或navicat之前已经安装过了,所以在注册时候,会出现这个错误。 解决办法: 1)删除注册…

navicat激活失败

WINR输入命令regedit打开注册表 以此展开定位计算机 \HKEY_CURRENT_USER\SOFTWARE\PremiumSoft,

Navicat Premium12 安装与激活

Navicat Premium 这个是第三方的客户端工具,比较轻便,可以远程登录数据库 安装以及破解教程 一、安装包下载安装: 链接: https://pan.baidu.com/s/1W47ECdPx8a2k5_2h2KYhuw 提取码: sfai 下一步即可; 二、破解 破解补丁下载…