MAE 代码实战详解

article/2025/11/10 21:20:12

MAE 代码实战详解

  • if__name__=="__main__"
    • model.forward
      • model.forward.encorder
      • model.forward.decorder
      • model.forward.loss
            • 大小排序索引-有点神奇
            • torch.gather

if__name__==“main

  • MAE 模型选择
def mae_vit_base_patch16_dec512d8b(**kwargs):model = MaskedAutoencoderViT(patch_size=16, embed_dim=768, depth=12, num_heads=12,decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)return model
  • debug 调试
if__name__=="__main__":model = mae_vit_base_patch16_dec512d8b()input = torch.rand(1,3,224,224)output = model(input) # debug

model.forward

    def forward(self, imgs, mask_ratio=0.75):latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]loss = self.forward_loss(imgs, pred, mask)return loss, pred, mask

model.forward.encorder

latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)

  • x = self.patch_embed(x)
    PatchEmbed理解

    x.shape:[B,C,H,W]->[B,H*W,C]

	PatchEmbed((proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))(norm): Identity())
    def forward(self, x):B, C, H, W = x.shape_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")x = self.proj(x)#Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))if self.flatten:x = x.flatten(2).transpose(1, 2)  # BCHW -> B H*W Cx = self.norm(x)#self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()return x

LayerNorm与BatchNorm区别

pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):"""grid_size: int of the grid height and widthreturn:pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)"""grid_h = np.arange(grid_size, dtype=np.float32)grid_w = np.arange(grid_size, dtype=np.float32)grid = np.meshgrid(grid_w, grid_h)  # here w goes first    #X, Y = np.meshgrid(x, y) 代表的是将x中每一个数据和y中每一个数据组合生成很多点,然后将这些点的x坐标放入到X中,y坐标放入Y中,并且相应位置是对应的      x中的元素先移动,(x1,y1),(x2,y1)  ...  (x1,y2),(x2,y2)... grid = np.stack(grid, axis=0)grid = grid.reshape([2, 1, grid_size, grid_size])pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)if cls_token:pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)return pos_embed

np.meshgrid
no.stack 填充

def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):assert embed_dim % 2 == 0# use half of dimensions to encode grid_hemb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):"""embed_dim: output dimension for each positionpos: a list of positions to be encoded: size (M,)out: (M, D)"""assert embed_dim % 2 == 0omega = np.arange(embed_dim // 2, dtype=np.float)omega /= embed_dim / 2.omega = 1. / 10000**omega  # (D/2,)pos = pos.reshape(-1)  # (M,)out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer productemb_sin = np.sin(out) # (M, D/2)emb_cos = np.cos(out) # (M, D/2)emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)return emb

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
Transformer学习笔记一:Positional Encoding(位置编码)
如何理解和使用NumPy.einsum?

model.forward.decorder

model.forward.loss

大小排序索引-有点神奇
        # sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove       ’’’只在sequence length 维度进行排序,torch.argsort返回排序后的值所对应原a的下标,即torch.sort()返回的indices’’’ids_restore = torch.argsort(ids_shuffle, dim=1)  # 之前从小到大的数的索引
torch.gather
torch.gather(input, dim, index, out=None) → TensorGathers values along an axis specified by dim.For a 3-D tensor the output is specified by:out[i][j][k] = input[index[i][j][k]][j][k] # dim=0out[i][j][k] = input[i][index[i][j][k]][k] # dim=1out[i][j][k] = input[i][j][index[i][j][k]] # dim=2Parameters: input (Tensor) – The source tensordim (int) – The axis along which to indexindex (LongTensor) – The indices of elements to gatherout (Tensor, optional) – Destination tensorExample:>>> t = torch.Tensor([[1,2],[3,4]])>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))1 14 3[torch.FloatTensor of size 2x2]
 For a 2-D tensor the output is specified by:out[i][j] = input[    index[i][j]   ][j] # dim=0out[i][j] = input[i][    index[i][j][k]   ][k] # dim=1

Example:

 >>> t = torch.Tensor([[1,2],[3,4]])>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))1 14 3
output[i][j]   =   input[i][   index[i][j]   ]#行对应
 >>> t = torch.Tensor([[1,2],[3,4]])>>> torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]]))1 23 2output[i][j]   =   input[   index[i][j]   ][j]#列对应

在这里插入图片描述
参考1


http://chatgpt.dhexx.cn/article/9ZdtdPqr.shtml

相关文章

MAE(Masked Autoencoders) 详解

MAE详解 0. 引言1. 网络结构1.1 Mask 策略1.2 Encoder1.3 Decoder 2. 关键问题解答2.1 进行分类任务怎么来做?2.2 非对称的编码器和解码器机制的介绍2.3 损失函数是怎么计算的?2.4 bert把mask放在编码端,为什么MAE加在解码端? 3. …

MAE-DET学习笔记

MAE-DET学习笔记 MAE-DET: Revisiting Maximum Entropy Principle in Zero-Shot NAS for Efficient Object Detection Abstract 在对象检测中,检测主干消耗了整个推理成本的一半以上。最近的研究试图通过借助神经架构搜索(NAS)优化主干架构…

MAE论文解读

文章目录 创新点算法原理MaskingMAE encoderMAE decoder重构目标 实验Baseline: ViT-Large.消融实验Mask token自监督方法比较迁移至目标检测任务及语义分割任务 结论 论文: 《Masked Autoencoders Are Scalable Vision Learners》 代码: https://github.com/facebookresearc…

