生成对抗网络(十)----------infoGAN

article/2025/3/4 11:30:37

一、infoGAN介绍

infoGAN采用的是无监督式学习的方式并尝试实现可解释特征。原始数据不包含任何标签信息,所有的特征都是通过网络以一种非监督的方式自动学习得到的。使用了信息论的原理,通过最大化输入噪声和观察值之间的互信息来对网络模型进行优化。inforGAN能够适应于各类复杂的数据集,可以同时实现离散特征和连续特征,较传统的GAN训练时间更短。

infoGAN在输入端把随机输入分为两部分:第一部分为z,代表随机噪声;第二部分为c,代表隐含编码,目标是希望在每个维度上都具备可解释特征。infoGAN与之前的GAN区别在于真实训练数据并不带有标签信息,而输入数据为隐含编码和随机噪声的组合,最后通过判别器一端和最大化互信息的方式还原隐含编码的信息。判别器D最终同时具备还原隐含编码和辨别真伪的能力。前者是为了生成图像能够很好地具备编码中的特性,也就是说隐含编码可以对生成网络产生相对显著的效果。后者是要求生成模型在还原信息的同时保证生成的数据与真实数据非常逼近。关于infoGAN的具体理论网上有很多。我觉得如果有人想研究。最重要的是看论文,结合一些博客的介绍最终弄明白这个模型。而我学习这些模型的目的是为了了解。我了解了模型大概的思想与做法。这就是我的目的。如果有一天我需要用到infoGAN。我会在去找论文。仔细研究理论。

二、infoGAN代码实现

1. 导包

from __future__ import print_function, divisionfrom keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categoricalimport keras.backend as K
import matplotlib.pyplot as plt
import numpy as npimport tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)

2. 初始化

class INFOGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.num_classes = 10self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 72optimizer = Adam(0.0002, 0.5)losses = ['binary_crossentropy', self.mutual_info_loss]# 建造判别器和识别网络self.discriminator, self.auxilliary = self.build_disk_and_q_net()# 编译判别网络self.discriminator.compile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# 编译识别网络self.auxilliary.compile(loss=[self.mutual_info_loss],optimizer=optimizer,metrics=['accuracy'])# 建造生成器self.generator = self.build_generator()# 生成器的输入为噪声和标签gen_input = Input(shape=(self.latent_dim,))img = self.generator(gen_input)# 固定判别器,只训练生成器self.discriminator.trainable = False# 判别器结果valid = self.discriminator(img)# 识别网络结果target_label = self.auxilliary(img)# 生成器和判别器的组合模型self.combined = Model(gen_input, [valid, target_label])self.combined.compile(loss=losses,optimizer=optimizer)

3. 损失函数

   def mutual_info_loss(self, c, c_given_x):eps = 1e-8conditional_entropy = K.mean(-K.sum(K.log(c_given_x + eps) * c, axis=1))entropy = K.mean(-K.sum(K.log(c + eps) * c, axis=1))return conditional_entropy + entropy

4. 生成器

    def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation='relu', input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding='same'))model.add(Activation('relu'))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(64,kernel_size=3, padding='same'))model.add(Activation('relu'))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(self.channels, kernel_size=3, padding='same'))model.add(Activation('tanh'))gen_input = Input(shape=(self.latent_dim,))img = model(gen_input)model.summary()return Model(gen_input, img)

5. 判别器和分类器

    def build_disk_and_q_net(self):img = Input(shape=self.img_shape)# 判别器和识别网络的共享层model = Sequential()model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))model.add(ZeroPadding2D(padding=((0, 1), (0, 1))))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(256, kernel_size=3, strides=2, padding='same'))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(512, kernel_size=3, strides=2, padding='same'))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Flatten())img_embedding = model(img)# 判别器validity = Dense(1, activation='sigmoid')(img_embedding)# 识别q_net = Dense(128, activation='relu')(img_embedding)label = Dense(self.num_classes, activation='softmax')(q_net)return Model(img, validity), Model(img, label)

6. 生成样本输入

    def sample_generator_input(self, batch_size):# 生成器输入sampled_noise = np.random.normal(0, 1, (batch_size, 62))sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)return sampled_noise, sampled_labels

