CGAN实现过程

article/2025/11/5 20:40:59

本文目录

    • 一、原理
    • 二、参数初始化
      • 1. G的输入
      • 2. D的输入
      • 3. 模型参数初始化
      • 4. 测试噪声
    • 三、执行过程
    • 四、测试

本文用MNIST数据集进行训练,并用图解的方法展示了CGAN与GAN中输入的区别,帮助理解CGAN的运行过程

一、原理

如下图所示,我们在输入噪声z时,额外加上一个限制条件conditionz和c通过生成器G得到生成的图片

二、参数初始化

有了上面的原理解释,我们就可以来初始化我们的参数了,大致可以看出我们有如下几个参数:噪声z,条件c,真实图片x,生成器和判别器的初始化参数

  • G的输入:z_y_vec_
  • D的输入:xy_fill_
  • 模型参数的初始化
  • 测试时用的噪声sample_z_以及对应的标签sample_y_

这里输入的单个噪声维度为z_dim=62,当然这里还有很多其他的初始化,比如optimizer等,因为本文主要介绍模型的的具体执行过程,所以只对变量得初始化做介绍

1. G的输入

  • 输入噪声z:z_: (64, 62)
  • 输入条件c:y_vec_:(64, 10)

最终G的输入:横向拼接z+c (64, 72)

G:
torch.Size([64, 72])
tensor([[0.8920, 0.9742, 0.6876,  ..., 0.0000, 0.0000, 0.0000],[0.5271, 0.6423, 0.7480,  ..., 0.0000, 1.0000, 0.0000],[0.9545, 0.6324, 0.9603,  ..., 0.0000, 0.0000, 0.0000],...,[0.1931, 0.7773, 0.8154,  ..., 0.0000, 0.0000, 0.0000],[0.0049, 0.7129, 0.3272,  ..., 0.0000, 0.0000, 0.0000],[0.2902, 0.1194, 0.0020,  ..., 0.0000, 1.0000, 0.0000]])

在这里插入图片描述

2. D的输入

  • 输入真实数据:x: (64, 1, 28, 28)
  • 输入生成数据:G(z):(64, 1, 28, 28)
  • 输入条件:c:y_fill_:(64, 10, 28, 28)

最终D的输入:横向拼接x+c (64, 11, 28, 28),也就是说取batch中的一个值,维度为(1,28, 28),将其作为(11, 28, 28)的第一维,剩下的十维如果标签为0则第二维为全1,剩下的为全0,如果标签为1则第三维为全1,剩下的为全0,以此类推

