InfoGAN详解与实现(采用tensorflow2.x实现)

article/2025/3/4 9:04:38

InfoGAN详解与实现(采用tensorflow2.x实现)

    • InfoGAN原理
    • InfoGAN实现
      • 导入必要库
      • 生成器
      • 鉴别器
      • 模型构建
      • 模型训练
      • 效果展示

InfoGAN原理

最初的GAN能够产生有意义的输出,但是缺点是它的属性无法控制。例如,无法明确向生成器提出生成女性名人的脸,该女性名人是黑发,白皙的肤色,棕色的眼睛,微笑着。这样做的根本原因是因为使用的100-dim噪声矢量合并了生成器输出的所有显着属性。
如果能够修改原始GAN,从而将表示形式分为合并和分离可解释的潜在编码向量,则可以告诉生成器要合成什么。
合并和分离编码可以表示如下:
合并编码与分离编码对比具有分离表示的GAN也可以以与普通GAN相同的方式进行优化。生成器的输出可以表示为:
G ( z , c ) = G ( z ) G(z,c)=G(z) G(z,c)=G(z)
编码 z = ( z , c ) z = (z,c) z=(z,c)包含两个元素, z z z表示合并表示, c = c 1 , c 2 , . . . , c L c=c_1,c_2,...,c_L c=c1,c2,...,cL表示分离的编码表示。
为了强制编码的解耦,InfoGAN提出了一种针对原始损失函数的正则化函数,该函数将潜在编码 c c c G ( z , c ) G(z,c) G(z,c)之间的互信息最大化:
I ( c ; G ( z , c ) ) = I G ( c ; z ) I(c;G(z,c))=IG(c;z) I(c;G(z,c))=IG(c;z)
正则化器强制生成器考虑潜在编码。在信息论领域,潜在编码 c c c G ( z , c ) G(z,c) G(z,c)之间的互信息定义为:
I ( G ( c ; z ) = H ( c ) − H ( c ∣ G ( z , c ) ) I(G(c;z)=H(c)-H(c|G(z,c)) I(G(c;z)=H(c)H(cG(z,c))
其中 H ( c ) H(c) H(c)是潜在编码 c c c的熵,而 H ( c ∣ G ( z , c ) ) H(c|G(z,c)) H(cG(z,c))是得到生成器的输出 G ( z , c ) G(z,c) G(z,c)后c的条件熵。
最大化互信息意味着在生成得到生成的输出时将 H ( c ∣ G ( z , c ) ) H(c|G(z,c)) H(cG(z,c))最小化或减小潜在编码中的不确定性。
但是由于估计 H ( c ∣ G ( z , c ) ) H(c|G(z,c)) H(cG(z,c))需要后验分布 p ( c ∣ G ( z , c ) ) = p ( c ∣ x ) p(c|G(z,c))=p(c|x) p(cG(z,c))=p(cx),因此难以估算 H ( c ∣ G ( z , c ) ) H(c|G(z,c)) H(cG(z,c))
解决方法是通过使用辅助分布 Q ( c ∣ x ) Q(c|x) Q(cx)估计后验概率来估计互信息的下限,估计相互信息的下限为:
I ( c ; G ( z , c ) ) ≥ L I ( G , Q ) = E c ∼ p ( c ) , x ∼ G ( z , c ) [ l o g Q ( c ∣ x ) ] + H ( c ) I(c;G(z,c)) \ge L_I(G,Q)=E_{c \sim p(c),x \sim G(z,c)}[logQ(c|x)]+H(c) I(c;G(z,c))LI(G,Q)=Ecp(c),xG(z,c)[logQ(cx)]+H(c)
在InfoGAN中,假设 H ( c ) H(c) H(c)为常数。因此,使互信息最大化是使期望最大化的问题。生成器必须确信已生成具有特定属性的输出。此期望的最大值为零。因此,互信息的下限的最大值为 H ( c ) H(c) H(c)。在InfoGAN中,离散潜在编码 Q ( c ∣ x ) Q(c|x) Q(cx)的可以用softmax表示。期望是tf.keras中的负categorical_crossentropy损失。
对于一维连续编码,期望是 c c c x x x上的二重积分,这是由于期望样本同时来自分离编码分布和生成器分布。估计期望值的一种方法是通过假设样本是连续数据的良好度量。因此,损失估计为 c l o g Q ( c ∣ x ) clogQ(c|x) clogQ(cx)
为了完成InfoGAN的网络,应该有一个 l o g Q ( c ∣ x ) logQ(c|x) logQ(cx)的实现。为简单起见,网络Q是附加到鉴别器的辅助网络。
InfoGAN网络架构鉴别器损失函数
L ( D ) = − E x ∼ p d a t a l o g D ( x ) − E z , c l o g [ 1 − D ( G ( z , c ) ) ] − λ I ( c ; G ( z , c ) ) \mathcal L^{(D)} = -\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_{z,c}log[1 − D(G(z,c))]-\lambda I(c;G(z,c)) L(D)=ExpdatalogD(x)Ez,clog[1D(G(z,c))]λI(c;G(z,c))
生成器损失函数:
L ( G ) = − E z , c l o g D ( G ( z , c ) ) − λ I ( c ; G ( z , c ) ) \mathcal L^{(G)} = -\mathbb E_{z,c}logD(G(z,c))-\lambda I(c;G(z,c)) L(G)=Ez,clogD(G(z,c))λI(c;G(z,c))
其中 λ \lambda λ是正的常数

InfoGAN实现

如果将其应用于MNIST数据集,InfoGAN可以学习分离的离散编码和连续编码,以修改生成器输出属性。 例如,像CGAN和ACGAN一样,将使用10维独热标签形式的离散编码来指定要生成的数字。但是,可以添加两个连续的编码,一个用于控制书写样式的角度,另一个用于调整笔划宽度。保留较小尺寸的编码以表示所有其他属性:

MNIST数据集编码形式

导入必要库

import tensorflow as tf
import numpy as np
from tensorflow import keras
import os
from matplotlib import pyplot as plt
import math
from PIL import Image
from tensorflow.keras import backend as K

生成器

def generator(inputs,image_size,activation='sigmoid',labels=None,codes=None):"""generator modelArguments:inputs (layer): input layer of generatorimage_size (int): Target size of one sideactivation (string): name of output activation layerlabels (tensor): input labelscodes (list): 2-dim disentangled codes for infoGANreturns:model: generator model"""image_resize = image_size // 4kernel_size = 5layer_filters = [128,64,32,1]inputs = [inputs,labels] + codesx = keras.layers.concatenate(inputs,axis=1)x = keras.layers.Dense(image_resize*image_resize*layer_filters[0])(x)x = keras.layers.Reshape((image_resize,image_resize,layer_filters[0]))(x)for filters in layer_filters:if filters > layer_filters[-2]:strides = 2else:strides = 1x = keras.layers.BatchNormalization()(x)x = keras.layers.Activation('relu')(x)x = keras.layers.Conv2DTranspose(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)if activation is not None:x = keras.layers.Activation(activation)(x)return keras.Model(inputs,x,name='generator')

鉴别器

def discriminator(inputs,activation='sigmoid',num_labels=None,num_codes=None):"""discriminator modelArguments:inputs (Layer): input layer of the discriminatoractivation (string): name of output activation layernum_labels (int): dimension of one-hot labels for ACGAN & InfoGANnum_codes (int): num_codes-dim 2 Q network if InfoGANReturns:Model: Discriminator model"""kernel_size = 5layer_filters = [32,64,128,256]x = inputsfor filters in layer_filters:if filters == layer_filters[-1]:strides = 1else:strides = 2x = keras.layers.LeakyReLU(0.2)(x)x = keras.layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=strides,padding='same')(x)x = keras.layers.Flatten()(x)outputs = keras.layers.Dense(1)(x)if activation is not None:print(activation)outputs = keras.layers.Activation(activation)(outputs)if num_labels:layer = keras.layers.Dense(layer_filters[-2])(x)labels = keras.layers.Dense(num_labels)(layer)labels = keras.layers.Activation('softmax',name='label')(labels)# 1-dim continous Q of 1st c given xcode1 = keras.layers.Dense(1)(layer)code1 = keras.layers.Activation('sigmoid',name='code1')(code1)# 1-dim continous Q of 2nd c given xcode2 = keras.layers.Dense(1)(layer)code2 = keras.layers.Activation('sigmoid',name='code2')(code2)outputs = [outputs,labels,code1,code2]return keras.Model(inputs,outputs,name='discriminator')

模型构建

#mi_loss
def mi_loss(c,q_of_c_give_x):"""mi_loss = -c * log(Q(c|x))"""return K.mean(-K.sum(K.log(q_of_c_give_x + K.epsilon()) * c,axis=1))def build_and_train_models(latent_size=100):"""Load the dataset, build InfoGAN models,Call the InfoGAN train routine."""(x_train,y_train),_ = keras.datasets.mnist.load_data()image_size = x_train.shape[1]x_train = np.reshape(x_train,[-1,image_size,image_size,1])x_train = x_train.astype('float32') / 255.num_labels = len(np.unique(y_train))y_train = keras.utils.to_categorical(y_train)#超参数model_name = 'infogan_mnist'batch_size = 64train_steps = 40000lr = 2e-4decay = 6e-8input_shape = (image_size,image_size,1)label_shape = (num_labels,)code_shape = (1,)#discriminator modelinputs = keras.layers.Input(shape=input_shape,name='discriminator_input')#discriminator with 4 outputsdiscriminator_model = discriminator(inputs,num_labels=num_labels,num_codes=2)optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay)loss = ['binary_crossentropy','categorical_crossentropy',mi_loss,mi_loss]loss_weights = [1.0,1.0,0.5,0.5]discriminator_model.compile(loss=loss,loss_weights=loss_weights,optimizer=optimizer,metrics=['acc'])discriminator_model.summary()input_shape = (latent_size,)inputs = keras.layers.Input(shape=input_shape,name='z_input')labels = keras.layers.Input(shape=label_shape,name='labels')code1 = keras.layers.Input(shape=code_shape,name='code1')code2 = keras.layers.Input(shape=code_shape,name='code2')generator_model = generator(inputs,image_size,labels=labels,codes=[code1,code2])generator_model.summary()optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5)discriminator_model.trainable = Falseinputs = [inputs,labels,code1,code2]adversarial_model = keras.Model(inputs,discriminator_model(generator_model(inputs)),name=model_name)adversarial_model.compile(loss=loss,loss_weights=loss_weights,optimizer=optimizer,metrics=['acc'])adversarial_model.summary()models = (generator_model,discriminator_model,adversarial_model)data = (x_train,y_train)params = (batch_size,latent_size,train_steps,num_labels,model_name)train(models,data,params)

模型训练

def train(models,data,params):"""Train the network#Argumentsmodels (Models): generator,discriminator,adversarial modeldata (tuple): x_train,y_train dataparams (tuple): Network params"""generator,discriminator,adversarial = modelsx_train,y_train = databatch_size,latent_size,train_steps,num_labels,model_name = paramssave_interval = 500code_std = 0.5noise_input = np.random.uniform(-1.0,1.,size=[16,latent_size])noise_label = np.eye(num_labels)[np.arange(0,16) % num_labels]noise_code1 = np.random.normal(scale=code_std,size=[16,1])noise_code2 = np.random.normal(scale=code_std,size=[16,1])train_size = x_train.shape[0]print(model_name,"Labels for generated images: ",np.argmax(noise_label, axis=1))for i in range(train_steps):rand_indexes = np.random.randint(0,train_size,size=batch_size)real_images = x_train[rand_indexes]real_labels = y_train[rand_indexes]#random codes for real imagesreal_code1 = np.random.normal(scale=code_std,size=[batch_size,1])real_code2 = np.random.normal(scale=code_std,size=[batch_size,1])#生成假图片,标签和编码noise = np.random.uniform(-1.,1.,size=[batch_size,latent_size])fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]fake_code1 = np.random.normal(scale=code_std,size=[batch_size,1])fake_code2 = np.random.normal(scale=code_std,size=[batch_size,1])inputs = [noise,fake_labels,fake_code1,fake_code2]fake_images = generator.predict(inputs)x = np.concatenate((real_images,fake_images))labels = np.concatenate((real_labels,fake_labels))codes1 = np.concatenate((real_code1,fake_code1))codes2 = np.concatenate((real_code2,fake_code2))y = np.ones([2 * batch_size,1])y[batch_size:,:] = 0#train discriminator networkoutputs = [y,labels,codes1,codes2]# metrics = ['loss', 'activation_1_loss', 'label_loss',# 'code1_loss', 'code2_loss', 'activation_1_acc',# 'label_acc', 'code1_acc', 'code2_acc']metrics = discriminator.train_on_batch(x, outputs)fmt = "%d: [dis: %f, bce: %f, ce: %f, mi: %f, mi:%f, acc: %f]"log = fmt % (i, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4], metrics[6])#train the adversarial networknoise = np.random.uniform(-1.,1.,size=[batch_size,latent_size])fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]fake_code1 = np.random.normal(scale=code_std,size=[batch_size,1])fake_code2 = np.random.normal(scale=code_std,size=[batch_size,1])y = np.ones([batch_size,1])inputs = [noise,fake_labels,fake_code1,fake_code2]outputs = [y,fake_labels,fake_code1,fake_code2]metrics = adversarial.train_on_batch(inputs,outputs)fmt = "%s [adv: %f, bce: %f, ce: %f, mi: %f, mi:%f, acc: %f]"log = fmt % (log, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4], metrics[6])print(log)if (i + 1) % save_interval == 0:# plot generator images on a periodic basisplot_images(generator,noise_input=noise_input,noise_label=noise_label,noise_codes=[noise_code1, noise_code2],show=False,step=(i + 1),model_name=model_name)# save the modelif (i + 1) % (2 * save_interval) == 0:generator.save(model_name + ".h5")

效果展示

#绘制生成图片
def plot_images(generator,noise_input,noise_label=None,noise_codes=None,show=False,step=0,model_name="gan"):"""Generate fake images and plot themFor visualization purposes, generate fake imagesthen plot them in a square grid# Argumentsgenerator (Model): The Generator Model for fake images generationnoise_input (ndarray): Array of z-vectorsshow (bool): Whether to show plot or notstep (int): Appended to filename of the save imagesmodel_name (string): Model name"""os.makedirs(model_name, exist_ok=True)filename = os.path.join(model_name, "%05d.png" % step)rows = int(math.sqrt(noise_input.shape[0]))if noise_label is not None:noise_input = [noise_input, noise_label]if noise_codes is not None:noise_input += noise_codesimages = generator.predict(noise_input)plt.figure(figsize=(2.2, 2.2))num_images = images.shape[0]image_size = images.shape[1]for i in range(num_images):plt.subplot(rows, rows, i + 1)image = np.reshape(images[i], [image_size, image_size])plt.imshow(image, cmap='gray')plt.axis('off')plt.savefig(filename)if show:plt.show()else:plt.close('all')
#模型训练
build_and_train_models(latent_size=62)
steps = 500

steps = 500

steps = 16000

steps = 16000

修改书写角度的分离编码

修改书写角度的分离编码


http://chatgpt.dhexx.cn/article/jQMmg1Jq.shtml

相关文章

InfoGAN论文笔记+源码解析

论文地址:InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets 源码地址:InfoGAN in TensorFlow GAN,Generative Adversarial Network是目前非常火也是非常有潜力的一个发展方向&#…

InfoGAN(基于信息最大化生成对抗网的可解释表征学习)

前言: 这篇博客为阅读论文后的总结与感受,方便日后翻阅、查缺补漏,侵删! 论文: InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets 解决的问题: In…

InfoGAN学习笔记

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, Pieter Abbeel https://arxiv.org/abs/1606.03657 一、从GAN到InfoGAN 1.GAN存在的问题 GAN…

生成对抗网络(十)----------infoGAN

一、infoGAN介绍 infoGAN采用的是无监督式学习的方式并尝试实现可解释特征。原始数据不包含任何标签信息,所有的特征都是通过网络以一种非监督的方式自动学习得到的。使用了信息论的原理,通过最大化输入噪声和观察值之间的互信息来对网络模型进行优化。…

InfoGAN详细介绍及特征解耦图像生成

InfoGAN详细介绍及特征解耦图像生成 一.InfoGAN框架理解特征耦合InfoGANInfoGAN论文实验结果 二.VAE-GAN框架理解VAE-GAN算法步骤 三.BiGAN框架理解四.InfoGAN论文复现使用MNIST数据集复现InfoGAN代码编写初始化判别器初始化生成器初始化分类器训练InfoGAN网络 总结参考文献及博…

InfoGAN介绍

今天给大家分享的是NIPS2016的InfoGAN。这篇paper所要达到的目标就是通过非监督学习得到可分解的特征表示。使用GAN加上最大化生成的图片和输入编码之间的互信息。最大的好处就是可以不需要监督学习,而且不需要大量额外的计算花销就能得到可解释的特征。 通常&#…

10.可视化、可视分析、探索式数据分析

前言:基于人大的《数据科学概论》第十章。主要内容为可视化的定义、可视化的意义、可视化的一般过程、科学可视化与信息可视化、可视化的原则、可视化实例、可视化的挑战和趋势、可视分析技术、探索式数据分析、可视化工具。 一、可视化的定义 可视化是数据的可视表…

国内外大数据可视化分析产品点评

一、KINETICA Kinetica 利用图像处理芯片提供计算支持,允许企业使用机器学习,商业智能分析和可视化技术更快速地分析海量数据; 点评:它的核心技术能力是类MapD的GPU数据库,功能性能较强大,非开源&#xf…

数据分析可视化之模型介绍

一 前言 “数据分析可视化”这条术语实现了成熟的科学可视化领域与较年轻的信息可视化领域的统一。 数据分析可视化:是指将大型数据集中的数据以图形图像形式表示,并利用数据分析和开发工具发现其中未知信息的处理过程。 数据可视化已经提出了许多方法&…

5个最受欢迎的大数据可视化软件

数据可视化的本质是视觉对话,数据可视化将数据分析技术与图形技术结合,清晰有效地将分析结果信息进行解读和传达。 大数据可视化是进行各种大数据分析解决的最重要组成部分之一。 一旦原始数据流被以图像形式表示时,以此做决策就变得容易多了…

软件架构-可视化

软件架构-可视化 当我们在讨论系统时,往往都会说这个系统的架构是什么样的,在你口述的同时,如果能借助某些图表,效果会更好,传统的uml建模比较复杂,目前的软件工程大家更关注效率(这里我不谈敏捷…

48 款数据可视化分析工具大集合

作者:DC君 来源:DataCastle数据城堡 本篇专门推荐48款数据可视化工具,全到你无法想象。 1、Excel 作为一个入门级工具,是快速分析数据的理想工具,也能创建供内部使用的数据图,但是Excel在颜色、线条和样式上…

爬取某小说榜单爬虫及可视化分析

爬取某小说榜单爬虫及可视化分析(仅用于学习) gitee代码链接:https://gitee.com/huang_jia_son/duoduo.git 介绍 GUI界面python爬虫数据清洗与处理pyecharts可视化展示软件架构 (1)通过tkinter制作GUI界面&#xf…

咖啡PowerBI可视化实例

目录 一、导入数据 二、构建指标 1、构架表之间的关系 ​ 2、完善表 3、构建指标 三、可视化 1、整体分析 2、省份业务 3、产品维度 4、地区维度 5、客户维度 6、价格分析 7、利润分析 8、其它分析 一、导入数据 二、构建指标 1、构架表之间的关系 软件会…

CodeScene - 软件质量可视化工具

CodeScene - 软件质量可视化工具 CodeScene https://codescene.com/ https://codescene.io/ The powerful visualization tool using Predictive Analytics to find hidden risks and social patterns in your code. 使用 Predictive Analytics 的功能强大的可视化工具&#x…

使用excel、python、tableau对招聘数据进行数据处理及可视化分析

招聘数据数据分析及可视化 数据来源前言一、观察数据删除重复值数据加工 二、利用python进行数据分析和可视化1.引入库2.读入数据观察描述统计,了解数据大概信息 3.数据预处理3.1数据清洗3.1.1 删除重复值3.1.2缺失值处理 3.2数据加工 4.数据可视化4.1城市岗位数量4…

可视化工具软件排行榜

市面上的数据可视化工具软件如此之多,有哪些可视化软件工具居于排行榜单的前列呢?你用的软件上榜了吗? 1、FineBI 来自帆软公司,虽作一个BI工具,但是可视化效果不错,可制作Dashboard。优势在于一旦准备好…

2020年六十款数据分析的可视化工具推荐

今天小编将为大家盘点六十款数据分析的可视化工具,让你妥妥的成为会议室乃至全公司最亮的崽~ 1、ChartBlocks ChartBlocks是一款网页版的可视化图表生成工具,在线使用。通过导入电子表格或者数据库来构建可视化图表。整个过程可以在图表的向导指示下完成。它的图表在HTML…

值得推荐的13款可视化软件,快收藏!

数据可视化力求用图表结合的方式把所有的数据整合在某一图像上,这样呈现在观众眼前的画面不仅仅是美观,且比以往长篇大论或是密密麻麻的数据表格更直观易懂,更便于观察分析。到今年上半年为止,国内外已经有了很多发展的较好的数据…

深入分析ArrayMap

前面我们分析了Android为了节省内存提供的一个HahMap<Integer, ?>的替代品SparseArray。SparseArray只能替代key的类型为int的Map。Android也提供了一个key不用局限于int的Map的实现&#xff0c;ArrayMap。老规矩我们通过调试来深入的分析一下ArrayMap&#xff08;看本文…