深度学习《CGAN模型》

article/2025/11/5 18:35:26

一:介绍
CGAN全程是Conditional Generative Adversarial Network,回想一下,传统的GAN或者其他的GAN都是通过一堆的训练数据,最后训练出了G网络,随机输入噪声最后产生的数据是这些训练数据类别中之一,我们提前无法预测是那哪一个?

因此,我们有的时候需要定向指定生成某些数据,比如我们想让G生成飞机,数字9,等等的图片数据。

怎么做呢:
1:就是给网络的输入噪声数据增加一些类别上的信息,就是说给定某些类别条件下,生成指定的数据,所以输入数据会有一些变化;

2:然后在损失函数那里,我们目标不再是输出1/0,也就是不再是简单的输出真实和构造。当判定是真实数据的时候,还需要判定出是哪一类别的图片。一般使用one-hot表示。
在这里插入图片描述

上图表示,改变输入噪声数据,给z增加类别y信息,怎么增加呢,就是简单的维度拼接,y可以是一个one-hot向量,或者其他表达形式。对于真实数据x不做变化,只用y来获取D的输出结果。

判别器D最后也应该输出是哪个类别,并且按照类别最小化来训练,也就是希望D(X)尽可能接近y。

二:实例操作
拿MNIST数据练手

网络的结构什么的都没有改变,唯一变化的就是,生成的噪声z拼接上了数据的类别标签,D的输出是数据的类别的one-hot向量,而不仅仅是0/1.
详细代码如下:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import pickle
import copyimport matplotlib.gridspec as gridspec
from torchvision.utils import save_image
import os# 定义展示图片的函数
def show_images(images):  # 定义画图工具print('images: ', images.shape)images = np.reshape(images, [images.shape[0], -1])sqrtn = int(np.ceil(np.sqrt(images.shape[0])))sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))fig = plt.figure(figsize=(sqrtn, sqrtn))gs = gridspec.GridSpec(sqrtn, sqrtn)gs.update(wspace=0.05, hspace=0.05)for i, img in enumerate(images):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(img.reshape([sqrtimg, sqrtimg]))returndef deprocess_img(img):out = 0.5 * (img + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return out# step 1: ===========================================加载数据
batch_size = 128
noise_dim = 100  # 噪声维度,还是选择100维度
label_dim = 10  # 标签维度,10个数字,10个维度
z_dimension = noise_dim + label_dim  # z dimension = 100 noise dim + 10 one-hot dimtransform_img = transforms.Compose([transforms.ToTensor()])
trainset = MNIST('./data', train=True, transform=transform_img, download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)# step 2: ===========================================定义模型
class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Conv2d(1, 32, 5, stride=1, padding=2),nn.LeakyReLU(0.2, True),nn.MaxPool2d((2, 2)),nn.Conv2d(32, 64, 5, stride=1, padding=2),nn.LeakyReLU(0.2, True),nn.MaxPool2d((2, 2)))self.fc = nn.Sequential(nn.Linear(7 * 7 * 64, 1024),nn.LeakyReLU(0.2, True),nn.Linear(1024, 10),nn.Sigmoid())def forward(self, x):  # x: [batch_size, 1, 28, 28]x = self.dis(x)x = x.view(x.size(0), -1)x = self.fc(x)return x  # [batch_size, 10]class generator(nn.Module):def __init__(self, input_size, num_feature):super(generator, self).__init__()self.fc = nn.Linear(input_size, num_feature)  # 1*56*56self.gen = nn.Sequential(nn.BatchNorm2d(1),nn.ReLU(True),nn.Conv2d(1, 50, 3, stride=1, padding=1),nn.BatchNorm2d(50),nn.ReLU(True),nn.Conv2d(50, 25, 3, stride=1, padding=1),nn.BatchNorm2d(25),nn.ReLU(True),nn.Conv2d(25, 1, 2, stride=2),nn.Tanh())def forward(self, x):  # x: [batch_size, 110]x = self.fc(x)x = x.view(x.size(0), 1, 56, 56)x = self.gen(x)return x  # [batch_size, 1, 28, 28]# 实例化模型
D_Net = discriminator()
G_Net = generator(z_dimension, 3136)  # 1*56*56# step 3: ===========================================定义优化器和损失函数
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D_Net.parameters(), lr=0.0003)
g_optimizer = optim.Adam(G_Net.parameters(), lr=0.0003)# step 4: ===========================================开始训练
if __name__ == "__main__":iter_count = 0show_every = 100epoch = 100gepoch = 1for i in range(epoch):for (img, label) in trainloader:img = Variable(img)print(img.shape)# 生成 lable 的 one-hot 向量,且设置对应类别位置是 1labels_onehot = np.zeros((img.shape[0], label_dim))labels_onehot[np.arange(img.shape[0]), label.numpy()] = 1# 生成随机向量,也就是噪声z,带有标签信息z = Variable(torch.randn(img.shape[0], noise_dim))z = np.concatenate((z.numpy(), labels_onehot), axis=1)z = Variable(torch.from_numpy(z).float())# 真实数据标签和虚假数据标签,real_label = Variable(torch.from_numpy(labels_onehot).float())  # 真实label对应类别是为1fake_label = Variable(torch.zeros(img.shape[0], label_dim))  # 假的label全是为0# compute loss of real_imgreal_out = D_Net(img)  # 真实图片送入判别器D输出0~1d_loss_real = criterion(real_out, real_label)  # 得到loss# compute loss of fake_imgfake_img = G_Net(z)  # 将向量放入生成网络G生成一张图片fake_out = D_Net(fake_img)  # 判别器判断假的图片d_loss_fake = criterion(fake_out, fake_label)  # 假的图片的loss# D bp and optimized_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()  # 判别器D的梯度归零d_loss.backward()  # 反向传播d_optimizer.step()  # 更新判别器D参数# 生成器G的训练compute loss of fake_imgfor j in range(gepoch):fake_img = G_Net(z)  # 将向量放入生成网络G生成一张图片output = D_Net(fake_img)  # 经过判别器得到结果g_loss = criterion(output, real_label)  # 得到假的图片与真实标签的loss# bp and optimizeg_optimizer.zero_grad()  # 生成器G的梯度归零g_loss.backward()  # 反向传播g_optimizer.step()  # 更新生成器G参数print("G")# 利用模型进行测试,指定按照顺序生成0~9的数字if (iter_count % show_every == 0):test_batch_size = 10test_label = torch.from_numpy(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))labels_onehot = np.zeros((test_batch_size, label_dim))labels_onehot[np.arange(test_batch_size), test_label.numpy()] = 1# 生成随机向量,也就是噪声z,带有标签信息test_z = Variable(torch.randn(test_batch_size, noise_dim))test_z = np.concatenate((test_z.numpy(), labels_onehot), axis=1)test_z = Variable(torch.from_numpy(test_z).float())fake_img = G_Net(test_z)  # 将向量放入生成网络G生成一张图片# imgs_numpy = deprocess_img(fake_img.data.cpu().numpy())# show_images(imgs_numpy)# plt.show()real_images = deprocess_img(fake_img.data)save_image(real_images, 'D:/software/Anaconda3/doc/3D_Img/cgan/test_%d.png' % (iter_count))iter_count += 1print('iter_count: ', iter_count)

