pytorch搭建DCGAN

article/2025/8/26 22:25:49

我们知道gan的过程是对生成分布拟合真实分布的一个过程,理想目标是让判别器无法识别输入数据到底是来源于生成器生成的数据还是真实的数据。

当然这是一个博弈的过程并且相互促进的过程,其真实的过程在于首先判别器会先拟合真实数据的分布,然后生成器拟合判别器的分布,从而达到生成器对真实数据分布的拟合。

请添加图片描述
图中蓝色部分为生成器,生成器的功能在于输入一个随机向量经过生成器一系列层的处理输出一个与真实数据尺寸一样的图片。 然后将生成器产生的图片与真实的图片信息一同的输入到判别器中,让判别器去区分该图片信息的源头,如果是判别器产生的图片则识别为fake,如果是生成器产生的图片,则判定为real,因此对于判别器的损失函数就为MSELoss(Pg,torch.zeros_like(Pg)) + MSELoss(Pr, torch.ones_like(Pr))
Pg表示生成器生成的数据,Pr表示真实数据)
而对于生成器来说它的目的在于生成的数据要欺骗判别器,也就是说让判别器都认为它产生的图片就是真实的图片数据(与真实图片无差别),所以生成器的损失函数就是
MSELoss(Pg, torch.ones_like(Pg))

DCGAN相对于普通的GAN只不过是在网络模型中采用了CNN模型
其中主要包含以下几点:
(1)使用指定步长的卷积层代替池化层

(2)生成器和判别器中都使用BatchNormlization

(3)移除全连接层

(4)生成器除去输出层采用Tanh外,全部使用ReLU作为激活函数

(5)判别器所有层都使用LeakyReLU作为激活函数
Generator网络结构

class Gernerator(nn.Module):def __init__(self, IMAGE_CHANNELS, NOISE_CHANNELS, feature_channels):super(Gernerator, self).__init__()self.features = nn.Sequential(self._Conv_block(in_channels=NOISE_CHANNELS, out_channels=feature_channels*4, stride=1, kernel_size=4,padding=0),self._Conv_block(in_channels=feature_channels*4, out_channels=feature_channels*8, stride=2, kernel_size=3,padding=1),self._Conv_block(in_channels=feature_channels*8, out_channels=feature_channels*4, stride=2, kernel_size=3,padding=1),nn.ConvTranspose2d(in_channels=feature_channels*4, out_channels=IMAGE_CHANNELS, stride=2, kernel_size=3,padding=1),nn.Tanh(),)def _Conv_block(self, in_channels, out_channels, kernel_size, stride, padding):feature = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,stride=stride,padding=padding,kernel_size=kernel_size,bias=False),# nn.BatchNorm2d(num_features=out_channels),nn.ReLU())return featuredef forward(self, x):return self.features(x)```![Discriminator网络层](https://img-blog.csdnimg.cn/1c36ed6d62b24cfbb4d7b2c27aa822a3.webp?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBA5byx5bCP55qE55qu5Y2h6b6Z,size_17,color_FFFFFF,t_70,g_se,x_16)```pythonclass Discriminator(nn.Module):def __init__(self, IMAGE_CHANNELS, feature_channels):super(Discriminator, self).__init__()self.features = nn.Sequential(self._Conv_block(in_channels=IMAGE_CHANNELS, out_channels=feature_channels, stride=2, kernel_size=3,padding=1),self._Conv_block(in_channels=feature_channels, out_channels=feature_channels * 2, stride=2, kernel_size=3,padding=1),self._Conv_block(in_channels=feature_channels * 2, out_channels=feature_channels * 4, stride=2, kernel_size=3,padding=1),self._Conv_block(in_channels=feature_channels * 4, out_channels=feature_channels * 2, stride=2, kernel_size=3,padding=1),nn.Conv2d(in_channels=feature_channels * 2, out_channels=IMAGE_CHANNELS, stride=2, kernel_size=3,padding=1),nn.Sigmoid())def _Conv_block(self, in_channels, out_channels, kernel_size, stride, padding):feature = nn.Sequential(nn.Conv2d(in_channels=in_channels,out_channels=out_channels,stride=stride,padding=padding,kernel_size=kernel_size,bias=False),# nn.BatchNorm2d(num_features=out_channels),nn.LeakyReLU(negative_slope=0.2))return featuredef forward(self, x):return torch.sigmoid(self.features(x))

