[生成对抗网络GAN入门指南](10)InfoGAN: Interpretable Representation Learning by Information Maximizing GAN

article/2025/3/4 9:10:36

本篇blog的内容基于原始论文InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets(NPIs2016)和《生成对抗网络入门指南》第六章。完整代码及简析见文章末尾


一、为什么要使用InfoGAN

InfoGAN采用无监督的方式学习,并尝试实现可解释特征。使用了信息论的原理,通过最大化输入噪声和观察值之间的互信息(Mutual Information,MI)来对网络模型进行优化。InfoGAN能适用于各种复杂的数据集,可以同时实现离散特征和连续特征。


二、输入端数据

InfoGAN在输入端把随机输入分为两个部分:

第一部分为z,代表噪声;

第二部分为c,代表隐含编码;

目标是希望在每个维度上都具备可解释型特征。

在同时输入噪声z和隐含编码c后,生成概率 P_G(x\mid c)=P_G(x),为了应对这个问题,在InfoGAN中需要对隐含编码c和生成分布G(z,c)求互信息 I(c;\ G(z.c)),并使其最大化

 

三、InfoGAN结构

InfoGAN和前面介绍过的GAN区别在于,真实训练数据不有标签数据,二输入数据为隐含编码和随机噪声的组合,最后通过判别器一端和最大化互信息的方式还原隐含编码的信息。也就是说,判别器D最终需要同时具备还原隐含编码和辨别真伪的能力。前者为了生成图像能够很好具备编码中的特性,也就是说隐含编码可以对生网络产生相对显著地成果;后者是要求生成模型在还原信息的同时保证生成的数据与真实数据非常逼近。

1. 互信息

互信息表示两个随机变量之间的依赖程度的度量。对于随机变量X和Y,互信息为I(X;Y),H(X)和H(Y)为边缘熵,H(X|Y)和H(Y|X)为条件熵。

2. 结构

 

3. 目标函数

当X和Y相互独立时候,互信息为0.给定任意的输入,希望生成器的 P_G(c\mid x) 有一个相对较小的熵,即希望隐含编码c的信息在生成过程中不会流失。对此我们修改目标函数:

由于概率P(c\mid x)能以得到,导致互信息难以最大化,实际计算可以定义一个近似概率的辅助分布来获取互信息的下界,推导如下:

