迁移学习实例

article/2025/10/12 23:34:32

上一篇我们介绍了迁移学习的核心思想和流程,我们介绍一个实例来加深理解。

传送门:迁移学习概述

获取预训练模型

pytorch和tensorflow都封装了很多预训练模型。

pytorch通过工具包torchvision.models模块获取,主要包括AlexNet、VGG系列、

ResNet系列、SqueezeNet和DenseNet等,通过设置参数pretrained=True即可获取。而Tensorflow内置在keras.application里面,当然,也可以通过TensorFlowHub网站自行下载。

from tensorflow.keras.applications import vgg16,resnet
from torchvision.models import AlexNet,VGG,ResNet
from torchvision.models import SqueezeNet,DenseNet

一个实例

下面通过一个例子对迁移学习有个感性的认识。预训练模型采用retnet18网络,一共分为八大步骤。

注:代码均来源于《深入浅出Embedding》第三章

1.导入模块

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import ImageFolder
from datetime import datetime

2.加载数据

加载相关数据集,首次下载需要将download设置为True,此外,还对数据做了一些预处理,标准化、图片裁剪等。

trans_train = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]
)trans_valid = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]
)trainset = torchvision.datasets.CIFAR10(root='.\data',train=True,download=True,transform=trans_train)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True,num_workers=2)testset = torchvision.datasets.CIFAR10(root='.\data',train=False,download=True,transform=trans_valid)
testloader = torch.utils.data.DataLoader(testset,batch_size=64,shuffle=False,num_workers=2)classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

下载过程如下:

a5968dd19d2d731b12e868dca05f8275.png

注:代码直接下载比较慢,可以点击链接直接手动下载,再导入相关路径,再次运行代码download设置为False即可

3.下载预训练模型

net = models.resnet18(pretrained=True)

这一步也需要时间,耐心等待.....如果这一步出错,先手动下载pth模型文件,再执行下面语句,可加载模型:

pthfile = r'/workspace/resnet18-f37072fd.pth'
model = torch.load(pthfile)
net = models.resnet18(pretrained=False)
net.load_state_dict(model)

4.冻结模型参数

将模型参数冻结

for param in net.parameters():param.requires_grad = False

5.修改输出类别器

将原来输出的1000类改为只有10类,做以下操作:

device = torch.device("cuda:1" if torch.cuda.is_avaliable() else "cpu")
net.fc = nn.Linear(512,10)

6.查看冻结前后参数情况

toatl_params = sum(p.numel() for p in net.parameters())
print('原参数个数:{}'.format(toatl_params))
toatl_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('可训练参数个数:{}'.format(toatl_trainable_params))

501a88540311cea059ba831493ab38d0.png

7.定义损失函数及优化器

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.fc.parameters(),lr=1e-3,weight_decay=1e-3,momentum=0.9)

还有评估指标和训练函数

#定义评估指标
def get_acc(output, label):total = output.shape[0]_, pred_label = output.max(1)num_correct = (pred_label == label).sum().item()return num_correct / total
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):prev_time = datetime.now()for epoch in range(num_epochs):train_loss = 0train_acc = 0net = net.train()for im, label in train_data:im = im.to(device)  # (bs, 3, h, w)label = label.to(device) # (bs, h, w)# forwardoutput = net(im)loss = criterion(output, label)# backwardoptimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += get_acc(output, label)cur_time = datetime.now()h, remainder = divmod((cur_time - prev_time).seconds, 3600)m, s = divmod(remainder, 60)time_str = "Time %02d:%02d:%02d" % (h, m, s)if valid_data is not None:valid_loss = 0valid_acc = 0net = net.eval()for im, label in valid_data:im = im.to(device)  # (bs, 3, h, w)label = label.to(device) # (bs, h, w)output = net(im)loss = criterion(output, label)valid_loss += loss.item()valid_acc += get_acc(output, label)epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "% (epoch, train_loss / len(train_data),train_acc / len(train_data), valid_loss / len(valid_data),valid_acc / len(valid_data)))else:epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %(epoch, train_loss / len(train_data),train_acc / len(train_data)))prev_time = cur_timeprint(epoch_str + time_str)

8.训练及验证模型

最后,进行模型训练即可

net=net.to(device)
train(net,trainloader,testloader,20,optimizer,criterion)

74954a751a6feba5717959a5af8100e2.png

参考资料:

《深入浅出Embedding》

https://www.ptorch.com/docs/1/models


http://chatgpt.dhexx.cn/article/YpQo6dmo.shtml

相关文章

迁移学习与微调的区别

一、迁移学习: 1、从字面意义上理解是知识转移的学习方法,指一种学习方法;类比机器学习、深度学习等等概念; 2、把已训练好的模型参数迁移到新的模型来帮助新模型训练二、微调: 1、从字面意义上理解是小小的调整&…

迁移学习---举一反三

1.概念 迁移学习是指充分考虑数据、任务、或者模型的相似性,将在旧领域学习到的模型,应用到新的领域的一种学习过程。 通俗的讲就是把已经学习训练好的模型参数迁移到新的模型进行训练。考虑到大部分数据或任务是存在相关性的,所以通过迁移…

迁移学习基础

人类具有跨任务传输知识的固有能力。我们在学习一项任务的过程中获得的知识,可以用来解决相关的任务。任务相关程度越高,我们就越容易迁移或交叉利用知识。到目前为止所讨论的机器学习和深度学习算法,通常都是被设计用于单独运作的。这些算法…

