详解 Pytorch 实现 MNIST

article/2025/9/22 19:53:43

MNIST虽然很简单,但是值得我们学习的东西还是有很多的。

项目虽然简单,但是个人建议还是将各个模块分开创建,特别是对于新人而言,模块化的创建会让读者更加清晰、易懂。

  • CNN模块:卷积神经网络的组成;
  • train模块:利用CNN模型 对 MNIST数据集 进行训练并保存模型
  • test模块:加载训练好的模型对测试集数据进行测试
  • cnn.pt : train 的CNN模型

注意!
有GPU的小伙伴尽量使用GPU训练,GPU的训练速度比CPU的训练速度高许多倍,可以节约大量训练时间

在这里插入图片描述

文章目录

    • 1、CNN 模块
      • CNN 模块分析
    • 2、train 模块
    • 3、test 模块
    • 4、MNIST 数据集


1、CNN 模块

MNIST的识别算法有很多,在此提供的是 卷积神经网络CNN ,其他算法也同样可以取得很好的识别效果,有兴趣的小伙伴可以自己尝试下。

在此就不得不提 Pytorch的优势了,都知道 Pytorch 是动态计算模型。但是何为动态计算模型呢?

  • 在此对比 Tensorflow。在流行的神经网络架构中, Tensorflow 就是最典型的静态计算架构。使用 Tensorflow 就必须先搭建好这样一个计算系统, 一旦搭建好了, 就不能改动了 (也有例外), 所有的计算都会在这种图中流动, 当然很多情况下这样就够了, 我们不需要改动什么结构。
  • 不动结构当然可以提高效率. 但是一旦计算流程不是静态的, 计算图要变动. 最典型的例子就是 RNN, 有时候 RNN 的 time step 不会一样, 或者在 training 和 testing 的时候, batch_size 和 time_step 也不一样, 这时, Tensorflow 就头疼了。
  • 如果用一个动态计算图的 Pytorch, 我们就好理解多了, 写起来也简单多了. PyTorch 支持在运行过程中根据运行参数动态改变应用模型。可以简单理解为:一种是先定义后使用,另一种是边使用边定义。动态计算图模式是 PyTorch 的天然优势之一,Google 2019年 3 月份发布的 TensorFlow 2.0 Alpha 版本中的 Eager Execution,被认为是在动态计算图模式上追赶 PyTorch 的举措。

如果暂时看不懂的小伙伴,可以先不管,先往后学习,等将来需要的时候再回头思考这段话。


CNN 模块分析

CNN 模块主要分为两个部分,一个是定义CNN模块,另一个是将各个模块组成前向传播通道

  • super() 函数: 是用于调用父类(超类)的一个方法。
    用来解决多重继承问题的,直接用类名调用父类方法在使用单继承的时候没问题,但是如果使用多继承,会涉及到查找顺序(MRO)、重复调用(钻石继承)等种种问题。
    super(SimpleCNN, self) 首先找到 SimpleCNN 的父类(就是类 nn.Module ),然后把类 SimpleCNN 的对象转换为类 nn.Module 的对象

  • nn.Sequential(): 是一个有顺序的容器,将神经网络模块 按照传入构造器的顺序依次被添加到计算图中执行。由于每一个神经网络模块都继承于nn.Module,通过索引的方式利用add_module函数将 nn.Sequential()模块 添加到现有模块中。

  • forward(): 是前向传播函数,将之前定义好的每层神经网络模块串联起来,同时也定义了模型的输入参数

  • x.view() & x.reshape(): 其实两者的作用并没有太大区别,作用都是调整张量的类型大小,view() 出现的更早些,而 reshape() 则是为了与 Numpy对齐,在 Pytorch 0.3版本之后添加的,两者作用没有太大区别;

#  !/usr/bin/env  python
#  -*- coding:utf-8 -*-
# @Time   :  2020.
# @Author :  绿色羽毛
# @Email  :  lvseyumao@foxmail.com
# @Blog   :  https://blog.csdn.net/ViatorSun
# @Note   :  from torch import nnclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.layer1 = nn.Sequential( nn.Conv2d(1,16,kernel_size=3) ,nn.BatchNorm2d(16) ,nn.ReLU(inplace=True))self.layer2 = nn.Sequential( nn.Conv2d(16,32,kernel_size=3) ,nn.BatchNorm2d(32) ,nn.ReLU(inplace=True) ,nn.MaxPool2d(kernel_size=2 , stride=2))self.layer3 = nn.Sequential( nn.Conv2d(32,64,kernel_size=3) ,nn.BatchNorm2d(64) ,nn.ReLU(inplace=True))self.layer4 = nn.Sequential( nn.Conv2d(64,128,kernel_size=3) ,nn.BatchNorm2d(128) ,nn.ReLU(inplace=True) ,nn.MaxPool2d(kernel_size=2 , stride=2))self.fc = nn.Sequential(nn.Linear(128*4*4,1024) ,nn.ReLU(inplace=True) ,nn.Linear(1024,128) ,nn.ReLU(inplace=True) ,nn.Linear(128,10) )def forward( self , x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)# x = x.view(x.size(0) , -1)x = x.reshape(x.size(0) , -1)fc_out = self.fc(x)return fc_out

