CGAN理论讲解及代码实现

article/2025/11/5 21:10:31

目录

1.原始GAN的缺点

2.CGAN中心思想

3.原始GAN和CGAN的区别

4.CGAN代码实现 

5.运行结果

6.CGAN缺陷


1.原始GAN的缺点

生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。

针对原始GAN不能生成具有特定属性的图片的问题,Mehdi Mirza等人提出了CGAN,其核心在于将属性信息y融入生成器和判别器中,属性y可以是任何标签的信息,例如图像的类别,人脸图像的面部表情等。

2.CGAN中心思想

CGAN的中心思想是希望可以控制GAN生成的图片,而不是单纯的随机生成图片。具体来说,Conditional  GAN在生成器和判别器的输入中添加了额外的条件信息,生成器生成的图片只有足够真实且条件相符,才能够通过判别器。

3.原始GAN和CGAN的区别

从公式上来看,CGAN相当于在原始GAN的基础上对生成器部分和判别器部分都加了一个条件。

从模型上来看,如下图所示

为了实现条件GAN的目的,生成网络和判别网络的原理和训练方式均要有所改变。模型部分,在判别器和生成器中都添加了额外信息y,y可以是类别标签或者是其他类型的数据,可以将y作为一个额外的输入层丢入判别器和生成器。

4.CGAN代码实现 

#导入库
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision #加载图片
from torchvision import transforms #图片变换import numpy as np
import matplotlib.pyplot as plt #绘图
import os
import glob
from PIL import Image#独热编码
def one_hot(x,class_count=10):return torch.eye(class_count)[x,:]transform = transforms.Compose([transforms.ToTensor(), #取值范围会被归一化到(0,1)之间transforms.Normalize(mean=0.5,std=0.5) #设置均值和方差均为0.5
])#加载数据集
dataset = torchvision.datasets.MNIST('data',train=True,transform=transform,target_transform = one_hot,download = True)
dl = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle = True)#定义生成器
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.linear1 = nn.Linear(100,128*7*7)self.bn1=nn.BatchNorm1d(128*7*7)self.linear2 = nn.Linear(10,128*7*7)self.bn2=nn.BatchNorm1d(128*7*7)self.deconv1 = nn.ConvTranspose2d(256,128,kernel_size=(3,3),stride=1,padding=1)  #生成(128,7,7)的二维图像self.bn3=nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128,64,kernel_size=(4,4),stride=2,padding=1)  #生成(64,14,14)的二维图像self.bn4=nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64,1,kernel_size=(4,4),stride=2,padding=1)  #生成(1,28,28)的二维图像def forward(self,x1,x2):x1=F.relu(self.linear1(x1))x1=self.bn1(x1)x1=x1.view(-1,128,7,7)x2=F.relu(self.linear2(x2))x2=self.bn2(x2)x2=x2.view(-1,128,7,7)x=torch.cat([x1,x2],axis=1)  #batch, 256, 7, 7x=x.view(-1,256,7,7)x=F.relu(self.deconv1(x))x=self.bn3(x)x=F.relu(self.deconv2(x))x=self.bn4(x)x=torch.tanh(self.deconv3(x))return x#定义判别器
#输入:1,28,28图片和长度为10的condition
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.linear = nn.Linear(10,1*28*28)self.conv1 = nn.Conv2d(2,64,kernel_size=3,stride=2)self.conv2 = nn.Conv2d(64,128,kernel_size=3,stride=2)self.bn = nn.BatchNorm2d(128)self.fc = nn.Linear(128*6*6,1)def forward(self,x1,x2): #x1代表label,x2代表imagex1=F.leaky_relu(self.linear(x1))x1=x1.view(-1,1,28,28)x=torch.cat([x1,x2],axis=1)  #shape:batch,2,28,28                x= F.dropout2d(F.leaky_relu(self.conv1(x)))x= F.dropout2d(F.leaky_relu(self.conv2(x)) )  #(batch,128,6,6)x = self.bn(x)x = x.view(-1,128*6*6) #展平x = torch.sigmoid(self.fc(x))return x#模型训练
#设备的配置
device='cuda' if torch.cuda.is_available() else 'cpu'
#初化生成器和判别器把他们放到相应的设备上
gen = Generator().to(device)
dis = Discriminator().to(device)
#交叉熵损失函数
loss_fn = torch.nn.BCELoss()
#训练器的优化器
d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-5)
#训练生成器的优化器
g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-4)
#定义可视化函数
def generate_and_save_images(model,epoch,label_input,noise_input):prediction = np.squeeze(model(noise_input,label_input).cpu().numpy())fig = plt.figure(figsize=(4,4))for i in range(prediction.shape[0]):plt.subplot(4,4,i+1)plt.imshow((prediction[i]+1)/2,cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))plt.show()
#设置生成绘图图片的随机张量,这里可视化16张图片
#生成16个长度为100的随机正态分布张量
noise_seed = torch.randn(16,100,device=device)
label_seed = torch.randint(0,10,size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)D_loss = [] #记录训练过程中判别器的损失
G_loss = [] #记录训练过程中生成器的损失
#训练循环
for epoch in range(10):#初始化损失值D_epoch_loss = 0G_epoch_loss = 0count = len(dl.dataset) #返回批次数#对数据集进行迭代for step,(img,label) in enumerate(dl):img =img.to(device) #把数据放到设备上label = label.to(device)size = img.shape[0] #img的第一位是size,获取批次的大小random_seed = torch.randn(size,100,device=device)#判别器训练(真实图片的损失和生成图片的损失),损失的构建和优化d_optimizer.zero_grad()#梯度归零#判别器对于真实图片产生的损失real_output = dis(label,img) #判别器输入真实的图片,real_output对真实图片的预测结果d_real_loss = loss_fn(real_output,torch.ones_like(real_output,device=device))d_real_loss.backward()#计算梯度#在生成器上去计算生成器的损失,优化目标是判别器上的参数generated_img = gen(random_seed,label) #得到生成的图片#因为优化目标是判别器,所以对生成器上的优化目标进行截断fake_output = dis(label,generated_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度就不会再传递到gen模型中了#判别器在生成图像上产生的损失d_fake_loss = loss_fn(fake_output,torch.zeros_like(fake_output,device=device))d_fake_loss.backward()#判别器损失disc_loss = d_real_loss + d_fake_loss#判别器优化d_optimizer.step()#生成器上损失的构建和优化g_optimizer.zero_grad() #先将生成器上的梯度置零fake_output = dis(label,generated_img)gen_loss = loss_fn(fake_output,torch.ones_like(fake_output,device=device))  #生成器损失gen_loss.backward()g_optimizer.step()#累计每一个批次的losswith torch.no_grad():D_epoch_loss +=disc_lossG_epoch_loss +=gen_loss#求平均损失with torch.no_grad():D_epoch_loss /=countG_epoch_loss /=countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)#训练完一个Epoch,打印提示并绘制生成的图片print("Epoch:",epoch)print(label_seed)generate_and_save_images(gen,epoch,label_seed_onehot,noise_seed)

