InfoGAN论文笔记+源码解析

article/2025/3/4 9:01:34

论文地址:InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets

源码地址:InfoGAN in TensorFlow

GAN,Generative Adversarial Network是目前非常火也是非常有潜力的一个发展方向,原始的GAN模型存在着无约束、不可控、噪声信号z很难解释等问题,近年来,在原始GAN模型的基础上衍生出了很多种模型,如:条件——CGAN、卷积——DCGAN等等,在本博客的前几篇博文里均进行了大致的解读,本篇博文将提到的InfoGAN也是GAN的一种改进成果,甚至被OPENAI称为去年的五大突破之一。今天我们就来看看,InfoGAN到底做出了什么样的改进,能达到一个什么样的效果呢。

2014年,Ian J. Goodfellow提出了生成对抗网络:Generative Adversarial Networks,通过generator和discriminator的对抗学习,最终可以得到一个与real data分布一致的fake data,但是由于generator的输入z是一个连续的噪声信号,并且没有任何约束,导致GAN无法利用这个z,并将z的具体维度与数据的语义特征对应起来,并不是一个Interpretable(可解释) Representation,而这正好是InfoGAN的出发点,它试图利用z,寻找一个可解释的表达,于是它将z进行了拆解,一是不可压缩的噪声z,二是可解释的 隐变量c,称作为latent code,而我们希望通过约束c与生成数据之间的关系,可以使得c里面包含有对数据的可解释的信息,如对MNIST数据,c可以分为categorical latent code代 来表数字种类信息(0~9),以及continuous latent code来表示倾斜度、笔画粗细等等。

为了引入c,作者利用互信息来对c进行约束,这是因为如果c对于生成数据G(z,c)具有可解释性,那么c和G(z,c)应该具有高度相关性,即互信息大,而如果是无约束的话,那么它们之间没有特定的关系,即互信息接近于0。因此我们希望c与G(z,c)的互信息I(c;G(z,c))越大越好,因此,模型的目标函数也变为:


但是在I(c;G(z,c))的计算中,真实的P(c|x)并不清楚,因此在具体的优化过程中,作者采用了变分推断的思想,引入了变分分布Q(c|x)来逼近P(c|x),它是基于最优互信息下界的轮流迭代实现最终的求解,于是InfoGAN的目标函数变为:


在具体的实现中,Q和D共用了所有的卷积层,并只在最后增加了一个全连接层来输出Q(c|x),因此InfoGAN并没有在原始的GAN上增加多少的计算量。

对于c,如果是categorical latent code,可以使用softmax的非线性输出来代表Q(c|x);如果是continuous latent code,可以使用高斯分布来表示。

在实验中,作者通过只改表c的某一个维度,来观察生成数据的变化,实验结果证明了,latent code确实学习到了一些可解释的信息,如在MNIST中的数字,倾斜度、笔画粗细等等。



下面我们看代码,在infogan\__init__.py中第212行:

if use_infogan:z_size = style_size + sum(categorical_cardinality) + num_continuous   # z_size=74sample_noise = create_infogan_noise_sample(categorical_cardinality,num_continuous,style_size)                                                                     # sample_noise.shape=[64 74]
其中style_size为62,categorical_cardinality为[10],num_continuous为2,看create_infogan_noise_sample,代表噪声信号的产生:

def create_infogan_noise_sample(categorical_cardinality, num_continuous, style_size):def sample(batch_size):return encode_infogan_noise(categorical_cardinality,create_categorical_noise(categorical_cardinality, size=batch_size),create_continuous_noise(num_continuous, style_size, size=batch_size))return sample
其中batch_size=64,看create_categorical_noise,代表categorical latent code的产生:

def create_categorical_noise(categorical_cardinality, size):noise = []for cardinality in categorical_cardinality:noise.append(np.random.randint(0, cardinality, size=size))return noise
其中np.random.randint(0, cardinality, size=size)表示生成[0 cardinality)半开半闭区间内的随机整数,在这里即0~9之间的整数,代表数字的种类。

看create_continuous_noise,代表continuous latent code以及不可压缩的噪声z的产生:

