DCGAN理论讲解及代码实现

article/2025/8/26 22:28:07

目录

DCGAN理论讲解

DCGAN的改进:

 DCGAN的设计技巧

DCGAN纯代码实现 

导入库

导入数据和归一化 

 定义生成器

定义鉴别器 

 初始化和 模型训练

 运行结果


DCGAN理论讲解

DCGAN也叫深度卷积生成对抗网络,DCGAN就是将CNN与GAN结合在一起,生成模型和判别模型都运用了深度卷积神经网络的生成对抗网络。

DCGAN将GAN与CNN相结合,奠定了之后几乎所有GAN的基本网络架构。DCGAN极大地提升了原始GAN训练的稳定性以及生成结果的质量

DCGAN主要是在网络架构上改进了原始的GAN,DCGAN的生成器与判别器都利用CNN架构替换了原始GAN的全连接网络,主要改进之处有如下几个方面,

DCGAN的改进:

(1)DCGAN的生成器和判别器都舍弃了CNN的池化层,判别器保留CNN的整体架构,生成器则是将卷积层替换成了反卷积层。

(2)在判别器和生成器中使用了BatchNormalization(BN)层,这里有助于处理初始化不良导致的训练问题,加速模型训练提升训练的稳定性。要注意,在生成器的输出层和判别器的输入层不使用BN层。

(3)在生成器中除输出层使用Tanh()激活函数,其余层全部使用Relu激活函数,在判别器中,除输出层外所有层都使用LeakyRelu激活函数,防止梯度稀疏

自己画的,凑合着看吧/(*/ω\*)捂脸/ 

 DCGAN的设计技巧

一,取消所有pooling层,G网络中使用转置卷积进行上采样,D网络中加入stride的卷积(为防止梯度稀疏)代替pooling

二,去掉FC层(全连接),使网络变成全卷积网络

三,G网络中使用Relu作为激活函数,最后一层用Tanh

四,D网络中使用LeakyRelu激活函数

五,在generator和discriminator上都使用batchnorm,解决初始化差的问题,帮助梯度传播到每一层,防止generator把所有的样本都收敛到同一点。直接将BN应用到所有层会导致样本震荡和模型不稳定,因此在生成器的输出层和判别器的输入层不使用BN层,可以防止这种现象。

六,使用Adam优化器

七,参数设置参考:LeakyRelu的斜率是0.2;Learing rate = 0.0002;batch size是128.

DCGAN纯代码实现 

导入库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim #优化
import numpy as np
import matplotlib.pyplot as plt #绘图
import torchvision #加载图片
from torchvision import transforms #图片变换

导入数据和归一化 

#对数据做归一化(-1,1)
transform=transforms.Compose([#将shanpe为(H,W,C)的数组或img转为shape为(C,H,W)的tensortransforms.ToTensor(), #转为张量并归一化到【0,1】;数据只是范围变了,并没有改变分布transforms.Normalize(mean=0.5,std=0.5)#数据归一化处理,将数据整理到[-1,1]之间;可让数据呈正态分布
])
#下载数据到指定的文件夹
train_ds = torchvision.datasets.MNIST('data/',train=True,transform=transform,download=True)
#数据的输入部分
train_dl=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)

 定义生成器

使用长度为100的noise作为输入,也可以使用torch.randn(batchsize,100,1,1)

 

class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.linear1 = nn.Linear(100,256*7*7)self.bn1=nn.BatchNorm1d(256*7*7)self.deconv1 = nn.ConvTranspose2d(256,128,kernel_size=(3,3),stride=1,padding=1)  #生成(128,7,7)的二维图像self.bn2=nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128,64,kernel_size=(4,4),stride=2,padding=1)  #生成(64,14,14)的二维图像self.bn3=nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64,1,kernel_size=(4,4),stride=2,padding=1)  #生成(1,28,28)的二维图像def forward(self,x):x=F.relu(self.linear1(x))x=self.bn1(x)x=x.view(-1,256,7,7)x=F.relu(self.deconv1(x))x=self.bn2(x)x=F.relu(self.deconv2(x))x=self.bn3(x)x=torch.tanh(self.deconv3(x))return x

定义鉴别器 

class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.conv1 = nn.Conv2d(1,64,kernel_size=3,stride=2)self.conv2 = nn.Conv2d(64,128,kernel_size=3,stride=2)self.bn = nn.BatchNorm2d(128)self.fc = nn.Linear(128*6*6,1)def forward(self,x):x= F.dropout2d(F.leaky_relu(self.conv1(x)))x= F.dropout2d(F.leaky_relu(self.conv2(x)) )  #(batch,128,6,6)x = self.bn(x)x = x.view(-1,128*6*6) #展平x = torch.sigmoid(self.fc(x))return x

 初始化和 模型训练