5.运行结果

因篇幅有限,只展示一部分运行结果

 

 

 

 

 

 

 

6.CGAN缺陷

CGAN生成的图像虽然有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像网络时对属性特征的处理方法均受到CGAN启发。


希望我的文章能对你有所帮助。欢迎👍点赞 ,📝评论,🌟关注,⭐️收藏

                                                                    

 


http://chatgpt.dhexx.cn/article/yK9JjHfC.shtml

相关文章

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信…

解决关于Navicat破解安装过程中出现“rsa public key not find”

解决关于Navicat破解安装过程中出现“rsa public key not find” 问题描述解决办法 问题描述 出现“rsa public key not find”的输出框 解决办法 首先先安装Navicat。安装后先不要运行 打开 然后点击 如果出现 则是对的 如果出现这个: 那就请你找到在本地的na…

navicat安装与激活

原文网址:https://www.jianshu.com/p/5f693b4c9468?mTypeGroup 一、Navicat Premium 12下载 Navicat Premium 12是一套数据库开发管理工具,支持连接 MySQL、Oracle等多种数据库,可以快速轻松地创建、管理和维护数据库。 Navicat Premium 12简…

Mac上安装 Navicat

1.下载安装包 Mac版 Navicat Premium 12 v12.0.23.0 官网下载地址: 英文64位 http://download.navicat.com/download/navicat120_premium_en.dmg 中文简体64位 http://download.navicat.com/download/navicat120_premium_cs.dmg 中文简体安装包:链接:h…

Navicat Premium 12.1.21 最新版激活工具及方法

At The Beginning ****** Sincerely regards to the author of the original work ******* 本帖持续更新 Last updated at 21st Aug 2019 Steps navicat_premium原版安装包 官网下载地址:https://www.navicat.com.cn/download/navicat-premium 注册工具下载 git…

Navicat Premium 12.1.16.0安装与激活