def create_continuous_noise(num_continuous, style_size, size):continuous = np.random.uniform(-1.0, 1.0, size=(size, num_continuous))style = np.random.standard_normal(size=(size, style_size))return np.hstack([continuous, style])
其中continuous latent code服从-1到1之间的均匀分布,style即噪声z服从标准正态分布,再将continuous latent code与style进行concat。
看encode_infogan_noise,将categorical latent code、continuous latent code及style合成:

def encode_infogan_noise(categorical_cardinality, categorical_samples, continuous_samples):noise = []for cardinality, sample in zip(categorical_cardinality, categorical_samples):noise.append(make_one_hot(sample, size=cardinality))noise.append(continuous_samples)return np.hstack(noise)
对于categorical latent code,将categorical进行one-hot编码,即生成长度为10的0-1向量。然后再将三者concat,就生成了噪声样本。

再看__init__.py的第33行:

def generator_forward(z,network_description,is_training,reuse=None,name="generator",use_batch_norm=True,debug=False):with tf.variable_scope(name, reuse=reuse):return run_network(z,network_description,is_training=is_training,use_batch_norm=use_batch_norm,debug=debug,strip_batchnorm_from_last_layer=True)
定义了生成器,其中network_description为"fc:1024,fc:7x7x128,reshape:7:7:128,deconv:4:2:64,deconv:4:2:1:sigmoid",输出为28*28的生成样本,即fake_image。

看第48行:

def discriminator_forward(img,network_description,is_training,reuse=None,name="discriminator",use_batch_norm=True,debug=False):with tf.variable_scope(name, reuse=reuse):out = run_network(img,network_description,is_training=is_training,use_batch_norm=use_batch_norm,debug=debug)out = layers.flatten(out)prob = layers.fully_connected(out,num_outputs=1,activation_fn=tf.nn.sigmoid,scope="prob_projection")return {"prob":prob, "hidden":out}
其中network_description为"conv:4:2:64:lrelu,conv:4:2:128:lrelu,fc:1024:lrelu",out的维度为[64 1024],prob的维度为[64 1]表示对输入样本关于real_image的预测概率。

第291行:

# discriminator should maximize:ll_believing_fake_images_are_fake = tf.log(1.0 - prob_fake + TINY)ll_true_images = tf.log(prob_true + TINY)discriminator_obj = (tf.reduce_mean(ll_believing_fake_images_are_fake) +tf.reduce_mean(ll_true_images))
定义了discriminator的目标函数,与原始GAN中的目标函数一致,其中TINY为很小的数,为了避免log里面的数等于0。

第299行:

# generator should maximize:ll_believing_fake_images_are_real = tf.reduce_mean(tf.log(prob_fake + TINY))generator_obj = ll_believing_fake_images_are_real
定义了generator的目标函数,与原始GAN中的目标函数一致。

看320行:

q_output = reconstruct_mutual_info(categorical_c_vectors,continuous_c_vector,categorical_lambda=args.categorical_lambda,continuous_lambda=args.continuous_lambda,fix_std=fix_std,hidden=discriminator_fake["hidden"],is_training=is_training_discriminator,name="mutual_info"
)
其中categorical_c_vectors对应了之前的categorical latent code,continuous_c_vector对应了continuous latent code,hidden为fake_image的discriminator输出,fix_std表示"Fix continuous var standard deviation to 1."。

再看reconstruct_mutual_info,第82行和第93行将fake_image的discriminator输出再输入到两个全连接中,最终的输出维度为[64 12]。看第101行:

ll_categorical = None
for true_categorical in true_categoricals:cardinality = true_categorical.get_shape()[1].valueprob_categorical = tf.nn.softmax(out[:, offset:offset + cardinality])ll_categorical_new = tf.reduce_sum(tf.log(prob_categorical + TINY) * true_categorical,reduction_indices=1)if ll_categorical is None:ll_categorical = ll_categorical_newelse:ll_categorical = ll_categorical + ll_categorical_new
关于categorical latent code的目标函数为G(z,c)对应于categorical的输出的softmax与categorical latent code的交叉熵。

第114行:

