Pytorch 之 MNIST 数据集实现

article/2025/9/22 19:55:19

目录

  • 1. 数据集介绍
  • 2. 代码
  • 2. 读代码(个人喜欢的顺序)
    • 2.1. 导入模块部分:
    • 2.2. Main 函数:

1. 数据集介绍

一般而言,MNIST 数据集测试就是机器学习和深度学习当中的"Hello World"工程。几乎是所有的教程都会把它放在最开始的地方。这是因为,这个简单的工程包含了大致的机器学习流程,通过练习这个工程有助于读者加深理解机器学习或者是深度学习的大致流程。
MNIST(Mixed National Institute of Standards and Technology database)是一个计算机视觉数据集,它包含 70000 张手写数字的灰度图片,其中每一张图片包含 28 X 28 个像素点。可以用一个数字数组来表示这张图片。

2. 代码

代码来自 pytorch 实例代码,链接:https://github.com/pytorch/examples/blob/master/mnist/main.py

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transformsclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5, 1)self.conv2 = nn.Conv2d(20, 50, 5, 1)self.fc1 = nn.Linear(4 * 4 * 50, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)x = x.view(-1, 4 * 4 * 50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)def train(args, model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test(args, model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch losspred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))def main():parser = argparse.ArgumentParser(description='PyTorch MNIST Example')parser.add_argument('--batch-size', type=int, default=64, metavar='N',help='input batch size for training (default: 64)')parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',help='input batch size for testing (default: 1000)')parser.add_argument('--epochs', type=int, default=10, metavar='N',help='number of epochs to train (default: 10)')parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate (default: 0.01)')parser.add_argument('--momentum', type=float, default=0.5, metavar='M',help='SGD momentum (default: 0.5)')parser.add_argument('--no-cuda', action='store_true', default=False,help='disables CUDA training')parser.add_argument('--seed', type=int, default=1, metavar='S',help='random seed (default: 1)')parser.add_argument('--log-interval', type=int, default=10, metavar='N',help='how many batches to wait before logging training status')parser.add_argument('--save-model', action='store_true', default=False,help='For Saving the current Model')args = parser.parse_args()use_cuda = not args.no_cuda and torch.cuda.is_available()torch.manual_seed(args.seed)device = torch.device("cuda" if use_cuda else "cpu")kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=args.batch_size, shuffle=True, **kwargs)test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=args.test_batch_size, shuffle=True, **kwargs)model = Net().to(device)optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)for epoch in range(1, args.epochs + 1):train(args, model, device, train_loader, optimizer, epoch)test(args, model, device, test_loader)if (args.save_model):torch.save(model.state_dict(), "mnist_cnn.pt")

2. 读代码(个人喜欢的顺序)

2.1. 导入模块部分:

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

导入基本的 Python 库和 Python 包:

  1. argparse 是一个模块,非常好用,可以用来设置模型的默认参数,也可以允许使用者在命令行模式手动设置参数
  2. torch 和 torchvision 是 pytorch 基本的的库。torch 自然指的是 pytorch, torchvision 是独立于 pytorch 的关于图像操作的一些方便工具库。torchvision主要包括以下几个包:
    (1)vision.datasets : 几个常用视觉数据集,可以下载和加载,这里主要的高级用法就是可以看源码如何自己写自己的Dataset的子类
    (2)vision.models : 流行的模型,例如 AlexNet, VGG, ResNet 和 Densenet 以及 与训练好的参数。
    (3)vision.transforms : 常用的图像操作,例如:随机切割,旋转,数据类型转换,图像到tensor ,numpy 数组到tensor , tensor 到 图像等。
    (4)vision.utils : 用于把形似 (3 x H x W) 的张量保存到硬盘中,给一个mini-batch的图像可以产生一个图像格网。
  3. torch.nn 参考 pytorch 文档:https://ptorch.com/docs/1/torch-nn
  4. torch.nn.functional 参考 pytorch 文档:https://ptorch.com/docs/8/torch.nn.functional
  5. import torch.optim 主要包含常见的优化算法

2.2. Main 函数:

parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,help='For Saving the current Model')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

这里就是使用了 argparse 模块,配置了 batch-size, test-batch-size, epochs, lr, momentum, -no-cuda, seed, log-interval, save-model 等参数,参数值就是 default 里面的默认值,同时我们可以在命令行允许的时候修改。例如

python main.py --epochs 20 --lr 0.008