由此可以得到互信息的下界值:

                                       I(c;G(z,c))\geq E_{x\sim G(z.c)} [ E_{C\sim P(c\midx )} [logQ(c^{'} \mid x)] ] +H(c)

 

4. InfoGAN的推导

我们可以重新改写之前不等式,并重新使蒙特卡洛方法逼近

得到我们最终的目标函数

 

四、实验效果

1.MNIST数据

我们发现通过控制隐含编码中的c_1可以调节生成数字是几,其他参数可以调节生成字符的倾斜程度、字体宽度等

 

2. 3D人脸数据

 

3. 椅子数据集

4. 门牌号数据集

 

五、实验代码

1. 导入相关包及超参数

from __future__ import print_function, divisionfrom keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, concatenate
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical
import keras.backend as Kimport matplotlib.pyplot as pltimport numpy as npclass INFOGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.num_classes = 10self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 72optimizer = Adam(0.0002, 0.5)losses = ['binary_crossentropy', self.mutual_info_loss]# Build and the discriminator and recognition networkself.discriminator, self.auxilliary = self.build_disk_and_q_net()self.discriminator.compile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# Build and compile the recognition network Qself.auxilliary.compile(loss=[self.mutual_info_loss],optimizer=optimizer,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise and the target label as input# and generates the corresponding digit of that labelgen_input = Input(shape=(self.latent_dim,))img = self.generator(gen_input)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated image as input and determines validityvalid = self.discriminator(img)# The recognition network produces the labeltarget_label = self.auxilliary(img)# The combined model  (stacked generator and discriminator)self.combined = Model(gen_input, [valid, target_label])self.combined.compile(loss=losses,optimizer=optimizer)

 

2. 构造生成器和判别器

    def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(self.channels, kernel_size=3, padding='same'))model.add(Activation("tanh"))gen_input = Input(shape=(self.latent_dim,))img = model(gen_input)model.summary()return Model(gen_input, img)def build_disk_and_q_net(self):img = Input(shape=self.img_shape)# Shared layers between discriminator and recognition networkmodel = Sequential()model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Flatten())img_embedding = model(img)# Discriminatorvalidity = Dense(1, activation='sigmoid')(img_embedding)# Recognitionq_net = Dense(128, activation='relu')(img_embedding)label = Dense(self.num_classes, activation='softmax')(q_net)# Return discriminator and recognition networkreturn Model(img, validity), Model(img, label)

 

3. 构造互信息

    def mutual_info_loss(self, c, c_given_x):"""The mutual information metric we aim to minimize"""eps = 1e-8conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))return conditional_entropy + entropydef sample_generator_input(self, batch_size):# Generator inputssampled_noise = np.random.normal(0, 1, (batch_size, 62))sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)return sampled_noise, sampled_labels

 

4. 训练

    def train(self, epochs, batch_size=128, sample_interval=50):# Load the dataset(X_train, y_train), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------#  Train Discriminator# ---------------------# Select a random half batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Sample noise and categorical labelssampled_noise, sampled_labels = self.sample_generator_input(batch_size)gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)# Generate a half batch of new imagesgen_imgs = self.generator.predict(gen_input)# Train on real and generated datad_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# Avg. lossd_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  Train Generator and Q-network# ---------------------g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])# Plot the progressprint ("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))# If at save interval => save generated image samplesif epoch % sample_interval == 0:self.sample_images(epoch)

 

5. 可视化

    def sample_images(self, epoch):r, c = 10, 10fig, axs = plt.subplots(r, c)for i in range(c):sampled_noise, _ = self.sample_generator_input(c)label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)gen_input = np.concatenate((sampled_noise, label), axis=1)gen_imgs = self.generator.predict(gen_input)gen_imgs = 0.5 * gen_imgs + 0.5for j in range(r):axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')axs[j,i].axis('off')fig.savefig("images/%d.png" % epoch)plt.close()def save_model(self):def save(model, model_name):model_path = "saved_model/%s.json" % model_nameweights_path = "saved_model/%s_weights.hdf5" % model_nameoptions = {"file_arch": model_path,"file_weight": weights_path}json_string = model.to_json()open(options['file_arch'], 'w').write(json_string)model.save_weights(options['file_weight'])save(self.generator, "generator")save(self.discriminator, "discriminator")if __name__ == '__main__':infogan = INFOGAN()infogan.train(epochs=50000, batch_size=128, sample_interval=50)

 

实验结果

 

完整代码

