PaddlePaddle2.0利用ResNet101预训练模型实现蝴蝶分类

article/2025/10/6 13:46:29

PaddlePaddle2.0利用ResNet101预训练模型实现蝴蝶分类

该项目来自于AI Studio上的公开项目,这里记录我的学习笔记,把一些细节的地方具体说明一下,并且提供完整的程序代码和分步说明,供大家本地PaddlePaddle2.0环境或者AI Studio上面复现。
用到的数据集来自AI Studio中公开的数据集,但是不知道为什么找不到那个数据集项目了,所以就提供我下载好的数据集的网盘链接:
链接:https://pan.baidu.com/s/19Fqsg_rUAQi9nvf3vLhI9w
提取码:qdkx
数据中对蝴蝶种类的描述是这样的,将蝴蝶分为9个属,20个物种,也就是说在本数据集中有的属中包含几个物种的蝴蝶,有的属中就包含单一一种蝴蝶,图片命名例子如下:
在这里插入图片描述

一、数据集中图片显示

import matplotlib.pyplot as plt
import PIL.Image as Imagepath='/home/aistudio/data/Butterfly20/001.Atrophaneura_horishanus/006.jpg'
img = Image.open(path)
plt.imshow(img)          #根据数组绘制图像
plt.show()               #显示图像
print(img.size)

在这里插入图片描述

二、数据准备

分为两个阶段:
(1)建立训练集中图片和标签之间的对应关系。
(2)确立好对应关系之后对数据进行预处理,即归一化、数据增强等操作。
先看第一部分:
这一部分的功能为:训练集中的图片和species.txt中的标签对应,标签的形式为0-19,共20个数字。

#以下代码用于建立样本数据读取路径与样本标签之间的关系
import os
import randomdata_list = [] #用个列表保存每个样本的读取路径、标签#由于属种名称本身是字符串,而输入模型的是数字。需要构造一个字典,把某个数字代表该属种名称。键是属种名称,值是整数。
label_list=[]
with open("/home/aistudio/data/species.txt") as f:for line in f:a,b = line.strip("\n").split(" ")label_list.append([b, int(a)-1])
label_dic = dict(label_list)#获取Butterfly20目录下的所有子目录名称,保存进一个列表之中
class_list = os.listdir("/home/aistudio/data/Butterfly20")
class_list.remove('.DS_Store') #删掉列表中名为.DS_Store的元素,因为.DS_Store并没有样本。for each in class_list:for f in os.listdir("/home/aistudio/data/Butterfly20/"+each):data_list.append(["/home/aistudio/data/Butterfly20/"+each+'/'+f,label_dic[each]])#按文件顺序读取,可能造成很多属种图片存在序列相关,用random.shuffle方法把样本顺序彻底打乱。
random.shuffle(data_list)#打印前十个,可以看出data_list列表中的每个元素是[样本读取路径, 样本标签]。
print(data_list[0:10])#打印样本数量,一共有1866个样本。
print("样本数量是:{}".format(len(data_list)))

在这里插入图片描述
注意经过了标签乱序的操作,所以输出的10条信息并不是连续的,如果不经过乱序标签操作,结果为:
在这里插入图片描述
再看第二部分:
这一部分功能为:数据预处理,处理好后设置数据读取器,划分好训练集和测试集。

