MNIST数据集

article/2025/9/22 20:25:49

一、MNIST数据集介绍

MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取,主要包括四个文件:
在这里插入图片描述
在上述文件中,训练集train一共包含了 60000 张图像和标签,而测试集一共包含了 10000 张图像和标签。

idx3表示3维,ubyte表示是以字节的形式进行存储的,t10k表示10000张测试图片(test10000)。

每张图片是一个28*28像素点的0 ~ 9的灰质手写数字图片,黑底白字,图像像素值为0 ~ 255,越大该点越白。

二、数据下载和读取

导入PyTorch的两个核心库torch和torchvision,这两个库基本包含了PyTorch会用到的许多方法和函数,其他库为下面所需要的一些辅助库。

import gzip
import osimport torch
import torchvision
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

注:

  • import datasets是为了方便自动下载数据集,可以下载多种数据集,如MNIST、ImageNet、CIFAR10等。
  • import transforms是pytorch中的图像预处理库,一般用Compose把多个步骤整合到一起。 相关详情见:transforms.Compose()函数

1、使用Pytorch自带的库函数

导入MNIST数据集代码:

train_data = datasets.MNIST(root="./data/",train=True,transform=transforms.ToTensor(),download=True)test_data = datasets.MNIST(root="./data/",train=False,transform=transforms.ToTensor(),download=True)

其中:

  1. root 指定MNIST数据集存放的路径

  2. train 设置为True表示导入的是训练集合,否则为测试集合

  3. transform 指定导入数据集时需要进行何种变换操作

    此处ToTensor()shape(H, W, C)nump.ndarrayimg转为shape(C, H, W)tensor,其将每一个数值归一化到[0,1],其归一化方法比较简单,直接除以255即可

  4. download 设置为True表示当root参数指定的数据集存放的路径下没有数据时,则自动从网络上下载MNIST数据集,否则就不自动下载

注意: 返回值为一个二元组(data,target),一般是与torch.utils.data.DataLoader配合使用,也可自己对数据进行处理,见二、2

加载MNIST数据集代码:

train_data_loader = torch.utils.data.DataLoader(dataset=train_data,batch_size=64,shuffle=True,drop_last=True)test_data_loader = torch.utils.data.DataLoader(dataset=test_data,batch_size=64,shuffle=False,drop_last=False)

其中:

  1. dataset 指定欲装载的MNIST数据集
  2. batch_size 设置了每批次装载的数据图片为64个(自行设置)
  3. shuffle 设置为True表示在装载数据时随机乱序,常用于进行多批次的模型训练
  4. drop_last 设置为True表示当数据集size不能整除batch_size时,则删除最后一个batch_size,否则就不删除

在加载完成后,可以选取其中一个批次的数据进行预览:

********************
images, labels = next(iter(train_data_loader))	# images:Tensor(64,1,28,28)、labels:Tensor(64,)
********************
img = torchvision.utils.make_grid(images)	# 把64张图片拼接为1张图片# pytorch网络输入图像的格式为(C, H, W),而numpy中的图像的shape为(H,W,C)。故需要变换通道才能有效输出
img = img.numpy().transpose(1, 2, 0) 
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

下图展示的是一个batch数据集(64张图片)的显示:
在这里插入图片描述
对其中某一个图片进行像素化展示:

# images:Tensor(64,1,28,28)、labels:Tensor(64,)	
images, labels = next(iter(train_data_loader))  #(1,28,28)表示该图像的 height、width、color(颜色通道,即单通道)
images = images.reshape(64, 28, 28) 
img = images[0, :, :]	# 取batch_size中的第一张图像
np.savetxt('img.txt', img.cpu().numpy(), fmt="%f", encoding='UTF-8')	# 将像素值写入txt文件,以便查看
img = img.cpu().numpy()	#转为numpy类型,方便有效输出fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')
width, height = img.shape
thresh = img.max()/2.5for x in range(width):for y in range(height):val = round(img[x][y], 2) if img[x][y] !=0 else 0ax.annotate(str(val), xy=(y, x),horizontalalignment='center',verticalalignment='center',color='white' if img[x][y] < thresh else 'black')
plt.show()

