Pytorch入门--详解Mnist手写字识别

article/2025/9/22 20:26:59

1 什么是Mnist?

        Mnist是计算机视觉领域中最为基础的一个数据集。

        MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集。Mnist中所有样本都会将原本28*28的灰度图转换为长度为784的一维向量作为输入,其中每个元素分别对应了灰度图中的灰度值。Mnist使用一个长度为10的向量作为该样本所对应的标签,其中向量索引值对应了该样本以该索引为结果的预测概率。
 

2、代码实现

需导入的python库

import torch
import scipy.misc
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torch import optimimport torch.nn as nn
import torch.nn.functional as F

构建模型

# 构建模型(简单的卷积神经网络)
class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size =5, padding = 2) # 卷积self.conv2 = nn.Conv2d(6, 16, 5)# Linear(in_feactures(输入的二维张量大小), out_feactures)self.fc1   = nn.Linear(16*5*5, 120) # 全连接self.fc2   = nn.Linear(120, 84)self.fc3   = nn.Linear(84, 10) # 最后输出10个类def forward(self, x):# 激活函数out = F.relu(self.conv1(x))# max_pool2d(input, kernel_size(卷积核), stride(卷积核步长)=None, padding=0, dilation=1, ceil_mode(空间输入形状)=False, return_indices=False)out = F.max_pool2d(out, kernel_size = 2) # 池化out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)# 将多维的的数据平铺为一维out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out

训练集

def train():# 学习率0.001learning_rate = 1e-3# 单次大小batch_size = 100# 总的循环epoches = 50lenet = LeNet()# 1、数据集准备# 这个函数包括了两个操作:transforms.ToTensor()将图片转换为张量,transforms.Normalize()将图片进行归一化处理trans_img = transforms.Compose([transforms.ToTensor()])# path = './data/'数据集下载后保存的目录,下载训练集trainset = MNIST('./data', train=True, transform=trans_img, download=True)# 构建数据集的DataLoader,# Pytorch自提供了DataLoader的方法来进行训练,该方法自动将数据集打包成为迭代器,能够让我们很方便地进行后续的训练处理# 迭代器(iterable)是一个超级接口! 是可以遍历集合的对象,trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=10)# 2、构建迭代器与损失函数criterian = nn.CrossEntropyLoss(reduction='sum')  # loss(损失函数)optimizer = optim.SGD(lenet.parameters(), lr=learning_rate)  # optimizer(迭代器)# 如果网络能在GPU中训练,就使用GPU;否则使用CPU进行训练device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#lenet.to("cpu")# 3、训练for i in range(epoches):running_loss = 0.running_acc = 0.for (img, label) in trainloader:  # 将图像和标签传输进device中optimizer.zero_grad()  # 求梯度之前对梯度清零以防梯度累加output=lenet(img)  # 对模型进行前向推理loss=criterian(output,label)  # 计算本轮推理的Loss值loss.backward()    # loss反传存到相应的变量结构当中optimizer.step()   # 使用计算好的梯度对参数进行更新running_loss+=loss.item()#print(output)_,predict=torch.max(output,1)  # 计算本轮推理的准确率correct_num=(predict==label).sum()running_acc+=correct_num.item()running_loss/=len(trainset)running_acc/=len(trainset)print("[%d/%d] Loss: %.5f, Acc: %.2f" % (i + 1, epoches, running_loss,100 * running_acc))return lenet

 测试集

def test(lenet):batch_size = 100trans_img = transforms.Compose([transforms.ToTensor()])testset = MNIST('./data', train=False, transform=trans_img, download=True)testloader = DataLoader(testset, batch_size, shuffle=False, num_workers=10)running_acc = 0.for (img, label) in testloader:output = lenet(img)_, predict = torch.max(output, 1)correct_num = (predict == label).sum()running_acc += correct_num.item()running_acc /= len(testset)return running_acc

主函数

if __name__ == '__main__':lenet = train()torch.save(lenet, 'lenet.pkl') # save modellenet = torch.load('lenet.pkl') # load modeltest_acc = test(lenet)print("Test Accuracy:Loss: %.2f" % test_acc)

