深度学习《InfoGAN模型》

article/2025/3/3 20:01:05

一:网络介绍
普通的GAN网络的特点是无约束,对网络输入的噪声也不好解释,CGAN中我们通过给噪声合并一些类别数据,改变了输出形式,可以训练出我们指定类别的数据,这一点也是某种程度的解释,但是解释性不强。
InfoGAN 主要特点是对GAN进行了一些改动,成功地让网络学到了可解释的特征,网络训练完成之后,我们可以通过设定输入生成器的隐含编码来控制生成数据的特征。
InfoGAN将输入生成器的随机噪声分成了两部分:一部分是随机噪声Z, 另一部分是由若干隐变量拼接而成的latent code c。其中,c会有先验的概率分布,可以是离散数据,也可以是连续数据,用来代表生成数据的不同特征。例如:对于MNIST数据集,c既包含离散部分也包含了连续部分,离散部分取值为0~9的离散随机变量(表示数字的类别),连续部分有两个连续型随机变量(分别表示倾斜度和粗细度)。其网络结构如下图:
在这里插入图片描述

其中,真实数据Real_data只是用来跟生成的Fake_data混合在一起进行真假判断,并根据判断的结果更新生成器和判别器,从而使生成的数据与真实数据接近。生成数据既要参与真假判断,还需要和隐变量C_vector求互信息,并根据互信息更新生成器和判别器,从而使得生成图像中保留了更多隐变量C_vector的信息。
InfoGAN网络结构还可以看成是如下形式:
在这里插入图片描述

G网络相当于是encoder,Q网络相当于是decoder,整个红色框框就是一个编码器结构,生成数据Fake_data相当于对输入隐变量C_vector的编码,只不过将编码还要输出给D网络去判别。其中和关键的一点是,判别器D和Q共用所有卷积层,只是最后的全连接层不同。

二:详细分析各个网络:

G网络:除了噪声z,还需要增加latent code(有离散数据和连续数据)。
D网络:正常输入,和Q共享卷积层,输出有1维的向量,判断是fake or true,
Q网络:也就是D网络,只不过输出经过两个不同的FC层,维度和latent code维度一致。

这里直接用pytorch代码过程来分析了。

1:对D来说:
判别器D的输入为:(batch_size, channel, img_ size, img_size),判别器D的输出为:(batch_size, 1)

优化过程是:

optimizer_D.zero_grad()  # 梯度清零# Loss for real images
d_real_pred, _, _ = discriminator(real_imgs)# Loss for fake images
gen_imgs = generator(z_noise, label_input, code_input).detach()
d_fake_pred, _, _ = discriminator(gen_imgs)# Total discriminator loss
d_loss = discriminator_loss(d_real_pred, d_fake_pred)  # 判别器的 loss
d_loss.backward()
optimizer_D.step()

其中discriminator_loss是:

def discriminator_loss(logits_real, logits_fake):  # 判别器的 losssize = logits_real.shape[0]true_labels = Variable(torch.ones(size, 1)).float()  # 和1作对比size = logits_fake.shape[0]false_labels = Variable(torch.zeros(size, 1)).float()  # 和0作对比loss = validity_loss(logits_real, true_labels) + validity_loss(logits_fake, false_labels)return loss

2:对G来说:
生成器G的输入为:(batch_size, noise_dim + discrete_dim + continuous_dim),其中noise_dim为输入噪声的维度,discrete_dim为离散隐变量的维度,continuous_dim为连续隐变量的维度。生成器G的输出为(batch_size, channel, img_size, img_size)

优化过程是:

optimizer_G.zero_grad()  # 梯度清零# 假的图片去欺骗D,让D误认为是真的。
gen_imgs = generator(z_noise, label_input, code_input)
g_real_pred, _, _ = discriminator(gen_imgs)
g_loss = generator_loss(g_real_pred)  # 生成网络的 loss
g_loss.backward()
optimizer_G.step()

其中generator_loss是:

def generator_loss(logits_fake):  # 生成器的 loss size = logits_fake.shape[0]true_labels = Variable(torch.ones(size, 1)).float()  #和1作对比loss = validity_loss(logits_fake, true_labels)return loss

3:对Q来说:
判别器Q的输入为:(batch_size, channel, img_size, img_size),Q的输出为:(batch_size, discrete_dim + continuous_dim)

optimizer_Q.zero_grad()gen_imgs = generator(z_noise, label_input, code_input)
_, pred_label, pred_code = discriminator(gen_imgs)info_loss = discrete_loss(pred_label, label_input) + continuous_loss(pred_code, code_input)
info_loss.backward()
optimizer_Q.step()