学习迁移学习

学习迁移学习 一、相关背景 随着机器学习和数据挖掘不断发展,一个愈加明显的问题出现在人们面前:要想机器学习能够正常运转必须要保证训练集和测试集有相同的特征空间并且同分布。一旦分布改变,大多数模型往往要根据数据重建,这…

联邦迁移学习

本博客地址:https://security.blog.csdn.net/article/details/123573886 一、联邦学习的定义 横向联邦学习和纵向联邦学习要求所有的参与方具有相同的特征空间或样本空间,从而建立起一个有效的共享机器学习模型。然而,在更多的实际情况下&am…

迁移学习(二)

迁移学习综述(二)(学习笔记) A Comprehensive Survey on Transfer Learning 1.引言 迁移学习的目标是利用来自相关领域(称为源领域)的知识,以提高学习性能或最小化目标领域中需要的标记示例的数量。知识转移并不总是…

深度学习中的迁移学习介绍

迁移学习(Transfer Learning)的概念早在20世纪80年代就有相关的研究,这期间的研究有的称为归纳研究(inductive transfer)、知识迁移(knowledge transfer)、终身学习(life-long learning)以及累积学习(incremental learning)等。直到2009年,香港科技大学杨…

迁移学习综述

这是我根据北京邮电大学一位博士的讲解视频所归纳的笔记 视频地址:https://www.bilibili.com/video/BV1ct41167kV?spm_id_from333.337.search-card.all.click 正文 我们为什么需要迁移学习? 众所周知,AlphaGo是通过强化学习去训练&#x…

整理学习之深度迁移学习

迁移学习(Transfer Learning)通俗来讲就是学会举一反三的能力,通过运用已有的知识来学习新的知识,其核心是找到已有知识和新知识之间的相似性,通过这种相似性的迁移达到迁移学习的目的。世间万事万物皆有共性&#xff…

迁移学习简要

什么是迁移学习 迁移学习是一种机器学习方法,就是把任务为A的开发模型作为其的初始点,重新使用在任务为B的开发模型的过程中。迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务。虽然大多数机器学习的新 算法都是为了解决单个任务而设计的…

迁移学习(Transfer),面试看这些就够了!(附代码)

1. 什么是迁移学习 迁移学习(Transfer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算法都…

迁移学习

简介 好的机器学习模型需要大量数据和许多GPU或TPU进行训练。大多数时候,他们只能执行特定的任务。 大学和大公司有时会发布他们的模型。但很可能你希望开发一个机器学习应用程序,但没有适合你的任务的可用模型。 但别担心,你不必收集大量数据…

迁移学习(Transfer)

1. 什么是迁移学习 迁移学习(Transfer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算法都…

2021-11-29 拿到第一个badger

Coursera课程的前三节 What is data science Tool of data science Methodology of data science 感觉更像是阅读理解 今天拿到了第一个badger 准备花一周把剩下俩阅读理解拿了,开始搞实战部分 我走得很慢,但是从来不停下

Postgersql神器之pgbadger安装配置

文章目录 1.介绍2.安装pgbager3.参数调整3.1相关参数内容3.2 重启db4.配置apache5.安装libtext-csv-perl,6.手动产生报告 7.排程自动产生分析报告7.2脚本授权:7.3设定crontab7.4检视pgbadger日志分析报告 1.介绍 pgbadger是postgresql 三大神器之一:pg_…

BoltDB,RocksDB,PebbleDB,BadgerDb简介及测评

几个常用数据库性能分析 ​ 最近公司需要选型一款单机KV数据库来做业务承载,所以我对比了目前市面上比较流行的几个KV数据库并记录下来,包括boltdb,rocksdb,pebbledb,badgerdb四款,我将简单分析一下各数据库的特点,最后用自己的简…

智能优化算法-蜜獾算法Honey Badger Algorithm(附Matlab代码)

引言 提出了一种新的元启发式优化算法——蜜獾算法(Honey Badger Algorithm,HBA)。该算法受蜜獾智能觅食行为的启发,从数学上发展出一种求解优化问题的高效搜索策略。蜜獾挖掘和采蜜的动态搜索行为。于2021年发表在Mathematics and Computers in Simula…

惠普ipaq蓝牙键盘配对码_将旧的Compaq Ipaq从Ubuntu 5.10 Breezy Badger升级到8.10 Intrepid Ibex...

惠普ipaq蓝牙键盘配对码 Ive got an old Compaq Ipaq that I gave my then 10 year old niece to play with. I put Ubuntu 5.10 "Breezy Badger" on it. Fast forward to today, and she wants it refreshed. It hasnt been on the Internet for 3 years so it does…

BadgerDB 原理及分布式数据库的中应用与优化

Part 1 - BadgerDB 设计架构 Badger[1] 是基于论文:WiscKey: Separating Keys from Values inSSD-conscious Storage[2] 的思想利用 Go 语言进行设计实现的。 LSM-Tree 的优势在于将随机写转换为顺序写,将大块的内存连续地写入到磁盘,减少磁…

vue手风琴组件_Vue 2的Badger手风琴组件

vue手风琴组件 Vue-Badger手风琴 (vue-badger-accordion) Badger-Accordion Component for Vue 2.0. Vue 2.0的Badge-Accordion组件。 An accessible light weight, vanilla JavaScript accordion. 轻巧的香草JavaScript手风琴。 View demo 查看演示 Download Source 下载源 …