【pytorch】CGAN编程实现

article/2025/11/5 20:29:04

CGAN介绍

由于原始GAN生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标类别不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信息y 融入生成器G和判别器D中,属性y可以是任何标签信息, 例如图像的类别、人脸图像的面部表情等。

cGAN的中心思想是希望 可以控制 GAN 生成的图片,而不 是单纯的随机生成图片。 具体来说,Conditional GAN 在生成器和判别器的输入中 增加了额外的 条件信息,生成器生成的图片只有足够真实 且与条件相符,才能够通过判别器。

实际上 , 在无条件约束的生成模型中 , 没法控制数据生成的模式。然而,通过额外的信息对模型进行约束,有可能指导数据生成的过程。条件约束可以是类标签 , 可以是图像修补的部分数据, 甚至是来自不同模态的数据

在这里插入图片描述
如果将上图绿色部分的y去掉,就是GAN的原理图。

为了实现条件GAN的目的,生成网络和判别网络的原理和 训练方式均要有所改变。

模型部分,在判别器和生成器中都添加了额外信息 y,y 可 以是类别标签或者是其他类型的数据,可以将 y 作为一个 额外的输入层丢入判别器和生成器。

在生成器中,作者将输入噪声 z 和 y 连在一起隐含表示, 带条件约束这个简单直接的改进被证明非常有效,并广泛用 于后续的相关工作中。论文是在MNIST数据集上以类别标 签为条件变量,生成指定类别的图像。作者还探索了CGAN 在用于图像自动标注的多模态学习上的应用,在MIR Flickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。

缺陷

cGAN生成的图像虽有很多缺陷,譬如图像边缘模糊,生成的图像分辨率太低等,但是它为后面的pix2pixGAN和CycleGAN开拓了道路,这两个模型转换图像风格时对属性特征的 处理方法均受cGAN启发。

Code

详细解释都在代码里标注了~