#设备的配置
device='cuda' if torch.cuda.is_available() else 'cpu'
#初化生成器和判别器把他们放到相应的设备上
gen = Generator().to(device)
dis = Discriminator().to(device)
#交叉熵损失函数
loss_fn = torch.nn.BCELoss()
#训练器的优化器
d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-5)
#训练生成器的优化器
g_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-4)
def generate_and_save_images(model,epoch,test_input):prediction = np.squeeze(model(test_input).detach().cpu().numpy())fig = plt.figure(figsize=(4,4))for i in range(prediction.shape[0]):plt.subplot(4,4,i+1)plt.imshow((prediction[i]+1)/2,cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))plt.show()
test_input = torch.randn(16,100 ,device=device) #16个长度为100的随机数
D_loss = []
G_loss = []
#训练循环
for epoch in range(30):#初始化损失值D_epoch_loss = 0G_epoch_loss = 0count = len(train_dl.dataset) #返回批次数#对数据集进行迭代for step,(img,_) in enumerate(train_dl):img =img.to(device) #把数据放到设备上size = img.shape[0] #img的第一位是size,获取批次的大小random_seed = torch.randn(size,100,device=device)#判别器训练(真实图片的损失和生成图片的损失),损失的构建和优化d_optimizer.zero_grad()#梯度归零#判别器对于真实图片产生的损失real_output = dis(img) #判别器输入真实的图片,real_output对真实图片的预测结果d_real_loss = loss_fn(real_output,torch.ones_like(real_output,device=device))d_real_loss.backward()#计算梯度#在生成器上去计算生成器的损失,优化目标是判别器上的参数generated_img = gen(random_seed) #得到生成的图片#因为优化目标是判别器,所以对生成器上的优化目标进行截断fake_output = dis(generated_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度就不会再传递到gen模型中了#判别器在生成图像上产生的损失d_fake_loss = loss_fn(fake_output,torch.zeros_like(fake_output,device=device))d_fake_loss.backward()#判别器损失disc_loss = d_real_loss + d_fake_loss#判别器优化d_optimizer.step()#生成器上损失的构建和优化g_optimizer.zero_grad() #先将生成器上的梯度置零fake_output = dis(generated_img)gen_loss = loss_fn(fake_output,torch.ones_like(fake_output,device=device))  #生成器损失gen_loss.backward()g_optimizer.step()#累计每一个批次的losswith torch.no_grad():D_epoch_loss +=disc_lossG_epoch_loss +=gen_loss#求平均损失with torch.no_grad():D_epoch_loss /=countG_epoch_loss /=countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)generate_and_save_images(gen,epoch,test_input)print('Epoch:',epoch)

 运行结果

 因篇幅有限,这里展示第一张和最后一张,这里我训练了30个epoch,有条件的可以多训练几次,训练越多效果越明显哦

 


希望我的文章能对你有所帮助。欢迎👍点赞 ,📝评论,🌟关注,⭐️收藏

                                                                    


http://chatgpt.dhexx.cn/article/gydFh8r0.shtml

相关文章

torch学习 (三十七):DCGAN详解

文章目录 引入1 生成器2 鉴别器3 模型训练:生成器与鉴别器的交互4 参数设置5 数据载入6 完整代码7 部分输出图像示意7.1 真实图像7.2 训练200个批次7.2 训练400个批次7.2 训练600个批次 引入 论文详解:Unsupervised representation learning with deep c…

GANs系列:DCGAN原理简介与基础GAN的区别对比

本文长期不定时更新最新知识,防止迷路记得收藏哦! 还未了解基础GAN的,可以先看下面两篇文章: GNA笔记--GAN生成式对抗网络原理以及数学表达式解剖 入门GAN实战---生成MNIST手写数据集代码实现pytorch 背景介绍 2016年&#…

Pix2Pix和CycleGAN

GAN的局限性 即便如此,传统的GAN也不是万能的,它有下面两个不足: 1. 没有**用户控制(user control)**能力 在传统的GAN里,输入一个随机噪声,就会输出一幅随机图像。 但用户是有想法滴&#xff…

PyTorch 实现Image to Image (pix2pix)

目录 一、前言 二、数据集 三、网络结构 四、代码 (一)net (二)dataset (三)train (四)test 五、结果 (一)128*128 (二)256*256 …

pix2pix、pix2pixHD 通过损失日志进行训练可视化

目录 背景 代码 结果 总结 背景 pix2pix(HD)代码在训练时会自动保存一个损失变化的txt文件,通过该文件能够对训练过程进行一个简单的可视化,代码如下。 训练的损失文件如图,对其进行可视化。 代码 #coding:utf-8 ## #author: QQ&#x…

Pix2Pix代码解析

