DL4J实战之三:经典卷积实例(LeNet-5)

article/2025/4/30 5:54:26

欢迎访问我的GitHub

这里分类和汇总了欣宸的全部原创(含配套源码):https://github.com/zq2599/blog_demos

本篇概览

  • 作为《DL4J》实战的第三篇,目标是在DL4J框架下创建经典的LeNet-5卷积神经网络模型,对MNIST数据集进行训练和测试,本篇由以下内容构成:
  1. LeNet-5简介
  2. MNIST简介
  3. 数据集简介
  4. 关于版本和环境
  5. 编码
  6. 验证

LeNet-5简介

  • 是Yann LeCun于1998年设计的卷积神经网络,用于手写数字识别,例如当年美国很多银行用其识别支票上的手写数字,LeNet-5是早期卷积神经网络最有代表性的实验系统之一
  • LeNet-5网络结构如下图所示,一共七层:C1 -> S2 -> C3 -> S4 -> C5 -> F6 -> OUTPUT
    在这里插入图片描述
  • 这张图更加清晰明了(原图地址:https://cuijiahua.com/blog/2018/01/dl_3.html),能够很好的指导咱们在DL4J上的编码:
    在这里插入图片描述
  • 按照上图简单分析一下,用于指导接下来的开发:
  1. 每张图片都是28*28的单通道,矩阵应该是[1, 28,28]
  2. C1是卷积层,所用卷积核尺寸5*5,滑动步长1,卷积核数目20,所以尺寸变化是:28-5+1=24(想象为宽度为5的窗口在宽度为28的窗口内滑动,能滑多少次),输出矩阵是[20,24,24]
  3. S2是池化层,核尺寸2*2,步长2,类型是MAX,池化操作后尺寸减半,变成了[20,12,12]
  4. C3是卷积层,所用卷积核尺寸5*5,滑动步长1,卷积核数目50,所以尺寸变化是:12-5+1=8,输出矩阵[50,8,8]
  5. S4是池化层,核尺寸2*2,步长2,类型是MAX,池化操作后尺寸减半,变成了[50,4,4]
  6. C5是全连接层(FC),神经元数目500,接relu激活函数
  7. 最后是全连接层Output,共10个节点,代表数字0到9,激活函数是softmax

MNIST简介

  • MNIST是经典的计算机视觉数据集,来源是National Institute of Standards and Technology (NIST,美国国家标准与技术研究所),包含各种手写数字图片,其中训练集60,000张,测试集 10,000张,
  • MNIST来源于250 个不同人的手写,其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员.,测试集(test set) 也是同样比例的手写数字数据
  • MNIST官网:http://yann.lecun.com/exdb/mnist/

数据集简介

  • 从MNIST官网下载的原始数据并非图片文件,需要按官方给出的格式说明做解析处理才能转为一张张图片,这些事情显然不是本篇的主题,因此咱们可以直接使用DL4J为我们准备好的数据集(下载地址稍后给出),该数据集中是一张张独立的图片,这些图片所在目录的名字就是该图片具体的数字,如下图,目录0里面全是数字0的图片:
    在这里插入图片描述
  • 上述数据集的下载地址有两个:
  1. 可以在CSDN下载(0积分):https://download.csdn.net/download/boling_cavalry/19846603
  2. github:https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz
  • 下载之后解压开,是个名为mnist_png的文件夹,稍后的实战中咱们会用到它

关于DL4J版本

  • 《DL4J实战》系列的源码采用了maven的父子工程结构,DL4J的版本在父工程dlfj-tutorials中定义为1.0.0-beta7
  • 本篇的代码虽然还是dlfj-tutorials的子工程,但是DL4J版本却使用了更低的1.0.0-beta6,之所以这么做,是因为下一篇文章,咱们会把本篇的训练和测试工作交给GPU来完成,而对应的CUDA库只有1.0.0-beta6
  • 扯了这么多,可以开始编码了

源码下载

  • 本篇实战中的完整源码可在GitHub下载到,地址和链接信息如下表所示(https://github.com/zq2599/blog_demos):
名称链接备注
项目主页https://github.com/zq2599/blog_demos该项目在GitHub上的主页
git仓库地址(https)https://github.com/zq2599/blog_demos.git该项目源码的仓库地址,https协议
git仓库地址(ssh)git@github.com:zq2599/blog_demos.git该项目源码的仓库地址,ssh协议
  • 这个git项目中有多个文件夹,《DL4J实战》系列的源码在dl4j-tutorials文件夹下,如下图红框所示:
    在这里插入图片描述
  • dl4j-tutorials文件夹下有多个子工程,本次实战代码在simple-convolution目录下,如下图红框:
    在这里插入图片描述

编码

  • 在父工程 dl4j-tutorials下新建名为 simple-convolution的子工程,其pom.xml如下,可见这里的dl4j版本被指定为1.0.0-beta6
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><parent><artifactId>dlfj-tutorials</artifactId><groupId>com.bolingcavalry</groupId><version>1.0-SNAPSHOT</version></parent><modelVersion>4.0.0</modelVersion><artifactId>simple-convolution</artifactId><properties><dl4j-master.version>1.0.0-beta6</dl4j-master.version></properties><dependencies><dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId></dependency><dependency><groupId>ch.qos.logback</groupId><artifactId>logback-classic</artifactId></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-core</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.nd4j</groupId><artifactId>${nd4j.backend}</artifactId><version>${dl4j-master.version}</version></dependency></dependencies>
</project>
  • 接下来按照前面的分析实现代码,已经添加了详细注释,就不再赘述了:
package com.bolingcavalry.convolution;import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;@Slf4j
public class LeNetMNISTReLu {// 存放文件的地址,请酌情修改
//    private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist";private static final String BASE_PATH = "E:\\temp\\202106\\26";public static void main(String[] args) throws Exception {// 图片像素高int height = 28;// 图片像素宽int width = 28;// 因为是黑白图像,所以颜色通道只有一个int channels = 1;// 分类结果,0-9,共十种数字int outputNum = 10;// 批大小int batchSize = 54;// 循环次数int nEpochs = 1;// 初始化伪随机数的种子int seed = 1234;// 随机数工具Random randNumGen = new Random(seed);log.info("检查数据集文件夹是否存在:{}", BASE_PATH + "/mnist_png");if (!new File(BASE_PATH + "/mnist_png").exists()) {log.info("数据集文件不存在,请下载压缩包并解压到:{}", BASE_PATH);return;}// 标签生成器,将指定文件的父目录作为标签ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();// 归一化配置(像素值从0-255变为0-1)DataNormalization imageScaler = new ImagePreProcessingScaler();// 不论训练集还是测试集,初始化操作都是相同套路:// 1. 读取图片,数据格式为NCHW// 2. 根据批大小创建的迭代器// 3. 将归一化器作为预处理器log.info("训练集的矢量化操作...");// 初始化训练集File trainData = new File(BASE_PATH + "/mnist_png/training");FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);trainRR.initialize(trainSplit);DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);// 拟合数据(实现类中实际上什么也没做)imageScaler.fit(trainIter);trainIter.setPreProcessor(imageScaler);log.info("测试集的矢量化操作...");// 初始化测试集,与前面的训练集操作类似File testData = new File(BASE_PATH + "/mnist_png/testing");FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);testRR.initialize(testSplit);DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);testIter.setPreProcessor(imageScaler); // same normalization for better resultslog.info("配置神经网络");// 在训练中,将学习率配置为随着迭代阶梯性下降Map<Integer, Double> learningRateSchedule = new HashMap<>();learningRateSchedule.put(0, 0.06);learningRateSchedule.put(200, 0.05);learningRateSchedule.put(600, 0.028);learningRateSchedule.put(800, 0.0060);learningRateSchedule.put(1000, 0.001);// 超参数MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed)// L2正则化系数.l2(0.0005)// 梯度下降的学习率设置.updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))// 权重初始化.weightInit(WeightInit.XAVIER)// 准备分层.list()// 卷积层.layer(new ConvolutionLayer.Builder(5, 5).nIn(channels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build())// 下采样,即池化.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build())// 卷积层.layer(new ConvolutionLayer.Builder(5, 5).stride(1, 1) // nIn need not specified in later layers.nOut(50).activation(Activation.IDENTITY).build())// 下采样,即池化.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build())// 稠密层,即全连接.layer(new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build())// 输出.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image.build();MultiLayerNetwork net = new MultiLayerNetwork(conf);net.init();// 每十个迭代打印一次损失函数值net.setListeners(new ScoreIterationListener(10));log.info("神经网络共[{}]个参数", net.numParams());long startTime = System.currentTimeMillis();// 循环操作for (int i = 0; i < nEpochs; i++) {log.info("第[{}]个循环", i);net.fit(trainIter);Evaluation eval = net.evaluate(testIter);log.info(eval.stats());trainIter.reset();testIter.reset();}log.info("完成训练和测试,耗时[{}]毫秒", System.currentTimeMillis()-startTime);// 保存模型File ministModelPath = new File(BASE_PATH + "/minist-model.zip");ModelSerializer.writeModel(net, ministModelPath, true);log.info("最新的MINIST模型保存在[{}]", ministModelPath.getPath());}
}
  • 执行上述代码,日志输出如下,训练和测试都顺利完成,准确率达到0.9886:
21:19:15.355 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1110 is 0.18300625613640034
21:19:15.365 [main] DEBUG org.nd4j.linalg.dataset.AsyncDataSetIterator - Manually destroying ADSI workspace
21:19:16.632 [main] DEBUG org.nd4j.linalg.dataset.AsyncDataSetIterator - Manually destroying ADSI workspace
21:19:16.642 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - ========================Evaluation Metrics========================# of classes:    10Accuracy:        0.9886Precision:       0.9885Recall:          0.9886F1 Score:        0.9885
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)=========================Confusion Matrix=========================0    1    2    3    4    5    6    7    8    9
---------------------------------------------------972    0    0    0    0    0    2    2    2    2 | 0 = 00 1126    0    3    0    2    1    1    2    0 | 1 = 11    1 1019    2    0    0    0    6    3    0 | 2 = 20    0    1 1002    0    5    0    1    1    0 | 3 = 30    0    2    0  971    0    3    2    1    3 | 4 = 40    0    0    3    0  886    2    1    0    0 | 5 = 56    2    0    1    1    5  942    0    1    0 | 6 = 60    1    6    0    0    0    0 1015    1    5 | 7 = 71    0    1    1    0    2    0    2  962    5 | 8 = 81    2    1    3    5    3    0    2    1  991 | 9 = 9Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
21:19:16.643 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 完成训练和测试,耗时[27467]毫秒
21:19:17.019 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 最新的MINIST模型保存在[E:\temp\202106\26\minist-model.zip]Process finished with exit code 0

关于准确率

  • 前面的测试结果显示准确率为0.9886,这是1.0.0-beta6版本DL4J的训练结果,如果换成1.0.0-beta7,准确率可以达到0.99以上,您可以尝试一下;

  • 至此,DL4J框架下的经典卷积实战就完成了,截止目前,咱们的训练和测试工作都是CPU完成的,工作中CPU使用率的上升十分明显,下一篇文章,咱们把今天的工作交给GPU执行试试,看能否借助CUDA加速训练和测试工作;

你不孤单,欣宸原创一路相伴

  1. Java系列
  2. Spring系列
  3. Docker系列
  4. kubernetes系列
  5. 数据库+中间件系列
  6. DevOps系列

http://chatgpt.dhexx.cn/article/ItHHfSwJ.shtml

相关文章

DL4J实战之二:鸢尾花分类

欢迎访问我的GitHub 这里分类和汇总了欣宸的全部原创(含配套源码)&#xff1a;https://github.com/zq2599/blog_demos 本篇概览 本文是《DL4J》实战的第二篇&#xff0c;前面做好了准备工作&#xff0c;接下来进入正式实战&#xff0c;本篇内容是经典的入门例子&#xff1a;鸢…

DL4J实战之一:准备

欢迎访问我的GitHub 这里分类和汇总了欣宸的全部原创(含配套源码)&#xff1a;https://github.com/zq2599/blog_demos 关于DL4J DL4J是Deeplearning4j的简称&#xff0c;是基于Java虚拟机的深度学习框架&#xff0c;是用java和scala开发的&#xff0c;已开源&#xff0c;官网&…

【DL4J】基本操作_学习笔记(二)

DL4J基本操作 文章目录 DL4J基本操作1. 创建矩阵2. 矩阵元素读取3. 矩阵行元素读取4. 矩阵运算 导入依赖 <nd4j.version>1.0.0-beta2</nd4j.version><dependency><groupId>org.nd4j</groupId><artifactId>nd4j-native-platform</artifa…

【DL4J速成】Deeplearning4j图像分类从模型自定义到测试

文章首发于微信公众号《有三AI》 【DL4J速成】Deeplearning4j图像分类从模型自定义到测试 欢迎来到专栏《2小时玩转开源框架系列》&#xff0c;这是我们第九篇&#xff0c;前面已经说过了caffe&#xff0c;tensorflow&#xff0c;pytorch&#xff0c;mxnet&#xff0c;keras&…

深度学习框架DeepLearning4J(DL4J)的安装及配置

一、DeepLearning4J的简介和系统要求 1、DeepLearning4J简介 Deeplearning4J&#xff08;以下简称DL4J&#xff09;不是第一个开源的深度学习项目&#xff0c;但与此前的其他项目相比&#xff0c;DL4J在编程语言和宗旨两方面都独具特色。DL4J是基于JVM、聚焦行业应用且提供商…

适合中学生看的英文电影

怎样利用好丰富的资源来学习英语口语呢&#xff1f;其实其实看什么样的剧、如何看剧都是很讲究的。一起来解锁吧。 一、选剧要学会拆解自己学习目标&#xff0c;选定合适的类型&#xff0c;各取所需。 并不是所有类型的国外影视剧都适合作为学习的素材&#xff0c;主要依据自身…

springboot+mybatis实现简单的增、删、查、改

这篇文章主要针对java初学者&#xff0c;详细介绍怎么创建一个基本的springboot项目来对数据库进行crud操作。 目录 第一步&#xff1a;准备数据库 第二步&#xff1a;创建springboot项目 方法1&#xff1a;通过spring官网的spring initilizer创建springboot项目 方法2&am…

tk-mybatis使用介绍,springboot整合tk-mybatis、PageHelper实现分页查询

Mybatis-Plus极大简化了我们的开发&#xff0c;作为mybatis的增强版&#xff0c;Mybatis-Plus确实帮我们减少了很多SQL语句的编写&#xff0c;通过其提高的API&#xff0c;可以方便快捷第完成增删查改操作。但是&#xff0c;其实除了Mybatis-Plus以外&#xff0c;还有一个技术t…

SXSW 2022线下展回归,今年有哪些有趣的AR/VR内容?

如今海外的线下活动开始逐渐恢复&#xff0c;今年的SXSW活动也回归线下。与往年相比&#xff0c;这场艺术、音乐、电影的年度盛会在今年进一步融合新兴科技&#xff0c;比如将AR/VR与线下活动结合&#xff0c;带来了更多样化的娱乐应用场景。 那么今年活动上都有哪些看点&#…

UE4 Ultra Dynamic Sky 参数翻译及功能概述

Ultra Dynamic Sky的虚幻商城链接: Ultra Dynamic Sky Ultra_Dynamic_Sky翻译及功能概述 basic controls 基础控制 Refresh Settings 刷新设置 检查此布尔一次&#xff0c;以刷新所有设置&#xff1b; Time Of Day 一天中的时间 一天中天空模仿的时间&#xff0c;从0000到…

更新《鸿门宴传奇》黎明/冯绍峰/张涵予/刘亦菲1024x436 高清下载!1.25G 附加720P种子...

高清下载!1.25G 附加720P种子" title="更新《鸿门宴传奇》黎明/冯绍峰/张涵予/刘亦菲1024x436 高清下载!1.25G 附加720P种子"> 高清下载!1.25G 附加720P种子" title="更新《鸿门宴传奇》黎明/冯绍峰/张涵予/刘亦菲1024x436 高清下载!1.25G 附加…

2019年如何成为一名合格的数据分析师

我是CPDA数据分析师 我是CDA数据分析员 我从事数据分析相关工作 我是个数据分析的小白 我想转行做数据分析 今天我围绕如何成为合格的数据分析师跟大家分享三个小话题: 找到在数据分析领域的定位 数据分析思维的训练 数据分析领域发展方向 一、找准数据分析师的定位 …

感谢折磨你的人[三]

第38节 肯定自己才能看见成功 美国联合保险公司董事长克里蒙史东说&#xff1a;“真正的成功秘诀是‘肯定人生’四个字&#xff0c;如果你能以坚定而乐观的态度&#xff0c;去面对一切困难险阻&#xff0c;那么&#xff0c;你一定能从其中得到好处。” 不要抱怨周遭人、事、物对…

舒淇放下黎明战胜抑郁 自称没责任感且不会结婚q1h

舒淇入行至今&#xff0c;最令人印象深入的恋情&#xff0c;要数与黎明的7年情&#xff0c;有传二人当年因黎明父亲及影迷反对而分手&#xff0c;有一段时光&#xff0c;舒淇更患上抑郁症&#xff01;舒淇日前接收拜访时&#xff0c;被问到若心境愁闷会如何面对&#xff0c;她说…

PS4计算机模块试题,越玩越留恋的PS4独占大作,馋坏了PC玩家,纷纷加入主机行列...

PS4游戏主机是发烧级玩家必有的装备&#xff0c;它的游戏性能远远大于电脑&#xff0c;对于任何的游戏软件优化很强&#xff0c;几乎很少出现卡顿和缺陷&#xff0c;因为它就是为游戏而生&#xff0c;各种游戏也是为主机固件量身定做&#xff0c;那么这个强大的游戏平台&#x…

ES6 课堂笔记

ES6 第一章 ECMASript 相关介绍 1.1 什么是 ECMA ECMA&#xff08;European Computer Manufacturers Association&#xff09;中文名称为欧洲计算机制造商协会&#xff0c;这个组织的目标是评估、开发和认可电信和计算机标准。1994 年后该组织改名为 Ecma 国际。 1.2 什么是…

深度学习:智能时代的核心驱动力量

内容简介 科技巨头纷纷拥抱学习,自动驾驶、AI、语音识别、图像识别、智能翻译以及震惊世界的 AlphaGo,背后都是学习在发挥的作用。学习是人工智能从概念到繁荣得以实现的主流技术。经过学习训练的计算机,不再被动按照指令运转,而是像自然进化的生命那样,开始自主地从经验中…

林家栋这三十年:深获万梓良、刘德华赏识,靠配角成为影帝

https://www.toutiao.com/a6703796759279174155/ 文 | 王珍一 编辑 | 李小白 很少有演员在成为影帝之后&#xff0c;还能静心的做着配角&#xff0c;林家栋做到了。 从香港无线电视艺员训练班的艺员到成为影帝&#xff0c;林家栋用了30年。 在这漫长的30年里&#xff0c;他静…

新特效火爆抖音!各路神仙齐唱《蚂蚁呀嘿》,网友:短短几秒需一生来治愈

金磊 杨净 发自 凹非寺 量子位 报道 | 公众号 QbitAI 当互联网大佬们集体唱歌&#xff0c;会擦出怎样的火花&#xff1f; 现在&#xff0c;火爆抖音的AI特效&#xff0c;一键就可以实现梦幻联动。 瞧&#xff01;马云、马化腾、马斯克等大佬们&#xff0c;正在集体演唱神曲《蚂…

《猩球黎明》首曝海报

2019独角兽企业重金招聘Python工程师标准>>> 昨日刚刚宣布将档期提至2014年7月11日的《猩球黎明》(Dawn of the Planet of the Apes)&#xff0c;在今日发布了首批角色海报&#xff0c;四张各色的猩猩脸孔&#xff0c;像人类的军人一般在战前在脸上图油彩&#xff0…