mean_contig = out[:, num_categorical:num_categorical + num_continuous]if fix_std:std_contig = tf.ones_like(mean_contig)
else:std_contig = tf.sqrt(tf.exp(out[:, num_categorical + num_continuous:num_categorical + num_continuous * 2]))epsilon = (true_continuous - mean_contig) / (std_contig + TINY)
ll_continuous = tf.reduce_sum(- 0.5 * np.log(2 * np.pi) - tf.log(std_contig + TINY) - 0.5 * tf.square(epsilon),reduction_indices=1,
)
关于continuous latent code的目标函数,将continuous latent code以均值为G(z,c)对应于continuous的输出,方差为1进行标准化,然后计算它以正态分布的概率密度作为目标函数。

mutual_info_lb = continuous_lambda * ll_continuous + categorical_lambda * ll_categorical
即为c与G(z,c)的互信息的目标函数。
再看训练,第421行:

# train discriminator
noise = sample_noise(batch_size)
_, summary_result1, disc_obj, infogan_obj = sess.run([train_discriminator, discriminator_obj_summary, discriminator_obj, neg_mutual_info_objective],feed_dict={true_images:batch,zc_vectors:noise,is_training_discriminator:True,is_training_generator:True}
)
以及第438行:

# train generator
noise = sample_noise(batch_size)
_, _, summary_result2, gen_obj, infogan_obj = sess.run([train_generator, train_mutual_info, generator_obj_summary, generator_obj, neg_mutual_info_objective],feed_dict={zc_vectors:noise,is_training_discriminator:True,is_training_generator:True}
)
看实验结果:



修改categorical变量,可以生成不同的数字图像;修改continuous变量,可以改变生成数字的倾斜度以及笔画的宽度。











http://chatgpt.dhexx.cn/article/cbg28CLX.shtml

相关文章

InfoGAN(基于信息最大化生成对抗网的可解释表征学习)

前言: 这篇博客为阅读论文后的总结与感受,方便日后翻阅、查缺补漏,侵删! 论文: InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets 解决的问题: In…

InfoGAN学习笔记

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, Pieter Abbeel https://arxiv.org/abs/1606.03657 一、从GAN到InfoGAN 1.GAN存在的问题 GAN…

生成对抗网络(十)----------infoGAN

一、infoGAN介绍 infoGAN采用的是无监督式学习的方式并尝试实现可解释特征。原始数据不包含任何标签信息,所有的特征都是通过网络以一种非监督的方式自动学习得到的。使用了信息论的原理,通过最大化输入噪声和观察值之间的互信息来对网络模型进行优化。…

InfoGAN详细介绍及特征解耦图像生成

InfoGAN详细介绍及特征解耦图像生成 一.InfoGAN框架理解特征耦合InfoGANInfoGAN论文实验结果 二.VAE-GAN框架理解VAE-GAN算法步骤 三.BiGAN框架理解四.InfoGAN论文复现使用MNIST数据集复现InfoGAN代码编写初始化判别器初始化生成器初始化分类器训练InfoGAN网络 总结参考文献及博…

InfoGAN介绍

今天给大家分享的是NIPS2016的InfoGAN。这篇paper所要达到的目标就是通过非监督学习得到可分解的特征表示。使用GAN加上最大化生成的图片和输入编码之间的互信息。最大的好处就是可以不需要监督学习,而且不需要大量额外的计算花销就能得到可解释的特征。 通常&#…

10.可视化、可视分析、探索式数据分析

前言:基于人大的《数据科学概论》第十章。主要内容为可视化的定义、可视化的意义、可视化的一般过程、科学可视化与信息可视化、可视化的原则、可视化实例、可视化的挑战和趋势、可视分析技术、探索式数据分析、可视化工具。 一、可视化的定义 可视化是数据的可视表…

国内外大数据可视化分析产品点评

一、KINETICA Kinetica 利用图像处理芯片提供计算支持,允许企业使用机器学习,商业智能分析和可视化技术更快速地分析海量数据; 点评:它的核心技术能力是类MapD的GPU数据库,功能性能较强大,非开源&#xf…

数据分析可视化之模型介绍

一 前言 “数据分析可视化”这条术语实现了成熟的科学可视化领域与较年轻的信息可视化领域的统一。 数据分析可视化:是指将大型数据集中的数据以图形图像形式表示,并利用数据分析和开发工具发现其中未知信息的处理过程。 数据可视化已经提出了许多方法&…

