Pix2Pix原理解析

article/2025/8/27 1:29:00

1.网络搭建

class UnetGenerator(nn.Module):"""Create a Unet-based generator"""def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):"""Construct a Unet generatorParameters:input_nc (int)  -- the number of channels in input imagesoutput_nc (int) -- the number of channels in output imagesnum_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,image of size 128x128 will become of size 1x1 # at the bottleneckngf (int)       -- the number of filters in the last conv layernorm_layer      -- normalization layerWe construct the U-Net from the innermost layer to the outermost layer.It is a recursive process."""super(UnetGenerator, self).__init__()# construct unet structureunet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layerfor i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filtersunet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)# gradually reduce the number of filters from ngf * 8 to ngfunet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layerdef forward(self, input):"""Standard forward"""return self.model(input)

Unet的模型结构如下图示,因此是从最内层开始搭建:

经过第一行后,网络结构如下,也就是最内层的下采样->上采样。

之后有一个循环,经过第一次循环后,在上一层的外围再次搭建了下采样和上采样:

经过第二次循环:

经过第三次循环:

可以看到每次反卷积的输入特征图的channel是1024,是因为它除了要接受上一层反卷积的输出(512维度),还要接受与其特征图大小相同的下采样层的输出(512维度),因此是1024的维度数。

循环完毕后,再次添加四次外部的降采样和反卷积,最终的网络结构如下:

UnetGenerator((model): UnetSkipConnectionBlock((model): Sequential((0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(1): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): UnetSkipConnectionBlock((model): Sequential((0): LeakyReLU(negative_slope=0.2, inplace=True)(1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(2): ReLU(inplace=True)(3): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): Dropout(p=0.5, inplace=False)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): Dropout(p=0.5, inplace=False)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): Dropout(p=0.5, inplace=False)))(4): ReLU(inplace=True)(5): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ReLU(inplace=True)(5): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ReLU(inplace=True)(5): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(2): ReLU(inplace=True)(3): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))(4): Tanh()))
)

2.反向传播过程

我们这里假定pix2pix是风格A2B,风格A就是左边的图,风格B是右边的图。

反向传播的代码如下,整个是先更新D再更新G。

(1)首先向前传播,输入A,经过G,得到fakeB;

(2)开始更新D,进入backward_D函数:

  • 将A和fakeB cat起来,cat的整体相当于下图中的negative img,送入D,得到pred_fake;
  • 计算pred_fake的GAN损失,标签为0;
  • 将A与real B cat起来,cat的整体相当于positive img,送入D,得到real_fake;
  • 计算pred_real的GAN损失,标签为1;
  • fake和real的GAN相加,得到总的判别器GAN损失。

(3)开始更新G,进入backward_G函数:

  • 将A和fakeB cat起来,cat的整体相当于下图中的negative img,送入D,得到pred_fake;
  • 计算pred_fake的GAN损失,标签为1;
  • 计算real B和fake B的逐像素损失L1;
  • 将GAN损失和逐像素损失L1相加,得到总损失。

下图就可视化了上述的过程。

    def backward_D(self):"""Calculate GAN loss for the discriminator"""# Fake; stop backprop to the generator by detaching fake_Bfake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminatorpred_fake = self.netD(fake_AB.detach())self.loss_D_fake = self.criterionGAN(pred_fake, False)# Realreal_AB = torch.cat((self.real_A, self.real_B), 1)pred_real = self.netD(real_AB)self.loss_D_real = self.criterionGAN(pred_real, True)# combine loss and calculate gradientsself.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5self.loss_D.backward()def backward_G(self):"""Calculate GAN and L1 loss for the generator"""# First, G(A) should fake the discriminatorfake_AB = torch.cat((self.real_A, self.fake_B), 1)pred_fake = self.netD(fake_AB)self.loss_G_GAN = self.criterionGAN(pred_fake, True)# Second, G(A) = Bself.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1# combine loss and calculate gradientsself.loss_G = self.loss_G_GAN + self.loss_G_L1self.loss_G.backward()def optimize_parameters(self):self.forward()                   # compute fake images: G(A)# update Dself.set_requires_grad(self.netD, True)  # enable backprop for Dself.optimizer_D.zero_grad()     # set D's gradients to zeroself.backward_D()                # calculate gradients for Dself.optimizer_D.step()          # update D's weights# update Gself.set_requires_grad(self.netD, False)  # D requires no gradients when optimizing Gself.optimizer_G.zero_grad()        # set G's gradients to zeroself.backward_G()                   # calculate graidents for Gself.optimizer_G.step()             # udpate G's weights

3.PatchGAN

pix2pix还对判别器的结构做了一定的改动。之前都是对整张图像输出一个是否为真实的概率。pix2pix提出了PatchGan的概念。PatchGAN对图片中的每一个N×N的小块(patch)计算概率,然后再将这些概率求平均值作为整体的输出。

