Diffusion Model

article/2025/10/19 15:20:55

DDPM

codebase 为 https://github.com/lucidrains/denoising-diffusion-pytorch

训练和推理流程如下:

Train

diffusion() ---> forward() ---> self.p_losses() 完成一个扩散阶段(包括前向计算和 BP),每次前向和 BP 中用到的 t(batch size 个)都是从 {1, 2, 3, ..., T} 中均匀采样得到的。

loss = diffusion(training_images) # training_images 为当前 batch 的输入图像
    def forward(self, img, *args, **kwargs):b, c, h, w, device, img_size, = *img.shape, img.device, self.image_sizeassert h == img_size and w == img_size, f'height and width of image must be {img_size}'t = torch.randint(0, self.num_timesteps, (b,), device=device).long()img = normalize_to_neg_one_to_one(img)return self.p_losses(img, t, *args, **kwargs)
    def p_losses(self, x_start, t, noise = None):b, c, h, w = x_start.shapenoise = default(noise, lambda: torch.randn_like(x_start))x = self.q_sample(x_start = x_start, t = t, noise = noise)model_out = self.model(x, t)target = noiseloss = self.loss_fn(model_out, target, reduction = 'none')loss = reduce(loss, 'b ... -> b (...)', 'mean')loss = loss * extract(self.p2_loss_weight, t, loss.shape)return loss.mean()

self.loss_fn() 为 L2 损失,对应论文中的:

self.q_sample() 输出的是由原图像 x0 和时间 t 计算出当前扩散采样点的 xt:

    def q_sample(self, x_start, t, noise=None):noise = default(noise, lambda: torch.randn_like(x_start))return (extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

送入 U-Net ,之后再与 random noise z 计算损失,再之后梯度回传更新参数完成当前 iter 的训练。

Inference 

diffusion.sample() ---> p_sample_loop() ---> self.p_sample() 完成一次采样:

sampled_images = diffusion.sample(batch_size = 4)
    @torch.no_grad()def sample(self, batch_size = 16):image_size, channels = self.image_size, self.channelssample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_samplereturn sample_fn((batch_size, channels, image_size, image_size))@torch.no_grad()def p_sample_loop(self, shape):...img = torch.randn(8, 3, 128, 128)   # random noisefor t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step'):img = self.p_sample(img, t)return img@torch.no_grad()def p_sample(self, x, t: int, clip_denoised = True):b, *_, device = *x.shape, x.devicebatched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = batched_times, clip_denoised = clip_denoised)noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0return model_mean + (0.5 * model_log_variance).exp() * noise

最终 p_sample() 返回的是当前采样阶段得到的图像。

self.num_timesteps 是采样步数,DDPM 中推理时的采样步数与训练时的 T 保持一致。self.p_mean_variance() 预测当前采样步的均值和方差,DDPM 将方差设为超参数,故只需要预测均值:

    def p_mean_variance(self, x, t, clip_denoised: bool):preds = self.model_predictions(x, t)x_start = preds.pred_x_startif clip_denoised:x_start.clamp_(-1., 1.)model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)return model_mean, posterior_variance, posterior_log_variance

U-Net 需要预测公式中的:

 对应代码中的 model_predictions():

    def model_predictions(self, x, t):model_output = self.model(x, t)     # thetapred_noise = model_output                   #x_start = self.predict_start_from_noise(x, t, model_output) # jun zhireturn ModelPrediction(pred_noise, x_start)def predict_start_from_noise(self, x_t, t, noise):return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise)

最终输出的均值对应代码:

    def q_posterior(self, x_start, x_t, t):posterior_mean = (extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +extract(self.posterior_mean_coef2, t, x_t.shape) * x_t)posterior_variance = extract(self.posterior_variance, t, x_t.shape)posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)return posterior_mean, posterior_variance, posterior_log_variance_clipped

DDIM

Train

diffusion() ---> forward() ---> self.p_losses()。除了 DDPM 中用到的损失外,在每个扩散阶段 DDIM 增加了一个损失,经过推导的 DDPM 优化目标的中间形式是最小化两个分布间的 KL 散度:

