目录
1.原始GAN的缺点
2.CGAN中心思想
3.原始GAN和CGAN的区别
4.CGAN代码实现
5.运行结果
6.CGAN缺陷
1.原始GAN的缺点
生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。
针对原始GAN不能生成具有特定属性的图片的问题,Mehdi Mirza等人提出了CGAN,其核心在于将属性信息y融入生成器和判别器中,属性y可以是任何标签的信息,例如图像的类别,人脸图像的面部表情等。
2.CGAN中心思想
CGAN的中心思想是希望可以控制GAN生成的图片,而不是单纯的随机生成图片。具体来说,Conditional GAN在生成器和判别器的输入中添加了额外的条件信息,生成器生成的图片只有足够真实且条件相符,才能够通过判别器。
3.原始GAN和CGAN的区别
从公式上来看,CGAN相当于在原始GAN的基础上对生成器部分和判别器部分都加了一个条件。

从模型上来看,如下图所示

为了实现条件GAN的目的,生成网络和判别网络的原理和训练方式均要有所改变。模型部分,在判别器和生成器中都添加了额外信息y,y可以是类别标签或者是其他类型的数据,可以将y作为一个额外的输入层丢入判别器和生成器。
4.CGAN代码实现
#导入库
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision #加载图片
from torchvision import transforms #图片变换import numpy as np
import matplotlib.pyplot as plt #绘图
import os
import glob
from PIL import Image#独热编码
def one_hot(x,class_count=10):return torch.eye(class_count)[x,:]transform = transforms.Compose([transforms.ToTensor(), #取值范围会被归一化到(0,1)之间transforms.Normalize(mean=0.5,std=0.5) #设置均值和方差均为0.5
])#加载数据集
dataset = torchvision.datasets.MNIST('data',train=True,transform=transform,target_transform = one_hot,download = True)
dl = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle = True)#定义生成器
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.linear1 = nn.Linear(100,128*7*7)self.bn1=nn.BatchNorm1d(128*7*7)self.linear2 = nn.Linear(10,128*7*7)self.bn2=nn.BatchNorm1d(128*7*7)self.deconv1 = nn.ConvTranspose2d(256,128,kernel_size=(3,3),stride=1,padding=1) #生成(128,7,7)的二维图像self.bn3=nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128,64,kernel_size=(4,4),stride=2,padding=1) #生成(64,14,14)的二维图像self.bn4=nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64,1,kernel_size=(4,4),stride=2,padding=1) #生成(1,28,28)的二维图像def forward(self,x1,x2):x1=F.relu(self.linear1(x1))x1=self.bn1(x1)x1=x1.view(-1,128,7,7)x2=F.relu(self.linear2(x2))x2=self.bn2(x2)x2=x2.view(-1,128,7,7)x=torch.cat([x1,x2],axis=1) #batch, 256, 7, 7x=x.view(-1,256,7,7)x=F.relu(self.deconv1(x))x=self.bn3(x)x=F.relu(self.deconv2(x))x=self.bn4(x)x=torch.tanh(self.deconv3(x))return x#定义判别器
#输入:1,28,28图片和长度为10的condition
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.linear = nn.Linear(10,1*28*28)self.conv1 = nn.Conv2d(2,64,kernel_size=3,stride=2)self.conv2 = nn.Conv2d(64,128,kernel_size=3,stride=2)self.bn = nn.BatchNorm2d(128)self.fc = nn.Linear(128*6*6,1)def forward(self,x1,x2): #x1代表label,x2代表imagex1=F.leaky_relu(self.linear(x1))x1=x1.view(-1,1,28,28)x=torch.cat([x1,x2],axis=1) #shape:batch,2,28,28 x= F.dropout2d(F.leaky_relu(self.conv1(x)))x= F.dropout2d(F.leaky_relu(self.conv2(x)) ) #(batch,128,6,6)x = self.bn(x)x = x.view(-1,128*6*6) #展平x = torch.sigmoid(self.fc(x))return x#模型训练
#设备的配置
device='cuda' if torch.cuda.is_available() else 'cpu'
#初化生成器和判别器把他们放到相应的设备上
gen = Generator().to(device)
dis = Discriminator().to(device)
#交叉熵损失函数
loss_fn = torch.nn.BCELoss()
#训练器的优化器
d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-5)
#训练生成器的优化器
g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-4)
#定义可视化函数
def generate_and_save_images(model,epoch,label_input,noise_input):prediction = np.squeeze(model(noise_input,label_input).cpu().numpy())fig = plt.figure(figsize=(4,4))for i in range(prediction.shape[0]):plt.subplot(4,4,i+1)plt.imshow((prediction[i]+1)/2,cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))plt.show()
#设置生成绘图图片的随机张量,这里可视化16张图片
#生成16个长度为100的随机正态分布张量
noise_seed = torch.randn(16,100,device=device)
label_seed = torch.randint(0,10,size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)D_loss = [] #记录训练过程中判别器的损失
G_loss = [] #记录训练过程中生成器的损失
#训练循环
for epoch in range(10):#初始化损失值D_epoch_loss = 0G_epoch_loss = 0count = len(dl.dataset) #返回批次数#对数据集进行迭代for step,(img,label) in enumerate(dl):img =img.to(device) #把数据放到设备上label = label.to(device)size = img.shape[0] #img的第一位是size,获取批次的大小random_seed = torch.randn(size,100,device=device)#判别器训练(真实图片的损失和生成图片的损失),损失的构建和优化d_optimizer.zero_grad()#梯度归零#判别器对于真实图片产生的损失real_output = dis(label,img) #判别器输入真实的图片,real_output对真实图片的预测结果d_real_loss = loss_fn(real_output,torch.ones_like(real_output,device=device))d_real_loss.backward()#计算梯度#在生成器上去计算生成器的损失,优化目标是判别器上的参数generated_img = gen(random_seed,label) #得到生成的图片#因为优化目标是判别器,所以对生成器上的优化目标进行截断fake_output = dis(label,generated_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度就不会再传递到gen模型中了#判别器在生成图像上产生的损失d_fake_loss = loss_fn(fake_output,torch.zeros_like(fake_output,device=device))d_fake_loss.backward()#判别器损失disc_loss = d_real_loss + d_fake_loss#判别器优化d_optimizer.step()#生成器上损失的构建和优化g_optimizer.zero_grad() #先将生成器上的梯度置零fake_output = dis(label,generated_img)gen_loss = loss_fn(fake_output,torch.ones_like(fake_output,device=device)) #生成器损失gen_loss.backward()g_optimizer.step()#累计每一个批次的losswith torch.no_grad():D_epoch_loss +=disc_lossG_epoch_loss +=gen_loss#求平均损失with torch.no_grad():D_epoch_loss /=countG_epoch_loss /=countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)#训练完一个Epoch,打印提示并绘制生成的图片print("Epoch:",epoch)print(label_seed)generate_and_save_images(gen,epoch,label_seed_onehot,noise_seed)
5.运行结果
因篇幅有限,只展示一部分运行结果





6.CGAN缺陷
CGAN生成的图像虽然有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像网络时对属性特征的处理方法均受到CGAN启发。
希望我的文章能对你有所帮助。欢迎👍点赞 ,📝评论,🌟关注,⭐️收藏


















