十分钟搞懂Pytorch如何读取MNIST数据集

article/2025/9/22 19:55:02

前言

本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧…

正文

在阅读教程书籍《深度学习入门之Pytorch》时,文中是如此加载MNIST手写数字训练集的:

train_dataset = datasets.MNIST(root='./MNIST',train=True,transform=data_tf,download=True)

解释一下参数

datasets.MNIST是Pytorch的内置函数torchvision.datasets.MNIST,通过这个可以导入数据集。

train=True 代表我们读入的数据作为训练集(如果为true则从training.pt创建数据集,否则从test.pt创建数据集)

transform则是读入我们自己定义的数据预处理操作

download=True则是当我们的根目录(root)下没有数据集时,便自动下载。

如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:

在这里插入图片描述

其中我们所需要的文件主要在raw文件夹下

train-images-idx3-ubyte.gz: training set images (9912422 bytes) 
train-labels-idx1-ubyte.gz: training set labels (28881 bytes) t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) 
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) 

接下来,书中是如此加载数据集的

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=5,shuffle=True)

由于DataLoader为Pytorch内部封装好的函数,所以对于它的调用方法需要自行去查阅。

我在最开始疑惑的点:传入的根目录在下载好数据集后,为MNIST下两个文件夹,而processed和raw文件夹下还有诸多文件,所以到底是如何读入数据的呢?所以我决定将数据集下载后,通过读取本地的MINIST数据集并进行装载。

首先,自定义数据类来继承和重写Dataset抽象类

class DealDataset(Dataset):"""读取数据、初始化数据"""def __init__(self, folder, data_name, label_name,transform=None):(train_set, train_labels) = self.load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式self.train_set = train_setself.train_labels = train_labelsself.transform = transformdef __getitem__(self, index):img, target = self.train_set[index], int(self.train_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):return len(self.train_set)'''load_data也是我们自定义的函数,用途:读取数据集中的数据 ( 图片数据+标签label'''def load_data(self,data_folder, data_name, label_name):with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)return (x_train, y_train)

接下来,调用我们自定义的数据类来加载数据集

trainDataset = DealDataset('./MNIST/MNIST/raw', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())# 训练数据和测试数据的装载
train_loader = torch.utils.data.DataLoader(dataset=trainDataset,batch_size=10, # 一个批次可以认为是一个包,每个包中含有10张图片shuffle=False,
)

通过这种方式便可以大概了解了读取数据集的过程。

接下来,我们来验证以下我们数据是否正确加载

# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

p.s.:其实这里是用cv2.imshow来展示图片,但是我的代码是在jupyter notebook上写的,所以只能通过plt来代替加载。
在这里插入图片描述

数据加载成功~

深入探索

可以看到,在load_data函数中

y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

这个offset=8又是为啥呢?
我们进入MNIST数据集的官方页面进行查看
在这里插入图片描述

通过文档介绍,可以看到
offset的0000-0003是 magic number,所以跳过不读,
offset的0004-0007是items数目
接下来这些代表的就是标签

同理对于

x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train)

在这里插入图片描述

根据刚才的分析方法,也可以明白为什么offset=16了

完整代码

1.直接使用pytorch自带的mnist数据集加载

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as pltdata_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]
)train_dataset = datasets.MNIST(root='./coding/learning/lrdata/MNIST',train=True,transform=data_tf,download=True)train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=5,shuffle=True)
# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

p.s.:记得自己修改root根目录。

2.使用自定义的数据类加载本地MNIST数据集

import numpy as np
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
import gzip
import os
import torchvision
import cv2
import matplotlib.pyplot as pltclass DealDataset(Dataset):"""读取数据、初始化数据"""def __init__(self, folder, data_name, label_name,transform=None):(train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式self.train_set = train_setself.train_labels = train_labelsself.transform = transformdef __getitem__(self, index):img, target = self.train_set[index], int(self.train_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):return len(self.train_set)def load_data(data_folder, data_name, label_name):with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)return (x_train, y_train)trainDataset = DealDataset('./coding/learning/lrdata/MNIST/MNIST/raw', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())# 训练数据和测试数据的装载
train_loader = torch.utils.data.DataLoader(dataset=trainDataset,batch_size=10, # 一个批次可以认为是一个包,每个包中含有10张图片shuffle=False,
)# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

参考

1.《深度学习入门之Pytorch》- 廖星宇
2.使用Pytorch进行读取本地的MINIST数据集并进行装载
3.顺藤摸瓜-mnist数据集的补充

在这里插入图片描述


http://chatgpt.dhexx.cn/article/dcgAw6NQ.shtml

相关文章

torchvision中datasets.MNIST介绍

