GAN,CGAN,DCGAN

article/2025/11/5 20:16:42

GAN对抗生成网络

训练流程在这里插入图片描述

图片以及训练过程来源
训练这样的两个模型的大方法就是单独交替迭代训练

  • 我们人为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,我们就已经默认真样本集所有的类标签都为1,而假样本集的所有类标签都为0。这样单就判别网络来说,此时问题就变成了一个再简单不过的有监督的二分类问题了。

  • 如果我们把刚才的判别网络串接在生成网络的后面,这样我们就知道真假了,也就有了误差了,所以对于生成网络的训练其实是对生成-判别网络串接的训练。我们要把这些假样本的标签都设置为1,也就是认为这些假样本在生成网络训练的时候是真样本,这样才能起到迷惑判别器的目的,也才能使得生成的假样本逐渐逼近为正样本。现在对于生成网络的训练,我们有了样本集(只有假样本集,没有真样本集),有了对应的label(全为1),是不是就可以训练了?有人会问,这样只有一类样本,训练啥呀?谁说一类样本就不能训练了?只要有误差就行。

  • 训练这个串接的网络的时候,一个很重要的操作就是不要判别网络的参数发生变化,也就是不让它参数发生更新,只是把误差一直传,传到生成网络那块后更新生成网络的参数。这样就完成了生成网络的训练了。

  • 在完成生成网络训练后,我们可以根据目前新的生成网络再对先前的那些噪声Z生成新的假样本了,并且训练后的假样本更真了。然后又有了新的真假样本集(其实是新的假样本集),这样又可以重复上述过程了。

代码实现

github—GAN源码
代码分析参考csdn博客

os.makedirs("images", exist_ok=True)

创建子文件夹images,exist_ok取值为Ture时,如果已存在该文件夹,也不会报错。

# 初始化参数,如rpoch次数,batch_size的大小等,sample_interval表示后续间隔保存CGAN图片的保存间隔。
parser = argparse.ArgumentParser() # 声明一个parser
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") # 添加如下参数,并设定默认值
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") # 后面的help是添加的描述
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args() # 读取命令行参数
print(opt) # 调用这些参数
img_shape = (opt.channels, opt.img_size, opt.img_size)

这些参数opt.channels, opt.img_size是需要去上一部分设定的参数的位置去找的, 图像的通道数为1,尺寸大小为28*28

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return img

搭建生成器神经网络

inplace=True的意思是进行原地操作,例如x=x+5对x就是一个原地操作,虽然y=x+5,x=y完成了同样的功能但不是原地操作。

forward中的z是在程序后面的定义的高斯噪声信号,形状为64*100

img.size(0)为64,也就是一批次训练的数目batch_size的值。

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity
adversarial_loss = torch.nn.BCELoss()

定义了损失函数nn.BCELoss(),输入(X,Y), X 需要经过sigmoid, Y元素的值只能是0或1float

dataloader = torch.utils.data.DataLoader(datasets.MNIST('../../data/mnist', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=opt.batch_size, shuffle=True)

PyTorch中数据读取的一个重要接口是torch.utils.data.DataLoader,该接口定义在dataloader.py脚本中,该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此该接口有点承上启下的作用。简单来说就是你训练的数据集不是一股脑的全部丢进来,而是分成了一批一批的,这个接口函数就是将数据集分批并转化成可以处理的Tensor类型。

transforms.Compose([transforms.ToTensor(),transforms.Normalize(std=(0.5,0.5,0.5),mean=(0.5,0.5,0.5))])
其作用就是先将输入归一化到(0,1),再使用公式”(x-mean)/std”,将每个元素分布到(-1,1)

for i, (imgs, _) in enumerate(dataloader):# Adversarial ground truthsvalid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)# Configure input  将真实的图片转化为神经网络可以处理的变量。real_imgs = Variable(imgs.type(Tensor))

这部分定义的相当于是一个标准,vaild可以看成是64行1列的向量,为了在后面计算损失时和1比较;fake也是一样是全为0的向量,用法和1的用法相同。

        # -----------------#  Train Generator# -----------------optimizer_G.zero_grad()# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))# Generate a batch of imagesgen_imgs = generator(z)# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()

np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)的意思就是输入从0到1之间,形状为imgs.shape[0], opt.latent_dim的随机高斯数据。

        # ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()

GAN交替训练的过程:

  • 重新计算假样本(假样本每次是需要更新的,产生越来越像的样本)
  • 训练D网络,一个二分类的神经网络
  • 训练G网络,一个串联起来的长网络,也是一个二分类的神经网络(不过只有假样本来训练),同时D部分参数在下一次的时候不能变了。