最后按照顺序生成0~9的图像效果还是很不错的。
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述


http://chatgpt.dhexx.cn/article/elitHqLA.shtml

相关文章

【pytorch】基于mnist数据集的cgan手写数字生成实现

(左边是数据集中的真图,右边是生成器生成的假图) 文章目录 0. 特别提示1. 学习目标2. 环境配置2.1. Python2.2. Pytorch2.3. Jupyter notebook2.4. Matplotlib 3. 具体实现3.1. 导入模块3.2. 设置随机种子3.3. 超参数配置3.4. 数据集3.5. 数据…

TensorFlow实现CGAN

条件GAN就是在GAN的基础上加入了一个条件y,在生成器和判别器中加入条件参与训练,这样训练出来的模型可以根据设置的条件生成想到的图,一般条件可以为label。CGAN的论文为:《Conditional Generative Adversarial Nets》。CGAN的结构…

【Keras-CGAN】MNIST / CIFAR-10

本博客是 One Day One GAN [DAY 3] 的 learning notes!用 CGAN 来做 MNIST 图片的生成! 参考 【Keras-MLP-GAN】MNIST 文章目录 1 CGAN(Conditional Generative Adversarial Nets)2 CGAN for MNIST2.1 导入必要的库2.2 搭建 gene…

CGAN及代码实现

