对抗生成网络原理和作用

article/2025/9/24 17:54:38

我们通过一个demo(gan.py )来讲解对抗生成网络的原理和作用

1、创建真实数据
2、使用GAN训练噪声数据
3、通过1200次的训练使得生成的数据的分布跟真实数据的分布差不多
4、通过debug方式一步步的讲解

二、原理:

1、G(x)是生成的数据,放到判别D网络中,希望D网络输出 0;x是真实的输入,希望D网络输出 1
这里写图片描述
2、x输入G网络通过一系列的参数生成G(x)
这里写图片描述
3、对于D网络希望他的判别标准要高些,这样生成的数据才更能接近真实数据,这就需要D_pre网络进行预先的判断
这里写图片描述

三、代码实现的主要步骤:

1、构造判别网络模型 3–14
2、构造生成网络模型 15–32
3、构造损失函数 33–35
4、训练对抗生成网络

import argparse #1、参数解析的包
import numpy as np
from scipy.stats import norm
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib import animation
import seaborn as sns #2、可视化的库sns.set(color_codes=True)  seed = 42
np.random.seed(seed)
tf.set_random_seed(seed)class DataDistribution(object):def __init__(self):self.mu = 4self.sigma = 0.5#44、def sample(self, N):samples = np.random.normal(self.mu, self.sigma, N)samples.sort()return samples#6、随机初始化分布,作为噪音点
class GeneratorDistribution(object):def __init__(self, range):self.range = rangedef sample(self, N):return np.linspace(-self.range, self.range, N) + \np.random.random(N) * 0.01#16、
def linear(input, output_dim, scope=None, stddev=1.0):#17、定义一个随机的初始化norm = tf.random_normal_initializer(stddev=stddev)#18、初始化常量为0const = tf.constant_initializer(0.0)with tf.variable_scope(scope or 'linear'):#19、w进行高斯处理话w = tf.get_variable('w', [input.get_shape()[1], output_dim], initializer=norm)#20、b进行常量初始化b = tf.get_variable('b', [output_dim], initializer=const)return tf.matmul(input, w) + b#29、生成网络只要两层就可以产生最终的输出结果
def generator(input, h_dim):h0 = tf.nn.softplus(linear(input, h_dim, 'g0'))h1 = linear(h0, 1, 'g1')return h1# 15、h0~h3 是分层的
def discriminator(input, h_dim):#h0是第一层的输出,h_dim * 2 隐层的数据h0 = tf.tanh(linear(input, h_dim * 2, 'd0'))h1 = tf.tanh(linear(h0, h_dim * 2, 'd1'))   h2 = tf.tanh(linear(h1, h_dim * 2, scope='d2'))#21、h3我们网络最总的输出结果h3 = tf.sigmoid(linear(h2, 1, scope='d3'))return h3
#24、优化器,学习率不断衰减的策略
def optimizer(loss, var_list, initial_learning_rate):decay = 0.95num_decay_steps = 150batch = tf.Variable(0)#25、学习率不断衰减的学习方式learning_rate = tf.train.exponential_decay(initial_learning_rate,batch,num_decay_steps,decay,staircase=True)#26、通过梯度下降定义求解器optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=batch,var_list=var_list)return optimizerclass GAN(object):#9、def __init__(self, data, gen, num_steps, batch_size, log_every):self.data = dataself.gen = genself.num_steps = num_stepsself.batch_size = batch_sizeself.log_every = log_everyself.mlp_hidden_size = 4self.learning_rate = 0.03#10、self._create_model()def _create_model(self):#11、构建D网络的骨架with tf.variable_scope('D_pre'):#12、输入,注意shape的参数self.pre_input = tf.placeholder(tf.float32, shape=(self.batch_size, 1))#13、labelself.pre_labels = tf.placeholder(tf.float32, shape=(self.batch_size, 1))#14、初始化操作D_pre = discriminator(self.pre_input, self.mlp_hidden_size)#22、预测值与真实值的差异D_pre和pre_labels的差异self.pre_loss = tf.reduce_mean(tf.square(D_pre - self.pre_labels))#23、self.pre_opt = optimizer(self.pre_loss, None, self.learning_rate)# This defines the generator network - it takes samples from a noise# distribution as input, and passes them through an MLP.with tf.variable_scope('Gen'):#27、噪音的输入self.z = tf.placeholder(tf.float32, shape=(self.batch_size, 1))#28、G网络用于数据的生成self.G = generator(self.z, self.mlp_hidden_size)# The discriminator tries to tell the difference between samples from the# true data distribution (self.x) and the generated samples (self.z).## Here we create two copies of the discriminator network (that share parameters),# as you cannot use the same network with different inputs in TensorFlow.with tf.variable_scope('Disc') as scope:#30、D网络用户判别功能self.x = tf.placeholder(tf.float32, shape=(self.batch_size, 1))#31、self.x 是真实的数据self.D1 = discriminator(self.x, self.mlp_hidden_size)scope.reuse_variables()#32、self.G是生成的数据self.D2 = discriminator(self.G, self.mlp_hidden_size)# Define the loss for discriminator and generator networks (see the original# paper for details), and create optimizers for both#33、判别网络的损失函数,希望D1趋近于1,希望D2趋近于0self.loss_d = tf.reduce_mean(-tf.log(self.D1) - tf.log(1 - self.D2))#34、生成网络(希望骗过判别网络)的损失函数,希望loss_g趋近于1self.loss_g = tf.reduce_mean(-tf.log(self.D2))self.d_pre_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D_pre')self.d_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Disc')self.g_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Gen')#35、通过优化器不断地优化loss_d和loss_gself.opt_d = optimizer(self.loss_d, self.d_params, self.learning_rate)self.opt_g = optimizer(self.loss_g, self.g_params, self.learning_rate)#36、开始训练def train(self):with tf.Session() as session:tf.global_variables_initializer().run()# pretraining discriminatornum_pretrain_steps = 1000#37、先训练D-profor step in range(num_pretrain_steps):#38、d = (np.random.random(self.batch_size) - 0.5) * 10.0#39、labels = norm.pdf(d, loc=self.data.mu, scale=self.data.sigma)#40、迭代pretrain_loss, _ = session.run([self.pre_loss, self.pre_opt], {self.pre_input: np.reshape(d, (self.batch_size, 1)),self.pre_labels: np.reshape(labels, (self.batch_size, 1))})#41、self.weightsD = session.run(self.d_pre_params)# 42、copy weights from pre-training over to new D networkfor i, v in enumerate(self.d_params):session.run(v.assign(self.weightsD[i]))for step in range(self.num_steps):# 43、update discriminatorx = self.data.sample(self.batch_size)z = self.gen.sample(self.batch_size)loss_d, _ = session.run([self.loss_d, self.opt_d], {self.x: np.reshape(x, (self.batch_size, 1)),self.z: np.reshape(z, (self.batch_size, 1))})# 45、迭代优化两个网络 update generatorz = self.gen.sample(self.batch_size)loss_g, _ = session.run([self.loss_g, self.opt_g], {self.z: np.reshape(z, (self.batch_size, 1))})if step % self.log_every == 0:print('{}: {}\t{}'.format(step, loss_d, loss_g))                if step % 100 == 0 or step==0 or step == self.num_steps -1 :self._plot_distributions(session)def _samples(self, session, num_points=10000, num_bins=100):xs = np.linspace(-self.gen.range, self.gen.range, num_points)bins = np.linspace(-self.gen.range, self.gen.range, num_bins)# data distributiond = self.data.sample(num_points)pd, _ = np.histogram(d, bins=bins, density=True)# generated sampleszs = np.linspace(-self.gen.range, self.gen.range, num_points)g = np.zeros((num_points, 1))for i in range(num_points // self.batch_size):g[self.batch_size * i:self.batch_size * (i + 1)] = session.run(self.G, {self.z: np.reshape(zs[self.batch_size * i:self.batch_size * (i + 1)],(self.batch_size, 1))})pg, _ = np.histogram(g, bins=bins, density=True)return pd, pgdef _plot_distributions(self, session):pd, pg = self._samples(session)p_x = np.linspace(-self.gen.range, self.gen.range, len(pd))f, ax = plt.subplots(1)ax.set_ylim(0, 1)plt.plot(p_x, pd, label='real data')plt.plot(p_x, pg, label='generated data')plt.title('1D Generative Adversarial Network')plt.xlabel('Data values')plt.ylabel('Probability density')plt.legend()plt.show()
def main(args): #3、够造一个modelmodel = GAN(#4、参数DataDistribution(),#5、GeneratorDistribution(range=8),#7、定义参数args.num_steps,args.batch_size,#8、隔多长时间args.log_every,)model.train()def parse_args():parser = argparse.ArgumentParser()parser.add_argument('--num-steps', type=int, default=1200,help='the number of training steps to take')parser.add_argument('--batch-size', type=int, default=12,help='the batch size')parser.add_argument('--log-every', type=int, default=10,help='print loss after this many steps')return parser.parse_args()if __name__ == '__main__':main(parse_args())

