迁移学习

article/2025/10/13 1:08:53

390a07bf2ada9a5907403646a4bdfa21.png

简介

好的机器学习模型需要大量数据和许多GPU或TPU进行训练。大多数时候,他们只能执行特定的任务。

大学和大公司有时会发布他们的模型。但很可能你希望开发一个机器学习应用程序,但没有适合你的任务的可用模型。

但别担心,你不必收集大量数据,也不必花费大量资金来开发自己的模型。你可以用迁移学习代替。这减少了训练时间,并且可以用更少的数据获得良好的性能。

什么是迁移学习?

在迁移学习中,我们使用模型在特定任务上收集的知识来解决不同但相关的任务。模型可以从上一个任务中学到的东西中获益,从而更快地学习新任务。

让我们在这里举个例子,假设你想在图像上检测狗。在互联网上,你可以找到一种可以检测猫的模型。由于这是一项非常类似的任务,你需要拍摄几张你的狗的照片,并重新训练模型以检测狗。

也许模型已经学会了通过猫的皮毛或它们有眼睛的事实来识别猫,这对识别狗也会很有帮助。

实际上有两种类型的迁移学习,特征提取和微调。

通常,这两种方法遵循相同的程序:

  • 初始化预训练的模型(我们要学习的模型)

  • 重塑最终层的形状,使其输出数量与新数据集中分类的数量相同

  • 定义要更新的层

  • 训练新数据集

特征提取

让我们考虑一个卷积神经网络结构,滤波器是一个密集层和一个输出神经元。

f283746064b598b965b53e1b4e8465e2.png

该网络经过训练,可以预测图像上有猫的概率。我们需要一个大数据集(有猫和没有猫的图像),而且训练时间很长。此步骤称为“预训练”。

bb03963feaf5605228abaaa3a2b380f3.png

然后是有趣的部分。我们再次训练网络,但这次是用一个包含狗的小图像数据集。

在训练过程中,除输出层外的所有层都被“冻结”。这意味着我们不会在训练期间更新它们。

训练后,网络输出狗在图像上可见的概率。此训练程序所需时间将少于之前的预训练。

28c9c65e1adf900ca291ad40db19b5aa.png

我们还可以选择“解冻”最后两层,即输出层和密集层。这取决于我们拥有的数据量。如果我们有更少的数据,我们可以考虑只训练最后一层。

微调

在微调中,我们从预训练的模型开始,但更新所有权重。

0e1466d2cdc8a95fdec6e2820d544c3d.png

pytorch中的迁移学习示例

将使用kaggle的猫与狗数据集。数据集可以在这里找到。你始终可以使用不同的数据集。

https://www.microsoft.com/en-us/download/details.aspx?id=54765

这里的任务与上面的示例略有不同。该模型用于识别哪些图像上有狗,哪些图像上有猫。要使代码正常工作,你必须按以下结构组织数据:

c30b2ac266a8692a8090bf2bfd9bacdd.png

你可以在这里找到更详细的猫与狗的对比。

https://medium.com/predict/using-pytorch-for-kaggles-famous-dogs-vs-cats-challenge-part-1-preprocessing-and-training-407017e1a10c

安装程序

我们首先导入所需的库。

from __future__ import print_functionfrom __future__ import divisionimport torchimport torch.nn as nnimport torch.optim as optimimport numpy as npimport torchvisionfrom torchvision import datasets, models, transformsimport matplotlib.pyplot as pltimport timeimport osimport copyprint("PyTorch Version: ",torch.__version__)  # PyTorch Version:  1.7.1print("Torchvision Version: ",torchvision.__version__) # Torchvision Version:  0.8.0a0

我们检查是否有与CUDA兼容的CPU,否则将使用该CPU。

2b1e0df0ff909d8daa12725ad8ddcb10.png

然后我们从torch vision加载预训练的ResNet50。

model_conv = torchvision.models.resnet50(pretrained=True)

数据扩充是通过对图像应用不同的变换来完成的,从而防止过拟合。

# 训练数据的增强和标准化# 只是为了验证而进行标准化data_transforms = {'train': transforms.Compose([transforms.RandomRotation(5),transforms.RandomHorizontalFlip(),transforms.RandomResizedCrop(224, scale=(0.96, 1.0), ratio=(0.95, 1.05)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),}

我们创建数据加载器,它将从内存中加载图像。

data_dir = 'data' #数据集目录image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}class_names = image_datasets['train'].classesprint(class_names) # => ['cats', 'dogs']print(f'Train image size: {dataset_sizes["train"]}')print(f'Validation image size: {dataset_sizes["val"]}')

创建学习率调度器,调度器将在训练期间修改学习率。或者,你可以使用ADAM优化器,它可以自动调整学习速率,并且不需要调度器。

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