一、Navicat Premium 12下载 Navicat Premium 12简体中文下载; 提取码:cgv4 二、Navicat Premium 12安装 双击安装,点击下一步: 同意协议,点击下一步: 选择安装位置(可默认)&…

Navicat Premium安装和激活

前言 Navicat Premium这个软件是非常的好用,这个软件中包含mysql,SQL Server等等的数据库,受到广大编程爱好者的欢迎,废话不多说,下面就直接进入主题,马上就是Navicat Premium安装和激活的环节。 (1&…

Navicat Premium 12破解激活

下载Navicat Premium 12并安装; 蓝奏云下载:Navicat Premium 12注册机 重要提示:该注册机来源于DeltaFoX。一般来说,由于注册机会修改.exe文件或.dll文件,加壳并且没有数字签名,所以杀毒软件会报毒。如需…

Navicat v15

特别注意: 1.断网,否则在安装过程中会失败2.关闭防火墙及杀毒软件   3.选择对应版本:mysql版就选择mysql 出现如下情况: 就卸载,删除注册表,重新安装,出现rsa public key not find的错误 以及 generate first a serial 错误都也是如此删除注册表的办法打开文件&am…

关于Navicat 数据库一直激活不成功的解决方法

首先激活时一直出现 rsa public key not found,说明获取不到激活码,此时就需要检查 - 在Patch的时候是不是没成功 使用破解软件如果出现说已经patch过了的时候赶紧卸载重装!! - 在激活的时候是不是没有断开网络 解决办法&#…

Navicat安装激活

有条件的同学麻烦不要使用下面的激活步骤,仅供个人学习使用 。。。。。。 。。。。。。 。。。。。。 。。。。。。 一、去官网下载最新Navicat软件https://www.navicat.com.cn/download/navicat-premium 二、去下载激活脚本https://github.com/DoubleLabyrinth/nav…

Navicat 12.1 Macos 激活指南

Navicat 12.1 Navicat从版本11开始使用,一直在macos上表现稳定,速度还快,操作也简单,比Mysql workbench好用多了, workbench总是会发生程序崩溃,修改数据还要点Apply键。对开发来说很不好用. 以下是整个过…

激活navicat提示rsa public key not find的问题

操作顺序先不打开Navicat,注机patch,然后再开Navicat注册 卸载原来的navicat重新安装再次点击patch选择路径就行了 还不行就记得,右键激活工具以管理员权限打开激活再次patch选择navicat的安装好的navicat.exe文件即可

navicat premiun 12激活

注册机: https://download.csdn.net/download/qq_31967985/10545930 步骤: 以管理员身份运行此注册机: 运行注册机 打开注册机后,1) Patch勾选Backup、Host和Navicat v12,然后点击Patch按钮: 默认勾选 …

Navicat安装激活时提示激活失败: 激活次数达到上限,90010003

Navicat安装激活时提示激活失败: 激活次数达到上限,90010003 问题: 概述: 激活失败。原因可能是由于已达到激活次数上限。请检查你是否已在卸载或重新安装 Navicat前取消激活许可证密钥。 90010003解决方法: 通过查看C:\WINDOW…

解决Navicat激活、注册时候出现No All Pattern Found的问题

用Navicat Keygen Patch v5.6.0.exe注册激活Navicat15时,出现No All Pattern Found的错误,具体原因是navicat注册表问题,或navicat之前已经安装过了,所以在注册时候,会出现这个错误。 解决办法: 1)删除注册…

navicat激活失败

WINR输入命令regedit打开注册表 以此展开定位计算机 \HKEY_CURRENT_USER\SOFTWARE\PremiumSoft,

Navicat Premium12 安装与激活

Navicat Premium 这个是第三方的客户端工具,比较轻便,可以远程登录数据库 安装以及破解教程 一、安装包下载安装: 链接: https://pan.baidu.com/s/1W47ECdPx8a2k5_2h2KYhuw 提取码: sfai 下一步即可; 二、破解 破解补丁下载…

Navicat premium 15激活教程及安装教程+报错解决办法

Navicat premium 15激活教程及安装教程报错解决办法 1、安装包和注册工具下载2、安装Navicate Premium 15,直接下一步安装即可,安装位置可以按照到D盘3、激活Navicate Premium 15打开安装包里面的Navicat Keygen Patch v5.6.0 DFoX.exe工具当点击path选择…

Navicat15 安装激活

** 安装激活注意事项: 1、必须断网! 2、注意先后顺序 ](https://blog.csdn.net/qq_42859450/article/details/126521267) ** 第一步:安装(一直下一步,中间会让你选择安装路径,也可不选择,默…