2、train 模块

#  !/usr/bin/env  python
#  -*- coding:utf-8 -*-
# @Time   :  2020.
# @Author :  绿色羽毛
# @Email  :  lvseyumao@foxmail.com
# @Blog   :  https://blog.csdn.net/ViatorSun
# @Note   :  import torch
import CNN
from torch import nn , optim
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader# 定义超参数
learning_rate = 1e-2      # 学习率
batch_size    = 128       # 批的大小
epoches_num   = 20        # 遍历训练集的次数# 下载训练集 MNIST 手写数字训练集
train_dataset = datasets.MNIST( root='./data', train=True, transform=transforms.ToTensor(), download=True )
train_loader  = DataLoader( train_dataset, batch_size=batch_size, shuffle=True )# 定义model 、loss 、optimizer
model = CNN.SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD( model.parameters(), lr=learning_rate )if torch.cuda.is_available():print("CUDA is enable!")model = model.cuda()model.train()# 开始训练
for epoch in range(epoches_num):print('*' * 40)train_loss = 0.0train_acc  = 0.0# 训练for i, data in enumerate(train_loader, 1 ):img, label = data# 拥有GPU的小伙伴还是推荐使用GPU训练if torch.cuda.is_available():img   = Variable(img).cuda()label = Variable(label).cuda()else:img   = Variable(img)label = Variable(label)# 前向传播optimizer.zero_grad()out  = model(img)loss = criterion(out, label)# 反向传播loss.backward()optimizer.step()# 损失/准确率计算train_loss += loss.item() * label.size(0)_ , pred    = out.max(1)num_correct = pred.eq(label).sum()accuracy    = pred.eq(label).float().mean()train_acc  += num_correct.item()print('Finish  {}  Loss: {:.6f}, Acc: {:.6f}'.format( epoch+1 , train_loss / len(train_dataset), train_acc / len(train_dataset )))# 保存模型
torch.save(model, 'cnn.pt')

3、test 模块

在模型的使用过程中,有些子模块(如:丢弃层、批次归一化层等)有两种状态,即训练状态和预测状态,在不同时候 Pytorch模型 需要在两种状态中相互转换。

  • model.tran() 方法会将模型(包含所有子模块)中的参数转换成训练状态
  • model.eval() 方法会将模型(包含所有子模块)中的参数转换成预测状态

Pytorch 的模型在不同状态下的预测准确性会有差异,在训练模型的时候需要转换为训练状态,在预测的时候需要转化为预测状态,否则最后模型预测准确性可能会降低,甚至会得到错误的结果。

#  !/usr/bin/env  python
#  -*- coding:utf-8 -*-
# @Time   :  2020.
# @Author :  绿色羽毛
# @Email  :  lvseyumao@foxmail.com
# @Blog   :  https://blog.csdn.net/ViatorSun
# @Note   :  import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader# 定义超参数
batch_size  = 128       # 批的大小# 下载训练集 MNIST 手写数字测试集
test_dataset  = datasets.MNIST( root='./data', train=False, transform=transforms.ToTensor())
test_loader   = DataLoader(test_dataset , batch_size=batch_size, shuffle=False)# 加载 Train 模型
model = torch.load('cnn.pt')
criterion = nn.CrossEntropyLoss()
model.eval()
eval_acc  = 0
eval_loss = 0# 测试
for data in test_loader:img, label = dataif torch.cuda.is_available():img   = Variable(img  ).cuda()label = Variable(label).cuda()else:img   = Variable(img  )label = Variable(label)out  = model(img)loss = criterion(out, label)eval_loss += loss.item() * label.size(0)_ , pred = torch.max(out,1)num_correct = (pred==label).sum()eval_acc += num_correct.item()print('Test Loss: {:.6f}   ,   Acc: {:.6f}'.format( eval_loss/(len(test_dataset)), eval_acc/(len(test_dataset)) ))

4、MNIST 数据集

如果还没有MNIST数据集,可以通过以下方式从 torchvision 下载,下载路径为项目路径下的 ‘./data’ 文件夹下,可以看到 MNIST 的数据是 ubyte

from torchvision import datasetstrain_dataset = datasets.MNIST( root='./data', train=True, transform=transforms.ToTensor(), download=True )

在这里插入图片描述
通过上述 datasets.MNIST 命令将 MNIST 数据读取到内存中,并转换为 Tensor 格式保存在 train_dataset 变量中,通过Debug 我们可以看到 MNIST 的数据 是 【10000,28,28】的数据,每个 【28,28】的数据对应的标签是 targets

在这里插入图片描述

然后我们可视化其中一个 【28,28】数据可以看出,其就是一个 28x28 的单通道灰度图,每个值表示一个像素点,其值范围为 【0-255】,像素值并不能直接传入模型,需要通过 transforms.ToTensor() 将其转化为 Tensor格式。