import torch,torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import random,numpy.random#设置随机种子, numpy, pytorch, python随机种子
def seed_torch(seed=2021):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.backends.cudnn.deterministic = True
seed_torch()#rusume是否使用预训练模型继续训练,问号处输入模型的编号
resume = True   #是继续训练,否重新训练
datasets = 'cifar10'  #选择cifar10,  mnist, fashion_mnist,STL10,Animeif datasets == 'cifar10' or  datasets=='STL10'or datasets=='Anime':nc = 3  #图片的通道数
elif datasets == 'mnist' or datasets== 'fashion_mnist':nc = 1
else:print('数据集选择错误')#类别数
n_classes = 10#控制生成器生成指定标签的图片
target_label=4#训练批次数
batch_size = 32#噪声向量的维度
nz = 100 #判别器的深度
ndf = 64
#生成器的深度
ngf = 64#真实标签
real_label = 1.0
#假标签
fake_label = 0.0
start_epoch = 0#模型#生成器                             #(N,nz+n_classes, 1,1)
netG = nn.Sequential(nn.ConvTranspose2d(nz+n_classes, ngf*8,4, 1,0,   bias=False), nn.BatchNorm2d(ngf*8), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,  bias=False), nn.BatchNorm2d(ngf*4), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*4, ngf*4,4,2, 1,bias=False), nn.BatchNorm2d(ngf*4), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*4, ngf*2,4,2, 1,bias=False), nn.BatchNorm2d(ngf*2), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*2, ngf*2,4,2, 1,bias=False), nn.BatchNorm2d(ngf*2), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*2, nc,4,2,1,    bias=False), nn.Tanh()  #(N,nc, 128,128))#判别器             #(N,nc+n_classes, 128,128)
netD = nn.Sequential(nn.Conv2d(nc+n_classes,   ndf*2, 4,2,1, bias=False), nn.BatchNorm2d(ndf*2),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*2,ndf*2, 4,2,1, bias=False), nn.BatchNorm2d(ndf*2),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*2,  ndf*4,4,2,1,bias=False),nn.BatchNorm2d(ndf*4),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*4,ndf*4,4,2,1,  bias=False),  nn.BatchNorm2d(ndf*4),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*4,ndf*8,4,2,1,  bias=False),  nn.BatchNorm2d(ndf*8),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*8,1,  4,1,0,    bias=False),  #(N,1,1,1)nn.Flatten(),    #(N,1)nn.Sigmoid())# custom weights initialization called on netG and netD
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find('BatchNorm') != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)netD.apply(weights_init)
netG.apply(weights_init)#加载数据集
apply_transform1 = transforms.Compose([transforms.Resize((128,128)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])apply_transform2 = transforms.Compose([transforms.Resize((128,128)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),])if datasets == 'cifar100':train_dataset = torchvision.datasets.CIFAR100(root='./Data', train=False, download=True,transform=apply_transform1)
elif datasets == 'cifar10':train_dataset = torchvision.datasets.CIFAR10(root='./Data', train=False, download=True,transform=apply_transform1)
elif datasets == 'STL10':train_dataset = torchvision.datasets.STL10(root='../data/STL10', split='train', download=True,transform=apply_transform1)
elif datasets == 'mnist':train_dataset = torchvision.datasets.MNIST(root='../data/mnist', train=False, download=True,transform=apply_transform2)
elif datasets == 'fashion_mnist':train_dataset = torchvision.datasets.FashionMNIST(root='../data/fashion_mnist', train=False, download=True,transform=apply_transform2)
elif datasets == 'Anime':train_dataset = torchvision.datasets.ImageFolder(root='../data/Anime',transform=apply_transform1)
else:print('数据集不存在')train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)#定义损失函数
criterion = torch.nn.BCELoss()
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')# setup optimizer
# optimizerD = torch.optim.Adam(netD.parameters(), lr=0.0002,betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0002,betas=(0.5,0.999))
optimizerD = torch.optim.RMSprop(netD.parameters(),lr=0.0002,alpha=0.99,eps=1e-08,weight_decay=0,momentum=0,)#显示16张图片if datasets=='Anime':image,label = next(iter(train_loader))image = (image*0.5+0.5)[:16]
elif datasets=='mnist' or datasets=='fashion_mnist':image = next(iter(train_loader))[0]image = image[:16]*0.5+0.5elif datasets=='STL10' :image = torch.Tensor(train_dataset.data[:16]/255)
else:image = torch.Tensor(train_dataset.data[:16]/255).permute(0,3,1,2)
plt.imshow(torchvision.utils.make_grid(image,nrow=4).permute(1,2,0))lb = LabelBinarizer()
lb.fit(list(range(0,n_classes)))#将标签进行one-hot编码
def to_categrical(y: torch.FloatTensor):y_one_hot = lb.transform(y.cpu())floatTensor = torch.FloatTensor(y_one_hot)return floatTensor.to(device)#样本和one-hot标签进行连接,以此作为条件生成
def concanate_data_label(data, y):  #data (N,nc, 128,128)y_one_hot = to_categrical(y)  #(N,1)->(N,n_classes)con = torch.cat((data, y_one_hot), 1)return con#如果继续训练,就加载预训练模型
if resume:print('==> Resuming from checkpoint..')checkpoint = torch.load('./checkpoint/GAN_%s_best.pth'%datasets)netG.load_state_dict(checkpoint['net_G'])                  netD.load_state_dict(checkpoint['net_D'])start_epoch = checkpoint['start_epoch']
#print('netG:','\n',netG)
#print('netD:','\n',netD)print('training on:   ',device, '   start_epoch',start_epoch)netD, netG = netD.to(device), netG.to(device)#固定生成器,训练判别器
for epoch in range(start_epoch,500):for batch, (data, target) in enumerate(train_loader):data = data.to(device)target = target.to(device)#拼接真实数据和标签target1 = to_categrical(target).unsqueeze(2).unsqueeze(3).float()  #加到噪声上 torch.Size([N, n_classes, 1, 1])target2 = target1.repeat(1, 1, data.size(2), data.size(3))   #加到数据上(N,n_classes,128,128)data = torch.cat((data, target2),dim=1)  #将标签与数据拼接 (N,channels,128,128),(N,n_classes, 128,128)->(N,channels+nc_classes,128,128)label = torch.full((data.size(0),1), real_label).to(device) # 按照shape,创建一模一样的向量#(1)训练判别器 #training real datanetD.zero_grad()output = netD(data)loss_D1 = criterion(output, label)loss_D1.backward() # 真实数据被判别为1 #training fake data,拼接噪声和标签noise_z = torch.randn(data.size(0), nz, 1, 1).to(device) # (N,噪声向量维度100,1,1)noise_z = torch.cat((noise_z, target1),dim=1) #(N,nz100+n_classes,1,1)#拼接假数据和标签fake_data = netG(noise_z) # 假数据来自噪声fake_data = torch.cat((fake_data,target2),dim=1) #(N,nc+n_classes,128,128)label = torch.full((data.size(0),1), fake_label).to(device) # (N,1)output = netD(fake_data.detach()) # (N,1)loss_D2 = criterion(output, label)loss_D2.backward()#更新判别器optimizerD.step()#(2)训练生成器netG.zero_grad()label = torch.full((data.size(0),1), real_label).to(device) # 像真的数据靠近output = netD(fake_data.to(device))lossG = criterion(output, label)lossG.backward()#更新生成器optimizerG.step()if batch %10==0:print('epoch: %4d, batch: %4d, discriminator loss: %.4f, generator loss: %.4f'%(epoch, batch, loss_D1.item()+loss_D2.item(), lossG.item()))#每2个epoch保存图片if epoch%2==0 and batch==0:#生成指定target_label的图片noise_z1 = torch.randn(data.size(0), nz, 1, 1).to(device)target3 = to_categrical(torch.full((data.size(0),1), target_label)).unsqueeze(2).unsqueeze(3).float()  #加到噪声上,目的是生成类别4的图片noise_z = torch.cat((noise_z1, target3),dim=1) #(N,nz+n_classes,1,1)fake_data = netG(noise_z.to(device))#如果是单通道图片,那么就转成三通道进行保存if nc ==1:fake_data=torch.cat((fake_data,fake_data,fake_data),dim=1)   #fake_data(N,1,H,W)->(N,3,H,W)#保存图片data = fake_data.detach().cpu().permute(0,2,3,1) # 通道数放最后data = np.array(data)#保存单张图片,将数据还原data = (data*0.5+0.5) # 这一步缩放很重要,否则某些像素点可能会小于0plt.imsave('./generated_fake/%s/epoch_%d.png'%(datasets,epoch), data[0])torchvision.utils.save_image(fake_data[:16]*0.5+0.5,'./generated_fake/%s/epoch_%d_grid.png'%(datasets,epoch),nrow=4,normalize=True)#保存模型       state = {'net_G': netG.state_dict(),'net_D': netD.state_dict(),'start_epoch':epoch+1}torch.save(state, './checkpoint/GAN_%s_best.pth'%(datasets))torch.save(state, './checkpoint/GAN_%s_best_copy.pth'%(datasets))

