鸢尾花数据集分类--神经网络

article/2025/11/10 20:52:42

1.1 鸢尾花数据集介绍

iris数据集是用来给莺尾花做分类的数据集,每个样本包含了花萼长度、花萼宽度、花瓣长度、花瓣宽度四个特征,我们需要建立一个分类器,该分类器可通过样本的四个特征来来判断样本属于山鸢尾(Setosa)、变色鸢尾(Versicolour)还是维吉尼亚鸢尾(Virginica)中的哪一个,选择神经网络进行分类。

1.2 思路流程

  • 导入鸢尾花数据集
  • 对数据集进行切分,分为训练集和测试集
  • 搭建网络模型
  • 训练网络
  • 将所训练出的模型进行保存(准确率大于90%)

1.3 网络模型

在这里插入图片描述

采用sigmoid等函数,算激活函数时(指数运算),计算量大,反向传播求误差梯度时,求导涉及除法,计算量相对大,而采用Relu激活函数,整个过程的计算量节省很多,故采用Relu作为激活函数

1.4 实现代码

导入所需要的的模块

import torch
import torch.nn as nn
from sklearn import datasets
from sklearn.model_selection import train_test_split

神经网络类

class Net(nn.Module):def __init__(self,in_num,out_num,hid_num):super(Net,self).__init__()self.network = nn.Sequential(nn.Linear(in_num,hid_num),nn.ReLU(),nn.Linear(hid_num,out_num))self.optimizer = torch.optim.SGD(self.parameters(), lr=0.05)self.loss_func = torch.nn.CrossEntropyLoss()def forward(self,x):return self.network(x)def train(self,x,y):out = self.forward(x)loss = self.loss_func(out,y)self.optimizer.zero_grad()loss.backward()self.optimizer.step()print('loss = %.4f' % loss.item())def test(self,x):return self.forward()

引入数据集,并按照8:2切分训练集和测试集

dataset = datasets.load_iris()
input = torch.FloatTensor(dataset['data'])
label = torch.LongTensor(dataset['target'])
x_train, x_test, y_train, y_test = train_test_split(input, label, test_size=0.2)

如果存在已有训练好的网络则导入,并在总体数据集上测试其准确性

try:print("iris_model exist and have been loaded")mynet = torch.load('iris_model.pkl')output = mynet(input)pred_y = torch.max(output, 1)[1].numpy()sum = 0for i in range(len(label)):if pred_y[i] == label[i]:sum = sum + 1accuracy = float(sum / len(label))print('model accuracy = %d%% (testing on the whole dataset)' % (accuracy * 100))

若不存在训练好的网络则进行训练,直到准确性大于90%后将其保存

except:mynet = Net(4,10,3)accuracy = 0.0while accuracy < 0.9:for i in range (10000):mynet.train(x_train,y_train)output = mynet(x_test)pred_y = torch.max(output, 1)[1].numpy()sum=0for i in range(len(y_test)):if pred_y[i] == y_test[i]:sum=sum+1accuracy = float(sum / len(y_test))torch.save(mynet, 'iris_model.pkl')print(mynet)print("The net have been saved")print('accuracy = %d%%' % (accuracy*100))

鸢尾花识别完整代码

import torch
import torch.nn as nn
from sklearn import datasets
from sklearn.model_selection import train_test_split
class Net(nn.Module):def __init__(self,in_num,out_num,hid_num):super(Net,self).__init__()self.network = nn.Sequential(nn.Linear(in_num,hid_num),nn.ReLU(),nn.Linear(hid_num,out_num))self.optimizer = torch.optim.SGD(self.parameters(), lr=0.05)self.loss_func = torch.nn.CrossEntropyLoss()def forward(self,x):return self.network(x)def train(self,x,y):out = self.forward(x)loss = self.loss_func(out,y)self.optimizer.zero_grad()loss.backward()self.optimizer.step()print('loss = %.4f' % loss.item())def test(self,x):return self.forward()if __name__ == '__main__':dataset = datasets.load_iris()input = torch.FloatTensor(dataset['data'])label = torch.LongTensor(dataset['target'])x_train, x_test, y_train, y_test = train_test_split(input, label, test_size=0.2)try:print("iris_model exist and have been loaded")mynet = torch.load('iris_model.pkl')output = mynet(input)pred_y = torch.max(output, 1)[1].numpy()sum = 0for i in range(len(label)):if pred_y[i] == label[i]:sum = sum + 1accuracy = float(sum / len(label))print('model accuracy = %d%% (testing on the whole dataset)' % (accuracy * 100))except:mynet = Net(4,10,3)accuracy = 0.0while accuracy < 0.9:for i in range (10000):mynet.train(x_train,y_train)output = mynet(x_test)pred_y = torch.max(output, 1)[1].numpy()sum=0for i in range(len(y_test)):if pred_y[i] == y_test[i]:sum=sum+1accuracy = float(sum / len(y_test))torch.save(mynet, 'iris_model.pkl')print(mynet)print("The net have been saved")print('accuracy = %d%%' % (accuracy*100))

