MNIST数据集手写数字识别(CNN)

article/2025/9/22 20:02:58

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎

📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃

🎁欢迎各位→点赞👍 + 收藏⭐️ + 留言📝​

📣系列专栏 - 机器学习【ML】 自然语言处理【NLP】  深度学习【DL】

 🖍foreword

✔说明⇢本人讲解主要包括Python、机器学习(ML)、深度学习(DL)、自然语言处理(NLP)等内容。

如果你对这个系列感兴趣的话,可以关注订阅哟👋

1 数据集介绍

MNIST 包括6万张28x28的训练样本,1万张测试样本,很多教程都会对它”下手”几乎成为一个 “典范”,可以说它就是计算机视觉里面的Hello World。所以我们这里也会使用MNIST来进行实战。

前面在介绍卷积神经网络的时候说到过LeNet-5,LeNet-5之所以强大就是因为在当时的环境下将MNIST数据的识别率提高到了99%,这里我们也自己从头搭建一个卷积神经网络,也达到99%的准确率

2 手写数字识别

首先,我们定义一些超参数

BATCH_SIZE=512 #大概需要2G的显存
EPOCHS=20 # 总共训练批次
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 让torch判断是否使用GPU,建议使用GPU环境,因为会快很多

因为Pytorch里面包含了MNIST的数据集,所以我们这里直接使用即可。 如果第一次执行会生成data文件夹,并且需要一些时间下载,如果以前下载过就不会再次下载了

由于官方已经实现了dataset,所以这里可以直接使用DataLoader来对数据进行读取

train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=BATCH_SIZE, shuffle=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!

测试集

test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=BATCH_SIZE, shuffle=True)
 

下面我们定义一个网络,网络包含两个卷积层,conv1和conv2,然后紧接着两个线性层作为输出,最后输出10个维度,这10个维度我们作为0-9的标识来确定识别出的是那个数字

在这里建议大家将每一层的输入和输出维度都作为注释标注出来,这样后面阅读代码的会方便很多

class ConvNet(nn.Module):def __init__(self):super().__init__()# batch*1*28*28(每次会送入batch个样本,输入通道数1(黑白图像),图像分辨率是28x28)# 下面的卷积层Conv2d的第一个参数指输入通道数,第二个参数指输出通道数,第三个参数指卷积核的大小self.conv1 = nn.Conv2d(1, 10, 5) # 输入通道数1,输出通道数10,核的大小5self.conv2 = nn.Conv2d(10, 20, 3) # 输入通道数10,输出通道数20,核的大小3# 下面的全连接层Linear的第一个参数指输入通道数,第二个参数指输出通道数self.fc1 = nn.Linear(20*10*10, 500) # 输入通道数是2000,输出通道数是500self.fc2 = nn.Linear(500, 10) # 输入通道数是500,输出通道数是10,即10分类def forward(self,x):in_size = x.size(0) # 在本例中in_size=512,也就是BATCH_SIZE的值。输入的x可以看成是512*1*28*28的张量。out = self.conv1(x) # batch*1*28*28 -> batch*10*24*24(28x28的图像经过一次核为5x5的卷积,输出变为24x24)out = F.relu(out) # batch*10*24*24(激活函数ReLU不改变形状))out = F.max_pool2d(out, 2, 2) # batch*10*24*24 -> batch*10*12*12(2*2的池化层会减半)out = self.conv2(out) # batch*10*12*12 -> batch*20*10*10(再卷积一次,核的大小是3)out = F.relu(out) # batch*20*10*10out = out.view(in_size, -1) # batch*20*10*10 -> batch*2000(out的第二维是-1,说明是自动推算,本例中第二维是20*10*10)out = self.fc1(out) # batch*2000 -> batch*500out = F.relu(out) # batch*500out = self.fc2(out) # batch*500 -> batch*10out = F.log_softmax(out, dim=1) # 计算log(softmax(x))return out

我们实例化一个网络,实例化后使用.to方法将网络移动到GPU

优化器我们也直接选择简单暴力的Adam

model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters())

下面定义一下训练的函数,我们将训练的所有操作都封装到这个函数中

def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if(batch_idx+1)%30 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))

测试的操作也一样封装成一个函数

