PyTorch笔记 - MAE(Masked Autoencoders) PyTorch源码

article/2025/11/10 20:34:34

欢迎关注我的CSDN:https://blog.csdn.net/caroline_wendy
本文地址:https://blog.csdn.net/caroline_wendy/article/details/128382935

Paper:MAE - Masked Autoencoders Are Scalable Vision Learners

  • 掩码的自编码器是可扩展的视觉学习器

  • Kaiming He,FAIR

Code:https://github.com/facebookresearch/mae

MAE结构:

image-20221219114816384

ViT的不同类型:An Image is worth 16X16 words: Transformers for Image recognition at scale

  • 一张图片相当于 16X16 个字: 用于大规模图像识别的Transformers,Google Research

image-20221219113705397

源码框架

data preprocesss:

  1. Image2tensor: 读取图像
    1. RGB 3channels
    2. PIL.Image.open + convert("RGB"), or torchvision.datasets.ImageFolder
      1. ImageFolder,需要固定的文件夹格式
    3. shape: (C, H, W), dtype: uint8
      1. unsigned integer 8bit, binary digit
      2. 挑选颜色,#000000(黑色) / #FFFFFF(白色),16进制,256=16x16
      3. min-max归一化,正态分布归一化
  2. augment: Crop/Resize/Flip,500x500,随机截取100x100 -> 224x224
  3. convert:
    1. torchvision.transforms.ToPILImage
    2. torchvision.transforms.PILToTensor,uint8转换为0~1之间的浮点数
    3. [0, 1]
  4. normalize:
    1. image=(image-mean)/std, global-level,均值为0,方差为1,mv归一化
    2. imagenet1k: mean [0.485,0.456,0.406], std [0.229, 0.224, 0.225]

model:

  1. encoder: 特征提取/分类任务,标准的ViT
    1. Image2patch2embedding
    2. position embedding,正余弦固定的embedding
    3. random masking (shuffle),随机mask,划分之后随机打散,取前面25%
    4. class token,添加class token
    5. Transformer Blocks (ViT-base / ViT-large / ViT-huge),MLP size是Hidden size的4倍,layers=blocks
  2. decoder:
    1. projection_layer
    2. unshuffle,还原到mask当中
    3. position embedding,解码器与编码器不同
    4. Transformer Blocks (shallow),轻量级的ViT
    5. regression layer,回归层MLP,把每个patch投影到像素空间
    6. mse loss function (norm pixel),重建相同图像
  3. forward functions
    1. forward encoder,编码器需要单细使用
    2. forward decoder,接收encoder的输入
    3. forward loss,为整体前向推理服务

traning:

  1. dataset,构建数据集,类似于torchvision.datasets.ImageFolder,xy元组的生成器
  2. data_loader,取sample拼成mini_batch
  3. model,模型
  4. optimizer,优化器
  5. load_model,加载模型
    1. model.state_dict()
    2. optimizer.state_dict(),parameter的辅助变量
    3. epoch,影响学习率
  6. train_one_epoch,核心的训练函数
  7. save_model
    1. model.state_dict()
    2. optimizer.state_dict()
    3. epoch/loss
    4. config

finetuning: 只需要用到编码器encoder,不需要使用解码器decoder

  1. strong augmentation,微调需要增加强增强
  2. bulder encoder + BN + MLP classifier head,注意额外增加BN
  3. interpolate(差值) position embedding,预训练和实际的patch数量不同
  4. load pre-trained model (strict=False),A模型加载B模型的参数,只要有相同的层,就会加载,不同的层会提示,使用随机。
  5. update all parameters
  6. AdamW optimizer
  7. label smoothing cross-entropy loss

linear probing:

  1. weak augmentation,增强较弱
  2. bulder encoder + BN(no affine仿射,去掉w和b) + MLP classifier head
  3. interpolate(差值) position embedding
  4. load pre-trained model (strict=False)
  5. only update parameter of MLP classifier head
  6. LARS(Layer-wise Adaptive Rate Scaling) optimizer
    • 各个层更新参数所使用的学习率,根据当前情况有所调整,而不是所有层使用相同的学习率,也就是每层有自己的local lr。
  7. cross-entropy loss