这样就可以修改 epochs 为 20, 学习率为 0.008。

  1. batch-size:我们一次训练如果只训练一张图片,显然效率太低了,很多 GPU, CPU 资源都没有利用到,因此可以一次训练多张图片。可以提高效率,但是如果显存太小,图片太大,可能导致显存不够用的问题,因此这个值在我们 MNIST 小数据集上设置可以随意一点,大点的数据集需要好好考量。
  2. test-batch-size:测试时的 batch-size
  3. epochs:训练的 epochs 数量
  4. lr:学习率
  5. momentum:SGD 算法里面的 momentum(动量)
  6. seed:随机种子
  7. log-interval:多少个 batch 打印一次训练状态,设置为10,即 10*64,一个batch size是 64 ,即每 640 张图片打印一次训练结果。
  8. save-model:是否保存模型
  9. cuda:GPU 是否可用
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")

设置随机种子,以及将 GPU 命名为 device

 kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=args.batch_size, shuffle=True, **kwargs)test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=args.test_batch_size, shuffle=True, **kwargs)

加载训练集和测试集,pytorch 自带有 MNIST 数据集。 并且转为我们需要的格式。

  1. train_loader 里面就是 6000 个训练集,包含有图片和图片的标签。
  2. test_loader 里面就是 1000 个训练集,包含有图片和图片的标签。
 kwargs = {'num_workers': 1, 'pin_memory': True} 

num_workers 是多进程的加载数,pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些。

model = Net().to(device)

构建我们的模型,并送到 GPU 中加速。
模型为:

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5, 1)self.conv2 = nn.Conv2d(20, 50, 5, 1)self.fc1 = nn.Linear(4 * 4 * 50, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)x = x.view(-1, 4 * 4 * 50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

两个卷积层,两个线性层:构建我们只看 init 部分就可以了,forward 部分是将数据送到模型,模型如何处理的过程。

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

优化方法是 SGD, 第一个参数是模型的参数,第二个是学习率,第三个是动量。

    for epoch in range(1, args.epochs + 1):train(args, model, device, train_loader, optimizer, epoch)test(args, model, device, test_loader)

一共训练 10 次,每次训练完都测试一次。

    if (args.save_model):torch.save(model.state_dict(), "mnist_cnn.pt")

如果配置为要保持模型,就保存模型到 mnist_cnn.pt 文件里面,如果没有会新建。

def train(args, model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))

训练过程,接收配置 args, 模型model, GPU device, 训练数据train_loader,优化器optimizer和当前训练周期epoch

   model.train()

模型进入训练模式

for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)

加载训练数据集合,这里一次 64 张图片以及对应的类别(0 - 9),data 是图片 shape 为[64, 1, 28, 28],64张图片,灰度图片,像素大小为28*28。 target 是 64 个图片的类别,每个用 0-9 的图片表示。

 optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()

优化过程开始,然后 data 送到模型,经过模型的 forward 步骤,最后一步 softmax 为这张图片属于 10 个类的概率,比如[0.7, 0.1, 0.1, 0.1,0, 0, 0, 0, 0, 0] 就是这张图片属于 0 的概率为 0.7, 属于1. 2. 3的概率为 0.1。损失函数为 NLLLoss 。
NLLLoss 的 输入 是一个对数概率向量和一个目标标签. 它不会为我们计算对数概率. 适合网络的最后一层是log_softmax. 损失函数 nn.CrossEntropyLoss() 与 NLLLoss() 相同, 唯一的不同是它为我们去做 softmax. 我们视为分类问题,最常见的损失函数就是交叉熵,我们这里本质也是交叉熵。因为模型最后一步 forward 是 softmax。

CrossEntropyLoss()=log_softmax() + NLLLoss()

然后根据 loss 用 SGD 算法优化参数。

 if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))

每 10 * 64 张图片打印训练状态:
在这里插入图片描述

