PyTorch 手把手教你实现 MNIST 数据集

article/2025/9/22 18:41:47

PyTorch MNIST 实现

  • 概述
  • 获取数据
  • 网络模型
  • train 函数
  • test 函数
  • main 函数
  • 完整代码

概述

MNIST 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个测试集. 数据的格式为单通道 28*28 的灰度图.

在这里插入图片描述

获取数据

在这里插入图片描述

def get_data():"""获取数据"""# 获取测试集train = torchvision.datasets.MNIST(root="./data", train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # 转换成张量torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化]))train_loader = DataLoader(train, batch_size=batch_size)  # 分割测试集# 获取测试集test = torchvision.datasets.MNIST(root="./data", train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # 转换成张量torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化]))test_loader = DataLoader(test, batch_size=batch_size)  # 分割训练# 返回分割好的训练集和测试集return train_loader, test_loader

网络模型

class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()# 卷积层self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))# Dropout层self.dropout1 = torch.nn.Dropout(0.25)self.dropout2 = torch.nn.Dropout(0.5)# 全连接层self.fc1 = torch.nn.Linear(9216, 128)self.fc2 = torch.nn.Linear(128, 10)def forward(self, x):"""前向传播"""# [b, 1, 28, 28] => [b, 32, 26, 26]out = self.conv1(x)out = F.relu(out)# [b, 32, 26, 26] => [b, 64, 24, 24]out = self.conv2(out)out = F.relu(out)# [b, 64, 24, 24] => [b, 64, 12, 12]out = F.max_pool2d(out, 2)out = self.dropout1(out)# [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216]out = torch.flatten(out, 1)# [b, 9216] => [b, 128]out = self.fc1(out)out = F.relu(out)# [b, 128] => [b, 10]out = self.dropout2(out)out = self.fc2(out)output = F.log_softmax(out, dim=1)return output

train 函数

def train(model, epoch, train_loader):"""训练"""# 训练模式model.train()# 迭代for step, (x, y) in enumerate(train_loader):# 加速if use_cuda:model = model.cuda()x, y = x.cuda(), y.cuda()# 梯度清零optimizer.zero_grad()output = model(x)# 计算损失loss = F.nll_loss(output, y)# 反向传播loss.backward()# 更新梯度optimizer.step()# 打印损失if step % 50 == 0:print('Epoch: {}, Step {}, Loss: {}'.format(epoch, step, loss))

test 函数

def test(model, test_loader):"""测试"""# 测试模式model.eval()# 存放正确个数correct = 0with torch.no_grad():for x, y in test_loader:# 加速if use_cuda:model = model.cuda()x, y = x.cuda(), y.cuda()# 获取结果output = model(x)# 预测结果pred = output.argmax(dim=1, keepdim=True)# 计算准确个数correct += pred.eq(y.view_as(pred)).sum().item()# 计算准确率accuracy = correct / len(test_loader.dataset) * 100# 输出准确print("Test Accuracy: {}%".format(accuracy))

main 函数

def main():# 获取数据train_loader, test_loader = get_data()# 迭代for epoch in range(iteration_num):print("\n================ epoch: {} ================".format(epoch))train(network, epoch, train_loader)test(network, test_loader)

完整代码

在这里插入图片描述
完整代码:

import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoaderclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()# 卷积层self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))# Dropout层self.dropout1 = torch.nn.Dropout(0.25)self.dropout2 = torch.nn.Dropout(0.5)# 全连接层self.fc1 = torch.nn.Linear(9216, 128)self.fc2 = torch.nn.Linear(128, 10)def forward(self, x):"""前向传播"""# [b, 1, 28, 28] => [b, 32, 26, 26]out = self.conv1(x)out = F.relu(out)# [b, 32, 26, 26] => [b, 64, 24, 24]out = self.conv2(out)out = F.relu(out)# [b, 64, 24, 24] => [b, 64, 12, 12]out = F.max_pool2d(out, 2)out = self.dropout1(out)# [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216]out = torch.flatten(out, 1)# [b, 9216] => [b, 128]out = self.fc1(out)out = F.relu(out)# [b, 128] => [b, 10]out = self.dropout2(out)out = self.fc2(out)output = F.log_softmax(out, dim=1)return output# 定义超参数
batch_size = 64  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 5  # 迭代次数
network = Model()  # 实例化网络
print(network)  # 调试输出网络结构
optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)  # 优化器# GPU 加速
use_cuda = torch.cuda.is_available()
print("是否使用 GPU 加速:", use_cuda)def get_data():"""获取数据"""# 获取测试集train = torchvision.datasets.MNIST(root="./data", train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # 转换成张量torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化]))train_loader = DataLoader(train, batch_size=batch_size)  # 分割测试集# 获取测试集test = torchvision.datasets.MNIST(root="./data", train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # 转换成张量torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化]))test_loader = DataLoader(test, batch_size=batch_size)  # 分割训练# 返回分割好的训练集和测试集return train_loader, test_loaderdef train(model, epoch, train_loader):"""训练"""# 训练模式model.train()# 迭代for step, (x, y) in enumerate(train_loader):# 加速if use_cuda:model = model.cuda()x, y = x.cuda(), y.cuda()# 梯度清零optimizer.zero_grad()output = model(x)# 计算损失loss = F.nll_loss(output, y)# 反向传播loss.backward()# 更新梯度optimizer.step()# 打印损失if step % 50 == 0:print('Epoch: {}, Step {}, Loss: {}'.format(epoch, step, loss))def test(model, test_loader):"""测试"""# 测试模式model.eval()# 存放正确个数correct = 0with torch.no_grad():for x, y in test_loader:# 加速if use_cuda:model = model.cuda()x, y = x.cuda(), y.cuda()# 获取结果output = model(x)# 预测结果pred = output.argmax(dim=1, keepdim=True)# 计算准确个数correct += pred.eq(y.view_as(pred)).sum().item()# 计算准确率accuracy = correct / len(test_loader.dataset) * 100# 输出准确print("Test Accuracy: {}%".format(accuracy))def main():# 获取数据train_loader, test_loader = get_data()# 迭代for epoch in range(iteration_num):print("\n================ epoch: {} ================".format(epoch))train(network, epoch, train_loader)test(network, test_loader)if __name__ == "__main__":main()

输出结果:

Model((conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))(dropout1): Dropout(p=0.25, inplace=False)(dropout2): Dropout(p=0.5, inplace=False)(fc1): Linear(in_features=9216, out_features=128, bias=True)(fc2): Linear(in_features=128, out_features=10, bias=True)
)
是否使用 GPU 加速: True================ epoch: 0 ================
Epoch: 0, Step 0, Loss: 2.3131277561187744
Epoch: 0, Step 50, Loss: 1.0419045686721802
Epoch: 0, Step 100, Loss: 0.6259541511535645
Epoch: 0, Step 150, Loss: 0.7194482684135437
Epoch: 0, Step 200, Loss: 0.4020516574382782
Epoch: 0, Step 250, Loss: 0.6890509128570557
Epoch: 0, Step 300, Loss: 0.28660136461257935
Epoch: 0, Step 350, Loss: 0.3277580738067627
Epoch: 0, Step 400, Loss: 0.2750288248062134
Epoch: 0, Step 450, Loss: 0.28428223729133606
Epoch: 0, Step 500, Loss: 0.3514065444469452
Epoch: 0, Step 550, Loss: 0.23386947810649872
Epoch: 0, Step 600, Loss: 0.25338059663772583
Epoch: 0, Step 650, Loss: 0.1743898093700409
Epoch: 0, Step 700, Loss: 0.35752204060554504
Epoch: 0, Step 750, Loss: 0.17575909197330475
Epoch: 0, Step 800, Loss: 0.20604261755943298
Epoch: 0, Step 850, Loss: 0.17389622330665588
Epoch: 0, Step 900, Loss: 0.3188241124153137
Test Accuracy: 96.56%================ epoch: 1 ================
Epoch: 1, Step 0, Loss: 0.23558208346366882
Epoch: 1, Step 50, Loss: 0.13511177897453308
Epoch: 1, Step 100, Loss: 0.18823786079883575
Epoch: 1, Step 150, Loss: 0.2644936144351959
Epoch: 1, Step 200, Loss: 0.145077645778656
Epoch: 1, Step 250, Loss: 0.30574971437454224
Epoch: 1, Step 300, Loss: 0.2386859953403473
Epoch: 1, Step 350, Loss: 0.08346735686063766
Epoch: 1, Step 400, Loss: 0.10480977594852448
Epoch: 1, Step 450, Loss: 0.07280707359313965
Epoch: 1, Step 500, Loss: 0.20928426086902618
Epoch: 1, Step 550, Loss: 0.20455852150917053
Epoch: 1, Step 600, Loss: 0.10085935145616531
Epoch: 1, Step 650, Loss: 0.13476189970970154
Epoch: 1, Step 700, Loss: 0.19087043404579163
Epoch: 1, Step 750, Loss: 0.0981522724032402
Epoch: 1, Step 800, Loss: 0.1961515098810196
Epoch: 1, Step 850, Loss: 0.041140712797641754
Epoch: 1, Step 900, Loss: 0.250461220741272
Test Accuracy: 98.03%================ epoch: 2 ================
Epoch: 2, Step 0, Loss: 0.09572553634643555
Epoch: 2, Step 50, Loss: 0.10370486229658127
Epoch: 2, Step 100, Loss: 0.17737184464931488
Epoch: 2, Step 150, Loss: 0.1570713371038437
Epoch: 2, Step 200, Loss: 0.07462178170681
Epoch: 2, Step 250, Loss: 0.18744900822639465
Epoch: 2, Step 300, Loss: 0.09910508990287781
Epoch: 2, Step 350, Loss: 0.08929706364870071
Epoch: 2, Step 400, Loss: 0.07703761011362076
Epoch: 2, Step 450, Loss: 0.10133732110261917
Epoch: 2, Step 500, Loss: 0.1314031481742859
Epoch: 2, Step 550, Loss: 0.10394387692213058
Epoch: 2, Step 600, Loss: 0.11612939089536667
Epoch: 2, Step 650, Loss: 0.17494803667068481
Epoch: 2, Step 700, Loss: 0.11065669357776642
Epoch: 2, Step 750, Loss: 0.061209067702293396
Epoch: 2, Step 800, Loss: 0.14715790748596191
Epoch: 2, Step 850, Loss: 0.03930797800421715
Epoch: 2, Step 900, Loss: 0.18030673265457153
Test Accuracy: 98.46000000000001%================ epoch: 3 ================
Epoch: 3, Step 0, Loss: 0.09266342222690582
Epoch: 3, Step 50, Loss: 0.0414913073182106
Epoch: 3, Step 100, Loss: 0.2152961939573288
Epoch: 3, Step 150, Loss: 0.12287424504756927
Epoch: 3, Step 200, Loss: 0.13468700647354126
Epoch: 3, Step 250, Loss: 0.11967387050390244
Epoch: 3, Step 300, Loss: 0.11301510035991669
Epoch: 3, Step 350, Loss: 0.037447575479745865
Epoch: 3, Step 400, Loss: 0.04699449613690376
Epoch: 3, Step 450, Loss: 0.05472381412982941
Epoch: 3, Step 500, Loss: 0.09839300811290741
Epoch: 3, Step 550, Loss: 0.07964356243610382
Epoch: 3, Step 600, Loss: 0.08182843774557114
Epoch: 3, Step 650, Loss: 0.05514759197831154
Epoch: 3, Step 700, Loss: 0.13785190880298615
Epoch: 3, Step 750, Loss: 0.062480345368385315
Epoch: 3, Step 800, Loss: 0.120387002825737
Epoch: 3, Step 850, Loss: 0.04458726942539215
Epoch: 3, Step 900, Loss: 0.17119190096855164
Test Accuracy: 98.55000000000001%================ epoch: 4 ================
Epoch: 4, Step 0, Loss: 0.08094145357608795
Epoch: 4, Step 50, Loss: 0.05615215748548508
Epoch: 4, Step 100, Loss: 0.07766406238079071
Epoch: 4, Step 150, Loss: 0.07915271818637848
Epoch: 4, Step 200, Loss: 0.1301635503768921
Epoch: 4, Step 250, Loss: 0.12118984013795853
Epoch: 4, Step 300, Loss: 0.073218435049057
Epoch: 4, Step 350, Loss: 0.04517696052789688
Epoch: 4, Step 400, Loss: 0.08493026345968246
Epoch: 4, Step 450, Loss: 0.03904269263148308
Epoch: 4, Step 500, Loss: 0.09386837482452393
Epoch: 4, Step 550, Loss: 0.12583576142787933
Epoch: 4, Step 600, Loss: 0.09053893387317657
Epoch: 4, Step 650, Loss: 0.06912104040384293
Epoch: 4, Step 700, Loss: 0.1502612829208374
Epoch: 4, Step 750, Loss: 0.07162325084209442
Epoch: 4, Step 800, Loss: 0.10512275993824005
Epoch: 4, Step 850, Loss: 0.028180215507745743
Epoch: 4, Step 900, Loss: 0.08492615073919296
Test Accuracy: 98.69%

http://chatgpt.dhexx.cn/article/8QB3KBTS.shtml

相关文章

连续学习入门(三):Permuted MNIST/Split MNIST/Sequential MNIST 数据集