evaluation:

  1. with torch.no_grad() -> efficient
  2. model.eval() -> accurate BN/dropout,BN和Dropout都起到正确作用
  3. top_k: top_1 or top_5

源码:

main文件、engine文件、models文件

  • models_mae.py

  • models_vit.py

models_mae.py

models_mae.py: 默认使用ViT-large,hidden_size=1024mlp_ratio扩大4倍

  • img_size: 图像尺寸
  • patch_size: patch尺寸
  • self.cls_token: 可训练的类别token
  • self.pos_embed: 固定的位置编码,不可训练,requires_grad=False,正余弦编码,cls_token也需要1个位置编码
  • 编码器encoder的Transformer Block,即self.blocks
  • 编码器的embedding映射到解码器的embedding
  • self.mask_token表示被遮挡的patch
  • 解码器decoder也需要使用cls_token
  • 回归的像素值,patch_size**2 * in_chans,图像面积乘以通道
  • 初始化pos_embeddecoder_pos_embed,数据相同。
  • self.cls_tokenself.mask_token,初始化整体分布
  • def initialize_weights(self),初始化所以参数
  • patchify图像变成块、unpatchify块变成图像
  • def forward_encoder(self, x, mask_ratio):,前向处理encoder
  • def forward_decoder(self, x, ids_restore):,前向处理decoder
  • x, mask, ids_restore = self.random_masking(x, mask_ratio),随机掩码操作
  • 随机噪声矩阵:noise = torch.rand(N, L, device=x.device) # noise in [0, 1],噪声只是为了排序
  • ids_keep = ids_shuffle[:, :len_keep],序列长度需要保留的索引
  • mask = torch.gather(mask, dim=1, index=ids_restore),获取被掩码的位置
  • 位置编码:x = x + self.pos_embed[:, 1:, :]cls_token = self.cls_token + self.pos_embed[:, :1, :]
  • 使用重复的mask_tokenmask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
  • pos_embeddecoder_pos_embed,编码不同,维度不同,长度也不同。
  • L2 Loss: loss = (pred - target) ** 2,支持归一化target,target = (target - mean) / (var + 1.e-6)**.5
  • 只计算mask部分的loss,loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches

支持3种模式:

# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks

获取random_mask的ids:

import torchx = torch.rand(5)
print(f"[Info] x: {x}")
idx_shuffle = torch.argsort(x)
print(f"[Info] idx_shuffle: {idx_shuffle}")
idx_restore = torch.argsort(idx_shuffle)
print(f"[Info] idx_restore: {idx_restore}")

main_pretrain.py

从命令行获取参数:

args = get_args_parser()
args = args.parse_args()
main(args)

固定随机种子:

seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

数据增强:

transform_train = transforms.Compose([transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubictransforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
print(dataset_train)

DataLoader:

data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train,batch_size=args.batch_size,num_workers=args.num_workers,pin_memory=args.pin_mem,drop_last=True,
)

逻辑:misc = miscellaneous混杂的

  • misc.init_distributed_mode(args),分布式模式的初始化
  • torch.distributed.barrier(),放在最后一行,等待所有进程初始化完成
  • device = torch.device(args.device),初始化设备
  • 有效的batch_size,eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size(),累计更新
  • 包裹DDP: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
  • 传入参数组:param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
  • 训练单轮:train_one_epoch
  • model.train(True)accum_iter累计更新步骤
  • with torch.cuda.amp.autocast():,支持自动混合精度的计算
  • loss_scaler,loss的更新逻辑

更新日志:

if args.output_dir and misc.is_main_process():if log_writer is not None:log_writer.flush()with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:f.write(json.dumps(log_stats) + "\n")

注释掉:model_mae.pyqk_scale参数,准备训练集,类似ImageFolder的组织形式,即可训练。

main_finetune.py

逻辑:

  • 构建训练集和测试集:dataset_train = build_dataset(is_train=True, args=args)dataset_val = build_dataset(is_train=False, args=args)
  • 使用build_transform,构建transform,训练(is_train=True)较多,验证(is_train=False)较少。
  • DataLoader: data_loader_traindata_loader_val
  • mixup_fn,用于CV的数据增强技巧。
  • models_vit.py中定制的VisionTransformer,参数格式一样,输出略有不同。
  • msg = model.load_state_dict(checkpoint_model, strict=False),导入预训练的模型。
  • 确保新增参数没有导入:assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
  • interpolate_pos_embed(model, checkpoint_model),位置编码进行差值。
  • Loss - MixUp: SoftTargetCrossEntropy(), smoothing: LabelSmoothingCrossEntropy(), plain: CrossEntropyLoss()
  • 加载模型: misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)