CGAN

  • CGAN目的就是能指定生成什么样的数据。核心就是通过给原始的GAN生成器G和判别器D添加额外的条件信息,实现条件生成模型。额外信息可以是类别标签或者是其他的辅助信息,最直接的就是使用类别标签信息y
  • 原始的GAN生成器的输入信息是一固定长度的噪声信息,那么CGAN中则是将噪声信息结合标签信息组合起来
  • 作为输入,标签信息一般是采用one-hot编码构成。
  • 原始的GAN判别器输入是图像数据(真实的训练样本和生成器生成的数据),在CGAN中则是将类别标签和图像数据进行组合作为判别器的输入。

代码实现CGAN

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()## n_classes是dataset的类别数self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim + opt.n_classes, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, noise, labels):# Concatenate label embedding and image to produce inputgen_input = torch.cat((self.label_emb(labels), noise), -1)img = self.model(gen_input)img = img.view(img.size(0), *img_shape)return img
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)self.model = nn.Sequential(nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1),)def forward(self, img, labels):# Concatenate label embedding and image to produce inputd_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)validity = self.model(d_in)return validity
def sample_image(n_row, batches_done):"""Saves a grid of generated digits ranging from 0 to n_classes"""# Sample noisez = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))# Get labels ranging from 0 to n_classes for n rowslabels = np.array([num for _ in range(n_row) for num in range(n_row)])labels = Variable(LongTensor(labels))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "images02/%d.png" % batches_done, nrow=n_row, normalize=True)

DCGAN

参考博客

  • DCGAN是将CNN 与 GAN结合,只是将G和D换成两个卷积神经网络(CNN),DCGAN将CNN做了一些改变,为提高样本质量和收敛速度。

  • strided convolution 替代确定性的pooling(从而可以让网络自己学习downsampling(下采样)

    • G网络中使用微步幅度卷积(fractionally strided convolution)代替 pooling 层
    • D网络中使用步幅卷积(strided convolution)代替 pooling 层。
  • 在 D 和 G 中均使用 batch normalization批量归一化 ,去掉 FC 层,使网络变为全卷积网络 ,G 网络中使用ReLU 激活函数,最后一层使用tanh激活函数 ,D 网络中所有层都使用 LeakyReLU 作为激活函数

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.init_size = opt.img_size // 4self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2)) #l1函数进行Linear变换。线性变换的两个参数是变换前的维度,和变换之后的维度self.conv_blocks = nn.Sequential(           #nn.sequential{}是一个组成模型的壳子,用来容纳不同的操作nn.BatchNorm2d(128),                    # BatchNorm2d的目的是使我们的一批(batch)feature map 满足均值0方差1,就是改变数据的量纲nn.Upsample(scale_factor=2),            #上采样,将图片放大两倍(这就是为啥class最先开始将图片的长宽除了4,下面还有一次放大2倍)nn.Conv2d(128, 128, 3, stride=1, padding=1), #二维卷积函数,(输入数据channel,输出的channel,步长,卷积核大小,padding的大小)nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),        #relu激活函数nn.Upsample(scale_factor=2),            #上采样nn.Conv2d(128, 64, 3, stride=1, padding=1),#二维卷积nn.BatchNorm2d(64, 0.8),                #BNnn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),                              #Tanh激活函数)def forward(self, z):out = self.l1(z)              #l1函数进行的是Linear变换 (第50行定义了)out = out.view(out.shape[0], 128, self.init_size, self.init_size)#view是维度变换函数,可以看到out数据变成了四维数据,第一个是batch_size(通过整个的代码,可明白),第二个是channel,第三,四是单张图片的长宽img = self.conv_blocks(out)return img
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]#Conv卷积,Relu激活,Dropout将部分神经元失活,进而防止过拟合if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))    #如果bn这个参数为True,那么就需要在block块里面添加上BatchNorm的归一化函数return blockself.model = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# The height and width of downsampled imageds_size = opt.img_size // 2 ** 4self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid()) #先进行线性变换,再进行激活函数激活#上一句中 128是指model中最后一个判别模块的最后一个参数决定的,ds_size由model模块对单张图片的卷积效果决定的,而2次方是整个模型是选取的长宽一致的图片def forward(self, img):out = self.model(img)out = out.view(out.shape[0], -1)    #将处理之后的数据维度变成batch * N的维度形式validity = self.adv_layer(out)      #第92行定义

http://chatgpt.dhexx.cn/article/PWC39Bht.shtml

相关文章

GAN论文阅读——CGAN

论文标题:Conditional Generative Adversarial Nets 论文链接:https://arxiv.org/pdf/1411.1784.pdf 参考资料:http://blog.csdn.net/solomon1558/article/details/52555083 一、CGAN的思想 在原始GAN学习笔记中,我们提到过&am…

PyTorch随笔 - 生成对抗网络的改进cGAN和LSGAN

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/129939225 本文介绍GAN的两个常见改进,cGAN和LSGAN,两者一般结合使用。 cGAN: Conditional Generative Adversa…

CGAN实现过程

