扩散模型(Diffusion model)代码详细解读

article/2025/10/19 15:09:36

扩散模型代码详细解读

代码地址:denoising-diffusion-pytorch/denoising_diffusion_pytorch.py at main · lucidrains/denoising-diffusion-pytorch (github.com)

前向过程和后向过程的代码都在GaussianDiffusion​这个类中。​

有问题可以一起讨论!

常见问题解决

Why self-conditioning? · Issue #94 · lucidrains/denoising-diffusion-pytorch (github.com)

"pred_x0" preforms better than "pred_noise" · Issue #58 · lucidrains/denoising-diffusion-pytorch (github.com)

What is objective=pred_x0 and how do you use it? · Issue #34 · lucidrains/denoising-diffusion-pytorch (github.com)

Conditional generation · Issue #7 · lucidrains/denoising-diffusion-pytorch (github.com)

Questions About DDPM · Issue #10 · lucidrains/denoising-diffusion-pytorch (github.com)
The difference between pred_x0, pred_v, pred_noise three objectives · Issue #153 · lucidrains/denoising-diffusion-pytorch (github.com)

前向训练过程

p_losses

首先是p_losses函数,这个是训练过程的主体部分。

def p_losses(self, x_start, t, noise = None):b, c, h, w = x_start.shape# 首先随机生成噪声noise = default(noise, lambda: torch.randn_like(x_start))# noise sample# 噪声采样,注意这个是一次性完成的x = self.q_sample(x_start = x_start, t = t, noise = noise)# if doing self-conditioning, 50% of the time, predict x_start from current set of times# and condition with unet with that# this technique will slow down training by 25%, but seems to lower FID significantly# 判断是否进行self-condition,就是利用前面步骤预测出的x0来辅助当前的预测x_self_cond = Noneif self.self_condition and random() < 0.5:with torch.no_grad():x_self_cond = self.model_predictions(x, t).pred_x_startx_self_cond.detach_()# predict and take gradient step# 将采样的x和self condition的x一起输入到model当中,这个model是UNet结构model_out = self.model(x, t, x_self_cond)# 模型预测的目标,分为三种if self.objective == 'pred_noise':target = noiseelif self.objective == 'pred_x0':target = x_startelif self.objective == 'pred_v':v = self.predict_v(x_start, t, noise)target = velse:raise ValueError(f'unknown objective {self.objective}')# 计算损失loss = self.loss_fn(model_out, target, reduction = 'none')loss = reduce(loss, 'b ... -> b (...)', 'mean')loss = loss * extract(self.p2_loss_weight, t, loss.shape)return loss.mean()

对其中的extract函数进行分析,extract函数实现如下:

def extract(a, t, x_shape):# Extract some coefficients at specified timesteps,# then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.b, *_ = t.shape# 使用了gather函数out = a.gather(-1, t)return out.reshape(b, *((1,) * (len(x_shape) - 1)))

q_sample

然后介绍p_losses函数中使用的其他函数,第一个是q_sample函数,它的作用是加上噪声,对应论文的公式:
在这里插入图片描述

其中self.sqrt_alphas_cumprod​和self.sqrt_one_minus_alphas_cumprod​分别是alpha的累乘值和1-alpha的累乘值,x_start相当于x0,noise相当于z。

def q_sample(self, x_start, t, noise=None):noise = default(noise, lambda: torch.randn_like(x_start))return (extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

model_predictions

然后是model_predictions函数,它的实现如下:

def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False):# 输入到UNet结构中获得输出model_output = self.model(x, t, x_self_cond)maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity# 暂不明确它的作用if self.objective == 'pred_noise':pred_noise = model_outputx_start = self.predict_start_from_noise(x, t, pred_noise)x_start = maybe_clip(x_start)elif self.objective == 'pred_x0':x_start = model_outputx_start = maybe_clip(x_start)pred_noise = self.predict_noise_from_start(x, t, x_start)elif self.objective == 'pred_v':v = model_outputx_start = self.predict_start_from_v(x, t, v)x_start = maybe_clip(x_start)pred_noise = self.predict_noise_from_start(x, t, x_start)# 返回得到的噪声和return ModelPrediction(pred_noise, x_start)

几种objective

model_predictions函数中有一个难点,就是其中的self.objective,它有三种形式:

  • pred_noise:这个相当于是预测噪声,此时UNet模型的输出是噪声
  • pred_x0:这个相当于是预测最开始的x,此时UNet模型的输出是去噪的图像
  • pred_v:这个相当于是预测速度v,它在这篇文章中提出。然后根据速度求出最开始的x,最后预测出噪声。