前言 本文主要介绍CGAN及其代码实现阅读本文之前,建议先阅读GAN(生成对抗网络)本文基于一次课程实验,代码仅上传了需要补充部分 CGAN 全称: C o n d i t i o n a l G e n e r a t i v e A d v e r s a r i a l N e t w o r k Conditional …

生成对抗网络(二)CGAN

一、简介 之前介绍了生成式对抗网络(GAN),关于GAN的变种比较多,我打算将几种常见的GAN做一个总结,也算是激励自己学习,分享自己的一些看法和见解。 之前提到的GAN是最基本的模型,我们的输入是随机噪声,输出…

读CGAN文章

提出CGAN是因为非条件的生成模型中,对生成的内容控制,实际上只要保证真实性就可以了;采用CGAN的话,我们会增加一些额外的信息去控制数据生成的过程,例如一些类别的标签,例如数字图片数据集中,可…

CGAN论文解读:Conditional Generative Adversarial Nets

论文链接:Conditional Generative Adversarial Nets 代码解读:Keras-CGAN_MNIST 代码解读 目录 一、前言 二、相关工作 三、网络结构 CGAN NETS 四、实验结果 4.1 单模态 (mnist实验) 4.2 多模态(自动为图片打…

第三章 CGAN

写在前面:最近看了《GAN实战》,由于本人忘性大,所以仅是笔记而已,方便回忆,如果能帮助大家就更好了。 目录 代价函数 训练过程 生成器和鉴别器 混淆矩阵 CGAN生成手写数字 导入声明 模型输入维度 生成器 鉴别…

【pytorch】CGAN编程实现

CGAN介绍 由于原始GAN生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标类别不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于…

GAN,CGAN,DCGAN

GAN对抗生成网络 训练流程 图片以及训练过程来源 训练这样的两个模型的大方法就是单独交替迭代训练。 我们人为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,我们就已经默认真样本集所有的类标签都为1&#xf…

GAN论文阅读——CGAN

论文标题:Conditional Generative Adversarial Nets 论文链接:https://arxiv.org/pdf/1411.1784.pdf 参考资料:http://blog.csdn.net/solomon1558/article/details/52555083 一、CGAN的思想 在原始GAN学习笔记中,我们提到过&am…

PyTorch随笔 - 生成对抗网络的改进cGAN和LSGAN

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/129939225 本文介绍GAN的两个常见改进,cGAN和LSGAN,两者一般结合使用。 cGAN: Conditional Generative Adversa…

CGAN实现过程

本文目录 一、原理二、参数初始化1. G的输入2. D的输入3. 模型参数初始化4. 测试噪声 三、执行过程四、测试 本文用MNIST数据集进行训练,并用图解的方法展示了CGAN与GAN中输入的区别,帮助理解CGAN的运行过程 一、原理 如下图所示,我们在输入…

CGAN

CGAN 生成符合特定描述的输出, 如图:给定描述,生成相应内容图片 CGAN中的negetive情况包含两种,(正确的描述,不清晰的图片)和(不正确的描述,清晰的图片)&a…

CGAN 简介与代码实战

1.介绍 原始GAN(GAN 简介与代码实战_天竺街潜水的八角的博客-CSDN博客)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约…

CGAN原理分析

1、CGAN原理分析 1.1 网络结构 CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息y,实现条件生成模型。CGAN原文中作者…

CGAN理论讲解及代码实现

目录 1.原始GAN的缺点 2.CGAN中心思想 3.原始GAN和CGAN的区别 4.CGAN代码实现 5.运行结果 6.CGAN缺陷 1.原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。 针对原始…

GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

一、原始GAN的缺点 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强。针对原始GAN不能生成具有特定属性的图片的问题, Mehdi Mirza等人提出了cGAN,其核心在于将属性信…

解决关于Navicat破解安装过程中出现“rsa public key not find”

解决关于Navicat破解安装过程中出现“rsa public key not find” 问题描述解决办法 问题描述 出现“rsa public key not find”的输出框 解决办法 首先先安装Navicat。安装后先不要运行 打开 然后点击 如果出现 则是对的 如果出现这个: 那就请你找到在本地的na…

navicat安装与激活

原文网址:https://www.jianshu.com/p/5f693b4c9468?mTypeGroup 一、Navicat Premium 12下载 Navicat Premium 12是一套数据库开发管理工具,支持连接 MySQL、Oracle等多种数据库,可以快速轻松地创建、管理和维护数据库。 Navicat Premium 12简…