1. CGAN从无监督GAN改进成有监督的GAN
GAN的基本原理输入是随机噪声,无法控制输出和输入之间的对应关系,也无法控制输出的模式,CGAN全称是条件GAN(Conditional GAN)改进基本的GAN解决了这个问题,CGAN和基本的GAN不同的地方是:
参考下面的链接
https://www.jianshu.com/p/39c57e9a6630
这里面介绍了实现CGAN有三种形式,从网络实现上的三种形式,没有讲解怎样优化目标函数
CGAN的一个问题是输入的有监督标签是离散型输入,如果输入中还有连续型输入,也就是C这个条件是个连续型的,那么将要继续参考InfoGAN
2. InfoGAN
参考下面的链接,非常详细的讲解了InfoGAN的原理、网络结构的实现、损失函数怎样求解
https://www.jianshu.com/p/fa892c81df60
InfoGAN的Info部分和判别器D共用了前面的网络,那么PyTorch怎么实现共用网咯呢?
2.1 参考下面的PyTorch实现
https://mp.weixin.qq.com/s?__biz=MzI3MzkyMzE5Mw==&mid=2247485031&idx=1&sn=e6ccbc33639462d59ee56923c59173b6&chksm=eb1aab71dc6d2267cc52bf769106067c53c867ad6a02063791674937857fb86da36ecd6cbfd9&token=1864035800&lang=zh_CN#rd
原来PyTorch定义判别器类的时候可以分成三个网络,分别是主网络、D网络和C网络、L网络,D网络和C网络和L网络公用主网络,这个例子中的InfoGAN得输入有随机噪声、离散输入(C部分)、连续输入(L部分),forward中先用主网络处理x,之后返回D网络、C网络和L网络
不得不说,这种写法很有趣啊
3. PyTorch实现Debug记录
我自己实现了InfoGAN网络,运行程序后接二连三出现了很多错误
Bug(1):
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace
参考链接:
https://blog.csdn.net/qq_32953463/article/details/115728762
出现这个错误的原因是Pytorch的版本问题,我的Pytorch是1.11.0版本,如果Pytorch版本低于1.4不会有这个问题,链接中提供了一种不需要重新安装Pytorch的办法,backward()放在一起,step()放在一起,zero_grad()不需要放在一起,如下截图所示,
不得不说,这么神奇,真的解决了
real_out = netD(real_img).mean()fake_out = netD(fake_img).mean()d_loss = 1 - real_out + fake_outnetD.zero_grad()g_loss = generator_criterion(fake_out, fake_img, real_img)netG.zero_grad()d_loss.backward(retain_graph=True)g_loss.backward()optimizerD.step()optimizerG.step()fake_img = netG(z)fake_out = netD(fake_img).mean()
Bug(2):RuntimeError: Trying to backward through the graph a second time but the buffers have already been f
或者说是pytorch中的retain_graph=True的作用
参考链接:https://blog.csdn.net/qq_39861441/article/details/104129368
总的来说进行一次backward之后,各个节点的值会清除,这样进行第二次backward会报错,如果加上retain_graph==True后,可以再来一次backward。
上面的示例代码中前两个网络D和G在backward的时候使用了retain_graph=True的参数,最后一个网络没有使用此参数,此参数的默认值是False
如果想了解底层的原理,建议阅读下面的链接,里面的图解非常的有趣
https://blog.csdn.net/SY_qqq/article/details/107384161
Bug(3):
RuntimeError: Found dtype Long but expected Float
这个错误来源于torch需要float类型,但是数据中是int类型或者long类型,解决方法是debug一个一个看变量中哪里出现了int或者long类型,假设variable是int或者long类型的变量,将它转换成float类型
variable = variable.to(torch.float32)
4. CGAN/ACGAN/InfoGAN都是什么?
参考:https://zhuanlan.zhihu.com/p/91592775
首先回顾一下原始GAN的损失函数,没有类别信息
CGAN提出使用类别标签作为辅助信息,进而指导数据生成过程。**从实现的层面来看,标签信息和噪声数据进行拼接送入生成器,标签信息和真实图片拼接送入判别器,**由此提出了改进版的CGAN损失函数,其中y表示类别标签:
这里划重点⭐️,CGAN标签信息只是和真实数据拼接之后输入到判别器之后就没有其它操作了,ACGAN还把判别器分成了两个部分,一个部分和基本的GAN的判别器功能相同,判别生成数据是否真实,另一个部分判断生成数据的类别,这两个部分的判别器共用一部分网络结构,ACGAN的判别起的损失函数也分成了2个部分
参考链接中介绍了ACGAN是在CGAN基础上更近一步的改进,将判别器的功能扩展为判别真假以及类别区分,可以认为ACGAN的判别器多出一个分类的功能(pytorch实现部分2.1讨论如何改进判别器结构,从而实现既能判别又能分类的功能)。由此,ACGAN的损失函数也分为了判别损失和分类损失两个部分,其中判别损失和CGAN并没有区别,形式如下:
这里再划重点⭐️,ACGAN的输入只有类别标签,是个离散值,损失也只有类别损失,InfoGAN的输入有离散值和连续值,可以是离散值也可以是连续值,也可以是离散值和连续值的拼接,InfoGAN的离散值不一定是类别标签,它可以是表示样本某种离散隐含特征的数值,这时我们再回头看第2部分InfoGAN的基本结构就更清晰了。
第一个图:CGAN的基本结构
第二个图:InfoGAN的基本结构