基于深度学习的花卉识别

article/2025/9/12 23:45:24

1、数据集

       春天来了,我在公园的小道漫步,看着公园遍野的花朵,看起来真让人心旷神怡,一周工作带来的疲惫感顿时一扫而光。难得一个糙汉子有闲情逸致俯身欣赏这些花朵儿,然而令人尴尬的是,我一朵都也不认识。那时我在想,如果有一个花卉识别软件,可以用手机拍一下就知道这是一种什么花朵儿,那就再好不过了。我不知道市场上是否有这样一种软件,但是作为一个从事深度学习的工程师,我马上知道了怎么做,最关键的不是怎么做,而是数据采集。住所附近就是大型公园,一年司机繁花似锦,得益于此,我可以在闲暇时间里采集到大量的花卉数据。本数据集由本人亲自使用手机进行拍摄采集,原始数据集包含了27万张图片,图片的尺寸为1024x1024,为了方便储存和传输,把原图缩小为224x224。采集数据是一个漫长的过程,因此数据集的发布采用分批发布的形式,也就是每采集够16种花卉,就发布一次数据集。每种花卉的图片数量约为2000张,每次发布的数据集的图片数量约为32000张,每次发布的数据集包含的花卉种类都不一样。目前花卉的种类只有48种,分为三批发布,不过随着时间的推移,采集到的花卉越来越多。这里就把数据集分享出来,供各位人工智能算法研究者使用。以下是花卉数据集的简要介绍和下载地址。
(1)花卉数据集01(数据集+训练代码下载地址)
       花卉数据集01,采集自2022年,一共16种花卉,数据集大小为32000张,图片大小为224x224的彩色图像。数据集包含的花卉名称为:一年蓬,三叶草,三角梅,两色金鸡菊,全叶马兰,全缘金光菊,剑叶金鸡菊,婆婆纳,油菜花,滨菊,石龙芮,绣球小冠花,蒲公英,蓝蓟,诸葛菜,鬼针草。数据集的缩略图如下:
图1 花卉数据集01
(2)花卉数据集02(数据集+训练源码下载地址)
       花卉数据集02,采集与2023年,一共16种花卉,每种花卉约2000张,总共32000,图片大小为224x224。数据集包含的花卉有:千屈菜,射干,旋覆花,曼陀罗,桔梗,棣棠,狗尾草,狼尾草,石竹,秋英,粉黛乱子草,红花酢浆草,芒草,蒲苇,马鞭草,黄金菊。数据集缩略图如下:
图2 花卉数据集02
(3)花卉数据集3(数据集+训练源码下载地址)
       花卉数据集03,采集与2023年,一共16种花卉,每种花卉约2000张,总共32000,图片大小为224x224。数据集包含的花卉有:北香花介,大花耧斗菜,小果蔷薇,小苜蓿,小蜡,泽珍珠菜,玫瑰,粉花绣线菊,线叶蓟,美丽月见草,美丽芍药,草甸鼠尾草,蓝花鼠尾草,蛇莓,长柔毛野豌豆,高羊茅。数据集缩略图如下:
图3 花卉数据集3

2、图片分类模型

       为了研究不同图片分类模型对于花朵的分类效果,以及图片分类模型在不同硬件平台的推理速度,这里分别使用目前主流的22种图片分类模型进行训练,并在cpu平台和GPU平台进行部署测试。这些模型是如下:

  • resnet系列:resnet18、resnet34、resnet50、resnet101、resnet152。
  • vgg系列:vgg11、vgg13、vgg6、vgg19。
  • squeezenet系列:squeezenet_v1、squeezenet_v2、squeezenet_v3。
  • mobilenet系列:mobilenet_v1、mobilenet_v2。
  • inception系列:inception_v1、inception_v2、inception_v3。
  • 其他系列:alexnet、lenet、mnist、tsl16、zfnet。

以上模型的训练代码基于pytorch架构,内置集成22了种模型,可进行傻瓜式训练。以下的代码块为训练代码的主脚本,完整的训练代码以及数据集请在此链接下载:源码下载链接。