在上面的代码中pred_fake = self.netD(fake_AB.detach())的输出就不是一个概率值,而是30×30的特征图,相当于有30×30个patch。

下图表示标准的D网络结构(n_layers = 3),n_layers 为主要的特征卷积层数为3。如何理解?

  • 下面(0)(1)表示head conv层,不算在n_layers layer中;
  • (2)(3)(4)才算做是标准的一个n_layers层,因此2-4、5-7、8-10一共是3层。
  • 最后有一个卷积层,channel维度为1。

需要注意一下,patchgan channel维度最大为512。

DataParallel((module): NLayerDiscriminator((model): Sequential((0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))(1): LeakyReLU(negative_slope=0.2, inplace=True)(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(4): LeakyReLU(negative_slope=0.2, inplace=True)(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): LeakyReLU(negative_slope=0.2, inplace=True)(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(10): LeakyReLU(negative_slope=0.2, inplace=True)(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))))
)

具体代码如下。与我们前面所述的稍微有些不一样,按照前面所述for n in range(1, n_layers)中相当于构建n_layers个特征提取层。但是代码中实际上构建了n_layers-1个,最后一个标准的特征提取层放在了sequence +=[...]中。

但是理解上还是可以按照前面。在spade框架中,就重新了构建patchgan的过程,其中就把最后一个标准的特征提取层也通过for n in range(1, n_layers)构建了。见https://github.com/NVlabs/SPADE/blob/master/models/networks/discriminator.py

class NLayerDiscriminator(nn.Module):def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):"""Construct a PatchGAN discriminatorParameters:input_nc (int)  -- the number of channels in input imagesndf (int)       -- the number of filters in the last conv layern_layers (int)  -- the number of conv layers in the discriminatornorm_layer      -- normalization layer"""super(NLayerDiscriminator, self).__init__()if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parametersuse_bias = norm_layer.func == nn.InstanceNorm2delse:use_bias = norm_layer == nn.InstanceNorm2dkw = 4  #卷积核的大小padw = 1  #padingsequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]  #head convnf_mult = 1nf_mult_prev = 1for n in range(1, n_layers):  # gradually increase the number of filtersnf_mult_prev = nf_multnf_mult = min(2 ** n, 8)sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),norm_layer(ndf * nf_mult),nn.LeakyReLU(0.2, True)]nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8)sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),norm_layer(ndf * nf_mult),nn.LeakyReLU(0.2, True)]sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # channel = 1self.model = nn.Sequential(*sequence)

4.与CGAN的不同之处

下面这张图是CGAN的示意图。可以看到

  • 在CGAN模型中,生成器的输入有两个,分别为一个噪声z,以及对应的条件y(在mnist训练中将图像和标签concat在一起),输出为符合该条件的图像G(z|y)
  • 判别器的输入同样也为两个,一个是条件,另一个满足该条件的真实图像x。

pix2pix模型与CGAN最大的不同在于,不再输入噪声z。因为实验中,即便给G输入一个噪声z,G也只学会将其忽略并生成图像,噪声z对输出结果的影响几乎微乎其微。因此为了简洁性,将z去掉了。

pix2pix模型中G的输入实际上等于CGAN模型的条件y


http://chatgpt.dhexx.cn/article/3WmlPjUW.shtml

相关文章

如何利用Pix2Pix将黑白图片自动变成彩色图片

实现黑白图片自动变成彩色图片 如果你有一幅黑白图片,你该如何上色让他变成彩色的呢?通常做法可能是使用PS工具来进行上色。那么,有没有什么办法进行自动上色呢?自动将黑白图片变成彩色图片?答案是有的,使用深度学习中的Pix2Pix网络就可以实现这一功能。 如图所示,我们…

Pix2Pix进一步了解

参考:Pix2Pix视频解读 一、Pix2Pix是输入图片矩阵而不是标签向量 1、生成器方面 Pix2Pix与CGAN之间的联系:CGAN生成器输入的是一个label,而我们现在要做的是把这个lable换成一个图片,如下所示。这个图片是一个建筑物的模…

CycleGAN与pix2pix训练自己的数据集-Pytorch

github:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 参考:https://blog.csdn.net/Gavinmiaoc/article/details/80585531 文章目录 CycleganDownload&Prerequisitesbefore your work数据集训练测试 pix2pix数据集训练测试 Cyclegan Do…

pix2pix学习系列(1):预训练模型测试pix2pix

