生成对抗网络——CGAN

article/2025/9/24 17:09:54

1.生成模型原理

1)CGAN的原理

传统的GAN或者其他的GAN都是通过一堆的训练数据,最后训练出了G网络,随机输入噪声最后产生的数据是这些训练数据类别中之一,我们提前无法预测是那哪一个?

因此,我们有的时候需要定向指定生成某些数据,比如我们想让G生成飞机,数字9,等等的图片数据。

假设现在要做一个项目:输入一段文字,输出一张图片,要让这张图片足够清晰并且符合这段文字的描述。我们搭建一个传统的NeuralNetwork(下称NN)去训练。
在这里插入图片描述
考虑我们输入的文字是“train”,希望NN能输出清晰的火车照片,那在数据集中,下面左图是正面的火车,它们统统都是正确的火车图片;下面右图是侧面的火车,它们也统统都是正确的火车。
在这里插入图片描述在这里插入图片描述
那在训练这个NN的时候,network会觉得说,火车既要长得像左边的图片,也要长得像右边的图片,那最终network的output就会变成这一大堆images的平均,可想而知那会是一张非常模糊并且错误的照片。

我们需要引入GANs技术来保证NN产生清晰准确的照片。

我们把原始的NN叫做G(Generator),现在它吃两个输入,一个是条件word:c,另外一个是从原始图片中sample出的分布z,它的输出是一个image:x,它希望这个x尽可能地符合条件c的描述,同时足够清晰,如下图。
在这里插入图片描述
在GANs中为了保证输出image的质量会引入一个D(Discriminator),这个D用来判断输入的x是真实图片还是伪造图片,如下图。
在这里插入图片描述
但是传统GANs只能保证让x尽可能地像真实图片,它忽略了让x符合条件描述c的要求。于是,为了解决这一问题,CGAN便被提出了。

我们的目的是,既要让输出的图片真实,也要让输出的图片符合条件c的描述。Discriminator输入便被改成了同时输入c和x,输出要做两件事情,一个是判断x是否是真实图片,另一个是x和c是否是匹配的。
在这里插入图片描述
比如说,在下面这个情况中,条件c是train,图片x也是一张清晰的火车照片,那么D的输出就会是1。

在这里插入图片描述
而在下面两个情况中,左边虽然输出图片清晰,但不符合条件c;右边输出图片不真实。因此两种情况中D的输出都会是0。

在这里插入图片描述
那CGAN的基本思路就是这样,下面我们具体看一下CGAN的算法实现。
在这里插入图片描述
因为CGAN是supervised学习,采样的每一项都是文字和图片的pair。CGAN的核心就是判断什么样的pair给高分,什么样的pair给低分。

2)做法

1:就是给网络的输入噪声数据增加一些类别上的信息,就是说给定某些类别条件下,生成指定的数据,所以输入数据会有一些变化;

2:然后在损失函数那里,我们目标不再是输出1/0,也就是不再是简单的输出真实和构造。当判定是真实数据的时候,还需要判定出是哪一类别的图片。一般使用one-hot表示。
在这里插入图片描述
上图表示,改变输入噪声数据,给z增加类别y信息,怎么增加呢,就是简单的维度拼接,y可以是一个one-hot向量,或者其他表达形式(此处采用词向量来表示,详情见代码部分)。对于真实数据x不做变化,只用y来获取D的输出结果。

判别器D最后也应该输出是哪个类别,并且按照类别最小化来训练,也就是希望D(X)尽可能接近y。

2.生成模型训练参考代码

train.py