如图所示:​
在这里插入图片描述

在上面的三种objective中,还涉及到了几种预测方法的实现,具体如下:

(1)predict_start_from_noise:这个函数的作用是根据噪声noise预测最开始的x,也就是去噪的图像。

其中self.sqrt_recip_alphas_cumprod​和self.sqrt_recipm1_alphas_cumprod​来自在这里插入图片描述
公式,它们分别为:在这里插入图片描述
在这里插入图片描述

公式来源文章:DDPM

def predict_start_from_noise(self, x_t, t, noise):return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise)

它对应论文中的公式如下:
在这里插入图片描述

(2)predict_noise_from_start:这个函数的作用是根据图像预测噪声,也就是加噪声。

def predict_noise_from_start(self, x_t, t, x0):return ((extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape))

它对应论文中的公式如下:
在这里插入图片描述
需要注意它是反推过来的,过程如下:

(3)predict_v:预测速度v

 def predict_v(self, x_start, t, noise):return (extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start)

它对应论文中的公式:在这里插入图片描述

(4)predict_start_from_v:根据速度v预测最初的x,也就是图像

def predict_start_from_v(self, x_t, t, v):return (extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v)

它对应论文中的公式如下:在这里插入图片描述其中zt相当于xt。

后向采样过程

sample函数

@torch.no_grad()
def sample(self, batch_size = 16, return_all_timesteps = False):image_size, channels = self.image_size, self.channels# 采样的函数sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample# 调用该函数return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)

该函数的作用是获取采样的函数然后进行调用,采样函数分成两种:p_sample_loop和ddim_sample。

p_sample_loop函数

 @torch.no_grad()def p_sample_loop(self, shape, return_all_timesteps = False):batch, device = shape[0], self.betas.device# 随机生成噪声图像img = torch.randn(shape, device = device)imgs = [img]x_start = None# 遍历所有的tfor t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):# 判断是否使用self-conditionself_cond = x_start if self.self_condition else None# 进行采样,得到去噪的图像img, x_start = self.p_sample(img, t, self_cond)imgs.append(img)# 判断是否返回每个步骤的img还是最后一步的imgret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)# 归一化ret = self.unnormalize(ret)return ret

其中涉及到归一化函数self.unnormalize​,含有两种

# normalization functions
def normalize_to_neg_one_to_one(img):return img * 2 - 1
def unnormalize_to_zero_to_one(t):return (t + 1) * 0.5

p_sample函数

@torch.no_grad()
def p_sample(self, x, t: int, x_self_cond = None):b, *_, device = *x.shape, x.devicebatched_times = torch.full((b,), t, device = x.device, dtype = torch.long)# 获得平均值,方差和x0model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)# 随机生成一个噪声	  noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0# 得到预测的图像,img = 平均值 + exp(0.5 * 方差) * noisepred_img = model_mean + (0.5 * model_log_variance).exp() * noisereturn pred_img, x_start

p_mean_variance函数

其中含有p_mean_variance​函数,代码实现如下:

def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):# 输入到UNet网络进行预测preds = self.model_predictions(x, t, x_self_cond)# 得到预测的x0x_start = preds.pred_x_start# 压缩x0中值的范围至[-1,1]if clip_denoised:x_start.clamp_(-1., 1.)# 得到x0后根据xt和t得到分布的平均值和方差model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)return model_mean, posterior_variance, posterior_log_variance, x_start

q_posterior函数

其中q_posterior​函数的实现如下:

def q_posterior(self, x_start, x_t, t):# 计算平均值posterior_mean = (extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +extract(self.posterior_mean_coef2, t, x_t.shape) * x_t)# 计算方差posterior_variance = extract(self.posterior_variance, t, x_t.shape)# 获得一个压缩范围的方差,且取对数posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)return posterior_mean, posterior_variance, posterior_log_variance_clipped

平均值和方差对应的公式如下:

在这里插入图片描述

其中self.posterior_mean_coef1​对应的是x0前面的系数,self.posterior_mean_coef2​对应的是xt前面的系数。

self.posterior_variance​对应的beta那部分的系数。

ddim_sample函数