DDIM 的每个扩散阶段加入了这个 KL 散度损失。DDIM 的 self.p_losses() 为:

    def p_losses(self, x_start, t, noise = None, clip_denoised = False):noise = default(noise, lambda: torch.randn_like(x_start))x_t = self.q_sample(x_start = x_start, t = t, noise = noise)# model outputmodel_output = self.model(x_t, t)# calculating kl loss for learned variance (interpolation)true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_t, t = t)model_mean, _, model_log_variance = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output)# kl loss with detached model predicted mean, for stability reasons as in paperdetached_model_mean = model_mean.detach()kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)kl = meanflat(kl) * NATdecoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)decoder_nll = meanflat(decoder_nll) * NAT# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))vb_losses = torch.where(t == 0, decoder_nll, kl)# simple loss - predicting noise, x0, or x_prevpred_noise, _ = model_output.chunk(2, dim = 1)simple_losses = self.loss_fn(pred_noise, noise)return simple_losses + vb_losses.mean() * self.vb_loss_weight

Inference

diffusion.sample() ---> self.ddim_sample()。 

    @torch.no_grad()def ddim_sample(self, shape, clip_denoised = True):batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objectivetimes = torch.linspace(0., total_timesteps, steps = sampling_timesteps + 2)[:-1]times = list(reversed(times.int().tolist()))time_pairs = list(zip(times[:-1], times[1:]))img = torch.randn(shape, device = device)for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):alpha = self.alphas_cumprod_prev[time]alpha_next = self.alphas_cumprod_prev[time_next]time_cond = torch.full((batch,), time, device = device, dtype = torch.long)pred_noise, x_start, *_ = self.model_predictions(img, time_cond)if clip_denoised:x_start.clamp_(-1., 1.)sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()c = ((1 - alpha_next) - sigma ** 2).sqrt()noise = torch.randn_like(img) if time_next > 0 else 0.img = x_start * alpha_next.sqrt() + \c * pred_noise + \sigma * noiseimg = unnormalize_to_zero_to_one(img)return img

采样时用 Xt 计算出 Xt-1,self.model_predictions() 输出的 x_start 为下式中的 predicted x0 项。


http://chatgpt.dhexx.cn/article/qirchAWd.shtml

相关文章

使用Batch Normalization解决VAE训练中的后验坍塌(posterior collapse)问题

前言 在训练VAE模型时,当我们使用过于过于强大的decoder时,尤其是自回归式的decoder比如LSTM时,存在一个非常大的问题就是,decoder倾向于不从latent variable z中学习,而是独立地重构数据,这个时候&#x…

【论文模型讲解】Learning to Select Knowledge for Response Generation in Dialog Systems(PostKS模型)

文章目录 前言背景Posterior Knowledge Selection 模型(PostKS)1. 对话编码器&知识编码器(Utterance Encoder&Knowledge Encoder)2. 知识管理器(Knowledge Manager)3. 解码器4. 损失函数 Q & A1. 先验知识模块相关问题 前言 论文网址&#…

引导方法深度补全系列—晚期融合模型—1—《Dense depth posterior (ddp) from single image and sparse range》文章细读

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 创新点 实施细节 对比sparse_RGBD tips 方法详解 损失函数 优缺点 总结 创新点 1.提出了基于贝叶斯理论的两步法网络做深度补全 文章概述 提出了两步法,实际…

扩散模型(Diffusion model)代码详细解读

扩散模型代码详细解读 代码地址:denoising-diffusion-pytorch/denoising_diffusion_pytorch.py at main lucidrains/denoising-diffusion-pytorch (github.com) 前向过程和后向过程的代码都在GaussianDiffusion​这个类中。​ 有问题可以一起讨论! …

DIFFUSION POSTERIOR SAMPLING FOR GENERALNOISY INVERSE PROBLEMS (Paper reading)

DIFFUSION POSTERIOR SAMPLING FOR GENERALNOISY INVERSE PROBLEMS Hyungjin Chung, Kim Jae Chul Graduate School of AI, ICLR 2023 spotlight, Cited:10, Code, Paper. 目录子 DIFFUSION POSTERIOR SAMPLING FOR GENERALNOISY INVERSE PROBLEMS1. 前言2. 整体思想3. 方法实…

Exploring and Distilling Posterior and Prior Knowledge for Radiology Report Generation

Exploring and Distilling Posterior and Prior Knowledge for Radiology Report Generation(探索和提炼后验和先验知识的放射学报告生成) 先验与后验目前的放射学报告生成的局限性Paper的贡献模型详解模型输入模型主要部分 先验与后验 在阅读这篇Paper…

国际会议poster: 海报制作流程 格式介绍

1 流程 word制作, 转pdf, 打印 2 模板 UCLhttps://wenku.baidu.com/view/034bcb7e4a7302768f99392a.html 3 CYBER2019格式要求 海报尺寸:A1尺寸23.4英寸(59.4厘米)宽,33.1英寸(84.1厘米)高。 请注意,以A4尺寸列印已递交的整张纸作为海报是不可…