import torch
import torch.nn as nn
import torch.optim as optim
from utils.dataloader import CustomImageDataset
from torch.utils.data import DataLoader
from utils.build_model import build_model
import argparse
import time
import osif __name__ == '__main__':parser =argparse.ArgumentParser(description='图片分类模型训练')parser.add_argument('-input_shape', type=tuple,default=(3,224,224),help='模型输入的通道数、高度、宽度')parser.add_argument('-train_imgs_dir', type=str,default='dataset/train',help='训练集目录')parser.add_argument('-test_imgs_dir', type=str,default='dataset/test',help='测试集目录')parser.add_argument('-classes_file', type=str,default='dataset/classes.txt',help='类别文件')parser.add_argument('-epochs', type=int,default=50,help='迭代次数')parser.add_argument('-batch_size', type=int,default=64,help='批大小,根据显存大小调整')parser.add_argument('-init_weights', type=str,default="init_weights/squeezenet_v1.pth",help='用于初始化的权重,请确保初始化的权重和训练的模型相匹配')parser.add_argument('-optim', type=str,default="adam",help='优化器选择,可选sgd或者adam. sgd优化器训练效果较好,但参数比较难调节,不好收敛')parser.add_argument('-lr', type=float,default=0.0001,help='初始学习率,此参数对模型训练影响较大,如果选择不合适,模型甚至不收敛.\如果遇到模型训练不收敛(损失函数不下降,准确度很低),可以尝试调整学习率.\resnet系列推荐优化器选择sgd,学习率设0.001;vgg系列优化器推荐adam,学习率为0.0001,其他模型优化器选择adam,推荐学习率为0.0002') parser.add_argument('-model_name', type=str,default='squeezenet_v1',help='模型名称,可选resnet18/resnet34/resnet50/resnet101/resnet152\/alexnet/lenet/zfnet/tsl16/mnist\vgg11/vgg13/vgg16,vgg19\squeezenet_v1/squeezenet_v2/squeezenet_v3\inception_v1/inception_v2/inception_v3\mobilenet_v1/mobilenet_v2/\')parser.add_argument('-argument', type=bool,default=True,help='是否在训练时开启数据增强模式')args = parser.parse_args()
print("模型:",args.model_name)
device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)classes=[]
try:with open(args.classes_file,"rt",encoding="ANSI")as f:for line in f:classes.append(line.strip())
except:with open(args.classes_file,"rt",encoding="UTF-8")as f:for line in f:classes.append(line.strip())num_class=len(classes)
model=build_model(args.model_name,args.input_shape,num_class)
if os.path.exists(args.init_weights):try:model.load_state_dict(args.init_weights)except:model.weights_init()print("参数初始化失败!请确保初始化参数与模型相一致.")
else:model.weights_init()print("没有找到名称为%s的权重文件,模型将跳过参数初始化"%(args.init_weights))
model=model.to(device)
# Create data loaders.training_data=CustomImageDataset(args.train_imgs_dir,classes,args.argument)
test_data=CustomImageDataset(args.test_imgs_dir,classes,False)train_dataloader = DataLoader(training_data, batch_size=args.batch_size,shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=args.batch_size)loss_fn=nn.CrossEntropyLoss()
if args.optim=="adam":optimizer=optim.Adam(model.parameters(), lr=args.lr)
else:optimizer=optim.SGD(model.parameters(), lr=args.lr,momentum=0.9,weight_decay=0.0005)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,args.epochs)
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()start=time.time()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction erroroptimizer.zero_grad()pred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()scheduler.step()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)end=time.time()print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]  time: {end-start:>3f}s")def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():start=time.time()for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()end=time.time()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} , Time: {end-start:>3f}s\n")return correctbest_accuracy=0
for t in range(args.epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)current_accuracy=test(test_dataloader, model, loss_fn)if current_accuracy>best_accuracy:best_accuracy=current_accuracytorch.save(model,"weights/%s_best_accuracy.pth"%(args.model_name))
torch.save(model,"weights/%s_last_accuracy_%.2f.pth"%(args.model_name,current_accuracy))
print("Done!")