from __future__ import print_function, divisionfrom keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, concatenate
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical
import keras.backend as Kimport matplotlib.pyplot as pltimport numpy as npclass INFOGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.num_classes = 10self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 72optimizer = Adam(0.0002, 0.5)losses = ['binary_crossentropy', self.mutual_info_loss]# Build and the discriminator and recognition networkself.discriminator, self.auxilliary = self.build_disk_and_q_net()self.discriminator.compile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# Build and compile the recognition network Qself.auxilliary.compile(loss=[self.mutual_info_loss],optimizer=optimizer,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise and the target label as input# and generates the corresponding digit of that labelgen_input = Input(shape=(self.latent_dim,))img = self.generator(gen_input)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated image as input and determines validityvalid = self.discriminator(img)# The recognition network produces the labeltarget_label = self.auxilliary(img)# The combined model  (stacked generator and discriminator)self.combined = Model(gen_input, [valid, target_label])self.combined.compile(loss=losses,optimizer=optimizer)def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(self.channels, kernel_size=3, padding='same'))model.add(Activation("tanh"))gen_input = Input(shape=(self.latent_dim,))img = model(gen_input)model.summary()return Model(gen_input, img)def build_disk_and_q_net(self):img = Input(shape=self.img_shape)# Shared layers between discriminator and recognition networkmodel = Sequential()model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Flatten())img_embedding = model(img)# Discriminatorvalidity = Dense(1, activation='sigmoid')(img_embedding)# Recognitionq_net = Dense(128, activation='relu')(img_embedding)label = Dense(self.num_classes, activation='softmax')(q_net)# Return discriminator and recognition networkreturn Model(img, validity), Model(img, label)def mutual_info_loss(self, c, c_given_x):"""The mutual information metric we aim to minimize"""eps = 1e-8conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))return conditional_entropy + entropydef sample_generator_input(self, batch_size):# Generator inputssampled_noise = np.random.normal(0, 1, (batch_size, 62))sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)return sampled_noise, sampled_labelsdef train(self, epochs, batch_size=128, sample_interval=50):# Load the dataset(X_train, y_train), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------#  Train Discriminator# ---------------------# Select a random half batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Sample noise and categorical labelssampled_noise, sampled_labels = self.sample_generator_input(batch_size)gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)# Generate a half batch of new imagesgen_imgs = self.generator.predict(gen_input)# Train on real and generated datad_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# Avg. lossd_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  Train Generator and Q-network# ---------------------g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])# Plot the progressprint ("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))# If at save interval => save generated image samplesif epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 10, 10fig, axs = plt.subplots(r, c)for i in range(c):sampled_noise, _ = self.sample_generator_input(c)label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)gen_input = np.concatenate((sampled_noise, label), axis=1)gen_imgs = self.generator.predict(gen_input)gen_imgs = 0.5 * gen_imgs + 0.5for j in range(r):axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')axs[j,i].axis('off')fig.savefig("images/%d.png" % epoch)plt.close()def save_model(self):def save(model, model_name):model_path = "saved_model/%s.json" % model_nameweights_path = "saved_model/%s_weights.hdf5" % model_nameoptions = {"file_arch": model_path,"file_weight": weights_path}json_string = model.to_json()open(options['file_arch'], 'w').write(json_string)model.save_weights(options['file_weight'])save(self.generator, "generator")save(self.discriminator, "discriminator")if __name__ == '__main__':infogan = INFOGAN()infogan.train(epochs=50000, batch_size=128, sample_interval=50)

 


http://chatgpt.dhexx.cn/article/RRw5E4id.shtml

相关文章

登堂入室__生成对抗网络的信息论扩展(infoGAN)(五)

简介 InfoGAN是生成对抗网络信息理论的扩展,能够以完全非监督的方式得到可分解的特征表示。它可以最大化隐含(latent)变量子集与观测值之间的互信息(mutual information),并且发现了有效优化互信息目标的下界。 原论文地址:https://arxiv.org…

对抗生成网络学习(五)——infoGAN生成宽窄不一,高低各异的服装影像(tensorflow实现)

一、背景 前一阶段比较忙,很久没有继续做GAN的实验了。近期终于抽空做完了infoGAN,个人认为infoGAN是对GAN的更进一步改进,由于GAN是输入的随机生成噪声,所以生成的图像也是随机的,而infoGAN想要生成的是指定特征的图…

GAN生成对抗网络合集(三):InfoGAN和ACGAN-指定类别生成模拟样本的GAN(附代码)

1 InfoGAN-带有隐含信息的GAN InfoGAN是一种把信息论与GAN相融合的神经网络,能够使网络具有信息解读功能。 GAN的生成器在构建样本时使用了任意的噪声向量x’,并从低维的噪声数据x’中还原出来高维的样本数据。这说明数据x’中含有具有与样本相同…

InfoGAN 翻译

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets翻译 摘要 本文描述了InfoGAN,它是生成对抗网络的信息论扩展,能够以完全无监督的方式学习分解表征。 InfoGAN是一种生成对抗网络,它…

【论文阅读】InfoGAN: Interpretable Representation Learning by Information Maximizing GAN