在这里插入图片描述

2、通过重构Dataset类读取特定的MNIST数据或者制作自己的MNIST数据集

定义一个子类,继承Dataset类, 重写 len()getitem() 方法。
getitem 是获取样本对,模型直接通过这一函数获得一对样本对{ x:y }
len 是指数据集长度

① 读取MNIST文件夹下processed文件中的training.pt、test.pt数据集

class Data_Loader(Dataset):def __init__(self, root, transform=None):self.data, self.targets = torch.load(root)	#采用torch.load进行读取,读取之后的结果为torch.Tensor形式self.transform = transformdef __getitem__(self, index):img, target = self.data[index], int(self.targets[index])img = Image.fromarray(img.numpy(), mode='L')if self.transform is not None:img = self.transform(img)img = transforms.ToTensor()(img)return img, targetdef __len__(self):return len(self.data)

接下来,调用我们自定义的Data_Loader类来读取数据集:

# root 为training.pt、test.pt文件所在的绝对路径
train_data = Data_Loader(root='./mnist/MNIST/processed/training.pt', transform= None)
test_data = Data_Loader(root='./mnist/MNIST/processed/test.pt', transform= None)

再使用torch.utils.data.DataLoadertrain_datatest_data进行加载,展示。

② 读取MNIST文件夹下raw文件中的数据集

class Data_Loader(Dataset):def __init__(self, folder, data_name, label_name, transform=None):(train_set, train_labels) = load_data(folder, data_name, label_name)self.train_set = train_setself.train_labels = train_labelsself.transform = transformdef __getitem__(self, index):img, target = self.train_set[index], int(self.train_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):return len(self.train_set)def load_data(data_folder, data_name, label_name):with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath:  # rb表示的是读取二进制数据y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)return x_train, y_train

接下来,调用我们自定义的Data_Loader类来读取数据集:

#	folder:MNIST数据集中raw文件的绝对路径# 读取MNIST数据集中的训练集
train_data = Data_Loader('./MNIST/MNIST/raw', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz", transform=transforms.ToTensor())# 读取MNIST数据集中的测试集
test_data = Data_Loader('./MNIST/MNIST/raw', "t10k-images-idx3-ubyte.gz","t10k-labels-idx1-ubyte.gz", transform=transforms.ToTensor())

再使用torch.utils.data.DataLoadertrain_datatest_data进行加载,展示。

③ 直接读取MNIST数据集

def data_loader(path):# Train Datamnist_train = torchvision.datasets.MNIST(root=path, train=True, download=False)train_data = mnist_train.data.view([-1, 784]).float()/255.	# mnist_train.data 获取数字train_labels = mnist_train.targets	# mnist_train.targets 获取标签# Test Datamnist_test = torchvision.datasets.MNIST(root=path, train=False, download=False)test_data = mnist_test.data.view([-1, 784]).float()/255.test_labels = mnist_test.targets# translate the numeric label [0-9] into a vector label[0, ..., 1, ..0, 0]train_labelvec = torch.zeros([10, 60000]).int()train_labelvec[train_labels.cpu().numpy(), np.arange(60000)] = 1test_labelvec = torch.zeros([10, 10000]).int()test_labelvec[test_labels.cpu().numpy(), np.arange(10000)] = 1return train_data, train_labelvec, test_data, test_labelvec

通过调用该函数可以获取数据:

"""Load data and preprocessDataTrain:Tensor(60000,784) LabelsTrain:Tensor(10,60000)DataTest:Tensor(10000,784)  LabelsTest:Tensor(10,10000)"""DataTrain, LabelsTrain, DataTest, LabelsTest = data_loader(args.datapath)

http://chatgpt.dhexx.cn/article/aSCf0lcb.shtml

相关文章

面试官: 你知道 JWT、JWE、JWS 、JWK嘛?

想起了 之前做过的 很多 登录授权 的项目 它相比原先的session、cookie来说&#xff0c;更快更安全&#xff0c;跨域也不再是问题&#xff0c;更关键的是更加优雅 &#xff0c;所以今天总结了一篇文章来介绍他 JWT 指JSON Web Token&#xff0c;如果在项目中通过 jjwt 来支持 J…

java jwe/jws_一篇文章带你分清楚JWT,JWS与JWE

随着移动互联网的兴起&#xff0c;传统基于session/cookie的web网站认证方式转变为了基于OAuth2等开放授权协议的单点登录模式(SSO)&#xff0c;相应的基于服务器session浏览器cookie的Auth手段也发生了转变&#xff0c;Json Web Token出现成为了当前的热门的Token Auth机制。 …

JWS实现WebService

WebService估计大家都有听过或者使用过。Java有几种常用的方式实现webservice&#xff0c;本文主要是讨论JWS实现。 什么是webservice 简单而言&#xff0c;webservice就是通过SOAP协议在Web上提供的服务&#xff0c;使用WSDL文件进行说明。其特点是走SOAP协议而不是http协议&…

WebService 理论详解、JWS(Java Web Service) 快速入门

目录 WebService (web服务)概述 WebService 平台技术 WebService 工作原理 WebService 开发流程 常见 Web Service 框架 JWS(Java Web Service) 概述 JWS(Java Web Service) 快速入门 WebService (web服务)概述 1、WebService&#xff08;Web服务&#xff09;是一种跨语…

一文理解 JWT、JWS、JWE、JWA、JWK、JOSE

原文收录 GitBook——统一接口认证解决方案 JsonWebToken 关于JsonWebToken的专业名词解释&#xff1a; unsecured JWT&#xff1a;默认头部{“alg”: “none”}的jwt令牌JWS(SignedJWT)&#xff1a;已签名的jwt,包含标准jwt结构&#xff1a;header、payload、signatureJWE…

JWS入门

JWS简介 JWS主要用来通过网络部署你的应用程序&#xff0c;它具有安全、稳定、易维护、易使用的特点。用户访问用JWS部署应用程序的站点&#xff0c;下载发布的应用程序&#xff0c;既可以在 线运行&#xff0c;也可以通过JWS的客户端离线运行已下载的应用程序。对同一个应用程…

【C语言】判断一个数是否是完全平方数(两种解法)

题目&#xff1a; 判断一个数是否是完全平方数。 以下数字为完全平方数&#xff1a;42*2,93*3,14412*12,16913*13 有两个方法&#xff0c;可以求完全平方数&#xff1a; 方法一&#xff1a;输入一个数&#xff0c;遍历所有比这个数小的数&#xff0c;只要有其中一个数满足条件…

C语言 输入10个数,将其中最小的数与第一个数对换,将最大的数与最后一个数对换

#include <stdio.h> void input(int *number){ //定义输入10个数的函数int i;printf("请输入10个整数:\n");for(i0;i<10;i)scanf("%d",&number[i]); } void max_min_value(int *number){ //交换函数int *max,*min,*p,temp;maxminnumber; //开…

C语言判断一个数是奇数还是偶数

#include <stdio.h> void main() { int n; scanf("%d",&n); //运用scanf函数可以输入想要的数字 //也可以采用int n&#xff08;取一个数&#xff09;进行运算 if(n%20)//if函数注意&#xff0c;%是取余 printf("%d是一个偶…

python判断三位数水仙花数_Python如何判断一个数字是否为水仙花数

水仙花数是一个三位数,并且每一位数字的三次方的和还等于这个数字。 下面我们来看一下如何用Python判断这个数字是否为水仙花数 工具/原料 电脑 Python开发工具 方法/步骤 1 创建一个变量s,用input代码和用户交互,代码如下: s = input("请输入一个数字:"…

c语言判断一个数是否是素数

1&#xff1a;什么是素数 素数就是一个数只能被1和他本身整除的数我们称之为素数。例如13&#xff0c;17&#xff0c;19一类的数。 2&#xff1a;求出一个数是否是素数的思路 素数是只能被1和本身整除的数&#xff0c;那么如果设这个数为n&#xff0c;那么它就不能被2~n-1整…

python用函数判断一个数是否为素数,python分享是否为素数 python输入并判断一个数是否为素数...

python输入并判断一个数是否为素数 x=int(input("x\n")); i=2; for i in range(2,x+1): if(x%i==0): break;if(i==x and i。 用python 判断一个数是否是素数 小编觉得小编的程序是对的但为什么没办法运行,那个弹出来的窗口是啥意思小编曾千万次的请分享:不要逼小编…

python中判断一个数是否为素数_怎么用python判断一个数是否是素数

先来看下什么是质数&#xff1a; 质数(Prime number)&#xff0c;又称素数&#xff0c;指在大于1的自然数中&#xff0c;除了1和该数自身外&#xff0c;无法被其他自然数整除的数(也可定义为只有1与该数本身两个因数的数)。 简单来说就是&#xff0c;只能除以1和自身的数(需要大…

取到一个数的各个位的方法

计算方式如下&#xff1a; 个位&#xff1a;用这个数除以1对10取余&#xff0c;num / 1 % 10; 因为1除以&#xff08;除了0以外&#xff09;任何数都等于这个数的倒数&#xff0c;所以计算个位可以直接对10取余&#xff08;num%10&#xff09; 来获得。 十位&#xff1a;除以…

得到一个数每一位数字的几种方法

1.&#xff08;最简单暴力&#xff09;直接将数字转换为字符串&#xff0c;然后转换为字符数组输出。 int n12345;char[] charsString.valueOf(n).toCharArray();for(int j0;j<chars.length;j){System.out.print(chars[j]" ");}2.整除法。 int n12345;List<Int…

Html5超链接重置为link状态,去除a标签下划线 html超链接更改颜色和去掉下划线

去掉a标签下划线&#xff1a; 对超链接下划线设置 使用代码"text-decoration" 语法&#xff1a; text-decoration : none || underline || blink || overline || line-through text-decoration参数&#xff1a; none : 无装饰 blink : 闪烁 underline : 下划线 line-…

html5 a标签去下划线,css中如何去掉a标签的下划线?

我们在HTML网页制作过程中&#xff0c;相信大家对css文本超链接这个概念并不陌生。我们都知道想要给某段文本或者指定元素添加一个锚点也就是超链接需要用到HTML中的a标签。 那么有的新手可能就会发现&#xff0c;在使用a标签时文本超链接会自动出现下划线&#xff01;从视觉美…

css中怎么消除a的下划线,如何使用css去掉a标签的下划线?(代码详解)

写html超链接的时候&#xff0c;超链接总是自带下划线&#xff0c;如果不需要下划线&#xff0c;我们需要将其去掉&#xff0c;下面我们就来说一下怎么去掉下划线。 我们在使用超链接的时候&#xff0c;下划线总是伴随着出现&#xff0c;从视觉上来说有着下划线的a标签总是感觉…

MySQL数据库常用命令

活动地址&#xff1a;CSDN21天学习挑战赛 1.对数据库常用命令 1.连接数据库 mysql -u用户名 -p密码 2.显示已有数据库 show databases; 3.创建数据库 create database sqlname; 4.选择数据库 use database sqlname; 5.显示数据库中的表&#xff08;先选择数据库&#xff09; sh…

mysql命令更新数据库_命令操作MySQL数据库

一、连接MYSQL 格式: mysql -h主机地址 -u用户名 -p用户密码 1、 连接到本机上的MYSQL。 首先打开DOS窗口,然后进入目录mysql\bin,再键入命令mysql -u root -p,回车后提示你输密码.注意用户名前可以有空格也可以没有空格,但是密码前必须没有空格,否则让你重新输入密码. 如…