一、infoGAN介绍
infoGAN采用的是无监督式学习的方式并尝试实现可解释特征。原始数据不包含任何标签信息,所有的特征都是通过网络以一种非监督的方式自动学习得到的。使用了信息论的原理,通过最大化输入噪声和观察值之间的互信息来对网络模型进行优化。inforGAN能够适应于各类复杂的数据集,可以同时实现离散特征和连续特征,较传统的GAN训练时间更短。
infoGAN在输入端把随机输入分为两部分:第一部分为z,代表随机噪声;第二部分为c,代表隐含编码,目标是希望在每个维度上都具备可解释特征。infoGAN与之前的GAN区别在于真实训练数据并不带有标签信息,而输入数据为隐含编码和随机噪声的组合,最后通过判别器一端和最大化互信息的方式还原隐含编码的信息。判别器D最终同时具备还原隐含编码和辨别真伪的能力。前者是为了生成图像能够很好地具备编码中的特性,也就是说隐含编码可以对生成网络产生相对显著的效果。后者是要求生成模型在还原信息的同时保证生成的数据与真实数据非常逼近。关于infoGAN的具体理论网上有很多。我觉得如果有人想研究。最重要的是看论文,结合一些博客的介绍最终弄明白这个模型。而我学习这些模型的目的是为了了解。我了解了模型大概的思想与做法。这就是我的目的。如果有一天我需要用到infoGAN。我会在去找论文。仔细研究理论。
二、infoGAN代码实现
1. 导包
from __future__ import print_function, divisionfrom keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categoricalimport keras.backend as K
import matplotlib.pyplot as plt
import numpy as npimport tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
2. 初始化
class INFOGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.num_classes = 10self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 72optimizer = Adam(0.0002, 0.5)losses = ['binary_crossentropy', self.mutual_info_loss]# 建造判别器和识别网络self.discriminator, self.auxilliary = self.build_disk_and_q_net()# 编译判别网络self.discriminator.compile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# 编译识别网络self.auxilliary.compile(loss=[self.mutual_info_loss],optimizer=optimizer,metrics=['accuracy'])# 建造生成器self.generator = self.build_generator()# 生成器的输入为噪声和标签gen_input = Input(shape=(self.latent_dim,))img = self.generator(gen_input)# 固定判别器,只训练生成器self.discriminator.trainable = False# 判别器结果valid = self.discriminator(img)# 识别网络结果target_label = self.auxilliary(img)# 生成器和判别器的组合模型self.combined = Model(gen_input, [valid, target_label])self.combined.compile(loss=losses,optimizer=optimizer)
3. 损失函数
def mutual_info_loss(self, c, c_given_x):eps = 1e-8conditional_entropy = K.mean(-K.sum(K.log(c_given_x + eps) * c, axis=1))entropy = K.mean(-K.sum(K.log(c + eps) * c, axis=1))return conditional_entropy + entropy
4. 生成器
def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation='relu', input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding='same'))model.add(Activation('relu'))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(64,kernel_size=3, padding='same'))model.add(Activation('relu'))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(self.channels, kernel_size=3, padding='same'))model.add(Activation('tanh'))gen_input = Input(shape=(self.latent_dim,))img = model(gen_input)model.summary()return Model(gen_input, img)
5. 判别器和分类器
def build_disk_and_q_net(self):img = Input(shape=self.img_shape)# 判别器和识别网络的共享层model = Sequential()model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))model.add(ZeroPadding2D(padding=((0, 1), (0, 1))))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(256, kernel_size=3, strides=2, padding='same'))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(512, kernel_size=3, strides=2, padding='same'))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Flatten())img_embedding = model(img)# 判别器validity = Dense(1, activation='sigmoid')(img_embedding)# 识别q_net = Dense(128, activation='relu')(img_embedding)label = Dense(self.num_classes, activation='softmax')(q_net)return Model(img, validity), Model(img, label)
6. 生成样本输入
def sample_generator_input(self, batch_size):# 生成器输入sampled_noise = np.random.normal(0, 1, (batch_size, 62))sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)return sampled_noise, sampled_labels
7. 训练
def train(self, epochs, batch_size=128, sample_interval=50):(X_train, y_train), (_, _) = mnist.load_data()X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# 真实数据valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):'''训练判别器'''idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# 噪声样本和分类标签sampled_noise, sampled_labels = self.sample_generator_input(batch_size)gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)# 生成新图像的一半梯度gen_imgs = self.generator.predict(gen_input)# 在真实数据和生成数据上训练d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# 平均损失d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)'''训练生成器'''g_loss = self.combined.train_on_batch(gen_input, [valid, sampled_labels])print("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f" % (epoch, d_loss[0], 100 * d_loss[1], g_loss[1], g_loss[2]))if epoch % 50 == 0:self.save_model()self.sample_images(epoch)
8. 显示样本
def sample_images(self, epoch):r, c = 10, 10fig, axs = plt.subplots(r, c)for i in range(r):sample_noise, _ = self.sample_generator_input(c)label = to_categorical(np.full(fill_value=i, shape=(r, 1)), num_classes=self.num_classes)gen_input = np.concatenate((sample_noise, label), axis=1)gen_imgs = self.generator.predict(gen_input)gen_imgs = 0.5 * gen_imgs + 0.5for j in range(r):axs[j, i].imshow(gen_imgs[j, :, :, 0], cmap='gray')axs[j, i].axis('off')fig.savefig("images/%d.png" % epoch)plt.close()
9. 保存模型
def save_model(self):def save(model, model_name):model_path = "saved_model/%s.json" % model_nameweights_path = "saved_model/%s_weights.hdf5" % model_nameoptions = {"file_arch": model_path,"file_weight": weights_path}json_string = model.to_json()open(options['file_arch'], 'w').write(json_string)model.save_weights(options['file_weight'])save(self.generator, "generator")save(self.discriminator, "discriminator")
10. 训练
if __name__ == '__main__':infogan = INFOGAN()infogan.train(epochs=50000, batch_size=128, sample_interval=50)
结果:1600代的结果。还没有训练到最好的结果。