7. 训练

    def train(self, epochs, batch_size=128, sample_interval=50):(X_train, y_train), (_, _) = mnist.load_data()X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# 真实数据valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):'''训练判别器'''idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# 噪声样本和分类标签sampled_noise, sampled_labels = self.sample_generator_input(batch_size)gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)# 生成新图像的一半梯度gen_imgs = self.generator.predict(gen_input)# 在真实数据和生成数据上训练d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# 平均损失d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)'''训练生成器'''g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])print("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f" % (epoch, d_loss[0], 100 * d_loss[1], g_loss[1], g_loss[2]))if epoch % 50 == 0:self.save_model()self.sample_images(epoch)

8. 显示样本

    def sample_images(self, epoch):r, c = 10, 10fig, axs = plt.subplots(r, c)for i in range(r):sample_noise, _ = self.sample_generator_input(c)label = to_categorical(np.full(fill_value=i, shape=(r, 1)), num_classes=self.num_classes)gen_input = np.concatenate((sample_noise, label), axis=1)gen_imgs = self.generator.predict(gen_input)gen_imgs = 0.5 * gen_imgs + 0.5for j in range(r):axs[j, i].imshow(gen_imgs[j, :, :, 0], cmap='gray')axs[j, i].axis('off')fig.savefig("images/%d.png" % epoch)plt.close()

9. 保存模型

    def save_model(self):def save(model, model_name):model_path = "saved_model/%s.json" % model_nameweights_path = "saved_model/%s_weights.hdf5" % model_nameoptions = {"file_arch": model_path,"file_weight": weights_path}json_string = model.to_json()open(options['file_arch'], 'w').write(json_string)model.save_weights(options['file_weight'])save(self.generator, "generator")save(self.discriminator, "discriminator")

10. 训练

if __name__ == '__main__':infogan = INFOGAN()infogan.train(epochs=50000, batch_size=128, sample_interval=50)

结果:1600代的结果。还没有训练到最好的结果。

 


http://chatgpt.dhexx.cn/article/bgsEv6UC.shtml

相关文章

InfoGAN详细介绍及特征解耦图像生成

InfoGAN详细介绍及特征解耦图像生成 一.InfoGAN框架理解特征耦合InfoGANInfoGAN论文实验结果 二.VAE-GAN框架理解VAE-GAN算法步骤 三.BiGAN框架理解四.InfoGAN论文复现使用MNIST数据集复现InfoGAN代码编写初始化判别器初始化生成器初始化分类器训练InfoGAN网络 总结参考文献及博…

InfoGAN介绍

今天给大家分享的是NIPS2016的InfoGAN。这篇paper所要达到的目标就是通过非监督学习得到可分解的特征表示。使用GAN加上最大化生成的图片和输入编码之间的互信息。最大的好处就是可以不需要监督学习,而且不需要大量额外的计算花销就能得到可解释的特征。 通常&#…

10.可视化、可视分析、探索式数据分析

前言:基于人大的《数据科学概论》第十章。主要内容为可视化的定义、可视化的意义、可视化的一般过程、科学可视化与信息可视化、可视化的原则、可视化实例、可视化的挑战和趋势、可视分析技术、探索式数据分析、可视化工具。 一、可视化的定义 可视化是数据的可视表…

国内外大数据可视化分析产品点评

一、KINETICA Kinetica 利用图像处理芯片提供计算支持,允许企业使用机器学习,商业智能分析和可视化技术更快速地分析海量数据; 点评:它的核心技术能力是类MapD的GPU数据库,功能性能较强大,非开源&#xf…

数据分析可视化之模型介绍

一 前言 “数据分析可视化”这条术语实现了成熟的科学可视化领域与较年轻的信息可视化领域的统一。 数据分析可视化:是指将大型数据集中的数据以图形图像形式表示,并利用数据分析和开发工具发现其中未知信息的处理过程。 数据可视化已经提出了许多方法&…

5个最受欢迎的大数据可视化软件

数据可视化的本质是视觉对话,数据可视化将数据分析技术与图形技术结合,清晰有效地将分析结果信息进行解读和传达。 大数据可视化是进行各种大数据分析解决的最重要组成部分之一。 一旦原始数据流被以图像形式表示时,以此做决策就变得容易多了…

软件架构-可视化