@torch.no_grad()
def ddim_sample(self, shape, return_all_timesteps = False):batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objectivetimes = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1)   # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timestepstimes = list(reversed(times.int().tolist()))time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]img = torch.randn(shape, device = device)imgs = [img]x_start = Nonefor time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):time_cond = torch.full((batch,), time, device = device, dtype = torch.long)self_cond = x_start if self.self_condition else Nonepred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True)imgs.append(img)if time_next < 0:img = x_startcontinuealpha = self.alphas_cumprod[time]alpha_next = self.alphas_cumprod[time_next]sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()c = (1 - alpha_next - sigma ** 2).sqrt()noise = torch.randn_like(img)img = x_start * alpha_next.sqrt() + \c * pred_noise + \sigma * noiseret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)ret = self.unnormalize(ret)return ret

上面部分依据的公式为:(文章)
在这里插入图片描述
在这里插入图片描述

训练的模型(UNet)

后续会继续更新!
对您有帮助请点赞收藏哦!


http://chatgpt.dhexx.cn/article/U7RJJAM9.shtml

相关文章

DIFFUSION POSTERIOR SAMPLING FOR GENERALNOISY INVERSE PROBLEMS (Paper reading)

DIFFUSION POSTERIOR SAMPLING FOR GENERALNOISY INVERSE PROBLEMS Hyungjin Chung, Kim Jae Chul Graduate School of AI, ICLR 2023 spotlight, Cited:10, Code, Paper. 目录子 DIFFUSION POSTERIOR SAMPLING FOR GENERALNOISY INVERSE PROBLEMS1. 前言2. 整体思想3. 方法实…

Exploring and Distilling Posterior and Prior Knowledge for Radiology Report Generation

Exploring and Distilling Posterior and Prior Knowledge for Radiology Report Generation&#xff08;探索和提炼后验和先验知识的放射学报告生成&#xff09; 先验与后验目前的放射学报告生成的局限性Paper的贡献模型详解模型输入模型主要部分 先验与后验 在阅读这篇Paper…

国际会议poster: 海报制作流程 格式介绍

1 流程 word制作&#xff0c; 转pdf, 打印 2 模板 UCLhttps://wenku.baidu.com/view/034bcb7e4a7302768f99392a.html 3 CYBER2019格式要求 海报尺寸:A1尺寸23.4英寸(59.4厘米)宽&#xff0c;33.1英寸(84.1厘米)高。 请注意&#xff0c;以A4尺寸列印已递交的整张纸作为海报是不可…

概率基础 · 联合概率 边缘概率 prior posterior likelihood

概率基础 联合概率 边缘概率 prior posterior likelihood 联合概率 (Joint Probability)边缘概率&#xff08;margin probability&#xff09;贝叶斯定理&#xff08;Bayes Theorem&#xff09;prior&#xff0c;posterior&#xff0c;likelihood&#xff1a;概率与似然的区别…

Stochastic Image Denoising By Sampling from the Posterior Distribution (Paper reading)

Stochastic Image Denoising By Sampling from the Posterior Distribution (Paper reading) Bahjat Kawar, Haifa(Israel), ICCV Workshop2021, Cited:22, Code:无, Paper. 目录子 Stochastic Image Denoising By Sampling from the Posterior Distribution (Paper reading)1…

GAN论文精读 P2GAN: Posterior Promoted GAN 用鉴别器产生的后验分布来提升生成器

《Posterior Promoted GAN with Distribution Discriminator for Unsupervised Image Synthesis》是大连理工学者发表的文章&#xff0c;被2021年CVPR收录。 文章地址&#xff1a;https://ieeexplore.ieee.org/document/9578672 本篇文章是阅读这篇论文的精读笔记。 一、原文…

先验、后验与似然

在学习SLAM 14讲第六章时&#xff0c;看到三个概念&#xff0c;有些不太了解&#xff0c;查阅资料后有了一些自己的理解。 三个概念存在于贝叶斯公式中 表示先验概率Prior&#xff0c;表示后验概率posterior&#xff0c;表示似然likelihood 上式可以写为 下面分别对三个概念进…

Prior 、Posterior 和 Likelihood 的理解与几种表达方式

Prior 、Posterior 和 Likelihood 的理解与几种表达方式 &#xff08;下载图片可以看大图。&#xff09;

Windows作为NTP同步时间的服务器时的设置

1.先关闭Windows系统自带的防火墙; 2. 在桌面上右击“计算机”&#xff0c; 选择“管理”&#xff0c; 然后选择“服务”。 具体如图所示 2. 选中“Windows Time”&#xff0c;设置为开启&#xff0c;这样就可以将“Windows Time”这一个服务打开。 3. “开始”--》“运…

NTP时钟服务器推荐-国内时间服务器顶尖设备