其中 optimizer_Q 是:

optimizer_Q = torch.optim.Adam(itertools.chain(generator.parameters(), discriminator.parameters()), lr=opt.lr, betas=(opt.beta_1, opt.beta_2)
)  # Q 就是多出来的那两个个FC网络,D和Q共用所有卷积层,只是最后的全连接层不同。

三:完整实例
种类还是用 MNIST数据集做测试,每一步骤都是有清晰的注释说明。

import argparse
import os
import numpy as np
import math
import itertoolsimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision.datasets import MNIST# step =========================   初始化参数   ===========
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=32, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--beta_1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--beta_2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")parser.add_argument("--noise_dim", type=int, default=62, help="dimensionality of the latent space")  # 原始噪声的维度
parser.add_argument("--code_discrete_dim", type=int, default=10, help="number of classes for dataset")  # 离散变量维度,这里是使用数字的类别
parser.add_argument("--code_continuous_dim", type=int, default=2, help="latent code")  # 连续变量的维度,假定是2维parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)# step =========================   加载MNIST数据   ===========
train_set = MNIST('./data', train=True, transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]))
train_data = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True)def deprocess_img(img):out = 0.5 * (img + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return out# step =========================   定义模型   ===========
# 初始化参数的函数
def weights_init_normal(m):class_name = m.__class__.__name__if class_name.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif class_name.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()input_dim = opt.noise_dim + opt.code_continuous_dim + opt.code_discrete_dimself.init_size = opt.img_size // 4  # Initial size before upsamplingself.l1 = nn.Sequential(nn.Linear(input_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise, labels, code):z = np.concatenate((noise, labels, code), axis=1)z = Variable(torch.from_numpy(z).float())out = self.l1(z)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_channels, out_channels, bn=True):"""Returns layers of each discriminator block"""block = [nn.Conv2d(in_channels, out_channels, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_channels, 0.8))return block# 共享卷积层self.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# Output layer,最后输出的FC 层是不同的。最后一层FCself.valid_fc_layer = nn.Sequential(nn.Linear(512, 1))self.discrete_fc_layer = nn.Sequential(nn.Linear(512, opt.code_discrete_dim), nn.Softmax())self.continuous_fc_layer = nn.Sequential(nn.Linear(512, opt.code_continuous_dim))def forward(self, img):# 共享 Conv 层out = self.conv_blocks(img)out = out.view(out.shape[0], -1)# FC 层,输入都是共享 Conv 层validity_val = self.valid_fc_layer(out)  # fake image? : 0   / real image? :  1discrete_val = self.discrete_fc_layer(out)  # 离散的输出continuous_val = self.continuous_fc_layer(out)  # 连续的输出return validity_val, discrete_val, continuous_val# 实例化 generator and discriminator
generator = Generator()
discriminator = Discriminator()# 初始化各自模型的参数权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# step =========================   定义损失函数和优化器   ===========
# Loss functions
validity_loss = torch.nn.MSELoss()  # real or fake
discrete_loss = torch.nn.BCELoss()  # 离散输入的输出的损失函数
continuous_loss = torch.nn.MSELoss()  # 连续输入的输出的损失函数def discriminator_loss(logits_real, logits_fake):  # 判别器的 losssize = logits_real.shape[0]true_labels = Variable(torch.ones(size, 1)).float()size = logits_fake.shape[0]false_labels = Variable(torch.zeros(size, 1)).float()loss = (validity_loss(logits_real, true_labels) + validity_loss(logits_fake, false_labels)) / 2return lossdef generator_loss(logits_fake):  # 生成器的 losssize = logits_fake.shape[0]true_labels = Variable(torch.ones(size, 1)).float()loss = validity_loss(logits_fake, true_labels)return loss# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.beta_1, opt.beta_2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.beta_1, opt.beta_2))
optimizer_Q = torch.optim.Adam(itertools.chain(generator.parameters(), discriminator.parameters()), lr=opt.lr, betas=(opt.beta_1, opt.beta_2)
)  # Q 就是多出来的那两个个FC网络,D和Q共用所有卷积层,只是最后的全连接层不同。# step =========================   开始训练   ===========# 得到 one-hot 向量的函数
def get_onehot_vector(label, label_dim):labels_onehot = np.zeros((label.shape[0], label_dim))labels_onehot[np.arange(label.shape[0]), label.numpy()] = 1return Variable(torch.FloatTensor(labels_onehot))iter_count = 0
show_every = 50# those is for test
os.makedirs("D:/software/Anaconda3/doc/3D_Img/inforgan/", exist_ok=True)
batch_size = 10
test_z_noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (batch_size, opt.noise_dim))))
test_label_input = get_onehot_vector(torch.from_numpy(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])), opt.code_discrete_dim)
test_code_input = Variable(torch.FloatTensor(np.zeros((batch_size, opt.code_continuous_dim))))for epoch in range(opt.n_epochs):for i, (real_imgs, labels) in enumerate(train_data):# ---------------------------------------------------------------#  prepare data# ---------------------------------------------------------------batch_size = real_imgs.shape[0]  # 获取 batch_size# 生成随机噪声数据,正态分布随机采样z_noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (batch_size, opt.noise_dim))))# 得到当前离散数据,用数字的类别作为离散数据输入label_input = get_onehot_vector(labels, opt.code_discrete_dim)# 离散数据输入,均值采样code_input = Variable(torch.FloatTensor(np.random.uniform(-1, 1, (batch_size, opt.code_continuous_dim))))# ---------------------------------------------------------------#  Train Discriminator# ---------------------------------------------------------------optimizer_D.zero_grad()  # 梯度清零# Loss for real imagesd_real_pred, _, _ = discriminator(real_imgs)# Loss for fake imagesgen_imgs = generator(z_noise, label_input, code_input).detach()d_fake_pred, _, _ = discriminator(gen_imgs)# Total discriminator lossd_loss = discriminator_loss(d_real_pred, d_fake_pred)  # 判别器的 lossd_loss.backward()optimizer_D.step()if i % 2 == 0 :# ---------------------------------------------------------------#  Train Generator# ---------------------------------------------------------------optimizer_G.zero_grad()  # 梯度清零# 假的图片去欺骗D,让D误认为是真的。gen_imgs = generator(z_noise, label_input, code_input)g_real_pred, _, _ = discriminator(gen_imgs)g_loss = generator_loss(g_real_pred)  # 生成网络的 lossg_loss.backward()optimizer_G.step()# ---------------------------------------------------------------# Information Loss# ---------------------------------------------------------------optimizer_Q.zero_grad()gen_imgs = generator(z_noise, label_input, code_input)_, pred_label, pred_code = discriminator(gen_imgs)info_loss = discrete_loss(pred_label, label_input) + 0.2 * continuous_loss(pred_code, code_input)info_loss.backward()optimizer_Q.step()# ---------------------------------------------------------------# test to output some images.# To do another procession.# ---------------------------------------------------------------print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [info loss: %f]"% (epoch, opt.n_epochs, i, len(train_data), d_loss.item(), g_loss.item(), info_loss.item()))if (iter_count % show_every == 0):fake_img = generator(test_z_noise, test_label_input, test_code_input)  # 将向量放入生成网络G生成一张图片#real_images = deprocess_img(fake_img.data)save_image(fake_img.data, 'D:/software/Anaconda3/doc/3D_Img/inforgan/test_%d.png' % (iter_count))iter_count += 1