3、图片分类模型评估

       分别训练了22种模型,图片为224x224的RGB图像。将约96000张图片划分为训练集和测试集,其中测试集占10%,一共9600张,训练集90%,一共86400张。训练充分后,对各种模型的top1、top2、top3、top4、top5准确度进行评估,并分别在cpu平台(intel i9)和gpu平台(RTX 3090)进行推理速度的测试。模型性能评估以及推理速度测试结果如表1所示。

表1 模型性能评估以及推理速度测试结果

模型参数量 [M]计算量 [G]GPU速度[FPS]CPU速度[FPS]Top1准确度[%]Top2准确度[%]Top3准确度[%]Top4准确度[%]Top5准确度[%]
resnet1811.193.64557.30204.0598.8499.7099.8699.9399.96
resnet3421.307.35347.39104.3598.6499.7499.8699.9399.96
resnet5034.9410.31295.2468.1998.6799.6699.8699.9499.97
resnet10153.9017.77171.8441.2798.6199.6499.8399.9099.94
resnet15268.4123.46130.2131.2098.4499.5899.8699.9299.96
vgg11128.9615.25462.2430.0092.8897.3598.6999.2699.48
vgg13129.1522.67411.5522.2795.1898.2299.2199.5599.73
vgg16134.4630.99340.3420.1495.3598.4899.2999.5099.60
vgg19139.7739.33292.5116.7194.8998.2599.0199.3999.62
mobilenet_v13.251.16942.09506.4297.4599.4499.7199.8399.90
mobilenet_v24.030.91489.69386.2095.9998.9899.5299.7599.81
inception_v16.023.20343.56203.4395.8098.7999.4499.7499.83
inception_v27.853.34291.10165.4998.3099.5499.8099.8599.90
inception_v321.877.65136.2571.8999.0599.8199.9299.9599.97
squeezenet_v10.761.61758.27362.2297.4499.3699.6999.8199.86
squeezenet_v20.761.61704.60360.7597.2799.2399.6799.8099.85
squeezenet_v31.102.37658.07267.2898.2899.5499.7799.8599.89
mnist_net214.4551.37189.8110.1389.4796.2098.1498.9199.28
AlexNet17.692.35858.92211.7396.2098.7999.4999.6999.76
LeNet78.450.941041.7576.1284.8693.7596.5597.9498.70
TSL16116.9523.56381.6324.0995.6198.4099.1599.5399.69
ZF_Net72.092.68351.8782.6896.5899.0399.4799.6899.80

       从表一展示的结果来看,面对48种花卉的分类任务:如果只关心Top5分类准确度,那么这些模型均能达到98%以上的分类准确度,大部分模型的Top5准确度都能达到99%以上,对于实际应用而言,花卉分类程序通常会给出5个备选项,这样的话,只要5个备选项里边存在一个正确选项,就可以认为花卉分类是成功的。当然,如果追求单一选项的准确性,resnet系列模型、inception系列、squeezenet系列模型,在Top1分类准确度上表现不俗,可以达到97%以上的准确度。通常来说,可以工程化的图片分类模型,不仅仅要求其具备良好的分类准确度,还对其推理速度有一定的要求。表1的推理速度测试数据,分别在Intel i9 CPU平台和英伟达RTX 3090 GPU平台进行测试,推理用的软件接口是onnx推理架构,测试的策略是逐一对1000张224x224的彩色图片输送到模型中进行推理,统计其总的推理时间,然后计算平均推理帧率。从表1的数据可以得知,GPU平台的推理速度要比CPU平台的推理的速度快很多,而且在GPU平台推理帧率高的模型,在CPU平台的推理帧率未必高,反之亦然,也就是说,模型推理帧率的排名,跟硬件平台是有关的。在GPU平台上推理帧率比较靠前的模型是squeezenet系列、mobilenet系列,以及alexnet和lenet;在CPU平台上推理帧率比较靠前的模型是squeezenet系列、mobilenet系列、alexnet、inception_v1、resnet18。综合来说,squeezenet系列、mobilenet系列得分是最高的,因为他们在分类准确度上表现优秀,并且在GPU和CPU平台上的推理帧率都变现不错,而且模型的参数量很小,适合在线部署和嵌入式部署,所以这些模型应当优先选择。另外,从评估和测试的结果来看,还可以得到以下几个结论:

  • 对于可以并行计算的硬件平台来说,比如GPU、NPU以及一些具有批处理能力的CPU,模型的推理帧率跟模型的参数量和计算量没有绝对的关联性,更多的是跟模型的结构有关,如果模型适合于并行计算,那么即使模型具有较大的计算量,其推理速度也可以很快;
  • 对于串行执行的计算硬件来说,比如常规指令的CPU,模型的推理速度跟模型的计算量的是线性相关的,也就是说计算量越大,推理帧率越低;
  • 适合GPU平台部署的模型未必适合在CPU平台上部署,所以模型的选择要根据最终的部署平台而定。