特征提取优化器

这里只计算最后一层的梯度,因此只训练最后一层。

for param in model_conv.parameters():param.requires_grad = False# 新构造模块的参数默认为require_grad =Truenum_ftrs = model_conv.fc.in_featuresmodel_conv.fc = nn.Linear(num_ftrs, 2)model_conv = model_conv.to(device)criterion = nn.CrossEntropyLoss()# 观察到只有最后一层的参数被优化optimizer_feature_extraction = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)# 衰减因子

用于微调的优化器

在这里,将对所有层进行训练。

818f301272af42ea7c4312945991fef6.png

训练

让我们定义训练循环。

def train_model(model, criterion, optimizer, scheduler, num_epochs=2, checkpoint = None):model.train()  # 将model设置为训练模式for i, (inputs, labels) in enumerate(dataloaders['train']):inputs = inputs.to(device)labels = labels.to(device)# 传送到设备optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)#  只有在训练阶段才backward + optimizeloss.backward()optimizer.step()# 统计数据scheduler.step()return model

最后,我们可以训练我们的模型。

使用特征提取:

1f68c7afbbe1deb2321e5895b6db486a.png

或使用微调:

trained_model = train_model(model_conv, criterion, optimizer_fine_tuning, exp_lr_scheduler )

迁移学习的优点

当我向人们推荐他们可以在ML项目中使用迁移学习时,他们有时会拒绝,宁愿自己训练一个模型,也不愿使用迁移学习。但是迁移学习也有很多优点:

  • 训练神经网络使用能源,从而增加全球碳排放。迁移学习通过减少训练时间拯救了世界。

  • 当训练数据不足时,迁移学习可能是让模型表现良好的唯一选择。在计算机视觉中,常常缺少训练数据。

结论

迁移学习对于现代数据科学家来说是一个方便的工具。为了节省时间、计算机资源和减少训练所需的数据量,你可以使用其他人预训练过的模型并对其执行迁移学习。

数据集集合

猫与狗数据集

数据集可在下找到。

https://www.tensorflow.org/datasets/catalog/cats_vs_dogs

感谢阅读!

☆ END ☆

如果看到这里,说明你喜欢这篇文章,请转发、点赞。微信搜索「uncle_pn」,欢迎添加小编微信「 woshicver」,每日朋友圈更新一篇高质量博文。

扫描二维码添加小编↓

6a86168fd79db062c747d847ad8491d1.png


http://chatgpt.dhexx.cn/article/1WZ3Evdb.shtml

相关文章

迁移学习(Transfer)