http://chatgpt.dhexx.cn/article/PwW8xyTQ.shtml

相关文章

生成对抗网络

论文阅读笔记,论文链接 Generative Adversarial Network 生成对抗网络 GAN 理解gan的原理 网络思想 在GAN网络当中,有两个网络,一个是生成网络G,另外一个是判别网络D。生成网络G的目的是生成数据,这里的数据可以是图片…

对抗生成网络GAN系列——CycleGAN简介及图片春冬变换案例

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题 🍊往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例    对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例 🍊近期目标:写…

MATLAB代码:对于对抗生成网络GAN的风光场景生成算法 关键词:场景生成 GAN 对抗生成网络 风光场景

MATLAB代码:对于对抗生成网络GAN的风光场景生成算法 关键词:场景生成 GAN 对抗生成网络 风光场景 仿真平台: pythontensorflow 主要内容:代码主要做的是基于数据驱动的风光新能源场景生成模型,具体为,通过构建了一种对…

对抗生成网络GAN系列——GAN原理及手写数字生成小案例

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题 🍊往期回顾:目标检测系列——开山之作RCNN原理详解    目标检测系列——Fast R-CNN原理详解   目标检测系列——Faster R-CNN原理详解 🍊近期目标&a…

GAN——对抗生成网络

GAN的基本思想 作为现在最火的深度学习模型之一,GAN全称对抗生成网络,顾名思义是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的。它使用两个神经网络,将一个神经网络与另一个神经网络进行对抗。 基本思想:&…