参考链接:https://github.com/yenchenlin/pix2pix-tensorflow https://blog.csdn.net/stdcoutzyx/article/details/78820728 utils.py from __future__ import division import math import json import random import pprint import scipy.misc import numpy as…

pix2pix 与 pix2pixHD的大致分析

目录 pix2pix与pix2pixHD的生成器 判别器 PatchGAN(马尔科夫判别器) 1、pix2pix 简单粗暴的办法 如何解决模糊呢? 其他tricks 2、pix2pixHD 高分辨率图像生成 模型结构 Loss设计 使用Instance-map的图像进行训练 语义编辑 总结 …

Tensorflow2.0之Pix2pix

文章目录 Pix2pix介绍Pix2pix应用Pix2pix生成器及判别器网络结构代码实现1、导入需要的库2、下载数据包3、加载并展示数据包中的图片4、处理图片4.1 将图像调整为更大的高度和宽度4.2 随机裁剪到目标尺寸4.3 随机将图像做水平镜像处理4.4 图像归一化4.5 处理训练集图片4.6 处理…

pix2pix算法笔记

论文:Image-to-Image Translation with Conditional Adversarial Networks 论文链接:https://arxiv.org/abs/1611.07004 代码链接:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 这篇论文发表在CVPR2017,简称pix2pix,是将GAN应用于有监督的图像到图像翻译的经…

Pix2Pix原理解析以及代码流程

文章目录 1、网络搭建2、反向传播过程3、PatchGAN4.与CGAN的不同之处 1、网络搭建 class UnetGenerator(nn.Module):"""Create a Unet-based generator"""def __init__(self, input_nc, output_nc, num_downs, ngf64, norm_layernn.BatchNorm2d…

图像翻译网络模型Pix2Pix

Pix2pix算法(Image-to-Image Translation,图像翻译),它的核心技术有三点:基于条件GAN的损失函数,基于U-Net的生成器和基于PatchGAN的判别器。Pix2Pix能够在诸多图像翻译任务上取得令人惊艳的效果,但因为它的输入是图像对&#xff…

GAN系列之pix2pix、pix2pixHD

1. 摘要 图像处理的很多问题都是将一张输入的图片转变为一张对应的输出图片,比如灰度图、梯度图、彩色图之间的转换等。通常每一种问题都使用特定的算法(如:使用CNN来解决图像转换问题时,要根据每个问题设定一个特定的loss funct…

Pix2Pix原理解析

1.网络搭建 class UnetGenerator(nn.Module):"""Create a Unet-based generator"""def __init__(self, input_nc, output_nc, num_downs, ngf64, norm_layernn.BatchNorm2d, use_dropoutFalse):"""Construct a Unet generatorPa…

如何利用Pix2Pix将黑白图片自动变成彩色图片

实现黑白图片自动变成彩色图片 如果你有一幅黑白图片,你该如何上色让他变成彩色的呢?通常做法可能是使用PS工具来进行上色。那么,有没有什么办法进行自动上色呢?自动将黑白图片变成彩色图片?答案是有的,使用深度学习中的Pix2Pix网络就可以实现这一功能。 如图所示,我们…

Pix2Pix进一步了解

参考:Pix2Pix视频解读 一、Pix2Pix是输入图片矩阵而不是标签向量 1、生成器方面 Pix2Pix与CGAN之间的联系:CGAN生成器输入的是一个label,而我们现在要做的是把这个lable换成一个图片,如下所示。这个图片是一个建筑物的模…

CycleGAN与pix2pix训练自己的数据集-Pytorch

github:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 参考:https://blog.csdn.net/Gavinmiaoc/article/details/80585531 文章目录 CycleganDownload&Prerequisitesbefore your work数据集训练测试 pix2pix数据集训练测试 Cyclegan Do…

pix2pix学习系列(1):预训练模型测试pix2pix

pix2pix学习系列(1):预训练模型测试pix2pix 参考文献: [Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试pix2pix模型 运行环境 win 10 1、代码下载 Gith…

pix2pix简要笔记

参考(40条消息) 全文翻译&杂记《Image-to-Image Translation with Conditional Adversarial NetWorks》_Maples丶丶的博客-CSDN博客_image-to-image translation 图像到图像通常有特定方法(没有通用),但本质是像素到像素的映射问题。本文…

简单理解Pix2Pix

论文名:Image-to-Image Translation with Conditional Adversarial Networks 论文地址:https://arxiv.org/abs/1611.07004 代码链接:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix Pix2Pix是做什么的 图像风格迁移,一…

Pix2pix网络的基本实现

Pix2pix Gan 主要用于图像之间的转换,又称图像翻译《Image-to-Image Translation with Conditional Adversarial Networks》 普通的GAN接受的G部分的输入是随机向量,输出的是图像。D部分接受的输入是图像(生成的或是真实的)&…