【MNIST】

article/2025/9/22 18:54:30

1. Normal Neural Network:

首先我用的是两层(input layer 和 output layer)的feed-forward的神经网络结构来训练数据, y = wx + b, 在输出层用的是softmax求概率,算loss用的是交叉熵的办法,选用梯度下降法来最小化loss更新参数w和b,通过对比output和input对应的label,计算出准确率,因为此网络结构较简单,只能达到91%的accuracy。

 

代码实现:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)sess = tf.InteractiveSession()# input x 和 label y_
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])# 初始化权值和偏置
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))# 初始化所有的变量
sess.run(tf.initialize_all_variables())# 求输出,用softmax:y = Wx + b
y = tf.nn.softmax(tf.matmul(x, W) + b)# 交叉熵作为cost function
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))# 用梯度下降法最小化交叉熵来更新参数
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)# 进行训练过程
for i in range(1000):batch = mnist.train.next_batch(100)optimizer.run(feed_dict={x: batch[0], y_: batch[1]})  # 引入x和y_的值# 判断y和y_是否equal
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 把判断y和y_的布尔值转化为数值
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))# 输出test数据集的accuracy
print(accuracy.eval(feed_dict={x: mnist.train.images, y_: mnist.train.labels}))
print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

2. CNN

用卷积神经网络结构进行了手写数字识别,主要用到的方法有:conv, max_pooling, ReLU, dropout, softmax. 计算loss用的是交叉熵,用Adam来最小化loss更新参数,准确率可以达到99.27%

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)sess = tf.InteractiveSession()x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])# 权重和偏差初始化
def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1)return tf.Variable(initial)def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)# 建立卷积和池化的函数
def conv2d(x, W):return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')def max_pool_2x2(x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')# 第一层卷积
W_conv1 = weight_variable([5, 5, 1, 32])  # 32 个 5x5 的filter
b_conv1 = bias_variable([32])
x_image = tf.reshape(x, [-1, 28, 28, 1])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)# 第二层卷积
W_conv2 = weight_variable([5, 5, 32, 64])  # 64个5x5x32的filter
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)# fully connected layer
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)# dropout:keep_prob:概率;在训练过程中使用dropout,在测试过程中不用dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)# output layer:softmax
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)# 训练和评估模型
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())
for i in range(20000):batch = mnist.train.next_batch(50)if i % 100 == 0:train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})print("step %d, training accuracy %g" % (i, train_accuracy))train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})print("test accuracy %g" % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

 


http://chatgpt.dhexx.cn/article/TlYunkoJ.shtml

相关文章

MNIST数据集使用详解

数据集下载网址:http://yann.lecun.com/exdb/mnist/ 下载后无需解压,将其放在一个文件夹下即可: 数据说明: 数据集常被分为2~3个部分 训练集(train set):用来学习的一组例子,用来适应分类器的参数[即权重]…

详解 Pytorch 实现 MNIST

MNIST虽然很简单,但是值得我们学习的东西还是有很多的。 项目虽然简单,但是个人建议还是将各个模块分开创建,特别是对于新人而言,模块化的创建会让读者更加清晰、易懂。 CNN模块:卷积神经网络的组成;trai…

十分钟搞懂Pytorch如何读取MNIST数据集

前言 本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧… 正文 在阅读教程书籍《深度学习入门之Pytorch》时,文中是如此加载MNIST手写数字训练集的: train_dataset datasets.MNIST(root./MNIST,trainTrue,transform…

torchvision中datasets.MNIST介绍

用法介绍 torchvision中datasets中所有封装的数据集都是torch.utils.data.Dataset的子类,它们都实现了__getitem__和__len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。以datasets.MNIST类为例,具体参数和用法如下所示…

万物皆用MNIST---MNIST数据集及创建自己的手写数字数据集

刚刚接触到人工智能的我们,必定会遇到一个非常非常非常熟悉的朋友------MNIST 这是一套流行的手写数字图片,常常被用来测试我们的思想和算法。这个数据集称为手写数字的MNIST数据库,从研究员Yann LeCun 的网站,可以得到这个…

Pytorch 之 MNIST 数据集实现

目录 1. 数据集介绍2. 代码2. 读代码(个人喜欢的顺序)2.1. 导入模块部分:2.2. Main 函数: 1. 数据集介绍 一般而言,MNIST 数据集测试就是机器学习和深度学习当中的"Hello World"工程。几乎是所…

MNIST数据集手写数字识别(CNN)

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集详解及可视化处理(pytorch)

MNIST数据集详解及可视化处理(pytorch) MNIST 数据集已经是一个被”嚼烂”了的数据集, 作为机器学习在视觉领域的“hello world”,很多教程都会对它”下手”, 几乎成为一个 “典范”。 不过有些人可能对它还不是很了解, 下面来介绍一下。 MN…

Mnist数据集介绍

Mnist数据集已经是一个被"嚼烂"了的数据集了,很多关于神经网络的教程都会对它下手。因此在开始深度学习之前,先对这个数据集介绍一下。 Mnist数据集图片格式介绍 Mnist数据集分为两部分,分别含有60000张训练图片和10000张测试图片…

使用MNIST数据集

首先,必须向各位强调的是:该数据集名字叫MNIST,而非MINIST~ 我之前就一直弄错了! 哈哈~ 网上有很多使用MNIST数据集的教程,要么太麻烦,要么需要翻墙下载,很慢。 在这里分…

Fashion MNIST进行分类

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集简介与使用

MNIST数据集简介 MNIST数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(t…

详解 MNIST 数据集

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下. MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分: Training set images: train-images-idx3-…

Mnist数据集简介

MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,…

[转]MNIST机器学习入门

MNIST机器学习入门 转自:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html?plg_nld1&plg_uin1&plg_auth1&plg_nld1&plg_usr1&plg_vkey1&plg_dev1 这个教程的目标读者是对机器学习和TensorFlow都不太了解的新手。如…

从手写数字识别入门深度学习丨MNIST数据集详解

就像无数人从敲下“Hello World”开始代码之旅一样,许多研究员从“MNIST数据集”开启了人工智能的探索之路。 MNIST数据集(Mixed National Institute of Standards and Technology database)是一个用来训练各种图像处理系统的二进制图像数据…

Pytorch入门--详解Mnist手写字识别

1 什么是Mnist? Mnist是计算机视觉领域中最为基础的一个数据集。 MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10…

MNIST数据集

一、MNIST数据集介绍 MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取,主要包括四个文件&…

面试官: 你知道 JWT、JWE、JWS 、JWK嘛?

想起了 之前做过的 很多 登录授权 的项目 它相比原先的session、cookie来说,更快更安全,跨域也不再是问题,更关键的是更加优雅 ,所以今天总结了一篇文章来介绍他 JWT 指JSON Web Token,如果在项目中通过 jjwt 来支持 J…

java jwe/jws_一篇文章带你分清楚JWT,JWS与JWE

随着移动互联网的兴起,传统基于session/cookie的web网站认证方式转变为了基于OAuth2等开放授权协议的单点登录模式(SSO),相应的基于服务器session浏览器cookie的Auth手段也发生了转变,Json Web Token出现成为了当前的热门的Token Auth机制。 …