#以下代码用于构造读取器与数据预处理
#首先需要导入相关的模块
import paddle
from paddle.vision.transforms import Compose, ColorJitter, Resize,Transpose, Normalize
import cv2
import numpy as np
from PIL import Image
from paddle.io import Dataset#自定义的数据预处理函数,输入原始图像,输出处理后的图像,可以借用paddle.vision.transforms的数据处理功能
def preprocess(img):transform = Compose([Resize(size=(224, 224)), #把数据长宽像素调成224*224Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], data_format='HWC'), #标准化Transpose(), #原始数据形状维度是HWC格式,经过Transpose,转换为CHW格式])img = transform(img).astype("float32")return img#自定义数据读取器
class Reader(Dataset):def __init__(self, data, is_val=False):super().__init__()#在初始化阶段,把数据集划分训练集和测试集。由于在读取前样本已经被打乱顺序,取20%的样本作为测试集,80%的样本作为训练集。self.samples = data[-int(len(data)*0.2):] if is_val else data[:-int(len(data)*0.2)]def __getitem__(self, idx):#处理图像img_path = self.samples[idx][0] #得到某样本的路径img = Image.open(img_path)if img.mode != 'RGB':img = img.convert('RGB')img = preprocess(img) #数据预处理--这里仅包括简单数据预处理,没有用到数据增强#处理标签label = self.samples[idx][1] #得到某样本的标签label = np.array([label], dtype="int64") #把标签数据类型转成int64return img, labeldef __len__(self):#返回每个Epoch中图片数量return len(self.samples)#生成训练数据集实例
train_dataset = Reader(data_list, is_val=False)#生成测试数据集实例
eval_dataset = Reader(data_list, is_val=True)#打印一个训练样本
print(len(train_dataset)) #1866*0.8=1492.8
#print(train_dataset[1136][0])
print(train_dataset[1136][0].shape)
print(train_dataset[1136][1])

在这里插入图片描述

三、建立模型

使用预训练模型。为了提升探索速度,建议首先选用比较成熟的基础模型,看看基础模型所能够达到的准确度。之后再试试模型融合,准确度是否有提升。最后可以试试自己独创模型。为简便,这里直接采用101层的残差网络ResNet,并且采用预训练模式。为什么要采用预训练模型呢?因为通常模型参数采用随机初始化,而预训练模型参数初始值是一个比较确定的值。这个参数初始值是经历了大量任务训练而得来的,比如用CIFAR图像识别任务来训练模型,得到的参数。虽然蝴蝶识别任务和CIFAR图像识别任务是不同的,但可能存在某些机器视觉上的共性。用预训练模型可能能够较快地得到比较好的准确度。
**注意:在PaddlePaddle2.0中,使用预训练模型只需要设定模型参数pretained=True。值得注意的是,预训练模型得出的结果类别是1000维度,要用个线性变换,把类别转化为20维度。**因为在我们的蝴蝶分类任务里面,一共有20个类别。

#定义模型
class MyNet(paddle.nn.Layer):def __init__(self):super(MyNet,self).__init__()self.layer=paddle.vision.models.resnet50(pretrained=True)self.fc = paddle.nn.Linear(1000, 20)#网络的前向计算过程def forward(self,x):x=self.layer(x)x=self.fc(x)return x

四、模型训练

训练的时候主要的步骤是:定义输入数据、模型封装、定义优化器、模型准备、开启训练。

#定义输入
input_define = paddle.static.InputSpec(shape=[-1,3,224,224], dtype="float32", name="img")
label_define = paddle.static.InputSpec(shape=[-1,1], dtype="int64", name="label")#实例化网络对象并定义优化器等训练逻辑
model = MyNet()
model = paddle.Model(model,inputs=input_define,labels=label_define) #用Paddle.Model()对模型进行封装
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
#上述优化器中的学习率(learning_rate)参数很重要。要是训练过程中得到的准确率呈震荡状态,忽大忽小,可以试试进一步把学习率调低。model.prepare(optimizer=optimizer, #指定优化器loss=paddle.nn.CrossEntropyLoss(), #指定损失函数metrics=paddle.metric.Accuracy()) #指定评估方法model.fit(train_data=train_dataset,     #训练数据集eval_data=eval_dataset,         #测试数据集batch_size=64,                  #一个批次的样本数量epochs=50,                      #迭代轮次save_dir="/home/aistudio/lup", #把模型参数、优化器参数保存至自定义的文件夹save_freq=20,                    #设定每隔多少个epoch保存模型参数及优化器参数log_freq=100                     #打印日志的频率
)

五、模型预测

模型预测主要步骤是:构建数据读取器(因为比赛时候官方提供的测试数据是没有标签的,因此读取数据的时候和训练的时候用到的数据读取器不一致,需要重新定义)、模型封装(这一步骤包含上一步定义好的数据读取器,所以要重新定义模型)、读取训练好的模型参数、模型准备、开始预测。

class InferDataset(Dataset):def __init__(self, img_path=None):"""数据读取Reader(推理):param img_path: 推理单张图片"""super().__init__()if img_path:self.img_paths = [img_path]else:raise Exception("请指定需要预测对应图片路径")def __getitem__(self, index):# 获取图像路径img_path = self.img_paths[index]# 使用Pillow来读取图像数据并转成Numpy格式img = Image.open(img_path)if img.mode != 'RGB': img = img.convert('RGB') img = preprocess(img) #数据预处理--这里仅包括简单数据预处理,没有用到数据增强return imgdef __len__(self):return len(self.img_paths)#实例化推理模型
model = paddle.Model(MyNet(),inputs=input_define)#读取刚刚训练好的参数
model.load('/home/aistudio/lup/final')#准备模型
model.prepare()#得到待预测数据集中每个图像的读取路径
infer_list=[]
with open("/home/aistudio/data/testpath.txt") as file_pred:for line in file_pred:infer_list.append("/home/aistudio/data/"+line.strip())#模型预测结果通常是个数,需要获得其对应的文字标签。这里需要建立一个字典。
def get_label_dict2():label_list2=[]with open("/home/aistudio/data/species.txt") as filess:for line in filess:a,b = line.strip("\n").split(" ")label_list2.append([int(a)-1, b])label_dic2 = dict(label_list2)return label_dic2label_dict2 = get_label_dict2()
#print(label_dict2)#利用训练好的模型进行预测
results=[]
for infer_path in infer_list:infer_data = InferDataset(infer_path)result = model.predict(test_data=infer_data)[0] #关键代码,实现预测功能result = paddle.to_tensor(result)result = np.argmax(result.numpy()) #获得最大值所在的序号results.append("{}".format(label_dict2[result])) #查找该序号所对应的标签名字#把结果保存起来
with open("work/result.txt", "w") as f:for r in results:f.write("{}\n".format(r))

注意:模型预测的时候不需要训练,只是根据训练好的网络提供测试集图片的的标签即可,因此类似于模型准备中,优化器的设置、损失函数定义等与训练有关的操作是不需要的。
在这里插入图片描述

六、查看网络结构

使用Paddle高级api查看网络的结构

model = paddle.Model(MyNet())
model.summary((1, 3, 224, 224))

在这里插入图片描述


http://chatgpt.dhexx.cn/article/Na36JN0H.shtml

相关文章

PyTorch实现的ResNet50、ResNet101和ResNet152

PyTorch实现的ResNet50、ResNet101和ResNet152 PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks import torch import torch.nn as nn import torchvision import numpy as npprint("PyTorch Version: ",torch.__version__) print("Torchvisio…

Pytorch实现ResNet50网络结构,包含ResNet18,ResNet34,ResNet50,ResNet101,ResNet152

创建各版本的ResNet模型,ResNet18,ResNet34,ResNet50,ResNet101,ResNet152 原文地址: https://arxiv.org/pdf/1512.03385.pdf 论文就不解读了,大部分解读都是翻译,看的似懂非懂,自己…

mindspore-ResNet101使用GPU进行训练时报错

multiprocessing.context.TimeoutError RuntimeError: mindspore/ccsrc/backend/session/kernel_build_client.h:109 Response] Response is empty 1、修改resnet101_imagenet2012_config.yaml中的训练集路径,更改类数量以适应新数据集 2、在models/official/cv/r…

官方代码 Deeplab v3+ resnet101 做backbone

大年初一我居然在更博客。今年过年由于病毒横行,没有串门没有聚餐,整个人闲的没事干。。。医生真是不容易,忙得团团转还有生命危险,新希望他们平安。 本篇不属于初级教程。如果完全看不懂请自行谷歌或搜索作者博客。 deeplab官方…

基于pytorch+Resnet101加GPT搭建AI玩王者荣耀

本源码模型主要用了SamLynnEvans Transformer 的源码的解码部分。以及pytorch自带的预训练模型"resnet101-5d3b4d8f.pth" 本资源整理自网络,源地址:https://github.com/FengQuanLi/ResnetGPT 注意运行本代码需要注意以下几点 注意!…

resnet101网络_网络标准101

resnet101网络 让我告诉你一个故事。 一旦我为我们的设计系统构建了另一个日期选择器组件。 它由文本输入和带有日历的弹出窗口组成,单击可显示日历。 然后,可以在外部单击或选择日期来关闭弹出窗口。 外部点击逻辑的大多数实现都是通过将实际点击侦听器…

基于ResNet101实现猴痘病毒识别任务

前言 大家好,我是阿光。 本专栏整理了《PyTorch深度学习项目实战100例》,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集。 正在更新中~ ✨ 🚨 我的项目环境: 平台:Windows10语言环境:python3.7编译器:PyCharmPy…

使用ResNet101作为预训练模型训练Faster-RCNN-TensorFlow-Python3-master

使用VGG16作为预训练模型训练Faster-RCNN-TensorFlow-Python3-master的详细步骤→Windows10Faster-RCNN-TensorFlow-Python3-masterVOC2007数据集。 如果使用ResNet101作为预训练模型训练Faster-RCNN-TensorFlow-Python3-master,在之前使用VGG16作为预训练模型的训练…

TensorRT学习笔记--基于FCN-ResNet101推理引擎实现语义分割

目录 前言 1--Pytorch模型转换为Onnx模型 2--Onnx模型可视化及测试 2-1--可视化Onnx模型 2-2--测试Onnx模型 3--Onnx模型转换为Tensor RT推理模型 4--基于Tensor RT使用推理引擎实现语义分割 前言 基于Tensor RT的模型转换流程:Pytorch → Onnx → Tensor RT…

迁移学习之ResNet50和ResNet101(图像识别)

文章目录 1.实现的效果:2.主文件TransorResNet.py: 1.实现的效果: 实际的图片: (1)可以看到ResNet50预测的前三个结果中第一个结果为:whippet(小灵狗) (2)Re…

Mask-RCNN(2)Resnet101

1. 对应着图像中的CNN部分,其对输入进来的图片有尺寸要求,需要可以整除2的6次方。在进行特征提取后,利用长宽压缩了两次、三次、四次、五次的特征层来进行特征金字塔结构的构造。Mask-RCNN使用Resnet101作为主干特征提取网络 2.ResNet101有…

Pytorch-预训练网络

预训练网络 我们可以把预训练的神经网络看作一个接收输入并生成输出的程序,该程序的行为是由神经网络的结构以及它在训练过程中所看到的样本所决定的,即期望的输入-输出对,或者期望输出应该满足的特性。我们可以在Pytorch中加载和运行这些预…

基于ResNet-101深度学习网络的图像目标识别算法matlab仿真

目录 1.算法理论概述 1.1、ResNet-101的基本原理 1.2、基于深度学习框架的ResNet-101实现 1.3网络训练与测试 2.部分核心程序 3.算法运行软件版本 4.算法运行效果图预览 5.算法完整程序工程 1.算法理论概述 介绍ResNet-101的基本原理和数学模型,并解释其在图…

【深度学习】ResNet网络详解

文章目录 ResNet参考结构概况conv1与池化层残差结构Batch Normalization总结 ResNet 参考 ResNet论文: https://arxiv.org/abs/1512.03385 本文主要参考视频:https://www.bilibili.com/video/BV1T7411T7wa https://www.bilibili.com/video/BV14E411H7U…

【使用Pytorch实现ResNet网络模型:ResNet50、ResNet101和ResNet152】

使用Pytorch实现Resnet网络模型:ResNet50、ResNet101和ResNet152 介绍什么是 ResNet?ResNet 的架构使用Pytorch构建 ResNet网络 介绍 在深度学习和计算机视觉领域取得了一系列突破。尤其是随着非常深的卷积神经网络的引入,这些模型有助于在图…

使用PyTorch搭建ResNet101、ResNet152网络

ResNet18的搭建请移步:使用PyTorch搭建ResNet18网络并使用CIFAR10数据集训练测试 ResNet34的搭建请移步:使用PyTorch搭建ResNet34网络 ResNet34的搭建请移步:使用PyTorch搭建ResNet50网络 参照我的ResNet50的搭建,由于50层以上几…

Java中的数组

数组 1.什么是数组 数组就是存储相同数据类型的一组数据,且长度固定 基本数据类型4类8种:byte/char/short/int/long/float/double/boolean 数组,是由同一种数据类型按照一定的顺序排序的集合,给这个数组起一个名字。是一种数据类型&#…

java输出数组(java输出数组)

多维数组在Java里如何创建多维数组? 这从第四个例子可以看出,它向我们演示了用花括号收集多个new表达式的能力: Integer[][] a4 { { new Integer (1), new Integer (2)}, { new Integer (3), new Integer (4)}, { new Integer (5), new…

java怎么输出数组(Java怎么给数组赋值)

Java中数组输出的三种方式。第一种方式,传统的for循环方式,第二种方式,for each循环,  第三种方式,利用Array类中的toString方法. 定义一个int类型数组,用于输出 int[] array={1,2,3,4,5}; 第一种方式,传统的for循环方式 for(int i=0;i {System.out.println(a[i]); } 第…

数组的输入与输出

前言: 我们知道对一个字符数组进行输入与输出时会用到: 输入:scanf,getchar,gets 输出:printf,putchar,puts 然而可能还有很多的朋友对这些还不是很了解,今天让我们共同学习数组的输入与输出吧。 %c格式是用于输入…