具体的情况需要具体设计相应的Generator和Discriminator

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as Transforms
from torchvision.datasets import ImageFolder
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import osclass Gernerator(nn.Module):def __init__(self, IMAGE_CHANNELS, NOISE_CHANNELS, feature_channels):super(Gernerator, self).__init__()self.features = nn.Sequential(self._Conv_block(in_channels=NOISE_CHANNELS, out_channels=feature_channels*4, stride=1, kernel_size=4,padding=0),self._Conv_block(in_channels=feature_channels*4, out_channels=feature_channels*8, stride=2, kernel_size=3,padding=1),self._Conv_block(in_channels=feature_channels*8, out_channels=feature_channels*4, stride=2, kernel_size=3,padding=1),nn.ConvTranspose2d(in_channels=feature_channels*4, out_channels=IMAGE_CHANNELS, stride=2, kernel_size=3,padding=1),nn.Tanh(),)def _Conv_block(self, in_channels, out_channels, kernel_size, stride, padding):feature = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,stride=stride,padding=padding,kernel_size=kernel_size,bias=False),# nn.BatchNorm2d(num_features=out_channels),nn.ReLU())return featuredef forward(self, x):return self.features(x)class Discriminator(nn.Module):def __init__(self, IMAGE_CHANNELS, feature_channels):super(Discriminator, self).__init__()self.features = nn.Sequential(self._Conv_block(in_channels=IMAGE_CHANNELS, out_channels=feature_channels, stride=2, kernel_size=3,padding=1),self._Conv_block(in_channels=feature_channels, out_channels=feature_channels * 2, stride=2, kernel_size=3,padding=1),self._Conv_block(in_channels=feature_channels * 2, out_channels=feature_channels * 4, stride=2, kernel_size=3,padding=1),self._Conv_block(in_channels=feature_channels * 4, out_channels=feature_channels * 2, stride=2, kernel_size=3,padding=1),nn.Conv2d(in_channels=feature_channels * 2, out_channels=IMAGE_CHANNELS, stride=2, kernel_size=3,padding=1),nn.Sigmoid())def _Conv_block(self, in_channels, out_channels, kernel_size, stride, padding):feature = nn.Sequential(nn.Conv2d(in_channels=in_channels,out_channels=out_channels,stride=stride,padding=padding,kernel_size=kernel_size,bias=False),# nn.BatchNorm2d(num_features=out_channels),nn.LeakyReLU(negative_slope=0.2))return featuredef forward(self, x):return torch.sigmoid(self.features(x))def initialize_weights(model):# Initializes weights according to the DCGAN paperfor m in model.modules():if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):nn.init.normal_(m.weight.data, 0.0, 0.02)IMAGE_CHANNELS = 3
NOISE_CHANNELS = 100
FEATURE_CHANNELS = 32
BATCH_SIZE = 16
NUM_EPOCHS = 5
LEARN_RATE = 2e-4
IMAGE_SIZE = 64
D_PATH = 'logs/121_D.pth'
G_PATH = 'logs/41q_G.pth'mytransformers = Transforms.Compose([Transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),Transforms.ToTensor(),Transforms.Normalize(std=[0.6585589, 0.55756074, 0.54101795], mean=[0.28972548, 0.28038123, 0.26353073]),
])
trainset = ImageFolder(root=r'D:\QQPCmgr\Desktop\gan\A', transform=mytransformers)
trainloader = DataLoader(dataset=trainset,batch_size=BATCH_SIZE,shuffle=True
)writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0if __name__ == '__main__':device = torch.device("cpu" if torch.cuda.is_available() else "cpu")print("cuda:0" if torch.cuda.is_available() else "cpu")Dnet = Discriminator(IMAGE_CHANNELS, FEATURE_CHANNELS).to(device)initialize_weights(Dnet)# Dnet.load_state_dict(torch.load(D_PATH))Gnet = Gernerator(IMAGE_CHANNELS, NOISE_CHANNELS, FEATURE_CHANNELS).to(device)initialize_weights(Gnet)Dnet.train()Gnet.train()# Gnet.load_state_dict(torch.load(G_PATH))noise = torch.randn((BATCH_SIZE, NOISE_CHANNELS, 1, 1)).to(device)ceritionG = nn.BCELoss(reduction='mean')ceritionD = nn.BCELoss(reduction='mean')optimizerG = torch.optim.Adam(params=Gnet.parameters(), lr=0.0002, betas=(0.5, 0.999))optimizerD = torch.optim.Adam(params=Dnet.parameters(), lr=0.0002, betas=(0.5, 0.999))for epoch in range(1000):for i, data in enumerate(trainloader, 1):optimizerD.zero_grad()optimizerG.zero_grad()r_img, _ = datar_img = r_img.to(device)fake_img = Gnet.forward(noise)r_label = (torch.ones_like(Gnet.forward(r_img))).to(device)f_label = torch.ones_like(Gnet.forward(r_img)).to(device)lossG = ceritionD(Dnet.forward(Gnet.forward(noise)), r_label)lossD = ceritionD(Dnet.forward(r_img), r_label) / 2 + ceritionD(Dnet.forward(Gnet.forward(noise)), f_label) / 2lossG.backward()lossD.backward()optimizerD.step(retain_graph=True)optimizerG.step(retain_graph=True)print('[epoch:%d],[lossD:%f],[lossG:%f]...........%d/10000' % (epoch, lossD.item(), lossG.item(), i*BATCH_SIZE))if i  % 50 == 0:with torch.no_grad():img_grid_real = torchvision.utils.make_grid(r_img, normalize=True,)img_grid_fake = torchvision.utils.make_grid(fake_img, normalize=True)writer_fake.add_image("fake_img", img_grid_fake, global_step=step)writer_real.add_image("real_img", img_grid_real, global_step=step)step += 1