MSE与MAE

均方误差 均方误差(MSE)是最常用的回归损失函数,计算方法是求预测值与真实值之间距离的平方和,公式如图。 下图是MSE函数的图像,其中目标值是100,预测值的范围从-10000到10000,Y轴代表的MSE取值范围是从0到正无穷&…

论文阅读|MAE

Masked Autoencoders Are Scalable Vision Learners 参考资料 Self-Supervised Learning 超详细解读 (六):MAE:通向 CV 大模型 - 知乎 (zhihu.com) Self-Supervised Learning 超详细解读 (目录) - 知乎 (zhihu.com)、 1. 有监督(Supervise…

MAE论文笔记

MAE论文笔记 Masked Autoencoders Are Scalable Vision Learners MAE模型和其他的结构的关系,可以认为是在ViT的基础上实现类似于BERT的通过完型填空获取图片的理解 标题和作者 Masked Autoencoders Are Scalable Vision Learners 其中的Autoencoders 中的auto是…

MAE

背景 作者开门见山说明了深度学习结构拥有越来越大的学习容量和性能的发展趋势,在一百万的图像数据上都很容易过拟合,所以常常需要获取几百万的标签数据用于训练,而这些数据公众通常是难以获取的。MAE的灵感来源是DAE(denosing autoencoder)…

RMSE(均方根误差)、MSE(均方误差)、MAE(平均绝对误差)、SD(标准差)

RMSE(Root Mean Square Error)均方根误差 衡量观测值与真实值之间的偏差。 常用来作为机器学习模型预测结果衡量的标准。 MSE(Mean Square Error)均方误差 MSE是真实值与预测值的差值的平方然后求和平均。 通过平方的形式便于…

【深度学习】详解 MAE

目录 摘要 一、引言 二、相关工作 三、方法 四、ImageNet 实验 4.1 主要属性 4.2 与先前结果的对比 4.3 部分微调 五、迁移学习实验 六、讨论与结论 七、核心代码 Title:Masked Autoencoders Are Scalable Vision LearnersPaper:https://arx…

MAE模型介绍

目录 介绍 模型 ​编辑 实验过程 结论 介绍 Masked Autoencoders Are Scalable Vision Learners Facebook Al的kaiming大神等人于2021年十一月提出了一种带自编码器(MAE),它基于(ViT)架构。他们的方法在imageNet上的表现要好于从零开始训练的VIT。 灵感来源&…

深度学习:MAE 和 RMSE 详解

平均绝对误差MAE(mean absolute error) 和均方根误差 RMSE(root mean squared error)是衡量变量精度的两个最常用的指标,同时也是机器学习中评价模型的两把重要标尺。 那两者之间的差异在哪里?它对我们的生活有什么启示…

RMSE、MAE等误差指标整理

1 MAE Mean Absolute Error ,平均绝对误差是绝对误差的平均值 for x, y in data_iter:ymodel(x)d np.abs(y - y_pred)mae d.tolist()#maesigma(|pred(x)-y|)/m MAE np.array(mae).mean() MAE/RMSE需要结合真实值的量纲才能判断差异。 下图是指,假如g…

MAE详解

目录 一、介绍 二、网络结构 1. encoder 2. decoder 3. LOSS 三、实验 全文参考:论文阅读笔记:Masked Autoencoders Are Scalable Vision Learners_塔_Tass的博客-CSDN博客 masked autoencoders(MAE)是hekaiming大佬又一新作,其做法很…

crontab用法详解

crontab命令用于设置周期性被执行的命令,适用于日志备份,清理缓存,健康状态检测等场合。 crontab的配置文件:/etc/crontab

linux的crontab用法与实例

linux的crontab用法与实例 crontab的适用场景 在Linux系统的实际使用中,可能会经常让系统在某个特定时间执行某些任务的情况,比如定时采集服务器的状态信息、负载状况;定时执行某些任务/脚本来对远端进行数据采集或者备份等操作。 首先通过…

定时任务 crontab 命令安装和用法整理

Crontab 概念 crontab命令常见于Unix和类Unix的操作系统之中,用于设置周期性被执行的指令,类似于闹钟,可以定时执行任务。该命令从标准输入设备读取指令,并将其存放于“crontab”文件中(是“cron table”的简写&#…

crontab用法与实例

crontab用法与实例 本文基于 ubuntu 18.04 在Linux系统的实际使用中,可能会经常碰到让系统在某个特定时间执行某些任务的情况,比如定时采集服务器的状态信息、负载状况;定时执行某些任务/脚本来对远端进行数据采集等。这里将介绍下crontab的配…

crontab的基本用法

1、 crontab -l 查看所有的定时任务 2、 crontab -e 编辑定时任务。 i 进入编辑模式 。esc退出编辑模式。:wq! 保存并退出。 报错信息: “/tmp/crontab.4qE940”:1: bad month errors in crontab file, can’t install. 说明定时任务编辑失败,文件中有错…

linux中crontab的用法

一:crontab 简介 crontab是linux下用来周期性的执行某种任务或等待处理某些事件的一个守护进程,与windows下的计划任务类似,当安装完成操作系统后,默认会安装此服务工具,并且会自动启动crond进程,crond进程…

1.4 - 操作系统 - Linux计划任务,CronTab用法详解

「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「订阅专栏」:此文章已录入专栏《网络安全入门到精通》 CronTab计划任务 一、服务二、查看计划任务三、编辑计划任务四、删除计划任务五、配置文件Linux系统使用CronTab命令来操作计划任务。…