结果:

 继上面对minst手写数字集进行训练和测试的完成,现将黑底白字的0~9的数字图片,进行识别

识别函数代码如下

def practice(img_path):img = Image.open(img_path)img = img.convert('L')prac_img = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor()])pracset = MNIST('./data', train=True, transform=prac_img, download=True)img = prac_img(img)img = torch.reshape(img, (1, 1, 28, 28))lenet = torch.load('lenet.pkl')  # load modeloutput = lenet(img)output = output.argmax(1)dict_target = pracset.class_to_idxdict_target = [indx for indx, vale in dict_target.items()]  # 获得标签字典print('识别类型为{}'.format(dict_target[output]))

主函数调用(图片路径按自身修改)

if __name__ == '__main__':practice('0.jpg')practice('1.jpg')practice('2.jpg')practice('3.jpg')practice('4.jpg')practice('5.jpg')practice('6.jpg')practice('7.jpg')practice('8.jpg')practice('9.jpg')

识别结果如下所示:


http://chatgpt.dhexx.cn/article/a78hYlgo.shtml

相关文章

MNIST数据集

一、MNIST数据集介绍 MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取,主要包括四个文件&…

面试官: 你知道 JWT、JWE、JWS 、JWK嘛?

想起了 之前做过的 很多 登录授权 的项目 它相比原先的session、cookie来说,更快更安全,跨域也不再是问题,更关键的是更加优雅 ,所以今天总结了一篇文章来介绍他 JWT 指JSON Web Token,如果在项目中通过 jjwt 来支持 J…

java jwe/jws_一篇文章带你分清楚JWT,JWS与JWE

随着移动互联网的兴起,传统基于session/cookie的web网站认证方式转变为了基于OAuth2等开放授权协议的单点登录模式(SSO),相应的基于服务器session浏览器cookie的Auth手段也发生了转变,Json Web Token出现成为了当前的热门的Token Auth机制。 …

JWS实现WebService

WebService估计大家都有听过或者使用过。Java有几种常用的方式实现webservice,本文主要是讨论JWS实现。 什么是webservice 简单而言,webservice就是通过SOAP协议在Web上提供的服务,使用WSDL文件进行说明。其特点是走SOAP协议而不是http协议&…

WebService 理论详解、JWS(Java Web Service) 快速入门

目录 WebService (web服务)概述 WebService 平台技术 WebService 工作原理 WebService 开发流程 常见 Web Service 框架 JWS(Java Web Service) 概述 JWS(Java Web Service) 快速入门 WebService (web服务)概述 1、WebService(Web服务)是一种跨语…

一文理解 JWT、JWS、JWE、JWA、JWK、JOSE

原文收录 GitBook——统一接口认证解决方案 JsonWebToken 关于JsonWebToken的专业名词解释: unsecured JWT:默认头部{“alg”: “none”}的jwt令牌JWS(SignedJWT):已签名的jwt,包含标准jwt结构:header、payload、signatureJWE…

JWS入门

JWS简介 JWS主要用来通过网络部署你的应用程序,它具有安全、稳定、易维护、易使用的特点。用户访问用JWS部署应用程序的站点,下载发布的应用程序,既可以在 线运行,也可以通过JWS的客户端离线运行已下载的应用程序。对同一个应用程…

【C语言】判断一个数是否是完全平方数(两种解法)

题目: 判断一个数是否是完全平方数。 以下数字为完全平方数:42*2,93*3,14412*12,16913*13 有两个方法,可以求完全平方数: 方法一:输入一个数,遍历所有比这个数小的数,只要有其中一个数满足条件…

C语言 输入10个数,将其中最小的数与第一个数对换,将最大的数与最后一个数对换

#include <stdio.h> void input(int *number){ //定义输入10个数的函数int i;printf("请输入10个整数:\n");for(i0;i<10;i)scanf("%d",&number[i]); } void max_min_value(int *number){ //交换函数int *max,*min,*p,temp;maxminnumber; //开…

C语言判断一个数是奇数还是偶数

#include <stdio.h> void main() { int n; scanf("%d",&n); //运用scanf函数可以输入想要的数字 //也可以采用int n&#xff08;取一个数&#xff09;进行运算 if(n%20)//if函数注意&#xff0c;%是取余 printf("%d是一个偶…