效果如下:
在这里插入图片描述

此外,DCGAN请参阅:
https://blog.csdn.net/stay_zezo/article/details/115735276

参考:
https://blog.csdn.net/studyeboy/article/details/118724526
https://blog.csdn.net/m0_62128864/article/details/123972758


http://chatgpt.dhexx.cn/article/sjOAUFXB.shtml

相关文章

GAN,CGAN,DCGAN

GAN对抗生成网络 训练流程 图片以及训练过程来源 训练这样的两个模型的大方法就是单独交替迭代训练。 我们人为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,我们就已经默认真样本集所有的类标签都为1&#xf…

GAN论文阅读——CGAN

论文标题:Conditional Generative Adversarial Nets 论文链接:https://arxiv.org/pdf/1411.1784.pdf 参考资料:http://blog.csdn.net/solomon1558/article/details/52555083 一、CGAN的思想 在原始GAN学习笔记中,我们提到过&am…

PyTorch随笔 - 生成对抗网络的改进cGAN和LSGAN

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/129939225 本文介绍GAN的两个常见改进,cGAN和LSGAN,两者一般结合使用。 cGAN: Conditional Generative Adversa…

CGAN实现过程

本文目录 一、原理二、参数初始化1. G的输入2. D的输入3. 模型参数初始化4. 测试噪声 三、执行过程四、测试 本文用MNIST数据集进行训练,并用图解的方法展示了CGAN与GAN中输入的区别,帮助理解CGAN的运行过程 一、原理 如下图所示,我们在输入…

CGAN