用法介绍 torchvision中datasets中所有封装的数据集都是torch.utils.data.Dataset的子类,它们都实现了__getitem__和__len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。以datasets.MNIST类为例,具体参数和用法如下所示…

万物皆用MNIST---MNIST数据集及创建自己的手写数字数据集

刚刚接触到人工智能的我们,必定会遇到一个非常非常非常熟悉的朋友------MNIST 这是一套流行的手写数字图片,常常被用来测试我们的思想和算法。这个数据集称为手写数字的MNIST数据库,从研究员Yann LeCun 的网站,可以得到这个…

Pytorch 之 MNIST 数据集实现

目录 1. 数据集介绍2. 代码2. 读代码(个人喜欢的顺序)2.1. 导入模块部分:2.2. Main 函数: 1. 数据集介绍 一般而言,MNIST 数据集测试就是机器学习和深度学习当中的"Hello World"工程。几乎是所…

MNIST数据集手写数字识别(CNN)

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集详解及可视化处理(pytorch)

MNIST数据集详解及可视化处理(pytorch) MNIST 数据集已经是一个被”嚼烂”了的数据集, 作为机器学习在视觉领域的“hello world”,很多教程都会对它”下手”, 几乎成为一个 “典范”。 不过有些人可能对它还不是很了解, 下面来介绍一下。 MN…

Mnist数据集介绍

Mnist数据集已经是一个被"嚼烂"了的数据集了,很多关于神经网络的教程都会对它下手。因此在开始深度学习之前,先对这个数据集介绍一下。 Mnist数据集图片格式介绍 Mnist数据集分为两部分,分别含有60000张训练图片和10000张测试图片…

使用MNIST数据集

首先,必须向各位强调的是:该数据集名字叫MNIST,而非MINIST~ 我之前就一直弄错了! 哈哈~ 网上有很多使用MNIST数据集的教程,要么太麻烦,要么需要翻墙下载,很慢。 在这里分…

Fashion MNIST进行分类

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集简介与使用

MNIST数据集简介 MNIST数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(t…

详解 MNIST 数据集

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下. MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分: Training set images: train-images-idx3-…

Mnist数据集简介

MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,…

[转]MNIST机器学习入门

MNIST机器学习入门 转自:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html?plg_nld1&plg_uin1&plg_auth1&plg_nld1&plg_usr1&plg_vkey1&plg_dev1 这个教程的目标读者是对机器学习和TensorFlow都不太了解的新手。如…

从手写数字识别入门深度学习丨MNIST数据集详解

就像无数人从敲下“Hello World”开始代码之旅一样,许多研究员从“MNIST数据集”开启了人工智能的探索之路。 MNIST数据集(Mixed National Institute of Standards and Technology database)是一个用来训练各种图像处理系统的二进制图像数据…

Pytorch入门--详解Mnist手写字识别

1 什么是Mnist? Mnist是计算机视觉领域中最为基础的一个数据集。 MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10…

MNIST数据集

一、MNIST数据集介绍 MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取,主要包括四个文件&…

面试官: 你知道 JWT、JWE、JWS 、JWK嘛?

想起了 之前做过的 很多 登录授权 的项目 它相比原先的session、cookie来说,更快更安全,跨域也不再是问题,更关键的是更加优雅 ,所以今天总结了一篇文章来介绍他 JWT 指JSON Web Token,如果在项目中通过 jjwt 来支持 J…

java jwe/jws_一篇文章带你分清楚JWT,JWS与JWE

随着移动互联网的兴起,传统基于session/cookie的web网站认证方式转变为了基于OAuth2等开放授权协议的单点登录模式(SSO),相应的基于服务器session浏览器cookie的Auth手段也发生了转变,Json Web Token出现成为了当前的热门的Token Auth机制。 …

JWS实现WebService

WebService估计大家都有听过或者使用过。Java有几种常用的方式实现webservice,本文主要是讨论JWS实现。 什么是webservice 简单而言,webservice就是通过SOAP协议在Web上提供的服务,使用WSDL文件进行说明。其特点是走SOAP协议而不是http协议&…

WebService 理论详解、JWS(Java Web Service) 快速入门

目录 WebService (web服务)概述 WebService 平台技术 WebService 工作原理 WebService 开发流程 常见 Web Service 框架 JWS(Java Web Service) 概述 JWS(Java Web Service) 快速入门 WebService (web服务)概述 1、WebService(Web服务)是一种跨语…

一文理解 JWT、JWS、JWE、JWA、JWK、JOSE

原文收录 GitBook——统一接口认证解决方案 JsonWebToken 关于JsonWebToken的专业名词解释: unsecured JWT:默认头部{“alg”: “none”}的jwt令牌JWS(SignedJWT):已签名的jwt,包含标准jwt结构:header、payload、signatureJWE…