软件架构-可视化 当我们在讨论系统时,往往都会说这个系统的架构是什么样的,在你口述的同时,如果能借助某些图表,效果会更好,传统的uml建模比较复杂,目前的软件工程大家更关注效率(这里我不谈敏捷…

48 款数据可视化分析工具大集合

作者:DC君 来源:DataCastle数据城堡 本篇专门推荐48款数据可视化工具,全到你无法想象。 1、Excel 作为一个入门级工具,是快速分析数据的理想工具,也能创建供内部使用的数据图,但是Excel在颜色、线条和样式上…

爬取某小说榜单爬虫及可视化分析

爬取某小说榜单爬虫及可视化分析(仅用于学习) gitee代码链接:https://gitee.com/huang_jia_son/duoduo.git 介绍 GUI界面python爬虫数据清洗与处理pyecharts可视化展示软件架构 (1)通过tkinter制作GUI界面&#xf…

咖啡PowerBI可视化实例

目录 一、导入数据 二、构建指标 1、构架表之间的关系 ​ 2、完善表 3、构建指标 三、可视化 1、整体分析 2、省份业务 3、产品维度 4、地区维度 5、客户维度 6、价格分析 7、利润分析 8、其它分析 一、导入数据 二、构建指标 1、构架表之间的关系 软件会…

CodeScene - 软件质量可视化工具

CodeScene - 软件质量可视化工具 CodeScene https://codescene.com/ https://codescene.io/ The powerful visualization tool using Predictive Analytics to find hidden risks and social patterns in your code. 使用 Predictive Analytics 的功能强大的可视化工具&#x…

使用excel、python、tableau对招聘数据进行数据处理及可视化分析

招聘数据数据分析及可视化 数据来源前言一、观察数据删除重复值数据加工 二、利用python进行数据分析和可视化1.引入库2.读入数据观察描述统计,了解数据大概信息 3.数据预处理3.1数据清洗3.1.1 删除重复值3.1.2缺失值处理 3.2数据加工 4.数据可视化4.1城市岗位数量4…

可视化工具软件排行榜

市面上的数据可视化工具软件如此之多,有哪些可视化软件工具居于排行榜单的前列呢?你用的软件上榜了吗? 1、FineBI 来自帆软公司,虽作一个BI工具,但是可视化效果不错,可制作Dashboard。优势在于一旦准备好…

2020年六十款数据分析的可视化工具推荐

今天小编将为大家盘点六十款数据分析的可视化工具,让你妥妥的成为会议室乃至全公司最亮的崽~ 1、ChartBlocks ChartBlocks是一款网页版的可视化图表生成工具,在线使用。通过导入电子表格或者数据库来构建可视化图表。整个过程可以在图表的向导指示下完成。它的图表在HTML…

值得推荐的13款可视化软件,快收藏!

数据可视化力求用图表结合的方式把所有的数据整合在某一图像上,这样呈现在观众眼前的画面不仅仅是美观,且比以往长篇大论或是密密麻麻的数据表格更直观易懂,更便于观察分析。到今年上半年为止,国内外已经有了很多发展的较好的数据…

深入分析ArrayMap

前面我们分析了Android为了节省内存提供的一个HahMap<Integer, ?>的替代品SparseArray。SparseArray只能替代key的类型为int的Map。Android也提供了一个key不用局限于int的Map的实现&#xff0c;ArrayMap。老规矩我们通过调试来深入的分析一下ArrayMap&#xff08;看本文…

ArrayMAP介绍

它不是一个适应大数据的数据结构&#xff0c;相比传统的HashMap速度要慢&#xff0c;因为查找方法是二分法&#xff0c;并且当你删除或者添加数据时&#xff0c;会对空间重新调整&#xff0c;在使用大量数据时&#xff0c;效率并不明显&#xff0c;低于50%。 ArrayMap is a ge…

Android特别的数据结构(二)ArrayMap源码解析

1. 数据结构 public final class ArrayMap<K,V> implements Map<K,V> 由两个数组组成&#xff0c;一个int[] mHashes用来存放Key的hash值&#xff0c;一个Object[] mArrays用来连续存放成对的Key和ValuemHashes数组按非严格升序排列初始默认容量为0减容&#xff…

ArrayMap 源码的详细解析

最近在写framework层的系统服务&#xff0c;发现Android 12中用来去重注册监听的map都是用的ArrayMap&#xff0c;因此仔细研究了ArrayMap的原理。 目录 一. ArrayMap概述 二. ArrayMap源码解析 1.主要包含的成员变量 2.构造函数 3. public boolean containsKey(Object ke…

SparseArray和ArrayMap

首先我们来介绍一下HashMap&#xff0c;了解它的优缺点&#xff0c;然后再对比一下其他的数据结构以及为什么要替代它。 HashMap HashMap是由数组单向链表的方式组成的&#xff0c;初始大小是16&#xff08;2的4次方&#xff09;&#xff0c;首次put的时候&#xff0c;才会真…