5个最受欢迎的大数据可视化软件

数据可视化的本质是视觉对话,数据可视化将数据分析技术与图形技术结合,清晰有效地将分析结果信息进行解读和传达。 大数据可视化是进行各种大数据分析解决的最重要组成部分之一。 一旦原始数据流被以图像形式表示时,以此做决策就变得容易多了…

软件架构-可视化

软件架构-可视化 当我们在讨论系统时,往往都会说这个系统的架构是什么样的,在你口述的同时,如果能借助某些图表,效果会更好,传统的uml建模比较复杂,目前的软件工程大家更关注效率(这里我不谈敏捷…

48 款数据可视化分析工具大集合

作者:DC君 来源:DataCastle数据城堡 本篇专门推荐48款数据可视化工具,全到你无法想象。 1、Excel 作为一个入门级工具,是快速分析数据的理想工具,也能创建供内部使用的数据图,但是Excel在颜色、线条和样式上…

爬取某小说榜单爬虫及可视化分析

爬取某小说榜单爬虫及可视化分析(仅用于学习) gitee代码链接:https://gitee.com/huang_jia_son/duoduo.git 介绍 GUI界面python爬虫数据清洗与处理pyecharts可视化展示软件架构 (1)通过tkinter制作GUI界面&#xf…

咖啡PowerBI可视化实例

目录 一、导入数据 二、构建指标 1、构架表之间的关系 ​ 2、完善表 3、构建指标 三、可视化 1、整体分析 2、省份业务 3、产品维度 4、地区维度 5、客户维度 6、价格分析 7、利润分析 8、其它分析 一、导入数据 二、构建指标 1、构架表之间的关系 软件会…

CodeScene - 软件质量可视化工具

CodeScene - 软件质量可视化工具 CodeScene https://codescene.com/ https://codescene.io/ The powerful visualization tool using Predictive Analytics to find hidden risks and social patterns in your code. 使用 Predictive Analytics 的功能强大的可视化工具&#x…

使用excel、python、tableau对招聘数据进行数据处理及可视化分析

招聘数据数据分析及可视化 数据来源前言一、观察数据删除重复值数据加工 二、利用python进行数据分析和可视化1.引入库2.读入数据观察描述统计,了解数据大概信息 3.数据预处理3.1数据清洗3.1.1 删除重复值3.1.2缺失值处理 3.2数据加工 4.数据可视化4.1城市岗位数量4…

可视化工具软件排行榜

市面上的数据可视化工具软件如此之多,有哪些可视化软件工具居于排行榜单的前列呢?你用的软件上榜了吗? 1、FineBI 来自帆软公司,虽作一个BI工具,但是可视化效果不错,可制作Dashboard。优势在于一旦准备好…

2020年六十款数据分析的可视化工具推荐

今天小编将为大家盘点六十款数据分析的可视化工具,让你妥妥的成为会议室乃至全公司最亮的崽~ 1、ChartBlocks ChartBlocks是一款网页版的可视化图表生成工具,在线使用。通过导入电子表格或者数据库来构建可视化图表。整个过程可以在图表的向导指示下完成。它的图表在HTML…

值得推荐的13款可视化软件,快收藏!

数据可视化力求用图表结合的方式把所有的数据整合在某一图像上,这样呈现在观众眼前的画面不仅仅是美观,且比以往长篇大论或是密密麻麻的数据表格更直观易懂,更便于观察分析。到今年上半年为止,国内外已经有了很多发展的较好的数据…

深入分析ArrayMap

前面我们分析了Android为了节省内存提供的一个HahMap<Integer, ?>的替代品SparseArray。SparseArray只能替代key的类型为int的Map。Android也提供了一个key不用局限于int的Map的实现&#xff0c;ArrayMap。老规矩我们通过调试来深入的分析一下ArrayMap&#xff08;看本文…

ArrayMAP介绍

它不是一个适应大数据的数据结构&#xff0c;相比传统的HashMap速度要慢&#xff0c;因为查找方法是二分法&#xff0c;并且当你删除或者添加数据时&#xff0c;会对空间重新调整&#xff0c;在使用大量数据时&#xff0c;效率并不明显&#xff0c;低于50%。 ArrayMap is a ge…