1. 什么是迁移学习 迁移学习(Transfer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算法都…

2021-11-29 拿到第一个badger

Coursera课程的前三节 What is data science Tool of data science Methodology of data science 感觉更像是阅读理解 今天拿到了第一个badger 准备花一周把剩下俩阅读理解拿了,开始搞实战部分 我走得很慢,但是从来不停下

Postgersql神器之pgbadger安装配置

文章目录 1.介绍2.安装pgbager3.参数调整3.1相关参数内容3.2 重启db4.配置apache5.安装libtext-csv-perl,6.手动产生报告 7.排程自动产生分析报告7.2脚本授权:7.3设定crontab7.4检视pgbadger日志分析报告 1.介绍 pgbadger是postgresql 三大神器之一:pg_…

BoltDB,RocksDB,PebbleDB,BadgerDb简介及测评

几个常用数据库性能分析 ​ 最近公司需要选型一款单机KV数据库来做业务承载,所以我对比了目前市面上比较流行的几个KV数据库并记录下来,包括boltdb,rocksdb,pebbledb,badgerdb四款,我将简单分析一下各数据库的特点,最后用自己的简…

智能优化算法-蜜獾算法Honey Badger Algorithm(附Matlab代码)

引言 提出了一种新的元启发式优化算法——蜜獾算法(Honey Badger Algorithm,HBA)。该算法受蜜獾智能觅食行为的启发,从数学上发展出一种求解优化问题的高效搜索策略。蜜獾挖掘和采蜜的动态搜索行为。于2021年发表在Mathematics and Computers in Simula…

惠普ipaq蓝牙键盘配对码_将旧的Compaq Ipaq从Ubuntu 5.10 Breezy Badger升级到8.10 Intrepid Ibex...

惠普ipaq蓝牙键盘配对码 Ive got an old Compaq Ipaq that I gave my then 10 year old niece to play with. I put Ubuntu 5.10 "Breezy Badger" on it. Fast forward to today, and she wants it refreshed. It hasnt been on the Internet for 3 years so it does…

BadgerDB 原理及分布式数据库的中应用与优化

Part 1 - BadgerDB 设计架构 Badger[1] 是基于论文:WiscKey: Separating Keys from Values inSSD-conscious Storage[2] 的思想利用 Go 语言进行设计实现的。 LSM-Tree 的优势在于将随机写转换为顺序写,将大块的内存连续地写入到磁盘,减少磁…

vue手风琴组件_Vue 2的Badger手风琴组件

vue手风琴组件 Vue-Badger手风琴 (vue-badger-accordion) Badger-Accordion Component for Vue 2.0. Vue 2.0的Badge-Accordion组件。 An accessible light weight, vanilla JavaScript accordion. 轻巧的香草JavaScript手风琴。 View demo 查看演示 Download Source 下载源 …

防追踪创安全网络环境 EFF推Beta版“Privacy Badger”扩展

自棱镜监控丑闻曝光之后在线安全和个人隐私成为网络上热门的话题,为此电子前沿基金会(EFF)今天推出了beta测试版本“Privacy Badger”扩展程序,为Chrome和Firefox浏览器用户打造安全的在线网络环境,阻挡一系列目前网络…

小米手机应用ICON角标Badger显示

项目有个新需求,像iOS一样给应用加个未读消息数量的角标,网上查了下这个开源框架还是不少的,这里介绍一个比较好用的,git地址:https://github.com/leolin310148/ShortcutBadger,集成和实现比较容易&#xf…

Honey Badger BFT共识协议详解

阅读建议 Honey Badger BFT应用了很多前人的研究,进行了巧妙的构造和优化,初次学习往往难以理解。在阅读时可以先大致了解各个构造块的基本作用,再了解总体的共识过程。之后回过头来深入研究各个构造块的原理,特别是BA算法&#…

badger和rocksDB性能对比

结论: 从最后一个表格来看,ssd只对batch_read和batch-write操作有优势,而且在多协程的情况下,这个优势也丢失了。从第二和第三个表格来看,badger的write操作比rocksDB慢了一个数量级,而batch_write操作badg…

Honey Badger BFT(异步共识算法)笔记

最近一直在看Honey Badger BFT共识协议,看了很多博客和一些相关的论文,但是发现有些博客存在着部分理解错误的地方,或者就是直接翻译2016年的那一篇论文,在经过半个多月的细读之后,打算整理出这篇博客,方便…

Badger、Leveldb

BadgerDB v2 介绍 2017年发行 来自DGraph实验室 开源 纯go语言编写 https://github.com/dgraph-io/badger https://godoc.org/github.com/dgraph-io/badger 内存模式 (所有数据存在内存,可能丢失数据)SSD优化键值分离 Key(00000*.sst) Valu…

badger 一个高性能的LSM K/V store

大家好,给大家介绍一下, 新晋的高性能的 K/V数据库: badger。 这是 dgraph.io开发的一款基于 log structured merge (LSM) tree 的 key-value 本地数据库, 使用 Go 开发。 事实上,市面上已经有一些知名的基于LSM tree的k/v数据库…

badger框架学习 (一)

1.badger是什么? badger是一种高性能的 K/V数据库。 这是 dgraph.io开发的一款基于 log structured merge (LSM) tree 的 key-value 本地数据库, 使用 Go 开发。 2.badger有什么优势? 事实上,市面上已经有一些知名的基于LSM tre…

No Free Lunch定理

Stanford大学Wolpert和Macready教授提出了NFL定理,它是优化领域中的一个重要理论研究成果,意义较为深远。现将其结论概括如下: 定理1 假设有A、B两种任意(随机或确定)算法,对于所有问题集,它们…

少样本学习原理快速入门,并翻译《Free Lunch for Few-Shot Learning: Distribution Calibration》

ICLR2021 Oral《Free Lunch for Few-Shot Learning: Distribution Calibration》 利用一个样本估计类别数据分布 9行代码提高少样本学习泛化能力 原论文:https://openreview.net/forum?idJWOiYxMG92s 源码:https://github.com/ShuoYang-1998/ICLR202…

Android lunch分析以及产品分支构建

转自:http://blog.csdn.net/generalizesoft/article/details/7253901 Android lunch分析以及产品分支构建 一、背景 随着Android应用范围越来越广泛,用户对Android的需求也越来越趋于复杂,在开发Android应用以及底层产品驱动时&#xff0…

Free Lunch is Over (免费午餐已经结束)

原文链接:The Free Lunch Is Over: A Fundamental Turn Toward Concurrency in Software 免费的午餐结束了 软件并行计算的基本转折点 继OO之后软件发展的又一重大变革——并行计算 你的免费午餐即将即将结束。我们能做什么?我们又将做什么&#xff1f…