没时间等,只迭代了一会


http://chatgpt.dhexx.cn/article/BseJOBBg.shtml

相关文章

tensorflow实现DCGAN

1、DCGAN的简单总结 【Paper】 : http://arxiv.org/abs/1511.06434 【github】 : https://github.com/Newmu/dcgan_code theano https://github.com/carpedm20/DCGAN-tensorflow tensorflow https://github.com/jacobgil/keras-dcgan keras https://github.c…

DCGAN TUTORIAL

Introduction 本教程将通过一个示例对DCGAN进行介绍。在向其展示许多真实名人的照片之后,我们将训练一个生成对抗网络(GAN)来产生新名人。此处的大多数代码来自pytorch / examples中的dcgan实现 ,并且本文档将对该实现进行详尽的…

DCGAN原文讲解

DCGAN的全称是Deep Convolution Generative Adversarial Networks(深度卷积生成对抗网络)。是2014年Ian J.Goodfellow 的那篇开创性的GAN论文之后一个新的提出将GAN和卷积网络结合起来,以解决GAN训练不稳定的问题的一篇paper. 关于基本的GAN的原理,可以…

DCGAN

转自:https://blog.csdn.net/liuxiao214/article/details/74502975 首先是各种参考博客、链接等,表示感谢。 1、参考博客1:地址 ——以下,开始正文。 2017/12/12 更新 解决训练不收敛的问题。 更新在最后面部分。 1、DCGAN的…

深度学习之DCGAN

这一此的博客我给大家介绍一下DCGAN的原理以及DCGAN的实战代码,今天我用最简单的语言给大家介绍DCGAN。 相信大家现在对深度学习有了一定的了解,对GAN也有了认识,如果不知道什么是GAN的可以去看我以前的博客,接下来我给大家介绍一下DCGAN的原理。 DCGAN DCGAN的全称是Deep Conv…

对抗神经网络(二)——DCGAN

一、DCGAN介绍 DCGAN即使用卷积网络的对抗网络,其原理和GAN一样,只是把CNN卷积技术用于GAN模式的网络里,G(生成器)网在生成数据时,使用反卷积的重构技术来重构原始图片。D(判别器)网…

对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题 🍊往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例 🍊近期目标:写好专栏的每一篇文章 🍊支持小苏:点赞…

DCGAN理论讲解及代码实现

目录 DCGAN理论讲解 DCGAN的改进: DCGAN的设计技巧 DCGAN纯代码实现 导入库 导入数据和归一化 定义生成器 定义鉴别器 初始化和 模型训练 运行结果 DCGAN理论讲解 DCGAN也叫深度卷积生成对抗网络,DCGAN就是将CNN与GAN结合在一起,生…