def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
 

下面开始训练,这里就体现出封装起来的好处了,只要写两行就可以了

for epoch in range(1, EPOCHS + 1):train(model, DEVICE, train_loader, optimizer, epoch)test(model, DEVICE, test_loader)
 
Train Epoch: 1 [14848/60000 (25%)]	Loss: 0.272529
Train Epoch: 1 [30208/60000 (50%)]	Loss: 0.235455
Train Epoch: 1 [45568/60000 (75%)]	Loss: 0.101858Test set: Average loss: 0.1018, Accuracy: 9695/10000 (97%)Train Epoch: 2 [14848/60000 (25%)]	Loss: 0.057989
Train Epoch: 2 [30208/60000 (50%)]	Loss: 0.083935
Train Epoch: 2 [45568/60000 (75%)]	Loss: 0.051921Test set: Average loss: 0.0523, Accuracy: 9825/10000 (98%)Train Epoch: 3 [14848/60000 (25%)]	Loss: 0.045383
Train Epoch: 3 [30208/60000 (50%)]	Loss: 0.049402
Train Epoch: 3 [45568/60000 (75%)]	Loss: 0.061366Test set: Average loss: 0.0408, Accuracy: 9866/10000 (99%)Train Epoch: 4 [14848/60000 (25%)]	Loss: 0.035253
Train Epoch: 4 [30208/60000 (50%)]	Loss: 0.038444
Train Epoch: 4 [45568/60000 (75%)]	Loss: 0.036877Test set: Average loss: 0.0433, Accuracy: 9859/10000 (99%)Train Epoch: 5 [14848/60000 (25%)]	Loss: 0.038996
Train Epoch: 5 [30208/60000 (50%)]	Loss: 0.020670
Train Epoch: 5 [45568/60000 (75%)]	Loss: 0.034658Test set: Average loss: 0.0339, Accuracy: 9885/10000 (99%)Train Epoch: 6 [14848/60000 (25%)]	Loss: 0.067320
Train Epoch: 6 [30208/60000 (50%)]	Loss: 0.016328
Train Epoch: 6 [45568/60000 (75%)]	Loss: 0.017037Test set: Average loss: 0.0348, Accuracy: 9881/10000 (99%)Train Epoch: 7 [14848/60000 (25%)]	Loss: 0.022150
Train Epoch: 7 [30208/60000 (50%)]	Loss: 0.009608
Train Epoch: 7 [45568/60000 (75%)]	Loss: 0.012742Test set: Average loss: 0.0346, Accuracy: 9895/10000 (99%)Train Epoch: 8 [14848/60000 (25%)]	Loss: 0.010173
Train Epoch: 8 [30208/60000 (50%)]	Loss: 0.019482
Train Epoch: 8 [45568/60000 (75%)]	Loss: 0.012159Test set: Average loss: 0.0323, Accuracy: 9886/10000 (99%)Train Epoch: 9 [14848/60000 (25%)]	Loss: 0.007792
Train Epoch: 9 [30208/60000 (50%)]	Loss: 0.006970
Train Epoch: 9 [45568/60000 (75%)]	Loss: 0.004989Test set: Average loss: 0.0294, Accuracy: 9909/10000 (99%)Train Epoch: 10 [14848/60000 (25%)]	Loss: 0.003764
Train Epoch: 10 [30208/60000 (50%)]	Loss: 0.005944
Train Epoch: 10 [45568/60000 (75%)]	Loss: 0.001866Test set: Average loss: 0.0361, Accuracy: 9902/10000 (99%)Train Epoch: 11 [14848/60000 (25%)]	Loss: 0.002737
Train Epoch: 11 [30208/60000 (50%)]	Loss: 0.014134
Train Epoch: 11 [45568/60000 (75%)]	Loss: 0.001365Test set: Average loss: 0.0309, Accuracy: 9905/10000 (99%)Train Epoch: 12 [14848/60000 (25%)]	Loss: 0.003344
Train Epoch: 12 [30208/60000 (50%)]	Loss: 0.003090
Train Epoch: 12 [45568/60000 (75%)]	Loss: 0.004847Test set: Average loss: 0.0318, Accuracy: 9902/10000 (99%)Train Epoch: 13 [14848/60000 (25%)]	Loss: 0.001278
Train Epoch: 13 [30208/60000 (50%)]	Loss: 0.003016
Train Epoch: 13 [45568/60000 (75%)]	Loss: 0.001328Test set: Average loss: 0.0358, Accuracy: 9906/10000 (99%)Train Epoch: 14 [14848/60000 (25%)]	Loss: 0.002219
Train Epoch: 14 [30208/60000 (50%)]	Loss: 0.003487
Train Epoch: 14 [45568/60000 (75%)]	Loss: 0.014429Test set: Average loss: 0.0376, Accuracy: 9896/10000 (99%)Train Epoch: 15 [14848/60000 (25%)]	Loss: 0.003042
Train Epoch: 15 [30208/60000 (50%)]	Loss: 0.002974
Train Epoch: 15 [45568/60000 (75%)]	Loss: 0.000871Test set: Average loss: 0.0346, Accuracy: 9909/10000 (99%)Train Epoch: 16 [14848/60000 (25%)]	Loss: 0.000618
Train Epoch: 16 [30208/60000 (50%)]	Loss: 0.003164
Train Epoch: 16 [45568/60000 (75%)]	Loss: 0.007245Test set: Average loss: 0.0357, Accuracy: 9905/10000 (99%)Train Epoch: 17 [14848/60000 (25%)]	Loss: 0.001874
Train Epoch: 17 [30208/60000 (50%)]	Loss: 0.013951
Train Epoch: 17 [45568/60000 (75%)]	Loss: 0.000729Test set: Average loss: 0.0322, Accuracy: 9922/10000 (99%)Train Epoch: 18 [14848/60000 (25%)]	Loss: 0.002581
Train Epoch: 18 [30208/60000 (50%)]	Loss: 0.001396
Train Epoch: 18 [45568/60000 (75%)]	Loss: 0.015521Test set: Average loss: 0.0389, Accuracy: 9914/10000 (99%)Train Epoch: 19 [14848/60000 (25%)]	Loss: 0.000283
Train Epoch: 19 [30208/60000 (50%)]	Loss: 0.001385
Train Epoch: 19 [45568/60000 (75%)]	Loss: 0.011184Test set: Average loss: 0.0383, Accuracy: 9901/10000 (99%)Train Epoch: 20 [14848/60000 (25%)]	Loss: 0.000472
Train Epoch: 20 [30208/60000 (50%)]	Loss: 0.003306
Train Epoch: 20 [45568/60000 (75%)]	Loss: 0.018017Test set: Average loss: 0.0393, Accuracy: 9899/10000 (99%)