概率基础 · 联合概率 边缘概率 prior posterior likelihood

概率基础 联合概率 边缘概率 prior posterior likelihood 联合概率 (Joint Probability)边缘概率(margin probability)贝叶斯定理(Bayes Theorem)prior,posterior,likelihood:概率与似然的区别…

Stochastic Image Denoising By Sampling from the Posterior Distribution (Paper reading)

Stochastic Image Denoising By Sampling from the Posterior Distribution (Paper reading) Bahjat Kawar, Haifa(Israel), ICCV Workshop2021, Cited:22, Code:无, Paper. 目录子 Stochastic Image Denoising By Sampling from the Posterior Distribution (Paper reading)1…

GAN论文精读 P2GAN: Posterior Promoted GAN 用鉴别器产生的后验分布来提升生成器

《Posterior Promoted GAN with Distribution Discriminator for Unsupervised Image Synthesis》是大连理工学者发表的文章,被2021年CVPR收录。 文章地址:https://ieeexplore.ieee.org/document/9578672 本篇文章是阅读这篇论文的精读笔记。 一、原文…

先验、后验与似然

在学习SLAM 14讲第六章时,看到三个概念,有些不太了解,查阅资料后有了一些自己的理解。 三个概念存在于贝叶斯公式中 表示先验概率Prior,表示后验概率posterior,表示似然likelihood 上式可以写为 下面分别对三个概念进…

Prior 、Posterior 和 Likelihood 的理解与几种表达方式

Prior 、Posterior 和 Likelihood 的理解与几种表达方式 (下载图片可以看大图。)

Windows作为NTP同步时间的服务器时的设置

1.先关闭Windows系统自带的防火墙; 2. 在桌面上右击“计算机”, 选择“管理”, 然后选择“服务”。 具体如图所示 2. 选中“Windows Time”,设置为开启,这样就可以将“Windows Time”这一个服务打开。 3. “开始”--》“运…

NTP时钟服务器推荐-国内时间服务器顶尖设备

电子钟时间服务器在物联网应用中起到了关键的作用,它能够为各种智能设备提供准确的时间参考,确保设备之间的协同工作和数据的准确传输。无论是智能家居、智能工厂还是智慧城市,电子钟时间服务器都是不可或缺的一部分。 一、产品卖点 时间服…

NTP同步时间失败。Linux作为客户端,Windows作为NTP时钟源服务端。

使用windows作NTP时钟源,NTP同步时间失败 【关 键 词】:NTP,时钟源,windows时钟源,同步时间失败 【故障类型】:操作维护->其他 【适用版本】:Linux 【问题描述】:windows做时钟…

如何在windows10 搭建 NTP 时间服务器

windows本身是可以作为NTP时间同步服务器使用的,本文介绍一下如何在win10上配置NTP时间同步服务器。 如何在windows10 搭建 NTP 时间服务器 工具/原料 系统版本:win10版本 [10.0.17134.706] 方法/步骤 使用组合键WIN R 启动运行窗口,在…

Linux服务器NTP客户端时钟同步配置方法

前提说明:配置客户端NTP时候,必须要有一台时钟服务器,可以是服务器搭建的,也可以是购买的时钟设备。我这里使用临时的时钟服务器IP地址10.10.4.100 步骤如下: 1 首先在客户端服务器中ping一下时钟的IP地址是否网络可通…

NTP时钟服务器(PTP服务器)无法同步的排查方法

NTP时钟服务器(PTP服务器)无法同步的排查方法 NTP时钟服务器(PTP服务器)无法同步的排查方法 NTP系统是典型的C-S模型,一般将整个系统分为服务器,网络和客户端三个区域,因NTP时间服务器一般在出厂时已经测试,并设置为可使用&#…

NTP时间服务器同步时钟系统安装汇总分享

在现代科技发展的背景下,各种设备的时间同步变得越来越重要。同步时钟管理系统的应用可以让多个设备在时间上保持一致,提高工作效率和安全性,为各个行业的发展提供了重要的支持。 一、同步时钟系统介绍 同步时钟管理系统的应用范围非常广泛&…

关于NTP时间服务器

NTP(Network Time Protocol) 网络时间协议,工作在UDP的123端口上。是用来使计算机时间同步化的一种协议,它可以使计算机对其服务器或时钟源(如石英钟,GPS等等)做同步化,它可以提供高精准度的时间校正(局域网…