4、总结

       花卉数据集共包括96000张图片,囊括了48种花卉的类别,其中10%为测试集,90%为训练集。图片的大小为224x224,通道数为3。一共使用了22种模型进行训练,通过模型评估和硬件平台部署测试得出结论:squeezenet_v1、squeezenet_v2、squeezenet_v3、mobilenet_v1、mobilenet_v2几个模型,具有参数量小,计算量小,分类准确度高的优点,并且在GPU平台和CPU平台上推理速度较快,适合在各种平台上部署,特别是适合移动端和嵌入式的部署。


http://chatgpt.dhexx.cn/article/jNw5IgBO.shtml

相关文章

机器学习-花卉识别系统

介绍 机器学习,人工智能,模式识别课题项目,基于tensorflow机器学习库使用CNN算法通过对四种花卉数据集进行训练,得出训练模型。同时基于Django框架开发可视化系统,实现上传图片预测是否为玫瑰,蒲公英&…

基于Python机器学习实现的花卉识别

目录 问题分析 3问题求解 3 2.1. 数据预处理 3 2.1.1. 预处理流程 3 2.1.2. 预处理实现 4 2.2. 降维可视化 4 2.2.1. 降维流程分析 4 2.2.2. PCA 方法降维 4 从图中给出的结果得到各个阶段的用时 6 2.2.3. t-SNE 方法求解 7随机产生初始解,得到在低维空间中的映射样…

如何扫一扫识别花草树木?教你高效识别花草的小妙招

如何扫一扫识别花草树木?如今金秋九月,很快桂花、菊花、满天星等这些花竞相开放,秋高气爽,毫无疑问是出门游玩的好时节。除了一些常见的花草之外,我们还可能遇到许多不认识的花草,那么这个时候我们应该怎么…

【01】花卉识别-基于tensorflow2.3实现

------------------------------------------------2021年6月18日重大更新-------------------------------------------------------------- 目前已经退出bug修复之后的tensorflow2.3物体分类代码,大家可以训练自己的数据集,快来试试吧 csdn教程链接&…

花卉识别--五个类别的检测

花卉识别–五个类别的检测 文章目录 花卉识别--五个类别的检测一、数据集的观察与查看二、将数据集分为data_train(训练集)和data_test(测试集)三、明确网络流程、建立网络结构四、定义损失函数、学习率、是否使用正则化五、存储模型、已经调用现有的训练好的模型进行测试六、画…

花卉识别(tensorflow)

参考教材:人工智能导论(第4版) 王万良 高等教育出版社 实验环境:Python3.6 Tensor flow 1.12 人工智能导论实验导航 实验一:斑马问题 https://blog.csdn.net/weixin_46291251/article/details/122246347 实验二:图像恢复 http…

常见花卉11种集锦及识别

一、月季花(蔷薇科植物) 矮小直立灌木;小枝有粗壮而略带钩状的皮刺,有时无刺。羽状复叶,小叶3-5,少数7,宽卵形或卵状矩圆形,长2-6厘米,宽1-3厘米,先端渐尖&am…

三分钟让你学会怎么识别花卉品种

有一次,我和朋友一起去旅游,来到了一个生态公园。在公园里,看到了很多美丽的花卉,但是有一种花都不认识它是什么品种,我们非常想知道它的名字和背后的故事。于是我便开始在网上搜索一些可以识别花卉的软件,…

基于TensorFlow的花卉识别

