上一篇我们介绍了迁移学习的核心思想和流程,我们介绍一个实例来加深理解。
传送门:迁移学习概述
获取预训练模型
pytorch和tensorflow都封装了很多预训练模型。
pytorch通过工具包torchvision.models模块获取,主要包括AlexNet、VGG系列、
ResNet系列、SqueezeNet和DenseNet等,通过设置参数pretrained=True即可获取。而Tensorflow内置在keras.application里面,当然,也可以通过TensorFlowHub网站自行下载。
from tensorflow.keras.applications import vgg16,resnet
from torchvision.models import AlexNet,VGG,ResNet
from torchvision.models import SqueezeNet,DenseNet
一个实例
下面通过一个例子对迁移学习有个感性的认识。预训练模型采用retnet18网络,一共分为八大步骤。
注:代码均来源于《深入浅出Embedding》第三章
1.导入模块
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import ImageFolder
from datetime import datetime
2.加载数据
加载相关数据集,首次下载需要将download设置为True,此外,还对数据做了一些预处理,标准化、图片裁剪等。
trans_train = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]
)trans_valid = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]
)trainset = torchvision.datasets.CIFAR10(root='.\data',train=True,download=True,transform=trans_train)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True,num_workers=2)testset = torchvision.datasets.CIFAR10(root='.\data',train=False,download=True,transform=trans_valid)
testloader = torch.utils.data.DataLoader(testset,batch_size=64,shuffle=False,num_workers=2)classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
下载过程如下:
注:代码直接下载比较慢,可以点击链接直接手动下载,再导入相关路径,再次运行代码download设置为False即可
3.下载预训练模型
net = models.resnet18(pretrained=True)
这一步也需要时间,耐心等待.....如果这一步出错,先手动下载pth模型文件,再执行下面语句,可加载模型:
pthfile = r'/workspace/resnet18-f37072fd.pth'
model = torch.load(pthfile)
net = models.resnet18(pretrained=False)
net.load_state_dict(model)
4.冻结模型参数
将模型参数冻结
for param in net.parameters():param.requires_grad = False
5.修改输出类别器
将原来输出的1000类改为只有10类,做以下操作:
device = torch.device("cuda:1" if torch.cuda.is_avaliable() else "cpu")
net.fc = nn.Linear(512,10)
6.查看冻结前后参数情况
toatl_params = sum(p.numel() for p in net.parameters())
print('原参数个数:{}'.format(toatl_params))
toatl_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('可训练参数个数:{}'.format(toatl_trainable_params))
7.定义损失函数及优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.fc.parameters(),lr=1e-3,weight_decay=1e-3,momentum=0.9)
还有评估指标和训练函数
#定义评估指标
def get_acc(output, label):total = output.shape[0]_, pred_label = output.max(1)num_correct = (pred_label == label).sum().item()return num_correct / total
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):prev_time = datetime.now()for epoch in range(num_epochs):train_loss = 0train_acc = 0net = net.train()for im, label in train_data:im = im.to(device) # (bs, 3, h, w)label = label.to(device) # (bs, h, w)# forwardoutput = net(im)loss = criterion(output, label)# backwardoptimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += get_acc(output, label)cur_time = datetime.now()h, remainder = divmod((cur_time - prev_time).seconds, 3600)m, s = divmod(remainder, 60)time_str = "Time %02d:%02d:%02d" % (h, m, s)if valid_data is not None:valid_loss = 0valid_acc = 0net = net.eval()for im, label in valid_data:im = im.to(device) # (bs, 3, h, w)label = label.to(device) # (bs, h, w)output = net(im)loss = criterion(output, label)valid_loss += loss.item()valid_acc += get_acc(output, label)epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "% (epoch, train_loss / len(train_data),train_acc / len(train_data), valid_loss / len(valid_data),valid_acc / len(valid_data)))else:epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %(epoch, train_loss / len(train_data),train_acc / len(train_data)))prev_time = cur_timeprint(epoch_str + time_str)
8.训练及验证模型
最后,进行模型训练即可
net=net.to(device)
train(net,trainloader,testloader,20,optimizer,criterion)
参考资料:
《深入浅出Embedding》
https://www.ptorch.com/docs/1/models