一文读懂对抗生成网络的3种模型

https://www.toutiao.com/i6635851641293636109/ 2018-12-17 14:53:28 基于对抗生成网络技术的在线工具edges2cats, 可以为简笔画涂色 前言 在GAN系列课程中分别讲解了对抗生成网络的三种模型,从Goodfellow最初提出的原始的对抗生成网络,到…

对抗生成网络(GAN)详解

目录 前言 目标函数 原理 训练 给定生成器,训练判别器 给定判别器,训练生成器 总结 前言 之前的生成模型侧重于将分布函数构造出来,然后使用最大似然函数去更新这个分布函数的参数,从而优化分布函数,但是这种方法…

对抗生成网络(GAN)简介及生成数字实战

一、简介 生成对抗网络(Generative Adversarial Netword,简称GAN),是一种生成式机器学习模型,该方法由伊恩古德费洛等人于2014年提出,曾被称为“机器学习这二十年来最酷的想法”,可以用来创造虚…

对抗生成网络(Generative Adversarial Net)

好久没有更新博客了,但似乎我每次更新博客的时候都这么说(泪)。最近对生活有了一些新的体会,工作上面,新的环境总算是适应了,知道了如何摆正工作和生活之间的关系,如何能在有效率工作的同时还能…

【PaddleOCR-det-finetune】一:基于PPOCRv3的det检测模型finetune训练