pix2pix学习系列(1):预训练模型测试pix2pix 参考文献: [Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试pix2pix模型 运行环境 win 10 1、代码下载 Gith…

pix2pix简要笔记

参考(40条消息) 全文翻译&杂记《Image-to-Image Translation with Conditional Adversarial NetWorks》_Maples丶丶的博客-CSDN博客_image-to-image translation 图像到图像通常有特定方法(没有通用),但本质是像素到像素的映射问题。本文…

简单理解Pix2Pix

论文名:Image-to-Image Translation with Conditional Adversarial Networks 论文地址:https://arxiv.org/abs/1611.07004 代码链接:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix Pix2Pix是做什么的 图像风格迁移,一…

Pix2pix网络的基本实现

Pix2pix Gan 主要用于图像之间的转换,又称图像翻译《Image-to-Image Translation with Conditional Adversarial Networks》 普通的GAN接受的G部分的输入是随机向量,输出的是图像。D部分接受的输入是图像(生成的或是真实的)&…

Pix2Pix(2017)+CycleGAN+Pix2PixHD

GAN 常规的深度学习任务如图像分类、目标检测以及语义分割或者实例分割,这些任务的结果都可以归结为预测。图像分类是预测单一的类别,目标检测是预测Bbox和类别,语义分割或者实例分割是预测每个像素的类别。而GAN是生成一个新的东西如一张图…

经典论文pix2pix详解

Image-to-Image Translation with Conditional Adversarial Networks https://phillipi.github.io/pix2pix/ https://arxiv.org/pdf/1611.07004.pdf https://github.com/phillipi/pix2pix https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 摘要:我们研…

pix2pix 学习笔记

论文: Image-to-Image Translation with Conditional Adversarial Networks https://arxiv.org/pdf/1611.07004v1.pdf 代码: 官方project:https://phillipi.github.io/pix2pix/ 官方torch代码:https://github.com/phillipi/pi…

生成对抗:Pix2Pix

cGAN : Pix2Pix 生成对抗网络还有一个有趣的应用就是,图像到图像的翻译。例如:草图到照片,黑白图像到RGB,谷歌地图到卫星视图,等等。Pix2Pix就是实现图像转换的生成对抗模型,但是Pix2Pix中的对抗网络又不同于普通的GAN…

Pix2Pix

1. 概述 很多的图像处理问题可以转换成图像到图像(Image-to-Image)的转换,即将一个输入图像翻译成另外一个对应的图像。通常直接学习这种转换,需要事先定义好损失函数,然而对于不同的转换任务,需要设计的损…

pix2pix的简介

概念: 给定一个输入数据和噪声数据生成目标图像,在pix2pix中判别器的输入是生成图像和源图像,而生成器的输入是源图像和随机噪声(使生成模型具有一定的随机性),pix2pix是通过在生成器的模型层加入Dropout来…

AI修图!pix2pix网络介绍

语言翻译是大家都知道的应用。但图像作为一种交流媒介,也有很多种表达方式,比如灰度图、彩色图、梯度图甚至人的各种标记等。在这些图像之间的转换称之为图像翻译,是一个图像生成任务。 多年来,这些任务都需要用不同的模型去生成…

pix2pix论文详解

pix2pix论文详解 – 潘登同学的对抗神经网络笔记 文章目录 pix2pix论文详解 -- 潘登同学的对抗神经网络笔记 pix2pix简介模型输入与GAN的区别Loss函数的选取conditional GAN的loss 生成器网络结构判别器网络结构训练过程生成器G的训练技巧将dropout用在预测 评估指标 艺术欣赏 …

对于pix2pix的介绍以及实现

最近读了pix2pix的相关文章,也是关于对抗生成的。它与之前接触的GAN有挺大的不同。比如从训练集来说,它是进行成对的训练(接下来会介绍),损失函数的不同比如加入了L1损失,以及生成器的输入,以及…

GAN系列之 pix2pixGAN 网络原理介绍以及论文解读

一、什么是pix2pix GAN 论文:《Image-to-Image Translation with Conditional Adversarial Networks》 pix2pix GAN主要用于图像之间的转换,又称图像翻译。图像处理的很多问题都是将一张输入的图片转变为一张对应的输出图片,端到端的训练。 …

pix2pix算法原理与实现

一、算法名称 Pix2pix算法(Image-to-Image Translation,图像翻译) 来源于论文:Image-to-Image Translation with Conditional Adversarial Networks 二、算法简要介绍、研究背景与意义 2.1介绍 图像处理、图形学和视觉中的许多问题都涉及到将输入图像转换为相应…

Java字符串按照字节数进行截取

本文为joshua317原创文章,转载请注明:转载自joshua317博客 Java字符串按照字节数进行截取 - joshua317的博客 一、问题 编写一个截取字符串的函数,输入为一个字符串和字节数,输出为按字节截取的字符串。但是要保证汉字不被截半个&#xff0…

JAVA中截取字符串中指定字符串

JAVA中截取指定字符串 举个例子,需要截取“abcdef”中的“cde”。 场景1:获取该字符串的下标。输出“cde”。 public static void main(String[] args) {// TODO Auto-generated method stubString data "abcdef";String out data.substri…