CGAN 生成符合特定描述的输出, 如图:给定描述,生成相应内容图片 CGAN中的negetive情况包含两种,(正确的描述,不清晰的图片)和(不正确的描述,清晰的图片)&a…

CGAN 简介与代码实战

1.介绍 原始GAN(GAN 简介与代码实战_天竺街潜水的八角的博客-CSDN博客)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约…

CGAN原理分析

1、CGAN原理分析 1.1 网络结构 CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息y,实现条件生成模型。CGAN原文中作者…

CGAN理论讲解及代码实现

目录 1.原始GAN的缺点 2.CGAN中心思想 3.原始GAN和CGAN的区别 4.CGAN代码实现 5.运行结果 6.CGAN缺陷 1.原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。 针对原始…

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信…

解决关于Navicat破解安装过程中出现“rsa public key not find”

解决关于Navicat破解安装过程中出现“rsa public key not find” 问题描述解决办法 问题描述 出现“rsa public key not find”的输出框 解决办法 首先先安装Navicat。安装后先不要运行 打开 然后点击 如果出现 则是对的 如果出现这个: 那就请你找到在本地的na…

navicat安装与激活

原文网址:https://www.jianshu.com/p/5f693b4c9468?mTypeGroup 一、Navicat Premium 12下载 Navicat Premium 12是一套数据库开发管理工具,支持连接 MySQL、Oracle等多种数据库,可以快速轻松地创建、管理和维护数据库。 Navicat Premium 12简…

Mac上安装 Navicat

1.下载安装包 Mac版 Navicat Premium 12 v12.0.23.0 官网下载地址: 英文64位 http://download.navicat.com/download/navicat120_premium_en.dmg 中文简体64位 http://download.navicat.com/download/navicat120_premium_cs.dmg 中文简体安装包:链接:h…

Navicat Premium 12.1.21 最新版激活工具及方法

At The Beginning ****** Sincerely regards to the author of the original work ******* 本帖持续更新 Last updated at 21st Aug 2019 Steps navicat_premium原版安装包 官网下载地址:https://www.navicat.com.cn/download/navicat-premium 注册工具下载 git…

Navicat Premium 12.1.16.0安装与激活

一、Navicat Premium 12下载 Navicat Premium 12简体中文下载; 提取码:cgv4 二、Navicat Premium 12安装 双击安装,点击下一步: 同意协议,点击下一步: 选择安装位置(可默认)&…

Navicat Premium安装和激活

前言 Navicat Premium这个软件是非常的好用,这个软件中包含mysql,SQL Server等等的数据库,受到广大编程爱好者的欢迎,废话不多说,下面就直接进入主题,马上就是Navicat Premium安装和激活的环节。 (1&…

Navicat Premium 12破解激活

下载Navicat Premium 12并安装; 蓝奏云下载:Navicat Premium 12注册机 重要提示:该注册机来源于DeltaFoX。一般来说,由于注册机会修改.exe文件或.dll文件,加壳并且没有数字签名,所以杀毒软件会报毒。如需…

Navicat v15

特别注意: 1.断网,否则在安装过程中会失败2.关闭防火墙及杀毒软件   3.选择对应版本:mysql版就选择mysql 出现如下情况: 就卸载,删除注册表,重新安装,出现rsa public key not find的错误 以及 generate first a serial 错误都也是如此删除注册表的办法打开文件&am…

关于Navicat 数据库一直激活不成功的解决方法

首先激活时一直出现 rsa public key not found,说明获取不到激活码,此时就需要检查 - 在Patch的时候是不是没成功 使用破解软件如果出现说已经patch过了的时候赶紧卸载重装!! - 在激活的时候是不是没有断开网络 解决办法&#…

Navicat安装激活

有条件的同学麻烦不要使用下面的激活步骤,仅供个人学习使用 。。。。。。 。。。。。。 。。。。。。 。。。。。。 一、去官网下载最新Navicat软件https://www.navicat.com.cn/download/navicat-premium 二、去下载激活脚本https://github.com/DoubleLabyrinth/nav…

Navicat 12.1 Macos 激活指南

Navicat 12.1 Navicat从版本11开始使用,一直在macos上表现稳定,速度还快,操作也简单,比Mysql workbench好用多了, workbench总是会发生程序崩溃,修改数据还要点Apply键。对开发来说很不好用. 以下是整个过…