torchvision中datasets.MNIST介绍

article/2025/9/22 19:57:46

用法介绍

torchvisiondatasets中所有封装的数据集都是torch.utils.data.Dataset的子类,它们都实现了__getitem____len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。以datasets.MNIST类为例,具体参数和用法如下所示:

CLASS torchvision.datasets.MNIST(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

  • root (string): 表示数据集的根目录,其中根目录存在MNIST/processed/training.ptMNIST/processed/test.pt的子目录
  • train (bool, optional): 如果为True,则从training.pt创建数据集,否则从test.pt创建数据集
  • download (bool, optional): 如果为True,则从internet下载数据集并将其放入根目录。如果数据集已下载,则不会再次下载
  • transform (callable, optional): 接收PIL图片并返回转换后版本图片的转换函数
  • target_transform (callable, optional): 接收PIL接收目标并对其进行变换的转换函数

datasets.MNIST(“mnist-data”)下载mnist数据集之后生成的目录结构如下所示

代码实例

以下是用datasets.MNIST()类加载mnist数据集,并进行训练的完整代码

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fclass CNN(nn.Module):def __init__(self):super().__init__()self.Sq1 = nn.Sequential(         nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),   # (16, 28, 28)                           #  output: (16, 28, 28)nn.ReLU(),                    nn.MaxPool2d(kernel_size=2),    # (16, 14, 14))self.Sq2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),  # (32, 14, 14)nn.ReLU(),                      nn.MaxPool2d(2),                # (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10)   def forward(self, x):x = self.Sq1(x)x = self.Sq2(x)x = x.view(x.size(0), -1)          output = self.out(x)return outputdef train():epoches = 1mnist_net = CNN()mnist_net.train()loss_fn = nn.CrossEntropyLoss()opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 5, shuffle=True)loss = 0 for epoch in range(epoches):for batch_X, batch_Y in train_loader:opitimizer.zero_grad()outputs = mnist_net(batch_X)loss = loss_fn(outputs, batch_Y)loss.backward()opitimizer.step()loss += loss.item()print(loss)if __name__ == '__main__':train()

http://chatgpt.dhexx.cn/article/Ghs6YrJE.shtml

相关文章

万物皆用MNIST---MNIST数据集及创建自己的手写数字数据集

刚刚接触到人工智能的我们,必定会遇到一个非常非常非常熟悉的朋友------MNIST 这是一套流行的手写数字图片,常常被用来测试我们的思想和算法。这个数据集称为手写数字的MNIST数据库,从研究员Yann LeCun 的网站,可以得到这个…

Pytorch 之 MNIST 数据集实现

目录 1. 数据集介绍2. 代码2. 读代码(个人喜欢的顺序)2.1. 导入模块部分:2.2. Main 函数: 1. 数据集介绍 一般而言,MNIST 数据集测试就是机器学习和深度学习当中的"Hello World"工程。几乎是所…

MNIST数据集手写数字识别(CNN)

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集详解及可视化处理(pytorch)

MNIST数据集详解及可视化处理(pytorch) MNIST 数据集已经是一个被”嚼烂”了的数据集, 作为机器学习在视觉领域的“hello world”,很多教程都会对它”下手”, 几乎成为一个 “典范”。 不过有些人可能对它还不是很了解, 下面来介绍一下。 MN…

Mnist数据集介绍

Mnist数据集已经是一个被"嚼烂"了的数据集了,很多关于神经网络的教程都会对它下手。因此在开始深度学习之前,先对这个数据集介绍一下。 Mnist数据集图片格式介绍 Mnist数据集分为两部分,分别含有60000张训练图片和10000张测试图片…

使用MNIST数据集

首先,必须向各位强调的是:该数据集名字叫MNIST,而非MINIST~ 我之前就一直弄错了! 哈哈~ 网上有很多使用MNIST数据集的教程,要么太麻烦,要么需要翻墙下载,很慢。 在这里分…

Fashion MNIST进行分类

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

MNIST数据集简介与使用

MNIST数据集简介 MNIST数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(t…

详解 MNIST 数据集

MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下. MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分: Training set images: train-images-idx3-…

Mnist数据集简介

MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,…

[转]MNIST机器学习入门

MNIST机器学习入门 转自:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html?plg_nld1&plg_uin1&plg_auth1&plg_nld1&plg_usr1&plg_vkey1&plg_dev1 这个教程的目标读者是对机器学习和TensorFlow都不太了解的新手。如…

从手写数字识别入门深度学习丨MNIST数据集详解

就像无数人从敲下“Hello World”开始代码之旅一样,许多研究员从“MNIST数据集”开启了人工智能的探索之路。 MNIST数据集(Mixed National Institute of Standards and Technology database)是一个用来训练各种图像处理系统的二进制图像数据…

Pytorch入门--详解Mnist手写字识别

1 什么是Mnist? Mnist是计算机视觉领域中最为基础的一个数据集。 MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10…

MNIST数据集

一、MNIST数据集介绍 MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取,主要包括四个文件&…

面试官: 你知道 JWT、JWE、JWS 、JWK嘛?

想起了 之前做过的 很多 登录授权 的项目 它相比原先的session、cookie来说,更快更安全,跨域也不再是问题,更关键的是更加优雅 ,所以今天总结了一篇文章来介绍他 JWT 指JSON Web Token,如果在项目中通过 jjwt 来支持 J…

java jwe/jws_一篇文章带你分清楚JWT,JWS与JWE

随着移动互联网的兴起,传统基于session/cookie的web网站认证方式转变为了基于OAuth2等开放授权协议的单点登录模式(SSO),相应的基于服务器session浏览器cookie的Auth手段也发生了转变,Json Web Token出现成为了当前的热门的Token Auth机制。 …

JWS实现WebService

WebService估计大家都有听过或者使用过。Java有几种常用的方式实现webservice,本文主要是讨论JWS实现。 什么是webservice 简单而言,webservice就是通过SOAP协议在Web上提供的服务,使用WSDL文件进行说明。其特点是走SOAP协议而不是http协议&…

WebService 理论详解、JWS(Java Web Service) 快速入门

目录 WebService (web服务)概述 WebService 平台技术 WebService 工作原理 WebService 开发流程 常见 Web Service 框架 JWS(Java Web Service) 概述 JWS(Java Web Service) 快速入门 WebService (web服务)概述 1、WebService(Web服务)是一种跨语…

一文理解 JWT、JWS、JWE、JWA、JWK、JOSE

原文收录 GitBook——统一接口认证解决方案 JsonWebToken 关于JsonWebToken的专业名词解释: unsecured JWT:默认头部{“alg”: “none”}的jwt令牌JWS(SignedJWT):已签名的jwt,包含标准jwt结构:header、payload、signatureJWE…

JWS入门

JWS简介 JWS主要用来通过网络部署你的应用程序,它具有安全、稳定、易维护、易使用的特点。用户访问用JWS部署应用程序的站点,下载发布的应用程序,既可以在 线运行,也可以通过JWS的客户端离线运行已下载的应用程序。对同一个应用程…