电子钟时间服务器在物联网应用中起到了关键的作用&#xff0c;它能够为各种智能设备提供准确的时间参考&#xff0c;确保设备之间的协同工作和数据的准确传输。无论是智能家居、智能工厂还是智慧城市&#xff0c;电子钟时间服务器都是不可或缺的一部分。 一、产品卖点 时间服…

NTP同步时间失败。Linux作为客户端,Windows作为NTP时钟源服务端。

使用windows作NTP时钟源&#xff0c;NTP同步时间失败 【关 键 词】&#xff1a;NTP&#xff0c;时钟源&#xff0c;windows时钟源&#xff0c;同步时间失败 【故障类型】&#xff1a;操作维护->其他 【适用版本】&#xff1a;Linux 【问题描述】&#xff1a;windows做时钟…

如何在windows10 搭建 NTP 时间服务器

windows本身是可以作为NTP时间同步服务器使用的&#xff0c;本文介绍一下如何在win10上配置NTP时间同步服务器。 如何在windows10 搭建 NTP 时间服务器 工具/原料 系统版本&#xff1a;win10版本 [10.0.17134.706] 方法/步骤 使用组合键WIN R 启动运行窗口&#xff0c;在…

Linux服务器NTP客户端时钟同步配置方法

前提说明&#xff1a;配置客户端NTP时候&#xff0c;必须要有一台时钟服务器&#xff0c;可以是服务器搭建的&#xff0c;也可以是购买的时钟设备。我这里使用临时的时钟服务器IP地址10.10.4.100 步骤如下&#xff1a; 1 首先在客户端服务器中ping一下时钟的IP地址是否网络可通…

NTP时钟服务器(PTP服务器)无法同步的排查方法

NTP时钟服务器(PTP服务器)无法同步的排查方法 NTP时钟服务器(PTP服务器)无法同步的排查方法 NTP系统是典型的C-S模型&#xff0c;一般将整个系统分为服务器&#xff0c;网络和客户端三个区域&#xff0c;因NTP时间服务器一般在出厂时已经测试&#xff0c;并设置为可使用&#…

NTP时间服务器同步时钟系统安装汇总分享

在现代科技发展的背景下&#xff0c;各种设备的时间同步变得越来越重要。同步时钟管理系统的应用可以让多个设备在时间上保持一致&#xff0c;提高工作效率和安全性&#xff0c;为各个行业的发展提供了重要的支持。 一、同步时钟系统介绍 同步时钟管理系统的应用范围非常广泛&…

关于NTP时间服务器

NTP(Network Time Protocol) 网络时间协议&#xff0c;工作在UDP的123端口上。是用来使计算机时间同步化的一种协议&#xff0c;它可以使计算机对其服务器或时钟源&#xff08;如石英钟&#xff0c;GPS等等)做同步化&#xff0c;它可以提供高精准度的时间校正&#xff08;局域网…

R语言产生对角阵、次对角阵等矩阵及矩阵运算

R语言产生各种类型的矩阵及矩阵运算 R语言产生一般的矩阵R语言产生单位阵R语言产生次对角阵R语言矩阵的常见运算 R语言产生一般的矩阵 # 依行排列&#xff0c;产生3行5列的矩阵 A matrix(c(1:15),3,5,byrowT)R语言产生单位阵 #产生对角线元素为1的6x6的单位阵 A diag(6) #产…

python课程设计矩阵对角线之和_python对角矩阵

广告关闭 腾讯云11.11云上盛惠 &#xff0c;精选热门产品助力上云&#xff0c;云服务器首年88元起&#xff0c;买的越多返的越多&#xff0c;最高返5000元&#xff01; #生成一个3*3的0-10之间的随机整数矩阵&#xff0c;如果需要指定下界则可以多加一个参数data5mat(random.…

SimpleMind Pro(电脑版思维导图软件)官方中文版V1.30.0.6068下载 | 电脑版思维导图软件哪个好用?

​ Simplemind Pro 是一款优秀的跨平台电脑版思维导图软件领导者&#xff0c;全球超过1000万用户&#xff0c;可帮助用户组织想法、记住信息并产生新想法&#xff0c;允许用户将主题放置在自由格式布局中的任何位置&#xff0c;或者使用各种自动布局之一&#xff0c;非常…

免费的思维导图软件都有哪些?

思维导图时当下非常热门的软件&#xff0c;学生可以用它来梳理课程知识、帮助巩固记忆&#xff1b;职场打工人可以用它来整理思路、列举待办清单、展示方案等等。但是&#xff0c;现在大部分思维导图软件都需要收费&#xff0c;作为钱包紧紧的新时代人类&#xff0c;还剩下哪些…