python判断三位数水仙花数_Python如何判断一个数字是否为水仙花数

水仙花数是一个三位数,并且每一位数字的三次方的和还等于这个数字。 下面我们来看一下如何用Python判断这个数字是否为水仙花数 工具/原料 电脑 Python开发工具 方法/步骤 1 创建一个变量s,用input代码和用户交互,代码如下: s = input("请输入一个数字:"…

c语言判断一个数是否是素数

1&#xff1a;什么是素数 素数就是一个数只能被1和他本身整除的数我们称之为素数。例如13&#xff0c;17&#xff0c;19一类的数。 2&#xff1a;求出一个数是否是素数的思路 素数是只能被1和本身整除的数&#xff0c;那么如果设这个数为n&#xff0c;那么它就不能被2~n-1整…

python用函数判断一个数是否为素数,python分享是否为素数 python输入并判断一个数是否为素数...

python输入并判断一个数是否为素数 x=int(input("x\n")); i=2; for i in range(2,x+1): if(x%i==0): break;if(i==x and i。 用python 判断一个数是否是素数 小编觉得小编的程序是对的但为什么没办法运行,那个弹出来的窗口是啥意思小编曾千万次的请分享:不要逼小编…

python中判断一个数是否为素数_怎么用python判断一个数是否是素数

先来看下什么是质数&#xff1a; 质数(Prime number)&#xff0c;又称素数&#xff0c;指在大于1的自然数中&#xff0c;除了1和该数自身外&#xff0c;无法被其他自然数整除的数(也可定义为只有1与该数本身两个因数的数)。 简单来说就是&#xff0c;只能除以1和自身的数(需要大…

取到一个数的各个位的方法

计算方式如下&#xff1a; 个位&#xff1a;用这个数除以1对10取余&#xff0c;num / 1 % 10; 因为1除以&#xff08;除了0以外&#xff09;任何数都等于这个数的倒数&#xff0c;所以计算个位可以直接对10取余&#xff08;num%10&#xff09; 来获得。 十位&#xff1a;除以…

得到一个数每一位数字的几种方法

1.&#xff08;最简单暴力&#xff09;直接将数字转换为字符串&#xff0c;然后转换为字符数组输出。 int n12345;char[] charsString.valueOf(n).toCharArray();for(int j0;j<chars.length;j){System.out.print(chars[j]" ");}2.整除法。 int n12345;List<Int…

Html5超链接重置为link状态,去除a标签下划线 html超链接更改颜色和去掉下划线

去掉a标签下划线&#xff1a; 对超链接下划线设置 使用代码"text-decoration" 语法&#xff1a; text-decoration : none || underline || blink || overline || line-through text-decoration参数&#xff1a; none : 无装饰 blink : 闪烁 underline : 下划线 line-…

html5 a标签去下划线,css中如何去掉a标签的下划线?

我们在HTML网页制作过程中&#xff0c;相信大家对css文本超链接这个概念并不陌生。我们都知道想要给某段文本或者指定元素添加一个锚点也就是超链接需要用到HTML中的a标签。 那么有的新手可能就会发现&#xff0c;在使用a标签时文本超链接会自动出现下划线&#xff01;从视觉美…

css中怎么消除a的下划线,如何使用css去掉a标签的下划线?(代码详解)

写html超链接的时候&#xff0c;超链接总是自带下划线&#xff0c;如果不需要下划线&#xff0c;我们需要将其去掉&#xff0c;下面我们就来说一下怎么去掉下划线。 我们在使用超链接的时候&#xff0c;下划线总是伴随着出现&#xff0c;从视觉上来说有着下划线的a标签总是感觉…

MySQL数据库常用命令

活动地址&#xff1a;CSDN21天学习挑战赛 1.对数据库常用命令 1.连接数据库 mysql -u用户名 -p密码 2.显示已有数据库 show databases; 3.创建数据库 create database sqlname; 4.选择数据库 use database sqlname; 5.显示数据库中的表&#xff08;先选择数据库&#xff09; sh…