github文件链接


http://chatgpt.dhexx.cn/article/vh7om8lN.shtml

相关文章

机器学习鸢尾花数据集分析

目录 1 sklearn数据集的使用2 sklearn数据集返回值介绍3 查看数据分布4 数据集的划分5 总结 1 sklearn数据集的使用 鸢尾属&#xff08;拉丁学名&#xff1a;Iris L.&#xff09;是单子叶植物纲&#xff0c;鸢尾科多年生草本植物&#xff0c;有块茎或匍匐状根茎&#xff1b;叶…

决策树可视化:鸢尾花数据集分类(附代码数据集)

决策树 数据集实战可视化评价 决策树是什么&#xff1f;决策树(decision tree)是一种基本的分类与回归方法。举个通俗易懂的例子&#xff0c;流程图就是一种决策树。 有没有车&#xff0c;没车的话有没有房&#xff0c;没房的话有没有存款&#xff0c;没存款pass。这个流程就是…

机器学习算法:基于鸢尾花(iris)数据集的数据可视化 (200+收藏)

文章目录 基于鸢尾花(iris)数据集的数据可视化1、数据导入2、查看样本数据3、特征与标签组合的散点可视化3.1、 散点图3.2、 箱型图3.2、 三维散点图想要看更加舒服的排版、更加准时的推送 关注公众号“不太灵光的程序员” 干货推送,微信随时解答你的疑问 😃😃😃 基于…

鸢尾花数据集的可视化

#TensorFlow实战 鸢尾花数据集的可视化化展示 文章目录 前言一、介绍二、步骤1.引入库2.读入数据 前言 数据可视化展示能在实验中可视化展出实验结果&#xff0c;是基础部分 一、介绍 鸢尾花数据集是公开的数据集&#xff0c;可通过URL从TensorFlow的Keras连接下载。 二、步…

探索sklearn | 鸢尾花数据集

1 鸢尾花数据集背景 鸢尾花数据集是原则20世纪30年代的经典数据集。它是用统计进行分类的鼻祖。 sklearn包不仅囊括很多机器学习的算法&#xff0c;也自带了许多经典的数据集&#xff0c;鸢尾花数据集就是其中之一。 导入的方法很简单&#xff0c;不过我比较好奇它是如何来存…

线性回归实例-鸢尾花数据集

文章目录 一、具体实现步骤1. 导入Iris鸢尾花数据集2. 提取花瓣数据3. 拆分数据4. 训练模型 二、可视化结果展示1. 训练集2. 测试集 三、相关知识点讲解1. train_test_split()函数2. LinearRegression()函数3. 散点图与折线统计图的绘制 这篇文章中&#xff0c;我们要通过鸢尾花…

基于朴素贝叶斯的鸢尾花数据集分类

目录 1.作者介绍2.理论知识介绍2.1算法介绍2.2数据集介绍 3.实验代码及结果3.1 数据集下载3.2实验代码3.2实验结果 1.作者介绍 王炜鑫&#xff0c;男&#xff0c;西安工程大学电子信息学院&#xff0c;2021级研究生 研究方向&#xff1a;小型无人直升机模型辨识 电子邮件&…

鸢尾花数据集的数据可视化

鸢尾花数据集的数据显示 一、鸢尾花数据集介绍1.历史2.数据集 二、鸢尾花数据集可视化1.普通读取数据方法2.运行结果3.普通读取数据方法4.运行结果5.未使用mglearn库的代码6.运行结果7.使用mglearn库的代码8.运行结果 一、鸢尾花数据集介绍 1.历史 安德森鸢尾花卉数据集&#…

鸢尾花数据集分类

数据集介绍 共有数据150组&#xff0c;每组包括花萼长、花萼宽、花瓣长、花瓣宽4个输入特征。 同时给出了&#xff0c;这一组特征对应的鸢尾花类别。类别包括Setosa Iris&#xff08;狗尾草 鸢尾&#xff09;&#xff0c;Versicolour Iris&#xff08;杂色鸢尾&#xff09;&…