load_model(加载模型)逻辑:

def load_model(args, model_without_ddp, optimizer, loss_scaler):if args.resume:if args.resume.startswith('https'):checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True)else:checkpoint = torch.load(args.resume, map_location='cpu')model_without_ddp.load_state_dict(checkpoint['model'])print("Resume checkpoint %s" % args.resume)if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):optimizer.load_state_dict(checkpoint['optimizer'])args.start_epoch = checkpoint['epoch'] + 1if 'scaler' in checkpoint:loss_scaler.load_state_dict(checkpoint['scaler'])print("With optim & sched!")

main_linprobe.py的不同:把编码器固定,只训练分类层,同时数据增强较弱。

for _, p in model.named_parameters():p.requires_grad = False
for _, p in model.head.named_parameters():p.requires_grad = True

That’s all.


http://chatgpt.dhexx.cn/article/Ekob53km.shtml

相关文章

何凯明新作MAE 学习笔记

【MAE与之前AI和CV领域最新工作的关系】 学习MAE视频【李沐】 He, K., Chen, X., Xie, S., Li, Y., Dollr, P., & Girshick, R. (2021). Masked autoencoders are scalable vision learners. arXiv preprint arXiv:2111.06377. 【Transformer】 Transforme纯注意力&…

MAE 代码实战详解

MAE 代码实战详解 if__name__"__main__"model.forwardmodel.forward.encordermodel.forward.decordermodel.forward.loss大小排序索引-有点神奇torch.gather if__name__“main” MAE 模型选择 def mae_vit_base_patch16_dec512d8b(**kwargs):model MaskedAutoenco…

MAE(Masked Autoencoders) 详解

MAE详解 0. 引言1. 网络结构1.1 Mask 策略1.2 Encoder1.3 Decoder 2. 关键问题解答2.1 进行分类任务怎么来做?2.2 非对称的编码器和解码器机制的介绍2.3 损失函数是怎么计算的?2.4 bert把mask放在编码端,为什么MAE加在解码端? 3. …

MAE-DET学习笔记

MAE-DET学习笔记 MAE-DET: Revisiting Maximum Entropy Principle in Zero-Shot NAS for Efficient Object Detection Abstract 在对象检测中,检测主干消耗了整个推理成本的一半以上。最近的研究试图通过借助神经架构搜索(NAS)优化主干架构…

MAE论文解读

文章目录 创新点算法原理MaskingMAE encoderMAE decoder重构目标 实验Baseline: ViT-Large.消融实验Mask token自监督方法比较迁移至目标检测任务及语义分割任务 结论 论文: 《Masked Autoencoders Are Scalable Vision Learners》 代码: https://github.com/facebookresearc…

MSE与MAE

均方误差 均方误差(MSE)是最常用的回归损失函数,计算方法是求预测值与真实值之间距离的平方和,公式如图。 下图是MSE函数的图像,其中目标值是100,预测值的范围从-10000到10000,Y轴代表的MSE取值范围是从0到正无穷&…

论文阅读|MAE