在这里插入图片描述


http://chatgpt.dhexx.cn/article/jvyTRGlc.shtml

相关文章

十分钟搞懂Pytorch如何读取MNIST数据集

前言 本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧… 正文 在阅读教程书籍《深度学习入门之Pytorch》时,文中是如此加载MNIST手写数字训练集的: train_dataset datasets.MNIST(root./MNIST,trainTrue,transform…

torchvision中datasets.MNIST介绍

用法介绍 torchvision中datasets中所有封装的数据集都是torch.utils.data.Dataset的子类,它们都实现了__getitem__和__len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。以datasets.MNIST类为例,具体参数和用法如下所示…

万物皆用MNIST---MNIST数据集及创建自己的手写数字数据集

刚刚接触到人工智能的我们,必定会遇到一个非常非常非常熟悉的朋友------MNIST 这是一套流行的手写数字图片,常常被用来测试我们的思想和算法。这个数据集称为手写数字的MNIST数据库,从研究员Yann LeCun 的网站,可以得到这个…

Pytorch 之 MNIST 数据集实现

目录 1. 数据集介绍2. 代码2. 读代码(个人喜欢的顺序)2.1. 导入模块部分:2.2. Main 函数: 1. 数据集介绍 一般而言,MNIST 数据集测试就是机器学习和深度学习当中的"Hello World"工程。几乎是所…

MNIST数据集手写数字识别(CNN)

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集详解及可视化处理(pytorch)

MNIST数据集详解及可视化处理(pytorch) MNIST 数据集已经是一个被”嚼烂”了的数据集, 作为机器学习在视觉领域的“hello world”,很多教程都会对它”下手”, 几乎成为一个 “典范”。 不过有些人可能对它还不是很了解, 下面来介绍一下。 MN…

Mnist数据集介绍

Mnist数据集已经是一个被"嚼烂"了的数据集了,很多关于神经网络的教程都会对它下手。因此在开始深度学习之前,先对这个数据集介绍一下。 Mnist数据集图片格式介绍 Mnist数据集分为两部分,分别含有60000张训练图片和10000张测试图片…

使用MNIST数据集

首先,必须向各位强调的是:该数据集名字叫MNIST,而非MINIST~ 我之前就一直弄错了! 哈哈~ 网上有很多使用MNIST数据集的教程,要么太麻烦,要么需要翻墙下载,很慢。 在这里分…

Fashion MNIST进行分类

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集简介与使用

MNIST数据集简介 MNIST数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(t…

详解 MNIST 数据集

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下. MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分: Training set images: train-images-idx3-…

Mnist数据集简介

MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,…

[转]MNIST机器学习入门

MNIST机器学习入门 转自:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html?plg_nld1&plg_uin1&plg_auth1&plg_nld1&plg_usr1&plg_vkey1&plg_dev1 这个教程的目标读者是对机器学习和TensorFlow都不太了解的新手。如…

从手写数字识别入门深度学习丨MNIST数据集详解

就像无数人从敲下“Hello World”开始代码之旅一样,许多研究员从“MNIST数据集”开启了人工智能的探索之路。 MNIST数据集(Mixed National Institute of Standards and Technology database)是一个用来训练各种图像处理系统的二进制图像数据…

Pytorch入门--详解Mnist手写字识别

1 什么是Mnist? Mnist是计算机视觉领域中最为基础的一个数据集。 MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10…

MNIST数据集

一、MNIST数据集介绍 MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取,主要包括四个文件&…

面试官: 你知道 JWT、JWE、JWS 、JWK嘛?

想起了 之前做过的 很多 登录授权 的项目 它相比原先的session、cookie来说,更快更安全,跨域也不再是问题,更关键的是更加优雅 ,所以今天总结了一篇文章来介绍他 JWT 指JSON Web Token,如果在项目中通过 jjwt 来支持 J…

java jwe/jws_一篇文章带你分清楚JWT,JWS与JWE

随着移动互联网的兴起,传统基于session/cookie的web网站认证方式转变为了基于OAuth2等开放授权协议的单点登录模式(SSO),相应的基于服务器session浏览器cookie的Auth手段也发生了转变,Json Web Token出现成为了当前的热门的Token Auth机制。 …

JWS实现WebService

WebService估计大家都有听过或者使用过。Java有几种常用的方式实现webservice,本文主要是讨论JWS实现。 什么是webservice 简单而言,webservice就是通过SOAP协议在Web上提供的服务,使用WSDL文件进行说明。其特点是走SOAP协议而不是http协议&…

WebService 理论详解、JWS(Java Web Service) 快速入门

目录 WebService (web服务)概述 WebService 平台技术 WebService 工作原理 WebService 开发流程 常见 Web Service 框架 JWS(Java Web Service) 概述 JWS(Java Web Service) 快速入门 WebService (web服务)概述 1、WebService(Web服务)是一种跨语…