本文目录 一、原理二、参数初始化1. G的输入2. D的输入3. 模型参数初始化4. 测试噪声 三、执行过程四、测试 本文用MNIST数据集进行训练,并用图解的方法展示了CGAN与GAN中输入的区别,帮助理解CGAN的运行过程 一、原理 如下图所示,我们在输入…

CGAN

CGAN 生成符合特定描述的输出, 如图:给定描述,生成相应内容图片 CGAN中的negetive情况包含两种,(正确的描述,不清晰的图片)和(不正确的描述,清晰的图片)&a…

CGAN 简介与代码实战

1.介绍 原始GAN(GAN 简介与代码实战_天竺街潜水的八角的博客-CSDN博客)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约…

CGAN原理分析

1、CGAN原理分析 1.1 网络结构 CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息y,实现条件生成模型。CGAN原文中作者…

CGAN理论讲解及代码实现

目录 1.原始GAN的缺点 2.CGAN中心思想 3.原始GAN和CGAN的区别 4.CGAN代码实现 5.运行结果 6.CGAN缺陷 1.原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。 针对原始…

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信…

解决关于Navicat破解安装过程中出现“rsa public key not find”

解决关于Navicat破解安装过程中出现“rsa public key not find” 问题描述解决办法 问题描述 出现“rsa public key not find”的输出框 解决办法 首先先安装Navicat。安装后先不要运行 打开 然后点击 如果出现 则是对的 如果出现这个: 那就请你找到在本地的na…

navicat安装与激活

原文网址:https://www.jianshu.com/p/5f693b4c9468?mTypeGroup 一、Navicat Premium 12下载 Navicat Premium 12是一套数据库开发管理工具,支持连接 MySQL、Oracle等多种数据库,可以快速轻松地创建、管理和维护数据库。 Navicat Premium 12简…

Mac上安装 Navicat

1.下载安装包 Mac版 Navicat Premium 12 v12.0.23.0 官网下载地址: 英文64位 http://download.navicat.com/download/navicat120_premium_en.dmg 中文简体64位 http://download.navicat.com/download/navicat120_premium_cs.dmg 中文简体安装包:链接:h…

Navicat Premium 12.1.21 最新版激活工具及方法

At The Beginning ****** Sincerely regards to the author of the original work ******* 本帖持续更新 Last updated at 21st Aug 2019 Steps navicat_premium原版安装包 官网下载地址:https://www.navicat.com.cn/download/navicat-premium 注册工具下载 git…

Navicat Premium 12.1.16.0安装与激活

一、Navicat Premium 12下载 Navicat Premium 12简体中文下载; 提取码:cgv4 二、Navicat Premium 12安装 双击安装,点击下一步: 同意协议,点击下一步: 选择安装位置(可默认)&…

Navicat Premium安装和激活

前言 Navicat Premium这个软件是非常的好用,这个软件中包含mysql,SQL Server等等的数据库,受到广大编程爱好者的欢迎,废话不多说,下面就直接进入主题,马上就是Navicat Premium安装和激活的环节。 (1&…

Navicat Premium 12破解激活

下载Navicat Premium 12并安装; 蓝奏云下载:Navicat Premium 12注册机 重要提示:该注册机来源于DeltaFoX。一般来说,由于注册机会修改.exe文件或.dll文件,加壳并且没有数字签名,所以杀毒软件会报毒。如需…

Navicat v15

特别注意: 1.断网,否则在安装过程中会失败2.关闭防火墙及杀毒软件   3.选择对应版本:mysql版就选择mysql 出现如下情况: 就卸载,删除注册表,重新安装,出现rsa public key not find的错误 以及 generate first a serial 错误都也是如此删除注册表的办法打开文件&am…

关于Navicat 数据库一直激活不成功的解决方法

首先激活时一直出现 rsa public key not found,说明获取不到激活码,此时就需要检查 - 在Patch的时候是不是没成功 使用破解软件如果出现说已经patch过了的时候赶紧卸载重装!! - 在激活的时候是不是没有断开网络 解决办法&#…

Navicat安装激活

有条件的同学麻烦不要使用下面的激活步骤,仅供个人学习使用 。。。。。。 。。。。。。 。。。。。。 。。。。。。 一、去官网下载最新Navicat软件https://www.navicat.com.cn/download/navicat-premium 二、去下载激活脚本https://github.com/DoubleLabyrinth/nav…

Navicat 12.1 Macos 激活指南

Navicat 12.1 Navicat从版本11开始使用,一直在macos上表现稳定,速度还快,操作也简单,比Mysql workbench好用多了, workbench总是会发生程序崩溃,修改数据还要点Apply键。对开发来说很不好用. 以下是整个过…

激活navicat提示rsa public key not find的问题

操作顺序先不打开Navicat,注机patch,然后再开Navicat注册 卸载原来的navicat重新安装再次点击patch选择路径就行了 还不行就记得,右键激活工具以管理员权限打开激活再次patch选择navicat的安装好的navicat.exe文件即可