我们看一下结果,准确率99%,没问题

如果你的模型连MNIST都搞不定,那么你的模型没有任何的价值

即使你的模型搞定了MNIST,你的模型也可能没有任何的价值

MNIST是一个很简单的数据集,由于它的局限性只能作为研究用途,对实际应用带来的价值非常有限。但是通过这个例子,我们可以完全了解一个实际项目的工作流程

我们找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。

并且通过这个实战我们已经有了一个很好的模板,以后的项目都可以以这个模板为样例


http://chatgpt.dhexx.cn/article/ajjAuhm5.shtml

相关文章

MNIST数据集详解及可视化处理(pytorch)

MNIST数据集详解及可视化处理(pytorch) MNIST 数据集已经是一个被”嚼烂”了的数据集, 作为机器学习在视觉领域的“hello world”,很多教程都会对它”下手”, 几乎成为一个 “典范”。 不过有些人可能对它还不是很了解, 下面来介绍一下。 MN…

Mnist数据集介绍

Mnist数据集已经是一个被"嚼烂"了的数据集了,很多关于神经网络的教程都会对它下手。因此在开始深度学习之前,先对这个数据集介绍一下。 Mnist数据集图片格式介绍 Mnist数据集分为两部分,分别含有60000张训练图片和10000张测试图片…

使用MNIST数据集

首先,必须向各位强调的是:该数据集名字叫MNIST,而非MINIST~ 我之前就一直弄错了! 哈哈~ 网上有很多使用MNIST数据集的教程,要么太麻烦,要么需要翻墙下载,很慢。 在这里分…