Masked Autoencoders Are Scalable Vision Learners 参考资料 Self-Supervised Learning 超详细解读 (六):MAE:通向 CV 大模型 - 知乎 (zhihu.com) Self-Supervised Learning 超详细解读 (目录) - 知乎 (zhihu.com)、 1. 有监督(Supervise…

MAE论文笔记

MAE论文笔记 Masked Autoencoders Are Scalable Vision Learners MAE模型和其他的结构的关系,可以认为是在ViT的基础上实现类似于BERT的通过完型填空获取图片的理解 标题和作者 Masked Autoencoders Are Scalable Vision Learners 其中的Autoencoders 中的auto是…

MAE

背景 作者开门见山说明了深度学习结构拥有越来越大的学习容量和性能的发展趋势,在一百万的图像数据上都很容易过拟合,所以常常需要获取几百万的标签数据用于训练,而这些数据公众通常是难以获取的。MAE的灵感来源是DAE(denosing autoencoder)…

RMSE(均方根误差)、MSE(均方误差)、MAE(平均绝对误差)、SD(标准差)

RMSE(Root Mean Square Error)均方根误差 衡量观测值与真实值之间的偏差。 常用来作为机器学习模型预测结果衡量的标准。 MSE(Mean Square Error)均方误差 MSE是真实值与预测值的差值的平方然后求和平均。 通过平方的形式便于…

【深度学习】详解 MAE

目录 摘要 一、引言 二、相关工作 三、方法 四、ImageNet 实验 4.1 主要属性 4.2 与先前结果的对比 4.3 部分微调 五、迁移学习实验 六、讨论与结论 七、核心代码 Title:Masked Autoencoders Are Scalable Vision LearnersPaper:https://arx…

MAE模型介绍

目录 介绍 模型 ​编辑 实验过程 结论 介绍 Masked Autoencoders Are Scalable Vision Learners Facebook Al的kaiming大神等人于2021年十一月提出了一种带自编码器(MAE),它基于(ViT)架构。他们的方法在imageNet上的表现要好于从零开始训练的VIT。 灵感来源&…

深度学习:MAE 和 RMSE 详解

平均绝对误差MAE(mean absolute error) 和均方根误差 RMSE(root mean squared error)是衡量变量精度的两个最常用的指标,同时也是机器学习中评价模型的两把重要标尺。 那两者之间的差异在哪里?它对我们的生活有什么启示…

RMSE、MAE等误差指标整理

1 MAE Mean Absolute Error ,平均绝对误差是绝对误差的平均值 for x, y in data_iter:ymodel(x)d np.abs(y - y_pred)mae d.tolist()#maesigma(|pred(x)-y|)/m MAE np.array(mae).mean() MAE/RMSE需要结合真实值的量纲才能判断差异。 下图是指,假如g…

MAE详解

目录 一、介绍 二、网络结构 1. encoder 2. decoder 3. LOSS 三、实验 全文参考:论文阅读笔记:Masked Autoencoders Are Scalable Vision Learners_塔_Tass的博客-CSDN博客 masked autoencoders(MAE)是hekaiming大佬又一新作,其做法很…

crontab用法详解

crontab命令用于设置周期性被执行的命令,适用于日志备份,清理缓存,健康状态检测等场合。 crontab的配置文件:/etc/crontab

linux的crontab用法与实例

linux的crontab用法与实例 crontab的适用场景 在Linux系统的实际使用中,可能会经常让系统在某个特定时间执行某些任务的情况,比如定时采集服务器的状态信息、负载状况;定时执行某些任务/脚本来对远端进行数据采集或者备份等操作。 首先通过…

定时任务 crontab 命令安装和用法整理

Crontab 概念 crontab命令常见于Unix和类Unix的操作系统之中,用于设置周期性被执行的指令,类似于闹钟,可以定时执行任务。该命令从标准输入设备读取指令,并将其存放于“crontab”文件中(是“cron table”的简写&#…

crontab用法与实例

crontab用法与实例 本文基于 ubuntu 18.04 在Linux系统的实际使用中,可能会经常碰到让系统在某个特定时间执行某些任务的情况,比如定时采集服务器的状态信息、负载状况;定时执行某些任务/脚本来对远端进行数据采集等。这里将介绍下crontab的配…

crontab的基本用法

1、 crontab -l 查看所有的定时任务 2、 crontab -e 编辑定时任务。 i 进入编辑模式 。esc退出编辑模式。:wq! 保存并退出。 报错信息: “/tmp/crontab.4qE940”:1: bad month errors in crontab file, can’t install. 说明定时任务编辑失败,文件中有错…