http://chatgpt.dhexx.cn/article/QDh2FVZN.shtml

相关文章

InfoGAN原理PyTorch实现Debug记录

1. CGAN从无监督GAN改进成有监督的GAN GAN的基本原理输入是随机噪声,无法控制输出和输入之间的对应关系,也无法控制输出的模式,CGAN全称是条件GAN(Conditional GAN)改进基本的GAN解决了这个问题,CGAN和基本…

CGAN和InfoGAN理解

在一些比较经典的GAN模型中(像WGAN、LSGAN、DCGAN等等),往往都是从样本空间里随机采样得到输入噪声,生成的图像究竟属于哪一个类别也是随机的。通过这些模型,我们无法生成指定类别的数据。 举个不恰当的例子&#xff…

InfoGAN

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, Pieter Abbeel https://arxiv.org/abs/1606.03657 一、从GAN到InfoGAN 1.GAN存在的问题 GAN通…

InfoGAN 生成时序序列

InfoGAN 生成时序序列 简介 完整代码:https://github.com/SongDark/timeseries_infogan 本文介绍用InfoGAN生成多维时序序列。 数据 数据集下载地址 NameClassDimensionTrain SizeTest SizeTruncatedCharacterTrajectories20314221436182 样本介绍 CharacterTr…

InfoGAN详解与实现(采用tensorflow2.x实现)

InfoGAN详解与实现(采用tensorflow2.x实现) InfoGAN原理InfoGAN实现导入必要库生成器鉴别器模型构建模型训练效果展示 InfoGAN原理 最初的GAN能够产生有意义的输出,但是缺点是它的属性无法控制。例如,无法明确向生成器提出生成女…