import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch import optim
import os
import numpy as np# 设置超参数
batch_size = 100
learning_rate = 0.0002
epochsize = 90
sample_dir = "images3"# 创建生成图像的目录
if not os.path.exists(sample_dir):os.makedirs(sample_dir)# 生成器结构
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_emb = nn.Embedding(10, 10)self.model = nn.Sequential(nn.Linear(110, 128),nn.LeakyReLU(0.2, inplace=True),nn.Linear(128, 256),nn.BatchNorm1d(256, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 512),nn.BatchNorm1d(512, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1024),nn.BatchNorm1d(1024, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(1024, 784),nn.Tanh())def forward(self, noise, label):out = torch.cat((noise, self.label_emb(label)), -1)img = self.model(out)     # torch.Size([64, 784])img = img.view(img.size(0), 1, 28, 28)     # torch.Size([64, 1, 32, 32])return img# 鉴别器结构
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.label_emb = nn.Embedding(10, 10)self.model = nn.Sequential(nn.Linear(794, 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1),nn.Sigmoid())def forward(self, img, label):img = img.view(img.size(0), -1)     # torch.Size([100, 784])x = torch.cat((img, self.label_emb(label)), -1)     # torch.Size([100, 794])x = self.model(x)   # torch.Size([100, 1])return x# 训练集下载
mnist_traindata = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])
]), download=False)
mnist_train = DataLoader(mnist_traindata, batch_size=batch_size, shuffle=True, pin_memory=True)# GPU加速
# device = torch.device('cuda')
# torch.cuda.set_device(0)G = Generator()
D = Discriminator()# 导入之前的训练模型
G.load_state_dict(torch.load('G_plus.ckpt'))
D.load_state_dict(torch.load('D_plus.ckpt'))# 设置优化器与损失函数,二分类的时候使用BCELoss较好,BCEWithLogitsLoss是自带一层Sigmoid
# criteon = nn.BCEWithLogitsLoss()
criteon = nn.BCELoss()
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate)# 开始训练
print("start training")
for epoch in range(epochsize):D_loss_total = 0G_loss_total = 0total_num = 0# 这里的RealImageLabel是没有用上的for batchidx, (realimage, realimage_label) in enumerate(mnist_train):# realimage = realimage.to(device)realscore = torch.ones(realimage.size(0), 1)   # value:1 torch.Size([128, 1])fakescore = torch.zeros(realimage.size(0), 1)   # value:0 torch.Size([128, 1])# 随机sample出噪声与标签,生成假图像z = torch.randn(realimage.size(0), 100)fakeimage_label = torch.LongTensor(np.random.randint(0, 10, realimage.size(0)))fakeimage = G(z, fakeimage_label)# 训练鉴别器————总的损失为两者相加d_realimage_loss = criteon(D(realimage, realimage_label), realscore)d_fakeimage_loss = criteon(D(fakeimage, fakeimage_label), fakescore)D_loss = d_realimage_loss + d_fakeimage_loss# 参数训练三个步骤D_optimizer.zero_grad()D_loss.backward()D_optimizer.step()# 计算一次epoch的总损失D_loss_total += D_loss# 训练生成器————损失只有一个# 上一次的梯度信息以消除,重新生成假图像fakeimage = G(z, fakeimage_label)G_loss = criteon(D(fakeimage, fakeimage_label), realscore)# 参数训练三个步骤G_optimizer.zero_grad()G_loss.backward()G_optimizer.step()# 计算一次epoch的总损失G_loss_total += G_loss# 打印相关的loss值if batchidx % 200 == 0:print("batchidx:{}/{}, D_loss:{}, G_loss:{}".format(batchidx, len(mnist_train), D_loss, G_loss))# 打印一次训练的loss值print('Epoch:{}/{}, D_loss:{}, G_loss:{}'.format(epoch, epochsize, D_loss_total / len(mnist_train),G_loss_total / len(mnist_train)))# 保存生成图像z = torch.randn(batch_size, 100)label = torch.LongTensor(np.array([num for _ in range(10) for num in range(10)]))save_image(G(z, label).data, os.path.join(sample_dir, 'images-{}.png'.format(epoch + 61)), nrow=10, normalize=True)# 保存网络结构torch.save(G.state_dict(), 'G_plus.ckpt')torch.save(D.state_dict(), 'D_plus.ckpt')

test.py

