BigGAN

article/2025/11/9 23:02:11

1、BIGGAN 解读

1.1、作者

Andrew Brock、Jeff Donahue、Karen Simonyan

1.2、摘要

尽管最近在生成图像建模方面取得了进展,但从 ImageNet 等复杂数据集中 成功生成高分辨率、多样化的样本仍然是一个难以实现的目标。为此,我们以迄 今为止最大的规模训练生成对抗网络,并研究该规模特有的不稳定性。我们发现, 对生成器应用正交正则化使其易于使用简单的“截断技巧”,通过减少生成器输 入的方差,可以精细控制样本保真度和品种之间的权衡。我们的修改导致模型在 类条件图像合成中设置了新的技术状态。当以 128×128 分辨率在 ImageNet 上训 练时,BigGANs 的 IS 分数为 166.5,FID 分数为 7.4,比之前最好的 IS 为 52.52 和 FID 为 18.65 有所改进。

1.3、模型

GResidualBlock块代码如下:

class GResidualBlock(nn.Module):''' Implements a residual block in BigGAN's generator '''def __init__(self,c_dim: int,in_channels: int,out_channels: int,):super().__init__()self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))self.conv2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))self.bn1 = ClassConditionalBatchNorm2d(c_dim, in_channels)self.bn2 = ClassConditionalBatchNorm2d(c_dim, out_channels)self.activation = nn.ReLU()self.upsample_fn = nn.Upsample(scale_factor=2)     # upsample occurs in every gblockself.mixin = (in_channels != out_channels)if self.mixin:self.conv_mixin = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0))def forward(self, x, y):# x,y输入给BatchNormh = self.bn1(x, y) # BatchNormh = self.activation(h)# ReLUh = self.upsample_fn(h) # Upsampleh = self.conv1(h)# 3x3Conv# x卷积后成h,y输入给BatchNormh = self.bn2(h, y) # BatchNormh = self.activation(h)# ReLUh = self.conv2(h)# 3x3Conv# x输入给Upsamplex = self.upsample_fn(x)# Upsampleif self.mixin:x = self.conv_mixin(x)# 1x1Conv# 1x1卷积后的x + 经过两次3x3卷积后的xreturn h + x # add

Non-Local Block的代码如下:

# Self-Attention module == Non-Local block
class AttentionBlock(nn.Module):''' Implements a self-attention block from SA-GAN '''def __init__(self, channels: int):super().__init__()self.channels = channelsself.theta = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=False))self.phi = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=False))self.g = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 2, kernel_size=1, padding=0, bias=False))self.o = nn.utils.spectral_norm(nn.Conv2d(channels // 2, channels, kernel_size=1, padding=0, bias=False))self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True)def forward(self, x):spatial_size = x.shape[2] * x.shape[3]# apply convolutions to get query (theta), key (phi), and value (g) transformstheta = self.theta(x)phi = F.max_pool2d(self.phi(x), kernel_size=2)g = F.max_pool2d(self.g(x), kernel_size=2)# reshape spatial size for self-attentiontheta = theta.view(-1, self.channels // 8, spatial_size)phi = phi.view(-1, self.channels // 8, spatial_size // 4)g = g.view(-1, self.channels // 2, spatial_size // 4)# compute dot product attention with query (theta) and key (phi) matricesbeta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), dim=-1)# compute scaled dot product attention with value (g) and attention (beta) matriceso = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.channels // 2, x.shape[2], x.shape[3]))# apply gain and residualreturn self.gamma * o + x

BigGAN的Generation结构如图所示:

根据上图代码如下:

class Generator(nn.Module):''' Implements the BigGAN generator '''def __init__(self,base_channels: int = 96,bottom_width: int = 4,# yml里面是2z_dim: int = 120,shared_dim: int = 128,n_classes: int = 1000,):super().__init__()n_chunks = 6    # 5 (generator blocks) + 1 (generator input)self.z_chunk_size = z_dim // n_chunks # 120//6 == 20self.z_dim = z_dimself.shared_dim = shared_dimself.bottom_width = bottom_widthself.n_classes = n_classes# no spectral normalization on embeddings, which authors observe to cripple the generatorself.shared_emb = nn.Embedding(n_classes, shared_dim)# Linear层 Linear(20,16*96*2**2)self.proj_z = nn.Linear(self.z_chunk_size, 16 * base_channels * bottom_width ** 2)# 不能用一个大nn。连续的,因为我们在每个块上添加class+noiseself.g_blocks = nn.ModuleList([# ResBlock up 16ch → 16chGResidualBlock(shared_dim + self.z_chunk_size, 16 * base_channels, 16 * base_channels),# ResBlock up 16ch → 8chGResidualBlock(shared_dim + self.z_chunk_size, 16 * base_channels, 8 * base_channels),# ResBlock up 8ch → 4chGResidualBlock(shared_dim + self.z_chunk_size, 8 * base_channels, 4 * base_channels),# ResBlock up 4ch → 2chGResidualBlock(shared_dim + self.z_chunk_size, 4 * base_channels, 2 * base_channels),# Non-Local Block (64 × 64)AttentionBlock(2 * base_channels),# ResBlock up 2ch → chGResidualBlock(shared_dim + self.z_chunk_size, 2 * base_channels, base_channels),])self.proj_o = nn.Sequential(# BN, ReLU, 3 × 3 Conv ch → 3, Tanhnn.BatchNorm2d(base_channels),nn.ReLU(inplace=True),nn.utils.spectral_norm(nn.Conv2d(base_channels, 3, kernel_size=1, padding=0)),nn.Tanh(),)def forward(self, z, y):'''z: random noise with size self.z_dimy: one-hot class embeddings with size self.shared_dim'''y = self.shared_emb(y)# class# 块z并连接到共享类嵌入zs = torch.split(z, self.z_chunk_size, dim=1)z = zs[0]ys = [torch.cat([y, z], dim=1) for z in zs[1:]] # Split的结果+Class# project noise and reshape to feed through generator blocksh = self.proj_z(z)# Linear层h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)# feed through generator blocksidx = 0for g_block in self.g_blocks:if isinstance(g_block, AttentionBlock):h = g_block(h)else:h = g_block(h, ys[idx])idx += 1# project to 3 RGB channels with tanh to map values to [-1, 1]h = self.proj_o(h)return h

1.4、试验

1.4.1、不同 Batch size 对性能的影响

作者发现简单地将 Batch size 增大就可以实现性能上较好的提升,文章做 了实验验证。在 Batch size 增大到原来 8 倍的时候,生成性能上的 IS 提高 了 46%。文章推测这可能是每批次覆盖更多模式的结果,为生成和判别两个网 络提供更好的梯度。增大 Batch size 还会带来在更少的时间训练出更好性能的 模型,但增大 Batch size 也会使得模型在训练上稳定性下降,后续再分析如何 提高稳定性。

在实验上,单单提高 Batch size 还受到限制,文章在每层的通道数也做了 相应的增加,当通道增加 50%,大约两倍于两个模型中的参数数量。这会导致 IS 进一步提高 21%。文章认为这是由于模型的容量相对于数据集的复杂性而增加。

1.4.2、选择先验分布

z 通过实验对比了 N(0,1)、Bernoulli{0,1}、Censored Normal max(N(0,1), 0),根据参考训练速度、模型性能,文章最终选择了 z∼ N(0,I)。

1.4.3、选择阈值

所谓的“截断技巧”就是通过对从先验分布 z 采样,通过设置阈值的方式 来截断 z 的采样,其中超出范围的值被重新采样以落入该范围内。这个阈值可 以根据生成质量指标 IS 和 FID 决定。 通过实验可以知道通过对阈值的设定,随着阈值的下降生成的质量会越来越 好,但是由于阈值的下降、采样的范围变窄,就会造成生成上取向单一化,造成 生成的多样性不足的问题。往往 IS 可以反应图像的生成质量,FID 则会更假注 重生成的多样性。

1.4.4、尝试控制 G

在探索模型的稳定性上,文章在训练期间监测一系列权重、梯度和损失统计 数据,以寻找可能预示训练崩溃开始的指标。实验发现每个权重矩阵的前三个奇 异值 σ0,σ1,σ2 是最有用的,它们可以使用 Alrnoldi 迭代方法进行有效计 算。

对于奇异值 σ0,大多数 G 层具有良好的光谱规范,但有些层(通常是 G 中 的第一层而非卷积)则表现不佳,光谱规范在整个训练过程中增长,在崩溃时爆 炸。

一顿操作后,文章得出了调节 G 可以改善模型的稳定性,但是无法确保一 直稳定,从而文章转向对 D 的控制。

1.4.5、尝试控制 D

考虑 D 网络的光谱,试图寻找额外的约束来寻求稳定的训练。使用正交正 则化,DropOut 和 L2 的各种正则思想重复该实验,揭示了这些正则化策略的都 有类似行为:对 D 的惩罚足够高,可以实现训练稳定性但是性能成本很高,但 是在图像生成性能上也是下降的,而且降的有点多。

实验还发现 D 在训练期间的损失接近于零,但在崩溃时经历了急剧的向上 跳跃,这种行为的一种可能解释是 D 过度拟合训练集,记忆训练样本而不是学 习真实图像和生成图像之间的一些有意义的边界。

为了评估这一猜测,文章在 ImageNet 训练和验证集上评估判别器,并测量 样本分类为真实或生成的百分比。虽然在训练集下精度始终高于 98%,但验证 准确度在 50-55% 的范围内,这并不比随机猜测更好(无论正则化策略如何)。

这证实了 D 确实记住了训练集,也符合 D 的角色:不断提炼训练数据并为 G 提 供有用的学习信号。 可以通过约束 D 来强制执行稳定性,但这样做会导致性能上的巨大成本。 使用现有技术,通过放松这种调节并允许在训练的后期阶段发生崩溃(人为把握 训练实际),可以实现更好的最终性能,此时模型被充分训练以获得良好的结果。

1.4.6、用分辨率评估模型

在 ImageNet 数据集下做评估,实验在 ImageNet ILSVRC 2012(大家都在 用的 ImageNet 的数据集)上 128×128,256×256 和 512×512 分辨率评估模 型。

1.4.7、验证 G 网络并非是记住训练集

为了进一步说明 G 网络并非是记住训练集,在固定 z 下通过调节条件标签 c 做插值生成,通过下图的实验结果可以发现,整个插值过程是流畅的,也能说 明 G 并非是记住训练集,而是真正做到了图像生成。

1.5、与 GAN 的对比

BigGAN 的主要改进有一下三部分:

(1)通过大规模 GAN 的应用,BigGAN 实现了生成上的巨大突破,参数量 扩大两到四倍,batchsize 扩大八倍;

(2)采用先验分布 z 的“截断技巧”,允许对样本多样性和保真度进行精 细控制;

(3)在大规模 GAN 的实现上不断克服模型训练问题,采用技巧减小训练的 不稳定,但完全的稳定性只能以极高的性能成本实现。


http://chatgpt.dhexx.cn/article/59FqMtrp.shtml

相关文章

BigGAN(2019)

论文引入 我们来看一下由 BigGAN 生成的图像: 随着 GAN、VAE 等一众生成模型的发展,图像生成在这几年是突飞猛进,14 年还在生成手写数字集,到 18 年已经将 ImageNet 生成的如此逼真了。 这中间最大的贡献者应该就是 GAN 了&…

BigGAN_用于高保真自然图像合成的大规模 GAN 训练

【飞桨】【Paddlepaddle】【论文复现】BigGAN 用于高保真自然图像合成的大规模 GAN 训练LARGE SCALE GAN TRAINING FOR HIGH FIDELITY NATURAL IMAGE SYNTHESIS1、BiagGAN的贡献2.1背景2.2具体措施与改变2.2.1规模(scaling up)2.2.2截断技巧(…

环形链表之快慢指针

环形链表 前言一、案例1、环形链表2、环形链表II 二、题解1、环形链表2、环形链表II3、源码4、寻找入环点的数学解法 总结参考文献 前言 对于环形链表,通过快慢指针,如果存在环,这这两个指针一定会相遇,这是一种经典的判断环或是应…

快慢指针判断链表中是否存在环以及查找环的起始位置

判断链表中是否有环? 使用快慢指针, 慢指针一次走一步, 快指针一次走两步, 当快慢指针相遇时,说明链表存在环 为什么快指针每次走两步而慢指针每次走一步呢? 因为slow指针和fast指针都会进入环内, 就像在环形跑道内不同位置的两个人;slow指针在后面, fast指针在前面, 但…

链表-快慢指针(C++)

一、链表 链表是由一组在内存中不必相连(不必相连:可以连续也可以不连续)的内存结构Node,按特定的顺序链接在一起的抽象数据类型。 我们常见的链表结构有单链表和双向链表。 单链表,保存了下一个结点的指针&#xf…

面试题 02.08. 环路检测-快慢指针+如何找到环的入口?(证明)Java

1.题目 2.思路 方法一——哈希表记录节点 思路很简单,记录一下每个节点出现的次数,如果某个节点出现了两次,代表此时有环,并且是环的入口,直接返回即可。 时间复杂度O(N) 空间复杂度O(N) public class Solution {…

链表中快慢指针的应用

目录 一、链表的中间结点 二、回文链表 三、链表中倒数第K个结点 四、删除链表的倒数第n个结点 五、环形链表 六、环形链表Ⅱ 一、链表的中间结点 给定一个头结点为 head 的非空单链表,返回链表的中间结点。 如果有两个中间结点,则返回第二个中间…

快慢指针思想

快慢指针思想 在做题当中经常会用到快慢指针,快慢指针就是定义两根指针,移动的速度一快一慢,从而创造出自己想要指针的差值。这个差值可以让我们找到链表上相应的节点。 参考链接:https://www.jianshu.com/p/21b4b8d7d31b 应用 …

指针的运用——快慢指针

快慢指针是指针的一种类型,在这里我们来了解下快慢指针 让我们来看一道题 一、题目 876. 链表的中间结点 首先我们对这道题进行分析,最容易让人想到的方法是直接使用n/2找到中点,如果我们不对链表进行遍历,我们该怎么做呢&…

浅谈快慢指针

快慢指针 快慢指针 快慢指针1.快慢指针的概念:2.快慢指针的应用:1. 判断单链表是否为循环链表2. 在有序链表中寻找中位数3.链表中倒数第k个节点 1.快慢指针的概念: 快慢指针就是存在两个指针,一个快指针,一个慢指针&a…

快慢指针应用总结

快慢指针 快慢指针中的快慢指的是移动的步长,即每次向前移动速度的快慢。例如可以让快指针每次沿链表向前移动2,慢指针每次向前移动1次。 快慢指针的应用 (1)判断单链表是否存在环 如果链表存在环,就好像操场的跑道是…

十大常用经典排序算法总结!!!

爆肝整理!堪称全网最详细的十大常用经典排序算法总结!!! 写在开头,本文经过参考多方资料整理而成,全部参考目录会附在文章末尾。很多略有争议性的细节都是在不断查阅相关资料后总结的,具有一定…

经典五大算法思想-------入门浅析

算法:求解具体问题的步骤描述,代码上表现出来是解决特定问题的一组有限的指令序列。 1、分治: 算法思想:规模为n的原问题的解无法直接求出,进行问题规模缩减,划分子问题(这里子问题相互独立而且…

算法设计经典算法

一、贪婪算法 1、概述 贪婪法又称贪心算法,是当追求的目标是一个问题的最优解时,设法把对整个问题的求解工作分成若干步骤来完成,是寻找最优解问题的常用方法。 贪婪法的特点是一般可以快速得到满意的解,因为它省去了为找最优解…

算法之经典图算法

图介绍表示图的数据结构图的两种搜索方式DFS可以处理问题BFS可以处理问题有向图最小生成树最短路径 图介绍 图:是一个顶点集合加上一个连接不同顶点对的边的集合组成。定义规定不允许出现重复边(平行边)、连接到顶点自身的边(自环…

计算机10大经典算法

算法一:快速排序法 快速排序是由东尼霍尔所发展的一种排序算法。在平均状况下,排序 n 个项目要Ο(n log n)次比较。在最坏状况下则需要Ο(n2)次比较,但这种状况并不常见。事实上,快速排序通常明显比其…

算法设计——五大算法总结

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 算法设计总结 一、【分治法】二、【动态规划法】三、【贪心算法】四、【回溯法】五、【分支限界法】 一、【分治法】 在计算机科学中,分治法是一种很重要的算法。…

十大经典算法总结

正文 排序算法说明 &#xff08;1&#xff09;排序的定义&#xff1a;对一序列对象根据某个关键字进行排序&#xff1b; 输入&#xff1a;n个数&#xff1a;a1,a2,a3,...,an 输出&#xff1a;n个数的排列:a1,a2,a3,...,an&#xff0c;使得a1<a2<a3<...<an。 再…

九大经典算法

1. 冒泡排序&#xff08;Bubble Sort&#xff09; 两个数比较大小&#xff0c;通过两两交换&#xff0c;像水中的泡泡一样&#xff0c;较大的数下沉&#xff0c;较小的数冒起来。 算法描述&#xff1a; 1.比较相邻的元素。如果第一个比第二个大&#xff0c;就交换它们两个&a…

最常用的五大算法

一、贪心算法 贪心算法&#xff08;又称贪婪算法&#xff09;是指&#xff0c;在对问题求解时&#xff0c;总是做出在当前看来是最好的选择。也就是说&#xff0c;不从整体最优上加以考虑&#xff0c;他所做出的仅是在某种意义上的局部最优解。贪心算法不是对所有问题都能得到整…