论文下载 bib: inproceedings{chenduan2016infogan,author {Xi Chen and Yan Duan and Rein Houthooft and John Schulman and Ilya Sutskever and Pieter Abbeel},title {InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Advers…

深度学习-李宏毅GAN学习之InfoGAN,VAE-GAN,BiGAN

深度学习-李宏毅GAN学习之InfoGAN,VAE-GAN,BiGAN 提出问题InfoGANVAE-GANBiGAN总结 提出问题 我们知道最基本的GAN就是输入一个随机的向量,输出一个图片。以手写数字为例,我们希望修改随机向量的某一维,能改变数字的特…

GAN及其变体C_GAN,infoGAN,AC_GAN,DC_GAN(一)

当时害怕篇幅过大,拆分两部分编写,下一篇文章见:GAN及其变体DCGAN, CGAN,infoGAN,BiGAN,ACGAN,WGAN,DualGAN(二) 在介绍GAN之前,我们先了解一些什么是生成模型(Generative Model)和判别模型(Di…

深度学习《InfoGAN模型》

一:网络介绍 普通的GAN网络的特点是无约束,对网络输入的噪声也不好解释,CGAN中我们通过给噪声合并一些类别数据,改变了输出形式,可以训练出我们指定类别的数据,这一点也是某种程度的解释,但是解…

InfoGAN原理PyTorch实现Debug记录

1. CGAN从无监督GAN改进成有监督的GAN GAN的基本原理输入是随机噪声,无法控制输出和输入之间的对应关系,也无法控制输出的模式,CGAN全称是条件GAN(Conditional GAN)改进基本的GAN解决了这个问题,CGAN和基本…

CGAN和InfoGAN理解

在一些比较经典的GAN模型中(像WGAN、LSGAN、DCGAN等等),往往都是从样本空间里随机采样得到输入噪声,生成的图像究竟属于哪一个类别也是随机的。通过这些模型,我们无法生成指定类别的数据。 举个不恰当的例子&#xff…

InfoGAN

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, Pieter Abbeel https://arxiv.org/abs/1606.03657 一、从GAN到InfoGAN 1.GAN存在的问题 GAN通…

InfoGAN 生成时序序列

InfoGAN 生成时序序列 简介 完整代码:https://github.com/SongDark/timeseries_infogan 本文介绍用InfoGAN生成多维时序序列。 数据 数据集下载地址 NameClassDimensionTrain SizeTest SizeTruncatedCharacterTrajectories20314221436182 样本介绍 CharacterTr…

InfoGAN详解与实现(采用tensorflow2.x实现)

InfoGAN详解与实现(采用tensorflow2.x实现) InfoGAN原理InfoGAN实现导入必要库生成器鉴别器模型构建模型训练效果展示 InfoGAN原理 最初的GAN能够产生有意义的输出,但是缺点是它的属性无法控制。例如,无法明确向生成器提出生成女…

InfoGAN论文笔记+源码解析

论文地址:InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets 源码地址:InfoGAN in TensorFlow GAN,Generative Adversarial Network是目前非常火也是非常有潜力的一个发展方向&#…

InfoGAN(基于信息最大化生成对抗网的可解释表征学习)

前言: 这篇博客为阅读论文后的总结与感受,方便日后翻阅、查缺补漏,侵删! 论文: InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets 解决的问题: In…

InfoGAN学习笔记

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, Pieter Abbeel https://arxiv.org/abs/1606.03657 一、从GAN到InfoGAN 1.GAN存在的问题 GAN…

生成对抗网络(十)----------infoGAN

一、infoGAN介绍 infoGAN采用的是无监督式学习的方式并尝试实现可解释特征。原始数据不包含任何标签信息,所有的特征都是通过网络以一种非监督的方式自动学习得到的。使用了信息论的原理,通过最大化输入噪声和观察值之间的互信息来对网络模型进行优化。…

InfoGAN详细介绍及特征解耦图像生成

InfoGAN详细介绍及特征解耦图像生成 一.InfoGAN框架理解特征耦合InfoGANInfoGAN论文实验结果 二.VAE-GAN框架理解VAE-GAN算法步骤 三.BiGAN框架理解四.InfoGAN论文复现使用MNIST数据集复现InfoGAN代码编写初始化判别器初始化生成器初始化分类器训练InfoGAN网络 总结参考文献及博…

InfoGAN介绍

今天给大家分享的是NIPS2016的InfoGAN。这篇paper所要达到的目标就是通过非监督学习得到可分解的特征表示。使用GAN加上最大化生成的图片和输入编码之间的互信息。最大的好处就是可以不需要监督学习,而且不需要大量额外的计算花销就能得到可解释的特征。 通常&#…