D:
torch.Size([64, 11, 28, 28])
tensor([[[[ 0.1099, -0.5590,  0.9668,  ...,  3.0843,  0.6788, -0.4171],[ 0.8949, -0.3523, -0.4086,  ..., -0.8257, -2.1445,  1.0512],[ 1.5333, -0.0918, -1.1146,  ..., -1.1746, -0.4689,  0.3702],[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],...,[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

在这里插入图片描述

3. 模型参数初始化

def initialize_weights(net):for m in net.modules():if isinstance(m, nn.Conv2d):m.weight.data.normal_(0, 0.02)m.bias.data.zero_()elif isinstance(m, nn.ConvTranspose2d):m.weight.data.normal_(0, 0.02)m.bias.data.zero_()elif isinstance(m, nn.Linear):m.weight.data.normal_(0, 0.02)m.bias.data.zero_()

4. 测试噪声

在测试时我们只需要设置G的输入就可以了,也就是说我们需要:

  • 输入噪声z:z_: (100, 62)
  • 输入条件c:y_vec_:(100, 10)

最终G的输入:横向拼接z+c (100, 72)

下面给出代码和输出

# fixed noise
sample_z_ = torch.randn((100, 62))
for i in range(10):sample_z_[i*10] = torch.rand(1, 62)for j in range(1, 10):sample_z_[i*10 + j] = sample_z_[i*10]
print(sample_z_)
"""
sample_z_:(100, 62)0-9:    same value10-19:  same value...90-99:  same value
"""
temp = torch.zeros((10, 1))     # (10,1)---> 0,0,0,0,0,0,0,0,0,0
for i in range(10):temp[i, 0] = i                     # (10, 1) ---> 0,1,2,3,4,5,6,7,8,9
# print("temp:      ", temp)temp_y = torch.zeros((100, 1))  #(100,1)---> 0,0,0,0,...,0,0,0,0
for i in range(10):             #(100,1)---> 0,1,2,3,...,6,7,8,9temp_y[i*10: (i+1)*10] = temp
# print("temp_y:    ", temp_y)           
sample_y_ = torch.zeros((100, 10)).scatter_(1, temp_y.type(torch.LongTensor), 1)
print(sample_y_)                       #(100,10)
'''
tensor([[0.3944, 0.9880, 0.4956,  ..., 0.0602, 0.9869, 0.5094],[0.3944, 0.9880, 0.4956,  ..., 0.0602, 0.9869, 0.5094],[0.3944, 0.9880, 0.4956,  ..., 0.0602, 0.9869, 0.5094],...,[0.2845, 0.7694, 0.9878,  ..., 0.3211, 0.0242, 0.0332],[0.2845, 0.7694, 0.9878,  ..., 0.3211, 0.0242, 0.0332],[0.2845, 0.7694, 0.9878,  ..., 0.3211, 0.0242, 0.0332]])
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
'''

下面给出详细的解释,我们知道G的输入有噪声以及条件,这里我们有100组噪声,每10组噪声的组内取值是完全相同的,但是组内的10个噪声每个噪声的条件是不同的,分别代表了数字0-9

也就是说我们希望用相同的噪声生成0-9一共十个数字,生成十组

三、执行过程

图中的红线代表一个执行流程,绿线代表一个执行流程,红色的方框为这一步反向传播的网络。因为判别器与生成器是分开训练的,用两个图来表示,左边是第一步训练判别器,右边是第二步训练生成器

  • step1:首先将样本进行输入,用BCE_loss来评估得到D_real_loss,然后将G生成的数据进行输入,同理评估得到D_fake_loss,将二者相加进行反向传播优化D。注意这一步不要优化G
  • step2:直接将G生成的数据进行输入,评估得到G_loss,反向传播优化G。注意这一步虽然是G生成的数据,但是通过D以后要与real进行求损失

在这里插入图片描述

四、测试

训练完后直接进行测试即可,最后测试生成的图片如下:


http://chatgpt.dhexx.cn/article/tHglOsQo.shtml

相关文章

CGAN

CGAN 生成符合特定描述的输出, 如图:给定描述,生成相应内容图片 CGAN中的negetive情况包含两种,(正确的描述,不清晰的图片)和(不正确的描述,清晰的图片)&a…

CGAN 简介与代码实战

1.介绍 原始GAN(GAN 简介与代码实战_天竺街潜水的八角的博客-CSDN博客)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约…

CGAN原理分析

1、CGAN原理分析 1.1 网络结构 CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息y,实现条件生成模型。CGAN原文中作者…

CGAN理论讲解及代码实现

目录 1.原始GAN的缺点 2.CGAN中心思想 3.原始GAN和CGAN的区别 4.CGAN代码实现 5.运行结果 6.CGAN缺陷 1.原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。 针对原始…

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信…

解决关于Navicat破解安装过程中出现“rsa public key not find”

解决关于Navicat破解安装过程中出现“rsa public key not find” 问题描述解决办法 问题描述 出现“rsa public key not find”的输出框 解决办法 首先先安装Navicat。安装后先不要运行 打开 然后点击 如果出现 则是对的 如果出现这个: 那就请你找到在本地的na…

navicat安装与激活

原文网址:https://www.jianshu.com/p/5f693b4c9468?mTypeGroup 一、Navicat Premium 12下载 Navicat Premium 12是一套数据库开发管理工具,支持连接 MySQL、Oracle等多种数据库,可以快速轻松地创建、管理和维护数据库。 Navicat Premium 12简…

Mac上安装 Navicat

1.下载安装包 Mac版 Navicat Premium 12 v12.0.23.0 官网下载地址: 英文64位 http://download.navicat.com/download/navicat120_premium_en.dmg 中文简体64位 http://download.navicat.com/download/navicat120_premium_cs.dmg 中文简体安装包:链接:h…

Navicat Premium 12.1.21 最新版激活工具及方法

At The Beginning ****** Sincerely regards to the author of the original work ******* 本帖持续更新 Last updated at 21st Aug 2019 Steps navicat_premium原版安装包 官网下载地址:https://www.navicat.com.cn/download/navicat-premium 注册工具下载 git…

Navicat Premium 12.1.16.0安装与激活

一、Navicat Premium 12下载 Navicat Premium 12简体中文下载; 提取码:cgv4 二、Navicat Premium 12安装 双击安装,点击下一步: 同意协议,点击下一步: 选择安装位置(可默认)&…

Navicat Premium安装和激活

前言 Navicat Premium这个软件是非常的好用,这个软件中包含mysql,SQL Server等等的数据库,受到广大编程爱好者的欢迎,废话不多说,下面就直接进入主题,马上就是Navicat Premium安装和激活的环节。 (1&…

Navicat Premium 12破解激活

下载Navicat Premium 12并安装; 蓝奏云下载:Navicat Premium 12注册机 重要提示:该注册机来源于DeltaFoX。一般来说,由于注册机会修改.exe文件或.dll文件,加壳并且没有数字签名,所以杀毒软件会报毒。如需…

Navicat v15

特别注意: 1.断网,否则在安装过程中会失败2.关闭防火墙及杀毒软件   3.选择对应版本:mysql版就选择mysql 出现如下情况: 就卸载,删除注册表,重新安装,出现rsa public key not find的错误 以及 generate first a serial 错误都也是如此删除注册表的办法打开文件&am…

关于Navicat 数据库一直激活不成功的解决方法

首先激活时一直出现 rsa public key not found,说明获取不到激活码,此时就需要检查 - 在Patch的时候是不是没成功 使用破解软件如果出现说已经patch过了的时候赶紧卸载重装!! - 在激活的时候是不是没有断开网络 解决办法&#…

Navicat安装激活

有条件的同学麻烦不要使用下面的激活步骤,仅供个人学习使用 。。。。。。 。。。。。。 。。。。。。 。。。。。。 一、去官网下载最新Navicat软件https://www.navicat.com.cn/download/navicat-premium 二、去下载激活脚本https://github.com/DoubleLabyrinth/nav…

Navicat 12.1 Macos 激活指南

Navicat 12.1 Navicat从版本11开始使用,一直在macos上表现稳定,速度还快,操作也简单,比Mysql workbench好用多了, workbench总是会发生程序崩溃,修改数据还要点Apply键。对开发来说很不好用. 以下是整个过…

激活navicat提示rsa public key not find的问题

操作顺序先不打开Navicat,注机patch,然后再开Navicat注册 卸载原来的navicat重新安装再次点击patch选择路径就行了 还不行就记得,右键激活工具以管理员权限打开激活再次patch选择navicat的安装好的navicat.exe文件即可

navicat premiun 12激活

注册机: https://download.csdn.net/download/qq_31967985/10545930 步骤: 以管理员身份运行此注册机: 运行注册机 打开注册机后,1) Patch勾选Backup、Host和Navicat v12,然后点击Patch按钮: 默认勾选 …

Navicat安装激活时提示激活失败: 激活次数达到上限,90010003

Navicat安装激活时提示激活失败: 激活次数达到上限,90010003 问题: 概述: 激活失败。原因可能是由于已达到激活次数上限。请检查你是否已在卸载或重新安装 Navicat前取消激活许可证密钥。 90010003解决方法: 通过查看C:\WINDOW…

解决Navicat激活、注册时候出现No All Pattern Found的问题

用Navicat Keygen Patch v5.6.0.exe注册激活Navicat15时,出现No All Pattern Found的错误,具体原因是navicat注册表问题,或navicat之前已经安装过了,所以在注册时候,会出现这个错误。 解决办法: 1)删除注册…