InfoGAN论文笔记+源码解析

论文地址:InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets 源码地址:InfoGAN in TensorFlow GAN,Generative Adversarial Network是目前非常火也是非常有潜力的一个发展方向&#…

InfoGAN(基于信息最大化生成对抗网的可解释表征学习)

前言: 这篇博客为阅读论文后的总结与感受,方便日后翻阅、查缺补漏,侵删! 论文: InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets 解决的问题: In…

InfoGAN学习笔记

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, Pieter Abbeel https://arxiv.org/abs/1606.03657 一、从GAN到InfoGAN 1.GAN存在的问题 GAN…

生成对抗网络(十)----------infoGAN

一、infoGAN介绍 infoGAN采用的是无监督式学习的方式并尝试实现可解释特征。原始数据不包含任何标签信息,所有的特征都是通过网络以一种非监督的方式自动学习得到的。使用了信息论的原理,通过最大化输入噪声和观察值之间的互信息来对网络模型进行优化。…

InfoGAN详细介绍及特征解耦图像生成

InfoGAN详细介绍及特征解耦图像生成 一.InfoGAN框架理解特征耦合InfoGANInfoGAN论文实验结果 二.VAE-GAN框架理解VAE-GAN算法步骤 三.BiGAN框架理解四.InfoGAN论文复现使用MNIST数据集复现InfoGAN代码编写初始化判别器初始化生成器初始化分类器训练InfoGAN网络 总结参考文献及博…

InfoGAN介绍

今天给大家分享的是NIPS2016的InfoGAN。这篇paper所要达到的目标就是通过非监督学习得到可分解的特征表示。使用GAN加上最大化生成的图片和输入编码之间的互信息。最大的好处就是可以不需要监督学习,而且不需要大量额外的计算花销就能得到可解释的特征。 通常&#…

10.可视化、可视分析、探索式数据分析

前言:基于人大的《数据科学概论》第十章。主要内容为可视化的定义、可视化的意义、可视化的一般过程、科学可视化与信息可视化、可视化的原则、可视化实例、可视化的挑战和趋势、可视分析技术、探索式数据分析、可视化工具。 一、可视化的定义 可视化是数据的可视表…

国内外大数据可视化分析产品点评

一、KINETICA Kinetica 利用图像处理芯片提供计算支持,允许企业使用机器学习,商业智能分析和可视化技术更快速地分析海量数据; 点评:它的核心技术能力是类MapD的GPU数据库,功能性能较强大,非开源&#xf…

数据分析可视化之模型介绍

一 前言 “数据分析可视化”这条术语实现了成熟的科学可视化领域与较年轻的信息可视化领域的统一。 数据分析可视化:是指将大型数据集中的数据以图形图像形式表示,并利用数据分析和开发工具发现其中未知信息的处理过程。 数据可视化已经提出了许多方法&…

5个最受欢迎的大数据可视化软件

数据可视化的本质是视觉对话,数据可视化将数据分析技术与图形技术结合,清晰有效地将分析结果信息进行解读和传达。 大数据可视化是进行各种大数据分析解决的最重要组成部分之一。 一旦原始数据流被以图像形式表示时,以此做决策就变得容易多了…

软件架构-可视化

软件架构-可视化 当我们在讨论系统时,往往都会说这个系统的架构是什么样的,在你口述的同时,如果能借助某些图表,效果会更好,传统的uml建模比较复杂,目前的软件工程大家更关注效率(这里我不谈敏捷…

48 款数据可视化分析工具大集合

作者:DC君 来源:DataCastle数据城堡 本篇专门推荐48款数据可视化工具,全到你无法想象。 1、Excel 作为一个入门级工具,是快速分析数据的理想工具,也能创建供内部使用的数据图,但是Excel在颜色、线条和样式上…

爬取某小说榜单爬虫及可视化分析

爬取某小说榜单爬虫及可视化分析(仅用于学习) gitee代码链接:https://gitee.com/huang_jia_son/duoduo.git 介绍 GUI界面python爬虫数据清洗与处理pyecharts可视化展示软件架构 (1)通过tkinter制作GUI界面&#xf…

咖啡PowerBI可视化实例

目录 一、导入数据 二、构建指标 1、构架表之间的关系 ​ 2、完善表 3、构建指标 三、可视化 1、整体分析 2、省份业务 3、产品维度 4、地区维度 5、客户维度 6、价格分析 7、利润分析 8、其它分析 一、导入数据 二、构建指标 1、构架表之间的关系 软件会…