import torch
from torch import nn
from torchvision.utils import save_image
import os
import numpy as np# 设置超参数
batch_size = 100
# learning_rate = 0.0002
# epochsize = 80
sample_dir = "test_images"# 创建生成图像的目录
if not os.path.exists(sample_dir):os.makedirs(sample_dir)# 生成器结构
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_emb = nn.Embedding(10, 10)self.model = nn.Sequential(nn.Linear(110, 128),nn.LeakyReLU(0.2, inplace=True),nn.Linear(128, 256),nn.BatchNorm1d(256, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 512),nn.BatchNorm1d(512, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1024),nn.BatchNorm1d(1024, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Linear(1024, 784),nn.Tanh())def forward(self, noise, label):out = torch.cat((noise, self.label_emb(label)), -1)img = self.model(out)     # torch.Size([64, 784])img = img.view(img.size(0), 1, 28, 28)     # torch.Size([64, 1, 32, 32])return img# 导入训练好的模型
G = Generator()
G.load_state_dict(torch.load('G_plus.ckpt'))# 保存图像
z = torch.randn(batch_size, 100)
# label = torch.LongTensor(np.array([num for _ in range(10) for num in range(10)]))
label = torch.tensor([7,8,1,3,4,2,6,5,9,0]*10)
# label = torch.full([100], 9)# label = []
# for i in range(10):
#     for j in range(10):
#         label.append(i)
#
# label = torch.tensor(label)
print(label)
print("label.shape:", label.size())save_image(G(z, label).data, os.path.join(sample_dir, 'images.png'), nrow=10, normalize=True)

3.生成模型结果展示

由于电脑配置不行,只能用cpu跑了,跑得比较慢,下面是结果展示:
在这里插入图片描述
epoch10 生成的图像
在这里插入图片描述
epoch50 生成的图像

在这里插入图片描述
epoch200 生成的图像

但是存一个问题,CGAN只能全部条件的图像,不能生成单一条件的图像。也就是无论数字的顺序如何排列,cgan都能准确的生成出来,如图所示:
在这里插入图片描述在这里插入图片描述
但是想让其生成单一的数字,比如全部生成数字“1”,就无法正常生成图像。如图所示
在这里插入图片描述

参考资料:

  1. 李宏毅老师的b站视频
  2. https://blog.csdn.net/a312863063/article/details/83573968
  3. https://blog.csdn.net/qq_29367075/article/details/109149211

http://chatgpt.dhexx.cn/article/S30U3KuI.shtml

相关文章

基于对抗生成网络的滚动轴承故障检测方法

人工智能技术与咨询 点击蓝字 关注我们 来源:《人工智能与机器人研究》 ,作者华丰 关键词: 不平衡工业时间序列;异常检测;生成对抗网络;滚动轴承数据 关注微信公众号:人工智能技术与咨询。了解更多咨询&…

深度学习 - 生成对抗网络

目录 1 GAN产生背景 2 GAN模型 3 CGAN 4 InfoGAN 5 Improved Techniques for Training GANs 6 DCGAN -- Deep convolutional generative adversarial networks 7 GAN应用 1 GAN产生背景 1. 机器学习方法 生成方法,所学到的模型称为生成式模型 生成方法通过观测…

对抗生成网络GAN系列——f-AnoGAN原理及缺陷检测实战

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题 🍊专栏推荐:深度学习网络原理与实战 🍊近期目标:写好专栏的每一篇文章 🍊支持小苏:点赞👍🏼、…

对抗生成网络原理和作用

我们通过一个demo(gan.py )来讲解对抗生成网络的原理和作用 1、创建真实数据 2、使用GAN训练噪声数据 3、通过1200次的训练使得生成的数据的分布跟真实数据的分布差不多 4、通过debug方式一步步的讲解 二、原理: 1、G(x&…

生成对抗网络

论文阅读笔记,论文链接 Generative Adversarial Network 生成对抗网络 GAN 理解gan的原理 网络思想 在GAN网络当中,有两个网络,一个是生成网络G,另外一个是判别网络D。生成网络G的目的是生成数据,这里的数据可以是图片…

对抗生成网络GAN系列——CycleGAN简介及图片春冬变换案例

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题 🍊往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例    对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例 🍊近期目标:写…

MATLAB代码:对于对抗生成网络GAN的风光场景生成算法 关键词:场景生成 GAN 对抗生成网络 风光场景

MATLAB代码:对于对抗生成网络GAN的风光场景生成算法 关键词:场景生成 GAN 对抗生成网络 风光场景 仿真平台: pythontensorflow 主要内容:代码主要做的是基于数据驱动的风光新能源场景生成模型,具体为,通过构建了一种对…

对抗生成网络GAN系列——GAN原理及手写数字生成小案例

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题 🍊往期回顾:目标检测系列——开山之作RCNN原理详解    目标检测系列——Fast R-CNN原理详解   目标检测系列——Faster R-CNN原理详解 🍊近期目标&a…

GAN——对抗生成网络

GAN的基本思想 作为现在最火的深度学习模型之一,GAN全称对抗生成网络,顾名思义是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的。它使用两个神经网络,将一个神经网络与另一个神经网络进行对抗。 基本思想:&…

一文读懂对抗生成网络的3种模型

https://www.toutiao.com/i6635851641293636109/ 2018-12-17 14:53:28 基于对抗生成网络技术的在线工具edges2cats, 可以为简笔画涂色 前言 在GAN系列课程中分别讲解了对抗生成网络的三种模型,从Goodfellow最初提出的原始的对抗生成网络,到…

对抗生成网络(GAN)详解

目录 前言 目标函数 原理 训练 给定生成器,训练判别器 给定判别器,训练生成器 总结 前言 之前的生成模型侧重于将分布函数构造出来,然后使用最大似然函数去更新这个分布函数的参数,从而优化分布函数,但是这种方法…

对抗生成网络(GAN)简介及生成数字实战

一、简介 生成对抗网络(Generative Adversarial Netword,简称GAN),是一种生成式机器学习模型,该方法由伊恩古德费洛等人于2014年提出,曾被称为“机器学习这二十年来最酷的想法”,可以用来创造虚…

对抗生成网络(Generative Adversarial Net)

好久没有更新博客了,但似乎我每次更新博客的时候都这么说(泪)。最近对生活有了一些新的体会,工作上面,新的环境总算是适应了,知道了如何摆正工作和生活之间的关系,如何能在有效率工作的同时还能…

【PaddleOCR-det-finetune】一:基于PPOCRv3的det检测模型finetune训练

文章目录 基本流程详细步骤打标签,构建自己的数据集下载PPOCRv3训练模型修改超参数,训练自己数据集启动训练导出模型 测试 相关参考手册在PaddleOCR项目工程中的位置: det模型训练和微调:PaddleOCR\doc\doc_ch\PPOCRv3_det_train.…

模型微调(Finetune)

参考:https://zhuanlan.zhihu.com/p/35890660 ppt下载地址:https://github.com/jiangzhubo/What-is-Fine-tuning 一.什么是模型微调 给定预训练模型(Pre_trained model),基于模型进行微调(Fine Tune)。相…

fine-tuning

微调(fine-tuning) 在平时的训练中,我们通常很难拿到大量的数据,并且由于大量的数据,如果一旦有调整,重新训练网络是十分复杂的,而且参数不好调整,数量也不够,所以我们可…

大模型的三大法宝:Finetune, Prompt Engineering, Reward

编者按:基于基础通用模型构建领域或企业特有模型是目前趋势。本文简明介绍了最大化挖掘语言模型潜力的三大法宝——Finetune, Prompt Engineering和RLHF——的基本概念,并指出了大模型微调面临的工具层面的挑战。 以下是译文,Enjoy! 作者 | B…

RCNN网络源码解读(Ⅲ) --- finetune训练过程

目录 0.回顾 1.finetune二分类代码解释(finetune.py) 1.1 load_data(定义获取数据的方法) 1.2 CustomFineTuneDataset类 1.3 custom_batch_sampler类( custom_batch_sampler.py) 1.4 训练train_mod…

FinSH

finSH介绍 FinSH 是 RT-Thread 的命令行组件,提供一套供用户在命令行调用的操作接口,主要用于调试或查看系统信息。它可以使用串口 / 以太网 / USB 等与 PC 机进行通信。 命令执行过程 功能: 支持鉴权,可在系统配置中选择打开/关闭。(TODO…

从统一视角看各类高效finetune方法

每天给你送来NLP技术干货! 来自:圆圆的算法笔记 随着预训练模型参数量越来越大,迁移学习的成本越来越高,parameter-efficient tuning成为一个热点研究方向。在以前我们在下游任务使用预训练大模型,一般需要finetune模型…