torch学习 (三十七):DCGAN详解

文章目录 引入1 生成器2 鉴别器3 模型训练:生成器与鉴别器的交互4 参数设置5 数据载入6 完整代码7 部分输出图像示意7.1 真实图像7.2 训练200个批次7.2 训练400个批次7.2 训练600个批次 引入 论文详解:Unsupervised representation learning with deep c…

GANs系列:DCGAN原理简介与基础GAN的区别对比

本文长期不定时更新最新知识,防止迷路记得收藏哦! 还未了解基础GAN的,可以先看下面两篇文章: GNA笔记--GAN生成式对抗网络原理以及数学表达式解剖 入门GAN实战---生成MNIST手写数据集代码实现pytorch 背景介绍 2016年&#…

Pix2Pix和CycleGAN

GAN的局限性 即便如此,传统的GAN也不是万能的,它有下面两个不足: 1. 没有**用户控制(user control)**能力 在传统的GAN里,输入一个随机噪声,就会输出一幅随机图像。 但用户是有想法滴&#xff…

PyTorch 实现Image to Image (pix2pix)

目录 一、前言 二、数据集 三、网络结构 四、代码 (一)net (二)dataset (三)train (四)test 五、结果 (一)128*128 (二)256*256 …

pix2pix、pix2pixHD 通过损失日志进行训练可视化

目录 背景 代码 结果 总结 背景 pix2pix(HD)代码在训练时会自动保存一个损失变化的txt文件,通过该文件能够对训练过程进行一个简单的可视化,代码如下。 训练的损失文件如图,对其进行可视化。 代码 #coding:utf-8 ## #author: QQ&#x…

Pix2Pix代码解析

参考链接:https://github.com/yenchenlin/pix2pix-tensorflow https://blog.csdn.net/stdcoutzyx/article/details/78820728 utils.py from __future__ import division import math import json import random import pprint import scipy.misc import numpy as…

pix2pix 与 pix2pixHD的大致分析

目录 pix2pix与pix2pixHD的生成器 判别器 PatchGAN(马尔科夫判别器) 1、pix2pix 简单粗暴的办法 如何解决模糊呢? 其他tricks 2、pix2pixHD 高分辨率图像生成 模型结构 Loss设计 使用Instance-map的图像进行训练 语义编辑 总结 …

Tensorflow2.0之Pix2pix

文章目录 Pix2pix介绍Pix2pix应用Pix2pix生成器及判别器网络结构代码实现1、导入需要的库2、下载数据包3、加载并展示数据包中的图片4、处理图片4.1 将图像调整为更大的高度和宽度4.2 随机裁剪到目标尺寸4.3 随机将图像做水平镜像处理4.4 图像归一化4.5 处理训练集图片4.6 处理…

pix2pix算法笔记

论文:Image-to-Image Translation with Conditional Adversarial Networks 论文链接:https://arxiv.org/abs/1611.07004 代码链接:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 这篇论文发表在CVPR2017,简称pix2pix,是将GAN应用于有监督的图像到图像翻译的经…

Pix2Pix原理解析以及代码流程

文章目录 1、网络搭建2、反向传播过程3、PatchGAN4.与CGAN的不同之处 1、网络搭建 class UnetGenerator(nn.Module):"""Create a Unet-based generator"""def __init__(self, input_nc, output_nc, num_downs, ngf64, norm_layernn.BatchNorm2d…

图像翻译网络模型Pix2Pix

Pix2pix算法(Image-to-Image Translation,图像翻译),它的核心技术有三点:基于条件GAN的损失函数,基于U-Net的生成器和基于PatchGAN的判别器。Pix2Pix能够在诸多图像翻译任务上取得令人惊艳的效果,但因为它的输入是图像对&#xff…

GAN系列之pix2pix、pix2pixHD

1. 摘要 图像处理的很多问题都是将一张输入的图片转变为一张对应的输出图片,比如灰度图、梯度图、彩色图之间的转换等。通常每一种问题都使用特定的算法(如:使用CNN来解决图像转换问题时,要根据每个问题设定一个特定的loss funct…