说明:本系列文章若无特别说明,则在技术上将 Continual Learning(连续学习)等同于 Incremental Learning(增量学习)、Lifelong Learning(终身学习),关于 Continual Learni…

导入mnist数据集

下载一个代码后,发现需要导入mnist数据集,首先新建一个py的文件,把代码复制过来,然后记得一定要改成这样的格式: from tensorflow.examples.tutorials.mnist import input_data mnist input_data.read_data_sets(MNI…

【MNIST】

1. Normal Neural Network: 首先我用的是两层(input layer 和 output layer)的feed-forward的神经网络结构来训练数据, y wx b, 在输出层用的是softmax求概率,算loss用的是交叉熵的办法,选用梯度下降法来最小化loss…

MNIST数据集使用详解

数据集下载网址:http://yann.lecun.com/exdb/mnist/ 下载后无需解压,将其放在一个文件夹下即可: 数据说明: 数据集常被分为2~3个部分 训练集(train set):用来学习的一组例子,用来适应分类器的参数[即权重]…

详解 Pytorch 实现 MNIST

MNIST虽然很简单,但是值得我们学习的东西还是有很多的。 项目虽然简单,但是个人建议还是将各个模块分开创建,特别是对于新人而言,模块化的创建会让读者更加清晰、易懂。 CNN模块:卷积神经网络的组成;trai…

十分钟搞懂Pytorch如何读取MNIST数据集

前言 本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧… 正文 在阅读教程书籍《深度学习入门之Pytorch》时,文中是如此加载MNIST手写数字训练集的: train_dataset datasets.MNIST(root./MNIST,trainTrue,transform…

torchvision中datasets.MNIST介绍

用法介绍 torchvision中datasets中所有封装的数据集都是torch.utils.data.Dataset的子类,它们都实现了__getitem__和__len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。以datasets.MNIST类为例,具体参数和用法如下所示…

万物皆用MNIST---MNIST数据集及创建自己的手写数字数据集

刚刚接触到人工智能的我们,必定会遇到一个非常非常非常熟悉的朋友------MNIST 这是一套流行的手写数字图片,常常被用来测试我们的思想和算法。这个数据集称为手写数字的MNIST数据库,从研究员Yann LeCun 的网站,可以得到这个…

Pytorch 之 MNIST 数据集实现

目录 1. 数据集介绍2. 代码2. 读代码(个人喜欢的顺序)2.1. 导入模块部分:2.2. Main 函数: 1. 数据集介绍 一般而言,MNIST 数据集测试就是机器学习和深度学习当中的"Hello World"工程。几乎是所…

MNIST数据集手写数字识别(CNN)

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集详解及可视化处理(pytorch)

MNIST数据集详解及可视化处理(pytorch) MNIST 数据集已经是一个被”嚼烂”了的数据集, 作为机器学习在视觉领域的“hello world”,很多教程都会对它”下手”, 几乎成为一个 “典范”。 不过有些人可能对它还不是很了解, 下面来介绍一下。 MN…

Mnist数据集介绍

Mnist数据集已经是一个被"嚼烂"了的数据集了,很多关于神经网络的教程都会对它下手。因此在开始深度学习之前,先对这个数据集介绍一下。 Mnist数据集图片格式介绍 Mnist数据集分为两部分,分别含有60000张训练图片和10000张测试图片…

使用MNIST数据集

首先,必须向各位强调的是:该数据集名字叫MNIST,而非MINIST~ 我之前就一直弄错了! 哈哈~ 网上有很多使用MNIST数据集的教程,要么太麻烦,要么需要翻墙下载,很慢。 在这里分…

Fashion MNIST进行分类

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集简介与使用

MNIST数据集简介 MNIST数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(t…

详解 MNIST 数据集

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下. MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分: Training set images: train-images-idx3-…

Mnist数据集简介

MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,…

[转]MNIST机器学习入门

MNIST机器学习入门 转自:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html?plg_nld1&plg_uin1&plg_auth1&plg_nld1&plg_usr1&plg_vkey1&plg_dev1 这个教程的目标读者是对机器学习和TensorFlow都不太了解的新手。如…

从手写数字识别入门深度学习丨MNIST数据集详解

就像无数人从敲下“Hello World”开始代码之旅一样,许多研究员从“MNIST数据集”开启了人工智能的探索之路。 MNIST数据集(Mixed National Institute of Standards and Technology database)是一个用来训练各种图像处理系统的二进制图像数据…

Pytorch入门--详解Mnist手写字识别

1 什么是Mnist? Mnist是计算机视觉领域中最为基础的一个数据集。 MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10…