使用Dl4j训练的一个手写数字识别软件

article/2025/7/19 7:07:47

DL4J使用之手写数字识别

最近一直在学习深度学习,由于我是Java程序员出身,就选择了一个面向Java的深度学习库—DL4J。为了更加熟练的掌握这个库的使用,我使用该库,以MNIST(http://yann.lecun.com/exdb/mnist/)手写数字数据集作为基础,训练了一个模型,来识别手写字体。下面我们从以下几个方面讲解该项目的实现:

  • DL4J使用之手写数字识别
    • DL4J简介
    • 模型的训练
      • 训练数据集(MNIST)
      • 模型架构
    • 模型性能
    • 模型的保存与加载
    • 结果展示
    • 总结 与展望

DL4J简介

Deeplearning4j是国外创业公司Skymind的产品。目前最新的版本更新到了0.7.2。源码全部公开并托管在github上(https://github.com/deeplearning4j/deeplearning4j)。从这个库的名字上可以看出,它就是转为Java程序员写的Deep Learning库。其实这个库吸引人的地方不仅仅在于它支持Java,更为重要的是它可以支持Spark。由于Deep Learning模型的训练需要大量的内存,而且原始数据的存储有时候也需要很大的外存空间,所以如果可以利用集群来处理便是最好不过了。当然,除了Deeplearning4j以外,还有一些Deep Learning的库可以支持Spark,比如yahoo/CaffeOnSpark,AMPLab/SparkNet以及Intel最近开源的BigDL。这些库我自己都没怎么用过,所以就不多说了,这里重点说说Deeplearning4j的使用。
从项目管理角度,DL4J官方给的例子中,推荐使用Maven构建项目,但是目前在学习阶段,我是直接从官网扣下来了需要的Jar包导入项目,这样有一个好处,在项目迁移到别的计算机上运行的时候不需要等待Maven下载jar包的时间。当然,工作中还是推荐使用Maven。不说了,下面是我提出来的Jar包:
这里写图片描述
看着还是挺庞大的,其实也难怪,毕竟深度学习需要大量的工作才能形成一个库。这些我已经上传到CSDN可以点击下方链接下载(https://download.csdn.net/download/yushengpeng/10286975)

模型的训练

训练数据集(MNIST)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。它有60000个训练样本集和10000个测试样本集。MNIST算是深度学习入门的一个数据集吧,也是一个比较优秀的手写数字数据集,可以用于半监督学习,并且取得了非常不错的成绩。下面是该数据集的部分截图:
这里写图片描述
关于如何将该数据集转换成DL4J能识别的格式,请学习DL4J的官方文档。我也上传了Dl4J的官方文档到了CSND,如果你有需求请前去下载(https://download.csdn.net/download/yushengpeng/10287018)。

模型架构

当我们正确读取数据后,我们需要定义具体的神经网络结构,这里我用的是Lenet,该网络是一个5层的神经网络(在深度学习中,我们约定俗成的认为输入层是第0层不参与层数统计),该网络各层情况如下:

第0层: nput layer: 输入数据为原始训练图像
第1层: Conv1:6个5*5的卷积核,步长Stride为1
第2层:Pooling1:卷积核size为2*2,步长Stride为2
第3层:Conv2:12个5*5的卷积核,步长Stride为1
第4层:Pooling2:卷积核size为2*2,步长Stride为2
第5层:Output layer:输出为10维向量

网络层级结构示意图如下:
这里写图片描述

Deeplearning4j的实现参考了官网(https://github.com/deeplearning4j/dl4j-examples)的例子。具体代码如下:

public class CNN_MNIST {private static Logger log = LoggerFactory.getLogger(CNN_MNIST.class);public static void main(String[] args) throws IOException {int nChannels = 1;int outputNum = 10; // The number of possible outcomesint batchSize = 64; // Test batch sizeint nEpochs = 2; // Number of training epochsint iterations = 1; // Number of training iterationsint seed = 123; //DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).regularization(true).l2(0.0005).learningRate(.01).weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).momentum(0.9).list().layer(0, new ConvolutionLayer.Builder(5, 5)// nIn and nOut specify depth. nIn here is the nChannels and// nOut is the number of filters to be applied.nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(1,new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(2, new ConvolutionLayer.Builder(5, 5)// Note that nIn need not be specified in later layers.stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(3,new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(5,new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)) // See note below.backprop(true).pretrain(false).build();MultiLayerNetwork model = new MultiLayerNetwork(conf);model.init();model.setListeners(new ScoreIterationListener(1));for (int i = 0; i < nEpochs; i++) {model.fit(mnistTrain);log.info("*** Completed epoch {} ***", i);log.info("Evaluate model....");Evaluation eval = new Evaluation(outputNum);while (mnistTest.hasNext()) {DataSet ds = mnistTest.next();INDArray output = model.output(ds.getFeatureMatrix(), false);eval.eval(ds.getLabels(), output);}log.info(eval.stats());mnistTest.reset();log.info("****************Example finished********************");log.info("******SAVE TRAINED MODEL******");// Details// Where to save modelFile locationToSave = new File("trained_mnist_model.zip");// boolean save Updaterboolean saveUpdater = false;// ModelSerializer needs modelname, saveUpdater, LocationModelSerializer.writeModel(model, locationToSave, saveUpdater);}}
}

可以发现,神经网络需要定义很多的超参数,学习率、正则化系数、卷积核的大小、激励函数等都是需要人为设定的。不同的超参数,对结果的影响很大,其实后来发现,很多时间都花在数据处理和调参方面。毕竟自己设计网络的能力有限,一般都是参考大牛的论文,然后自己照葫芦画瓢地实现。这里实现的Lenet的结构是:卷积–>下采样–>卷积–>下采样–>全连接。和原论文的结构基本一致。卷积核的大小也是参考的原论文。具体细节可参考之前发的论文链接。这里我们设置了一个Score的监听事件,主要是可以在训练的时候获取每一次权重更新后损失函数的收敛情况,如下面所示:
这里写图片描述

模型性能

of classes:10
Accuracy0.9918
Precision0.9917
Recall0.9917
F1 Score0.9917

模型性能还是不错的,在10000个手写数字测试集上的准确率能达到99.17%。当然,模型的好坏跟神经网络的架构,超参的设置都有关系,关于到底选用什么样的模型架构需要更多的经验,知识。一般具体问题具体分析。

模型的保存与加载

当我们训练好了一个模型的时候,我们需要将训练好的模型持久化到本地磁盘,或者其他存储介质。因为训练模型是一个非常耗时的工作,模型的大小,数据集的大小,训练一个模型需要一天,一周,一个月,甚至是更长的时间。我们不可能每次在实际的项目中,需要的时候再去训练出一个模型。DL4J也为我们实现了模型的持久化功能,具体代码如下:

File locationToSave = new File("trained_mnist_model.zip");//保存路径,存储位置
boolean saveUpdater = false;
ModelSerializer.writeModel(model, locationToSave, saveUpdater);

当然持久化模型是为了再次加载模型,使用模型。DL4J也为我们实现了模型的的加载功能,具体代码如下:

NativeImageLoader loader = new NativeImageLoader(28, 28, 1);
INDArray image = loader.asMatrix(new File("XXX://test.jpg"));//从本地磁盘中加载文件
DataNormalization scaler = new ImagePreProcessingScaler(0,1);
scaler.transform(image);
INDArray output = model.output(image);//对图片进行分类预测

结果展示

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

总结 与展望

通过这个小项目,我参照官网手册,初步实现了LENET网络。并取得了不错的成果。当然,我也是学习了充足的理论之后,再来学习DL4J这个深度学习框架的。关于这个项目的源码,你可以去我的GutHub上下载:https://github.com/ShengPengYu/writtingRecoginition。
。该项目还有不足之处,比方说可以边测试边学习,我们在发现我们书写的测试数据分类不准确的时候,可以加入到训练数据库,在线对模型实时训练,因为每个用户的书写风格不一样,可能对分类结果有一定的影响。边测试边训练,可以训练出符合用户个人情况的模型。还有一种情况是。当然我也有一定的思考,比方说,如果我对目前模型进一步改进,做一个汉字识别项目,那么最后一层该使用什么架构,中国汉字那么多,如果使用one-hot模式,会不会维度太大,在时间复杂度和空间复杂度上是一个非常严峻的问题。


http://chatgpt.dhexx.cn/article/mwZ3VPUM.shtml

相关文章

识别数字的软件有哪些?这几款识别数字工具安利给你

嘿&#xff0c;朋友们&#xff0c;你们有没有遇到过需要处理大量数字的情况&#xff0c;要是一个一个手动输入感觉十分麻烦&#xff0c;还会耗费大量时间和精力&#xff1f;别着急&#xff0c;现在数字识别的软件已经非常发达了&#xff0c;只需要一款好用的数字识别软件&#…

这款数字识别软件你知道吗

识别数字技术是指通过计算机自动识别数字的能力&#xff0c;通常采用数字图像处理和模式识别等技术进行实现。你别看这个技术好像很高大上&#xff0c;实际上现在已经有很多软件可以做到识别数字了&#xff0c;你知道识别数字的软件有哪些吗&#xff1f;今天我就为大家科普这项…

基于TensorFlow深度学习框架,运用python搭建LeNet-5卷积神经网络模型和mnist手写数字识别数据集,设计一个手写数字识别软件。

本软件是基于TensorFlow深度学习框架&#xff0c;运用LeNet-5卷积神经网络模型和mnist手写数字识别数据集所设计的手写数字识别软件。 具体实现如下&#xff1a; 1.读入数据&#xff1a;运用TensorFlow深度学习框架&#xff0c;下载并读入mnist手写数字识别数据集。 2.构建模型…

OCR手写数字识别什么软件好用?介绍一种

OCR是指用电子设备检查文本上的资料&#xff0c;然后对图像文件进行分析处理&#xff0c;从而获取文字及版面信息的过程。那OCR手写数字识别有好用的软件吗&#xff1f;当我们需要整理大量手写资料需要整理时&#xff0c;下面这两款软件就派上用场了。 软件一、我们可以使用识别…

识别数字的软件有哪些?自动识别数字的方法并不难

每个月月初时&#xff0c;作为销售助理的同事经常要整理一大堆数据&#xff0c;密密麻麻的数字看得他头晕眼花&#xff0c;特别是有些图片里的数字&#xff0c;一不小心就容易出错&#xff0c;酿成严重的数据错误。像平时我也会处理到一些数据图片&#xff0c;为了准确及时的整…

Unity 渲染YUV数据 ---- 以Unity渲染Android Camera数据为例子

1 背景 一般Unity都是RGB直接渲染的&#xff0c;但是总有特殊情况下&#xff0c;需要渲染YUV数据。比如&#xff0c;Unity读取Android的Camera YUV数据&#xff0c;并渲染。本文就基于这种情况&#xff0c;来展开讨论。 Unity读取Android的byte数组&#xff0c;本身就耗时&am…

图形学之Unity渲染管线流程分析

文章来源&#xff1a; 学习通http://www.bdgxy.com/ 普学网http://www.boxinghulanban.cn/ 智学网http://www.jaxp.net/ 表格制作excel教程http://www.tpyjn.cn/ 学习通http://www.tsgmyy.cn/ 下图是《Unity Shader 入门精要》一书中的渲染流程图&#xff1b; ApplicationS…

Unity渲染(二):Shader着色器基础入门之渲染Image图片

Unity渲染(二):图片渲染 通过这里&#xff0c;你会学习到怎么将一张图片渲染到UI的Image组件或者SpriteRenderer上&#xff0c;以及透明物体的渲染。 上一章:Unity渲染(一):着色器基础入门之纯色Shader 开发环境&#xff1a;Unity5.0或者更高 透明与不透明的最终效果 概述 1…

unity 渲染性能分析工具

目标 既然要优化&#xff0c;肯定要有个目标&#xff1a; pc上一般要求&#xff1a;一秒渲染60帧 移动端&#xff1a;一秒渲染30帧 这应该是最低的要求&#xff0c;如果游戏运行时&#xff0c;游戏帧率有变化&#xff0c;人眼能够明显的感觉到帧率下降。 优化的首要规则是找到…

unity 渲染环境设置

环境光分为两种&#xff0c;一种是环境光漫反射SH&#xff08;球谐光照&#xff09;&#xff0c;另一种是环境光的镜面反射IBL&#xff08;基于图像的渲染&#xff09;。 光照的配置位置可以在 窗口 -> 渲染 -> 光照 打开。 环境照明对应的就是环境漫反射&#xff0c;环…

【流程向】模型复原与Unity渲染

项目简述 简单记录下学校里的一个项目&#xff0c;涉及到对/何家村遗宝/的模型复原&#xff0c;记录一下模型制作的全流程&#xff0c;同时涉及到Unity中一些优化画面的技术点。项目中渲染效果优先&#xff0c;没有怎么考虑性能。 流程&#xff1a;Blender高低模与展UV ->…

Unity中的物体渲染顺序

big seven 文章目录 前言 一、摄像机渲染 二、划分渲染队列 三、不透明物体的渲染 四、透明物体的渲染 五、UGUI元素的渲染 总结 前言 Unity中物体的渲染顺序 提示&#xff1a;以下是本篇文章正文内容&#xff0c;下面案例可供参考 一、摄像机渲染 Unity中的渲染顺序首先是…

Unity渲染流程概述

本篇的任务是回答&#xff1a;在Untiy的渲染流程中CPU和GPU分别做了什么。 渲染到设备屏幕显示的每一帧的画面&#xff0c;都经历几个阶段的加工过程&#xff1a; 应用程序阶段&#xff08;CPU&#xff09;&#xff1a;识别出潜在可视的网格实例&#xff0c;并把他们及其材质…

Unity_渲染_灯光_前向渲染

前向渲染路径 前向渲染的作用和意义场景内有多个灯光,如何渲染每个灯光对物体的影响 前向渲染的作用和意义 前向渲染的作用:处理多光源的渲染,多光源渲染在unity 有2中渲染方式 前向渲染和延时渲染 .延时渲染主要用于主机,PC平台,不在本次讨论范围.主要来研究前向渲染前向渲染…

【Unity渲染】前向渲染和延迟渲染的区别及切换

前向渲染和延迟渲染通道的区别&#xff0c;主要在对于光源的处理上。 Unity默认是前向渲染通道&#xff0c;如果光源特别多&#xff0c;可以使用延迟渲染。 前向渲染 使用前向渲染路径时&#xff0c;被照亮的对象将在单独的通道中进行渲染。根据场景中的光源数量以及它们是否…

从FrameDebugger看Unity渲染

从FrameDebugger看Unity渲染(一) Unity如何渲染一个3D2D的游戏画面&#xff0c;今天通过FrameDebugger来看下Unity内置渲染管线的渲染策略, 后续再出一些URP渲染管线相关的文章。 对啦&#xff01;这里有个游戏开发交流小组里面聚集了一帮热爱学习游戏的零基础小白&#xff0c…

UnityShader入门精要——Unity中的渲染优化技术(二)

减少DrawCall数目 最常见的优化技术——批处理。实现原理为减少渲染每一帧所需的drawcall数目。使用同一个材质的物体可以一起处理。 优点缺点动态批处理切处理都是Unity 自动完成的&#xff0c;不需要我们自己做任何操作&#xff0c;而且物体是可以移动的限制很多&#xff0c…

Unity渲染顺序(2)

Camera 除了Screen Space - Overlay(屏幕空间覆盖模式)下的Canvas,场景中的其他物体需要渲染到屏幕中&#xff0c;都需要在指定的相机的绘制下。场景中可以创建多个相机&#xff0c;每个相机所拍摄的内容可能并不相同&#xff0c;在场景中有多相机的情况&#xff0c;不同的相机…

Unity渲染顺序(1)

添加排序层级 在Unity编辑器的右上角选择Layers 按钮&#xff0c;在下拉菜单中点击Edit Layers…选项&#xff0c;将显示当前Unity的Tags, Sorting Layers&#xff0c;和Layers 编辑选项。 Sorting Layers是Unity中对排序的层级的定义块&#xff0c;在面板中越靠后的排序层级越…

Unity的渲染流程

Unity中坐标空间的转换&#xff1a; Unity的渲染流程&#xff1a; 渲染到设备屏幕的每一帧画面都要经历如下几个阶段&#xff1a; 应用程序阶段&#xff08;CPU&#xff09;&#xff1a;将材质和模型数据发送给GPU 几何阶段&#xff08;GPU&#xff09;&#xff1a;进行顶点…