文章目录 基本流程详细步骤打标签,构建自己的数据集下载PPOCRv3训练模型修改超参数,训练自己数据集启动训练导出模型 测试 相关参考手册在PaddleOCR项目工程中的位置: det模型训练和微调:PaddleOCR\doc\doc_ch\PPOCRv3_det_train.…

模型微调(Finetune)

参考:https://zhuanlan.zhihu.com/p/35890660 ppt下载地址:https://github.com/jiangzhubo/What-is-Fine-tuning 一.什么是模型微调 给定预训练模型(Pre_trained model),基于模型进行微调(Fine Tune)。相…

fine-tuning

微调(fine-tuning) 在平时的训练中,我们通常很难拿到大量的数据,并且由于大量的数据,如果一旦有调整,重新训练网络是十分复杂的,而且参数不好调整,数量也不够,所以我们可…

大模型的三大法宝:Finetune, Prompt Engineering, Reward

编者按:基于基础通用模型构建领域或企业特有模型是目前趋势。本文简明介绍了最大化挖掘语言模型潜力的三大法宝——Finetune, Prompt Engineering和RLHF——的基本概念,并指出了大模型微调面临的工具层面的挑战。 以下是译文,Enjoy! 作者 | B…

RCNN网络源码解读(Ⅲ) --- finetune训练过程

目录 0.回顾 1.finetune二分类代码解释(finetune.py) 1.1 load_data(定义获取数据的方法) 1.2 CustomFineTuneDataset类 1.3 custom_batch_sampler类( custom_batch_sampler.py) 1.4 训练train_mod…

FinSH

finSH介绍 FinSH 是 RT-Thread 的命令行组件,提供一套供用户在命令行调用的操作接口,主要用于调试或查看系统信息。它可以使用串口 / 以太网 / USB 等与 PC 机进行通信。 命令执行过程 功能: 支持鉴权,可在系统配置中选择打开/关闭。(TODO…

从统一视角看各类高效finetune方法

每天给你送来NLP技术干货! 来自:圆圆的算法笔记 随着预训练模型参数量越来越大,迁移学习的成本越来越高,parameter-efficient tuning成为一个热点研究方向。在以前我们在下游任务使用预训练大模型,一般需要finetune模型…

finetune

finetune的含义是获取预训练好的网络的部分结构和权重,与自己新增的网络部分一起训练。下面介绍几种finetune的方法。 完整代码:https://github.com/toyow/learn_tensorflow/tree/master/finetune 一,如何恢复预训练的网络 方法一&#xf…

11.2 模型finetune

一、Transform Learning 与 Model Finetune 二、pytorch中的Finetune 一、Transfer Learning 与 Model Finetune 1. 什么是Transfer Learning? 迁移学习是机器学习的一个分支,主要研究源域的知识如何应用到目标域当中。迁移学习是一个很大的概念。 怎么理解源域…

飞桨深度学习学院零基础深度学习7日入门-CV疫情特辑学习笔记(四)DAY03 车牌识别

本课分为理论和实战两个部分 理论:卷积神经网络 1.思考全连接神经网络的问题 一般来收机器学习模型实践分为三个步骤,(1)建立模型 (2)选择损失函数 (3)参数调整学习 1.1 模型结构不…

unity sdk(android)-友盟推送SDK接入

注意:一开始想接友盟Unity的SDk,但是导入后缺少各种jar,所以最后还是接了android的,demo文档齐全 官方文档:开发者中心 按照官方文档对接即可, 接入流程 1、项目中com.android.tools.build:gradle配置&…