DL4J实战之二:鸢尾花分类

article/2025/4/30 5:31:51

欢迎访问我的GitHub

这里分类和汇总了欣宸的全部原创(含配套源码):https://github.com/zq2599/blog_demos

本篇概览

  • 本文是《DL4J》实战的第二篇,前面做好了准备工作,接下来进入正式实战,本篇内容是经典的入门例子:鸢尾花分类
  • 下图是一朵鸢尾花,我们可以测量到它的四个特征:花瓣(petal)的宽和高,花萼(sepal)的 宽和高:
    在这里插入图片描述
  • 鸢尾花有三种:Setosa、Versicolor、Virginica
  • 今天的实战是用前馈神经网络Feed-Forward Neural Network (FFNN)就行鸢尾花分类的模型训练和评估,在拿到150条鸢尾花的特征和分类结果后,我们先训练出模型,再评估模型的效果:
    在这里插入图片描述

源码下载

  • 本篇实战中的完整源码可在GitHub下载到,地址和链接信息如下表所示(https://github.com/zq2599/blog_demos):
名称链接备注
项目主页https://github.com/zq2599/blog_demos该项目在GitHub上的主页
git仓库地址(https)https://github.com/zq2599/blog_demos.git该项目源码的仓库地址,https协议
git仓库地址(ssh)git@github.com:zq2599/blog_demos.git该项目源码的仓库地址,ssh协议
  • 这个git项目中有多个文件夹,《DL4J实战》系列的源码在dl4j-tutorials文件夹下,如下图红框所示:
    在这里插入图片描述
  • dl4j-tutorials文件夹下有多个子工程,本次实战代码在dl4j-tutorials目录下,如下图红框:
    在这里插入图片描述

编码

  • dl4j-tutorials工程下新建子工程classifier-iris,其pom.xml如下:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><parent><artifactId>dlfj-tutorials</artifactId><groupId>com.bolingcavalry</groupId><version>1.0-SNAPSHOT</version></parent><modelVersion>4.0.0</modelVersion><artifactId>classifier-iris</artifactId><properties><maven.compiler.source>8</maven.compiler.source><maven.compiler.target>8</maven.compiler.target></properties><dependencies><dependency><groupId>com.bolingcavalry</groupId><artifactId>commons</artifactId><version>${project.version}</version></dependency><dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId></dependency><dependency><groupId>org.nd4j</groupId><artifactId>${nd4j.backend}</artifactId></dependency><dependency><groupId>ch.qos.logback</groupId><artifactId>logback-classic</artifactId></dependency></dependencies>
</project>
  • 上述pom.xml有一处需要注意的地方,就是${nd4j.backend}参数的值,该值在决定了后端线性代数计算是用CPU还是GPU,本篇为了简化操作选择了CPU(因为个人的显卡不同,代码里无法统一),对应的配置就是nd4j-native
  • 源码全部在Iris.java文件中,并且代码中已添加详细注释,就不再赘述了:
package com.bolingcavalry.classifier;import com.bolingcavalry.commons.utils.DownloaderUtility;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;/*** @author will (zq2599@gmail.com)* @version 1.0* @description: 鸢尾花训练* @date 2021/6/13 17:30*/
@SuppressWarnings("DuplicatedCode")
@Slf4j
public class Iris {public static void main(String[] args) throws  Exception {//第一阶段:准备// 跳过的行数,因为可能是表头int numLinesToSkip = 0;// 分隔符char delimiter = ',';// CSV读取工具RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);// 下载并解压后,得到文件的位置String dataPathLocal = DownloaderUtility.IRISDATA.Download();log.info("鸢尾花数据已下载并解压至 : {}", dataPathLocal);// 读取下载后的文件recordReader.initialize(new FileSplit(new File(dataPathLocal,"iris.txt")));// 每一行的内容大概是这样的:5.1,3.5,1.4,0.2,0// 一共五个字段,从零开始算的话,标签在第四个字段int labelIndex = 4;// 鸢尾花一共分为三类int numClasses = 3;// 一共150个样本int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)// 加载到数据集迭代器中DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);DataSet allData = iterator.next();// 洗牌(打乱顺序)allData.shuffle();// 设定比例,150个样本中,百分之六十五用于训练SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training// 训练用的数据集DataSet trainingData = testAndTrain.getTrain();// 验证用的数据集DataSet testData = testAndTrain.getTest();// 指定归一化器:独立地将每个特征值(和可选的标签值)归一化为0平均值和1的标准差。DataNormalization normalizer = new NormalizerStandardize();// 先拟合normalizer.fit(trainingData);// 对训练集做归一化normalizer.transform(trainingData);// 对测试集做归一化normalizer.transform(testData);// 每个鸢尾花有四个特征final int numInputs = 4;// 共有三种鸢尾花int outputNum = 3;// 随机数种子long seed = 6;//第二阶段:训练log.info("开始配置...");MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).activation(Activation.TANH)       // 激活函数选用标准的tanh(双曲正切).weightInit(WeightInit.XAVIER)     // 权重初始化选用XAVIER:均值 0, 方差为 2.0/(fanIn + fanOut)的高斯分布.updater(new Sgd(0.1))  // 更新器,设置SGD学习速率调度器.l2(1e-4)                          // L2正则化配置.list()                            // 配置多层网络.layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)  // 隐藏层.build()).layer(new DenseLayer.Builder().nIn(3).nOut(3)          // 隐藏层.build()).layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)   // 损失函数:负对数似然.activation(Activation.SOFTMAX)                     // 输出层指定激活函数为:SOFTMAX.nIn(3).nOut(outputNum).build()).build();// 模型配置MultiLayerNetwork model = new MultiLayerNetwork(conf);// 初始化model.init();// 每一百次迭代打印一次分数(损失函数的值)model.setListeners(new ScoreIterationListener(100));long startTime = System.currentTimeMillis();log.info("开始训练");// 训练for(int i=0; i<1000; i++ ) {model.fit(trainingData);}log.info("训练完成,耗时[{}]ms", System.currentTimeMillis()-startTime);// 第三阶段:评估// 在测试集上评估模型Evaluation eval = new Evaluation(numClasses);INDArray output = model.output(testData.getFeatures());eval.eval(testData.getLabels(), output);log.info("评估结果如下\n" + eval.stats());}
}
  • 编码完成后,运行main方法,可见顺利完成训练并输出了评估结果,还有混淆矩阵用于辅助分析:
    在这里插入图片描述
  • 至此,咱们的第一个实战就完成了,通过经典实例体验的DL4J训练和评估的常规步骤,对重要API也有了初步认识,接下来会继续实战,接触到更多的经典实例;

你不孤单,欣宸原创一路相伴

  1. Java系列
  2. Spring系列
  3. Docker系列
  4. kubernetes系列
  5. 数据库+中间件系列
  6. DevOps系列

http://chatgpt.dhexx.cn/article/E9tYCAjr.shtml

相关文章

DL4J实战之一:准备

欢迎访问我的GitHub 这里分类和汇总了欣宸的全部原创(含配套源码)&#xff1a;https://github.com/zq2599/blog_demos 关于DL4J DL4J是Deeplearning4j的简称&#xff0c;是基于Java虚拟机的深度学习框架&#xff0c;是用java和scala开发的&#xff0c;已开源&#xff0c;官网&…

【DL4J】基本操作_学习笔记(二)

DL4J基本操作 文章目录 DL4J基本操作1. 创建矩阵2. 矩阵元素读取3. 矩阵行元素读取4. 矩阵运算 导入依赖 <nd4j.version>1.0.0-beta2</nd4j.version><dependency><groupId>org.nd4j</groupId><artifactId>nd4j-native-platform</artifa…

【DL4J速成】Deeplearning4j图像分类从模型自定义到测试

文章首发于微信公众号《有三AI》 【DL4J速成】Deeplearning4j图像分类从模型自定义到测试 欢迎来到专栏《2小时玩转开源框架系列》&#xff0c;这是我们第九篇&#xff0c;前面已经说过了caffe&#xff0c;tensorflow&#xff0c;pytorch&#xff0c;mxnet&#xff0c;keras&…

深度学习框架DeepLearning4J(DL4J)的安装及配置

一、DeepLearning4J的简介和系统要求 1、DeepLearning4J简介 Deeplearning4J&#xff08;以下简称DL4J&#xff09;不是第一个开源的深度学习项目&#xff0c;但与此前的其他项目相比&#xff0c;DL4J在编程语言和宗旨两方面都独具特色。DL4J是基于JVM、聚焦行业应用且提供商…

适合中学生看的英文电影

怎样利用好丰富的资源来学习英语口语呢&#xff1f;其实其实看什么样的剧、如何看剧都是很讲究的。一起来解锁吧。 一、选剧要学会拆解自己学习目标&#xff0c;选定合适的类型&#xff0c;各取所需。 并不是所有类型的国外影视剧都适合作为学习的素材&#xff0c;主要依据自身…

springboot+mybatis实现简单的增、删、查、改

这篇文章主要针对java初学者&#xff0c;详细介绍怎么创建一个基本的springboot项目来对数据库进行crud操作。 目录 第一步&#xff1a;准备数据库 第二步&#xff1a;创建springboot项目 方法1&#xff1a;通过spring官网的spring initilizer创建springboot项目 方法2&am…

tk-mybatis使用介绍,springboot整合tk-mybatis、PageHelper实现分页查询

Mybatis-Plus极大简化了我们的开发&#xff0c;作为mybatis的增强版&#xff0c;Mybatis-Plus确实帮我们减少了很多SQL语句的编写&#xff0c;通过其提高的API&#xff0c;可以方便快捷第完成增删查改操作。但是&#xff0c;其实除了Mybatis-Plus以外&#xff0c;还有一个技术t…

SXSW 2022线下展回归,今年有哪些有趣的AR/VR内容?

如今海外的线下活动开始逐渐恢复&#xff0c;今年的SXSW活动也回归线下。与往年相比&#xff0c;这场艺术、音乐、电影的年度盛会在今年进一步融合新兴科技&#xff0c;比如将AR/VR与线下活动结合&#xff0c;带来了更多样化的娱乐应用场景。 那么今年活动上都有哪些看点&#…

UE4 Ultra Dynamic Sky 参数翻译及功能概述

Ultra Dynamic Sky的虚幻商城链接: Ultra Dynamic Sky Ultra_Dynamic_Sky翻译及功能概述 basic controls 基础控制 Refresh Settings 刷新设置 检查此布尔一次&#xff0c;以刷新所有设置&#xff1b; Time Of Day 一天中的时间 一天中天空模仿的时间&#xff0c;从0000到…

更新《鸿门宴传奇》黎明/冯绍峰/张涵予/刘亦菲1024x436 高清下载!1.25G 附加720P种子...

高清下载!1.25G 附加720P种子" title="更新《鸿门宴传奇》黎明/冯绍峰/张涵予/刘亦菲1024x436 高清下载!1.25G 附加720P种子"> 高清下载!1.25G 附加720P种子" title="更新《鸿门宴传奇》黎明/冯绍峰/张涵予/刘亦菲1024x436 高清下载!1.25G 附加…

2019年如何成为一名合格的数据分析师

我是CPDA数据分析师 我是CDA数据分析员 我从事数据分析相关工作 我是个数据分析的小白 我想转行做数据分析 今天我围绕如何成为合格的数据分析师跟大家分享三个小话题: 找到在数据分析领域的定位 数据分析思维的训练 数据分析领域发展方向 一、找准数据分析师的定位 …

感谢折磨你的人[三]

第38节 肯定自己才能看见成功 美国联合保险公司董事长克里蒙史东说&#xff1a;“真正的成功秘诀是‘肯定人生’四个字&#xff0c;如果你能以坚定而乐观的态度&#xff0c;去面对一切困难险阻&#xff0c;那么&#xff0c;你一定能从其中得到好处。” 不要抱怨周遭人、事、物对…

舒淇放下黎明战胜抑郁 自称没责任感且不会结婚q1h

舒淇入行至今&#xff0c;最令人印象深入的恋情&#xff0c;要数与黎明的7年情&#xff0c;有传二人当年因黎明父亲及影迷反对而分手&#xff0c;有一段时光&#xff0c;舒淇更患上抑郁症&#xff01;舒淇日前接收拜访时&#xff0c;被问到若心境愁闷会如何面对&#xff0c;她说…

PS4计算机模块试题,越玩越留恋的PS4独占大作,馋坏了PC玩家,纷纷加入主机行列...

PS4游戏主机是发烧级玩家必有的装备&#xff0c;它的游戏性能远远大于电脑&#xff0c;对于任何的游戏软件优化很强&#xff0c;几乎很少出现卡顿和缺陷&#xff0c;因为它就是为游戏而生&#xff0c;各种游戏也是为主机固件量身定做&#xff0c;那么这个强大的游戏平台&#x…

ES6 课堂笔记

ES6 第一章 ECMASript 相关介绍 1.1 什么是 ECMA ECMA&#xff08;European Computer Manufacturers Association&#xff09;中文名称为欧洲计算机制造商协会&#xff0c;这个组织的目标是评估、开发和认可电信和计算机标准。1994 年后该组织改名为 Ecma 国际。 1.2 什么是…

深度学习:智能时代的核心驱动力量

内容简介 科技巨头纷纷拥抱学习,自动驾驶、AI、语音识别、图像识别、智能翻译以及震惊世界的 AlphaGo,背后都是学习在发挥的作用。学习是人工智能从概念到繁荣得以实现的主流技术。经过学习训练的计算机,不再被动按照指令运转,而是像自然进化的生命那样,开始自主地从经验中…

林家栋这三十年:深获万梓良、刘德华赏识,靠配角成为影帝

https://www.toutiao.com/a6703796759279174155/ 文 | 王珍一 编辑 | 李小白 很少有演员在成为影帝之后&#xff0c;还能静心的做着配角&#xff0c;林家栋做到了。 从香港无线电视艺员训练班的艺员到成为影帝&#xff0c;林家栋用了30年。 在这漫长的30年里&#xff0c;他静…

新特效火爆抖音!各路神仙齐唱《蚂蚁呀嘿》,网友:短短几秒需一生来治愈

金磊 杨净 发自 凹非寺 量子位 报道 | 公众号 QbitAI 当互联网大佬们集体唱歌&#xff0c;会擦出怎样的火花&#xff1f; 现在&#xff0c;火爆抖音的AI特效&#xff0c;一键就可以实现梦幻联动。 瞧&#xff01;马云、马化腾、马斯克等大佬们&#xff0c;正在集体演唱神曲《蚂…

《猩球黎明》首曝海报

2019独角兽企业重金招聘Python工程师标准>>> 昨日刚刚宣布将档期提至2014年7月11日的《猩球黎明》(Dawn of the Planet of the Apes)&#xff0c;在今日发布了首批角色海报&#xff0c;四张各色的猩猩脸孔&#xff0c;像人类的军人一般在战前在脸上图油彩&#xff0…

知名演员从北大毕业!学位论文让网友直呼:请收下我的膝盖!

来源&#xff1a;广州日报 编辑&#xff1a;双一流高校 近日&#xff0c;49岁香港男艺人马浚伟发布微博称&#xff0c;自己已通过北京大学光华管理学院硕士研究生学位论文答辩&#xff0c;顺利毕业。 相关的一则话题达到了1200万的阅读量&#xff0c;超7000名网友参与讨论。 两…