鸢尾花数据集分类-决策树

文章目录 决策树数据集代码实验分析 决策树 决策树&#xff08;Decision Tree&#xff09;是一种基本的分类与回归方法&#xff0c;当决策树用于分类时称为分类树&#xff0c;用于回归时称为回归树。主要介绍分类树。 决策树由结点和有向边组成。结点有两种类型&#xff1a;内…

鸢尾花数据集的各种玩法

目录 鸢尾花数据集下载鸢尾花数据集iris csv文件下载数据集 Pandas访问csv数据集 Pandas库Pandas二维数据基本操作 读取csv数据集文件设置列标题names参数 访问数据显示统计信息DataFrame的常用属性&#xff1a;ndim、size、shape转化为NumPy数组 访问数组元素–索引和切片 鸢…

iris鸢尾花数据集最全数据分析

写在前面 在写这篇文章之前&#xff0c;首先安利下jupyter&#xff0c;简直是神作&#xff0c;既可以用来写文章&#xff0c;又可以用来写代码&#xff0c;文章和代码并存&#xff0c;简直就是写代码/文章/教程的利器。 安装很简单&#xff1a;pip install jupyter 使用很简单…

sklearn数据集——iris鸢尾花数据集

参考书籍&#xff1a;Python机器学习基础教程 1、初始数据 鸢尾花&#xff08;Iris&#xff09;数据集&#xff0c;是机器学习和统计学中一个经典的数据集。它包含在 scikit-learn 的 datasets 模块中。 我们可以调用 load_iris 函数来加载数据&#xff1a; from sklearn.da…

重拾Iris鸢尾花数据集分析

最近我又又又开始了我的机器学习道路&#xff0c;并且回过头来重新看了一遍Iris数据分析&#xff0c;作为机器学习里面最经典的案例之一&#xff0c;鸢尾花既是我入门机器学习到放弃的地方&#xff0c;又是再次细读之后给予我灵感的地方。 下面介绍一下这次灵感之旅&am…

Python-鸢尾花数据集Iris 数据可视化 :读取数据、显示数据、描述性统计、散点图、直方图、KDE图、箱线图

本博客运行环境为Jupyter Notebook、Python3。使用的数据集是鸢尾花数据集&#xff08;Iris&#xff09;。主要叙述的是数据可视化。 IRIS数据集以鸢尾花的特征作为数据来源&#xff0c;数据集包含150个数据集&#xff0c;有4维&#xff0c;分为3 类&#xff0c;每类50个数据&a…

《机器学习》分析鸢尾花数据集

转载地址&#xff1a;https://www.cnblogs.com/mandy-study/p/7941365.html 分析鸢尾花数据集 下面将结合Scikit-learn官网的逻辑回归模型分析鸢尾花示例&#xff0c;给大家进行详细讲解及拓展。由于该数据集分类标签划分为3类&#xff08;0类、1类、2类&#xff09;&#xff…

笔记篇二:鸢尾花数据集分类

目录 一、鸢尾花数据集 二、逻辑回归分析 三、逻辑回归实现鸢尾花数据集分类 四、散点图绘制 一、鸢尾花数据集 1、问题 Iris 鸢尾花数据集是一个经典数据集&#xff0c;在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录&#xff0c;每类各 5…

IRIS鸢尾花数据集(多种格式)-下载地址

最近看的例子有用到IRIS数据集&#xff0c; 个人找了半天&#xff0c;才找到合适格式的数据集。 因此&#xff0c;将我找到的数据集分享给大家&#xff0c;以免大家像我一样找很久。 我这里有3种格式的数据集&#xff0c;分别是&#xff1a; 1. iris.csv 2. Iris.data 3.…

鸢尾花数据集基本用法

Iris鸢尾花数据集是一个经典的数据集。 包含3类共150条记录&#xff0c;每类各50项数据&#xff0c;每一条记录都有四个体征。 可以通过这四个特征来预测鸢尾花属于哪一个品种。 一.鸢尾花数据集 首先导入数据集&#xff0c;用pandas读入iris.csv数据集&#xff0c;读取后的…

鸢尾花(iris)数据集分析

原文链接&#xff1a;https://www.jianshu.com/p/52b86c774b0b Iris 鸢尾花数据集是一个经典数据集&#xff0c;在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录&#xff0c;每类各 50 个数据&#xff0c;每条记录都有 4 项特征&#xff1a;花萼长度…