概要设计 数据分析 本次设计的主题是花卉识别,数据为TensorFlow的官方数据集flower_photos,包括5种花卉(雏菊、蒲公英、玫瑰、向日葵和郁金香)的图片,并有对应类别的标识(daisy、dandelion、roses、sunfl…

Spring Boot集成微信扫码登录(实测通过)

微信实现扫码登录 一:具体流程:1、先登录你的 [微信开放平台](https://open.weixin.qq.com)2、创建网站应用3、设置你的AppSecret和授权回调域(不用加http/https)4、开始编码实现 二:实现效果三:注意事项&a…

pc端实现微信扫码登录

pc端实现微信扫码登录 流程:使用vue-wxlogin组件当我们打开微信扫一扫,此时二维码组件会有变化,显示扫描成功 我们的手机就会弹出一个授权页面。记住让后端绑定一个微信公众,通过授权该公众就可以了 效果: 当点击同意…

Java实现微信扫码登录

微信扫码登录 1. 授权流程说明第一步:请求 code第二步:通过 code 获取 access_token第三步:通过 access_token 调用接口 2. 授权流程代码3. 用户登录和登出4. Spring AOP 校验用户有没有登录5. 拦截登录校验不通过抛出的异常 1. 授权流程说明…

vue 使用企业微信扫一扫

vue 使用企业微信扫一扫 vue 使用企业微信扫一扫 第一次调用企业微信功能,有点坑,折腾了好几天,终于好了,记录一下操作过程。 了解功能所需权限(config和agentConfig) 首先要确定使用的功能需要获取的权…

VUE实现微信扫码登录

获取access_token时序图&#xff1a; public中index.html引入 <script src"https://res.wx.qq.com/connect/zh_CN/htmledition/js/wxLogin.js"></script> 微信登录操作 new WxLogin({// 以下操作把请求到的二维码嵌入到id为"weixin"的标签中i…

微信扫码登录原理解析

&#xff08;尊重劳动成果&#xff0c;转载请注明出处&#xff1a;http://blog.csdn.net/qq_25827845/article/details/78823861冷血之心的博客&#xff09; 最近针对扫码登录机制做了一个调研&#xff0c;以下以微信网页扫码登录为例进行一个总结。 1、微信扫码登录过程&…

web微信扫码登录

微信web扫码登录的大致流程&#xff0c;最后有源码基本是够用了&#xff0c;后续登录这一块会继续完善&#xff0c;会加上shiro、redis&#xff0c;前端准备用react来做&#xff0c;搞个全套的 开始之前我们先来看几个问题&#xff0c;有兴趣的可以了解下欢迎发表评论提出意见…

使用码上登录实现微信扫一扫登录

微信扫一扫登录测试 码上登录开发和使用登录的时序图准备工作后台开发前端显示 码上登录 码上登录是一个小程序&#xff0c;对个体开发者提供了免费的微信扫一扫登录入口&#xff0c;因为微信开发者需要企业认证&#xff0c;没办法在个人网站上做测试。码上登录相当于一个桥接…

微信扫码登录的一种开发思路

微信扫码授权登录流程&#xff1a; 用户在显示二维码的页面用手机扫码授权页面跳转到指定地址&#xff0c;URL上带有参数code前端通过code向服务端请求用于权限认证的token前端后续请求在请求头带上token作为身份标识 需要解决的问题 按照上述的流程&#xff0c;前端最简单的…

java集成微信扫码登录

具体流程可以看微信官网的扫码登录文档 地址&#xff1a;https://open.weixin.qq.com/cgi-bin/showdocument?actiondir_list&tresource/res_list&verify1&idopen1419316505&token&langzh_CN 一、 前期准备 1、注册 微信开放平台&#xff1a;https://open…

企业微信扫码登录

企业微信扫码登录步骤&#xff1a; 1.首先在要放置二维码的页面提供一个盒子用于防止生成的二维码 2.在当前页面将企业微信提供的js进行引入 3.调用提供的方法实例&#xff0c; 4.要获得扫码成功之后的code和state值&#xff0c;调用服务&#xff0c;就能查到当前用户的token&…