Fashion MNIST进行分类

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集简介与使用

MNIST数据集简介 MNIST数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(t…

详解 MNIST 数据集

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下. MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分: Training set images: train-images-idx3-…

Mnist数据集简介

MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,…

[转]MNIST机器学习入门

MNIST机器学习入门 转自:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html?plg_nld1&plg_uin1&plg_auth1&plg_nld1&plg_usr1&plg_vkey1&plg_dev1 这个教程的目标读者是对机器学习和TensorFlow都不太了解的新手。如…

从手写数字识别入门深度学习丨MNIST数据集详解

就像无数人从敲下“Hello World”开始代码之旅一样,许多研究员从“MNIST数据集”开启了人工智能的探索之路。 MNIST数据集(Mixed National Institute of Standards and Technology database)是一个用来训练各种图像处理系统的二进制图像数据…

Pytorch入门--详解Mnist手写字识别

1 什么是Mnist? Mnist是计算机视觉领域中最为基础的一个数据集。 MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10…

MNIST数据集

一、MNIST数据集介绍 MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取,主要包括四个文件&…

面试官: 你知道 JWT、JWE、JWS 、JWK嘛?

想起了 之前做过的 很多 登录授权 的项目 它相比原先的session、cookie来说,更快更安全,跨域也不再是问题,更关键的是更加优雅 ,所以今天总结了一篇文章来介绍他 JWT 指JSON Web Token,如果在项目中通过 jjwt 来支持 J…

java jwe/jws_一篇文章带你分清楚JWT,JWS与JWE

随着移动互联网的兴起,传统基于session/cookie的web网站认证方式转变为了基于OAuth2等开放授权协议的单点登录模式(SSO),相应的基于服务器session浏览器cookie的Auth手段也发生了转变,Json Web Token出现成为了当前的热门的Token Auth机制。 …

JWS实现WebService

WebService估计大家都有听过或者使用过。Java有几种常用的方式实现webservice,本文主要是讨论JWS实现。 什么是webservice 简单而言,webservice就是通过SOAP协议在Web上提供的服务,使用WSDL文件进行说明。其特点是走SOAP协议而不是http协议&…

WebService 理论详解、JWS(Java Web Service) 快速入门

目录 WebService (web服务)概述 WebService 平台技术 WebService 工作原理 WebService 开发流程 常见 Web Service 框架 JWS(Java Web Service) 概述 JWS(Java Web Service) 快速入门 WebService (web服务)概述 1、WebService(Web服务)是一种跨语…

一文理解 JWT、JWS、JWE、JWA、JWK、JOSE

原文收录 GitBook——统一接口认证解决方案 JsonWebToken 关于JsonWebToken的专业名词解释: unsecured JWT:默认头部{“alg”: “none”}的jwt令牌JWS(SignedJWT):已签名的jwt,包含标准jwt结构:header、payload、signatureJWE…

JWS入门

JWS简介 JWS主要用来通过网络部署你的应用程序,它具有安全、稳定、易维护、易使用的特点。用户访问用JWS部署应用程序的站点,下载发布的应用程序,既可以在 线运行,也可以通过JWS的客户端离线运行已下载的应用程序。对同一个应用程…

【C语言】判断一个数是否是完全平方数(两种解法)

题目: 判断一个数是否是完全平方数。 以下数字为完全平方数:42*2,93*3,14412*12,16913*13 有两个方法,可以求完全平方数: 方法一:输入一个数,遍历所有比这个数小的数,只要有其中一个数满足条件…

C语言 输入10个数,将其中最小的数与第一个数对换,将最大的数与最后一个数对换

#include <stdio.h> void input(int *number){ //定义输入10个数的函数int i;printf("请输入10个整数:\n");for(i0;i<10;i)scanf("%d",&number[i]); } void max_min_value(int *number){ //交换函数int *max,*min,*p,temp;maxminnumber; //开…

C语言判断一个数是奇数还是偶数

#include <stdio.h> void main() { int n; scanf("%d",&n); //运用scanf函数可以输入想要的数字 //也可以采用int n&#xff08;取一个数&#xff09;进行运算 if(n%20)//if函数注意&#xff0c;%是取余 printf("%d是一个偶…