def test(args, model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch losspred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

测试模型。

    model.eval()test_loss = 0correct = 0

模型进入测试模式,初始化计数变量。然后加载测试图片,data 送到模型,然后计算叠加损失 loss,pred 就是找出 softmax 之后数组里面最大值得索引,最大值就是预测最有可能的概率,然后索引就是预测的数字。例如上面 [0.7, 0.1, 0.1, 0.1,0, 0, 0, 0, 0, 0] , 最大概率值是 0.7,索引就是 0,我们预测这个数字是 0。然后叠加准确的个数,最后 loss 和 准确个数除以测试图片个数就是平均的 loss 和 准确率。


http://chatgpt.dhexx.cn/article/014y7TK7.shtml

相关文章

MNIST数据集手写数字识别(CNN)

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集详解及可视化处理(pytorch)

MNIST数据集详解及可视化处理(pytorch) MNIST 数据集已经是一个被”嚼烂”了的数据集, 作为机器学习在视觉领域的“hello world”,很多教程都会对它”下手”, 几乎成为一个 “典范”。 不过有些人可能对它还不是很了解, 下面来介绍一下。 MN…

Mnist数据集介绍

Mnist数据集已经是一个被"嚼烂"了的数据集了,很多关于神经网络的教程都会对它下手。因此在开始深度学习之前,先对这个数据集介绍一下。 Mnist数据集图片格式介绍 Mnist数据集分为两部分,分别含有60000张训练图片和10000张测试图片…

使用MNIST数据集

首先,必须向各位强调的是:该数据集名字叫MNIST,而非MINIST~ 我之前就一直弄错了! 哈哈~ 网上有很多使用MNIST数据集的教程,要么太麻烦,要么需要翻墙下载,很慢。 在这里分…

Fashion MNIST进行分类

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集简介与使用

MNIST数据集简介 MNIST数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(t…

详解 MNIST 数据集

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下. MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分: Training set images: train-images-idx3-…

Mnist数据集简介

MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,…

[转]MNIST机器学习入门

MNIST机器学习入门 转自:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html?plg_nld1&plg_uin1&plg_auth1&plg_nld1&plg_usr1&plg_vkey1&plg_dev1 这个教程的目标读者是对机器学习和TensorFlow都不太了解的新手。如…

从手写数字识别入门深度学习丨MNIST数据集详解

就像无数人从敲下“Hello World”开始代码之旅一样,许多研究员从“MNIST数据集”开启了人工智能的探索之路。 MNIST数据集(Mixed National Institute of Standards and Technology database)是一个用来训练各种图像处理系统的二进制图像数据…

Pytorch入门--详解Mnist手写字识别

1 什么是Mnist? Mnist是计算机视觉领域中最为基础的一个数据集。 MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10…

MNIST数据集

一、MNIST数据集介绍 MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取,主要包括四个文件&…

面试官: 你知道 JWT、JWE、JWS 、JWK嘛?

想起了 之前做过的 很多 登录授权 的项目 它相比原先的session、cookie来说,更快更安全,跨域也不再是问题,更关键的是更加优雅 ,所以今天总结了一篇文章来介绍他 JWT 指JSON Web Token,如果在项目中通过 jjwt 来支持 J…

java jwe/jws_一篇文章带你分清楚JWT,JWS与JWE

随着移动互联网的兴起,传统基于session/cookie的web网站认证方式转变为了基于OAuth2等开放授权协议的单点登录模式(SSO),相应的基于服务器session浏览器cookie的Auth手段也发生了转变,Json Web Token出现成为了当前的热门的Token Auth机制。 …

JWS实现WebService

WebService估计大家都有听过或者使用过。Java有几种常用的方式实现webservice,本文主要是讨论JWS实现。 什么是webservice 简单而言,webservice就是通过SOAP协议在Web上提供的服务,使用WSDL文件进行说明。其特点是走SOAP协议而不是http协议&…

WebService 理论详解、JWS(Java Web Service) 快速入门

目录 WebService (web服务)概述 WebService 平台技术 WebService 工作原理 WebService 开发流程 常见 Web Service 框架 JWS(Java Web Service) 概述 JWS(Java Web Service) 快速入门 WebService (web服务)概述 1、WebService(Web服务)是一种跨语…

一文理解 JWT、JWS、JWE、JWA、JWK、JOSE

原文收录 GitBook——统一接口认证解决方案 JsonWebToken 关于JsonWebToken的专业名词解释: unsecured JWT:默认头部{“alg”: “none”}的jwt令牌JWS(SignedJWT):已签名的jwt,包含标准jwt结构:header、payload、signatureJWE…

JWS入门

JWS简介 JWS主要用来通过网络部署你的应用程序,它具有安全、稳定、易维护、易使用的特点。用户访问用JWS部署应用程序的站点,下载发布的应用程序,既可以在 线运行,也可以通过JWS的客户端离线运行已下载的应用程序。对同一个应用程…

【C语言】判断一个数是否是完全平方数(两种解法)

题目: 判断一个数是否是完全平方数。 以下数字为完全平方数:42*2,93*3,14412*12,16913*13 有两个方法,可以求完全平方数: 方法一:输入一个数,遍历所有比这个数小的数,只要有其中一个数满足条件…

C语言 输入10个数,将其中最小的数与第一个数对换,将最大的数与最后一个数对换

#include <stdio.h> void input(int *number){ //定义输入10个数的函数int i;printf("请输入10个整数:\n");for(i0;i<10;i)scanf("%d",&number[i]); } void max_min_value(int *number){ //交换函数int *max,*min,*p,temp;maxminnumber; //开…