pytorch 图片分类,python 图片分类,resnet18 图片分类

article/2025/11/7 13:44:59

pytorch 图片分类,python 图片分类,resnet18 图片分类,深度学习 图片分类

pytorch版本:1.5.0+cu101

全部源码,可以直接运行。

下载地址:https://download.csdn.net/download/TangLingBo/12598435

网络是用 resnet18 ,可以修改图片的大小,默认是32 x32 

如果出现需要下载的文件或者问题可以联系:QQ 1095788063

图片结构:

测试结果:

训练代码:

import torch as t
import torchvision as tv
import os
import time
import numpy as np
from tqdm import tqdm# 一些参数配置
class DefaultConfigs(object):data_dir = "./imageData/"  # 图片目录data_list = ["train", "test"]  # train=训练集,test=测试集lr = 0.001  # 学习率(默认值:1e-3epochs = 51  # 训练次,越多就越好num_classes = 10  # 分类image_size = 32  # 图片大小 ,可以改,因为用的是 resnet18 的网络,越大就越慢batch_size = 40  # 批量大小,看自己电脑的配置,需要占用 CPU或者GPU资源channels = 3  # 通道数use_gpu = t.cuda.is_available()  # 启用gpu,如果电脑不支持,直接设置为 False ,GPU 训练效果最好config = DefaultConfigs()
config.use_gpu = False  # 我的电脑不支持,设置为 False# 对Tensor进行变换 颜色转换   mean=给定均值:(R,G,B) std=方差:(R,G,B)
normalize = tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])# Train数据需要进行随机裁剪,Test数据不要进行裁剪了
transform = {# tv.transforms.Resize 用于重设图片大小   train  训练集数据# tv.transforms.CenterCrop([224,224])   将给定的PIL.Image进行中心切割config.data_list[0]: tv.transforms.Compose([tv.transforms.Resize([config.image_size, config.image_size]),tv.transforms.CenterCrop([config.image_size,config.image_size]),tv.transforms.ToTensor(), normalize]),# test 测试数据config.data_list[1]: tv.transforms.Compose([tv.transforms.Resize([config.image_size, config.image_size]),tv.transforms.ToTensor(),normalize])
}# 数据集
datasets = {x: tv.datasets.ImageFolder(root=os.path.join(config.data_dir, x), transform=transform[x])for x in config.data_list
}# 数据加载器
dataloader = {x: t.utils.data.DataLoader(dataset=datasets[x],batch_size=config.batch_size,shuffle=True)for x in config.data_list
}# 构建网络模型 resnet18
def get_model(num_classes):#resnet18 好像要下载什么的,忘记了,可以联系我model = tv.models.resnet18(pretrained=True)# 梯度什么的,电脑硬件支持,可以把下述代码屏蔽,则训练整个网络,最终准确率会上升,训练数据会变慢# for parma in model.parameters():#  parma.requires_grad = Falsemodel.fc = t.nn.Sequential(t.nn.Dropout(p=0.3), t.nn.Linear(512, num_classes))return model# 训练模型(支持自动GPU加速)
def train(epochs):model = get_model(config.num_classes)loss_f = t.nn.CrossEntropyLoss()# GPUif config.use_gpu:model = model.cuda()loss_f = loss_f.cuda()opt = t.optim.Adam(model.fc.parameters(), lr=config.lr)# 时间time_start = time.time()for epoch in range(epochs):train_loss = []train_acc = []test_loss = []test_acc = []model.train(True)  # 将模块设置为训练模式print("Epoch {}/{}".format(epoch + 1, epochs))for batch, datas in tqdm(enumerate(iter(dataloader["train"]))):x, y = datas# 开启GPU 加速if config.use_gpu:x, y = x.cuda(), y.cuda()y_ = model(x)# print(x.shape, y.shape, y_.shape)_, pre_y_ = t.max(y_, 1)pre_y = y# print(y_.shape)loss = loss_f(y_, pre_y)# print(y_.shape)acc = t.sum(pre_y_ == pre_y)loss.backward()opt.step()opt.zero_grad()if config.use_gpu:loss = loss.cpu()acc = acc.cpu()train_loss.append(loss.data)train_acc.append(acc)time_end = time.time()print("正式 批次 {}, Train 损失:{:.4f}, Train 准确率:{:.4f}, 训练时间: {}".format(batch + 1,np.mean(train_loss) / config.batch_size,np.mean(train_acc) / config.batch_size,(time_end - time_start)))model.train(False)  # 关闭训练模式for batch, datas in tqdm(enumerate(iter(dataloader["test"]))):x, y = datasif config.use_gpu:x, y = x.cuda(), y.cuda()y_ = model(x)# print(x.shape,y.shape,y_.shape)_, pre_y_ = t.max(y_, 1)pre_y = y# print(y_.shape)loss = loss_f(y_, pre_y)acc = t.sum(pre_y_ == pre_y)if config.use_gpu:loss = loss.cpu()acc = acc.cpu()test_loss.append(loss.data)test_acc.append(acc)print("测试 批次 {}, 损失:{:.4f}, 准确率:{:.4f}".format(batch + 1, np.mean(test_loss) / config.batch_size,np.mean(test_acc) / config.batch_size))t.save(model, 'model/' + str(epoch + 1) + "_ttmodel.pkl")  # 保存整个神经网络的结构和模型参数t.save(model.state_dict(), 'model/' + str(epoch + 1) + "_ttmodel_params.pkl")  # 只保存神经网络的模型参数print('训练结束')#开始训练
if __name__ == "__main__":train(config.epochs)

调用代码:

import torch as t
import torchvision as tv
from PIL import Image
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as npbCuda = t.cuda.is_available()  # 是否开启 GPU
bCuda = False  # 不启用GPU 我的电脑不支持
device = t.device("cuda:0" if bCuda else "cpu")img_size = 32  # 图片大小,可以改# 对Tensor进行变换 颜色转换   mean=给定均值:(R,G,B) std=方差:(R,G,B)
normalize = tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = tv.transforms.Compose([tv.transforms.Resize([img_size, img_size]), tv.transforms.CenterCrop([img_size, img_size]),tv.transforms.ToTensor(), normalize])# 分类数组
classes = ['凹下标志-0', '凸上标志-1', '打滑标志-2', '左弯标志-3', '右弯标志-4', '连续转弯标志-5', '00020-6', '00021-7', '00022-8', '00023-9']# 显示图片方法
def imshow(img):plt.imshow(img)plt.show()# 单张图片调用
def prediect(model, img_path, imgType, isShowSoftmax=False, isShowImg=False):t.no_grad()image_PIL = Image.open(img_path)# imshow(image_PIL)image_tensor = transform(image_PIL)# 以下语句等效于 img = torch.unsqueeze(image_tensor, 0)image_tensor.unsqueeze_(0)# 没有这句话会报错image_tensor = image_tensor.to(device)out = model(image_tensor)# 得到预测结果,并且从大到小排序_, indices = t.sort(out, descending=True)# 返回每个预测值的百分数percentage = t.nn.functional.softmax(out, dim=1)[0] * 100# 是否显示每个分类的预测值item = indices[0]if isShowSoftmax:for idx in item:ss = percentage[idx]value = ss.item();name = classes[idx]print('名称:', name, '预测值:', value)# 预测最大值_, predicted = t.max(out.data, 1)maxPredicted = classes[predicted.item()]maxAccuracy = percentage[item[0]].item()if imgType == maxPredicted:print('预测正确,预测结果:', maxPredicted, '预测值:', maxAccuracy)else:print('预测错误,正确结果:', imgType, ',预测结果:', maxPredicted, '预测值:', maxAccuracy, '图片:', img_path)if isShowImg:plt.imshow(image_PIL)plt.show()# 构建网络模型 resnet18
def get_model(num_classes):# resnet18 好像要下载什么的,忘记了,可以联系我model = tv.models.resnet18(pretrained=True)# 梯度什么的,电脑硬件支持,可以把下述代码屏蔽,则训练整个网络,最终准确率会上升,训练数据会变慢# for parma in model.parameters():#  parma.requires_grad = Falsemodel.fc = t.nn.Sequential(t.nn.Dropout(p=0.3), t.nn.Linear(512, num_classes))return model# 测试集
def loadtestdata():path = "./imageData/test/"testset = tv.datasets.ImageFolder(path, transform=transform)testloader = t.utils.data.DataLoader(testset, batch_size=40, shuffle=True, num_workers=6)return testloader# 测试全部
def testAll(model):testloader = loadtestdata()dataiter = iter(testloader)images, labels = dataiter.next()print(labels)print('真实值: ', " ".join('%5s' % classes[labels[j]] for j in range(25)))  # 打印前25个GT(test集里图片的标签)outputs = model(Variable(images))_, predicted = t.max(outputs.data, 1)print('预测值: ', " ".join('%5s' % classes[predicted[j]] for j in range(25)))# 打印前25个预测值imshow2(tv.utils.make_grid(images, nrow=5))  # nrow是每行显示的图片数量,缺省值为8def imshow2(img):img = img / 2 + 0.5  # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()if __name__ == '__main__':# 直接加载model = t.load('model/51_ttmodel.pkl')# 加载2 ,看官方的解释# model = get_model(classes.__len__())  # 10 分类数量# load_weights = t.load('model/51_ttmodel_params.pkl', map_location='cpu')# model.load_state_dict(load_weights)model = model.to(device)  # GPUmodel.eval()  # 运行模式# 测试全部图片testAll(model)# 测试一张图片# # 凹下标志-0# prediect(model,'imageData/test/00000/01160_00000.png', classes[0], False, False)# prediect(model,'imageData/test/00000/01160_00001.png', classes[0], False, False)# prediect(model,'imageData/test/00000/01160_00002.png', classes[0], False, False)# prediect(model,'imageData/test/00000/01798_00000.png', classes[0], False, False)# prediect(model,'imageData/test/00000/01798_00001.png', classes[0], False, False)# prediect(model,'imageData/test/00000/01798_00002.png', classes[0], False, False)## # 凸上标志-1# prediect(model,'imageData/test/00001/00029_00000.png', classes[1], False, False)# prediect(model,'imageData/test/00001/00029_00001.png', classes[1], False, False)# prediect(model,'imageData/test/00001/00029_00002.png', classes[1], False, False)# prediect(model,'imageData/test/00001/00079_00000.png', classes[1], False, False)# prediect(model,'imageData/test/00001/00079_00002.png', classes[1], False, False)# prediect(model,'imageData/test/00001/00079_00001.png', classes[1], False, False)## # 打滑标志-2# prediect(model,'imageData/test/00002/01503_00000.png', classes[2], False, False)# prediect(model,'imageData/test/00002/01503_00001.png', classes[2], False, False)# prediect(model,'imageData/test/00002/01503_00002.png', classes[2], False, False)# prediect(model,'imageData/test/00002/01515_00000.png', classes[2], False, False)# prediect(model,'imageData/test/00002/01515_00001.png', classes[2], False, False)# prediect(model,'imageData/test/00002/01515_00002.png', classes[2], False, False)## # 左弯标志-3# prediect(model,'imageData/test/00003/00207_00000.png', classes[3], False, False)# prediect(model,'imageData/test/00003/00207_00001.png', classes[3], False, False)# prediect(model,'imageData/test/00003/00207_00002.png', classes[3], False, False)# prediect(model,'imageData/test/00003/00211_00000.png', classes[3], False, False)# prediect(model,'imageData/test/00003/00211_00001.png', classes[3], False, False)# prediect(model,'imageData/test/00003/00211_00002.png', classes[3], False, False)# prediect(model,'imageData/test/00003/02664_00000.png', classes[3], False, False)# prediect(model,'imageData/test/00003/02664_00001.png', classes[3], False, False)# prediect(model,'imageData/test/00003/02664_00002.png', classes[3], False, False)## # 右弯标志-4# prediect(model,'imageData/test/00004/00214_00000.png', classes[4], False, False)# prediect(model,'imageData/test/00004/00214_00001.png', classes[4], False, False)# prediect(model,'imageData/test/00004/00214_00002.png', classes[4], False, False)# prediect(model,'imageData/test/00004/00282_00000.png', classes[4], False, False)# prediect(model,'imageData/test/00004/00282_00001.png', classes[4], False, False)# prediect(model,'imageData/test/00004/00282_00002.png', classes[4], False, False)# prediect(model,'imageData/test/00004/02567_00000.png', classes[4], False, False)# prediect(model,'imageData/test/00004/02567_00001.png', classes[4], False, False)# prediect(model,'imageData/test/00004/02567_00002.png', classes[4], False, False)# prediect(model,'imageData/test/00004/02660_00000.png', classes[4], False, False)# prediect(model,'imageData/test/00004/02660_00001.png', classes[4], False, False)# prediect(model,'imageData/test/00004/02660_00002.png', classes[4], False, False)## # 连续转弯标志-5# prediect(model,'imageData/test/00005/00575_00000.png', classes[5], False, False)# prediect(model,'imageData/test/00005/00575_00001.png', classes[5], False, False)# prediect(model,'imageData/test/00005/00575_00002.png', classes[5], False, False)# prediect(model,'imageData/test/00005/01893_00000.png', classes[5], False, False)# prediect(model,'imageData/test/00005/01893_00001.png', classes[5], False, False)# prediect(model,'imageData/test/00005/01893_00002.png', classes[5], False, False)# prediect(model,'imageData/test/00005/02225_00000.png', classes[5], False, False)# prediect(model,'imageData/test/00005/02225_00001.png', classes[5], False, False)# prediect(model,'imageData/test/00005/02225_00002.png', classes[5], False, False)### # 00020-6# prediect(model,'imageData/test/00020/00230_00000.png', classes[6], False, False)# prediect(model, 'imageData/test/00020/00230_00001.png', classes[6], True, True)# prediect(model,'imageData/test/00020/00230_00002.png', classes[6], False, False)# prediect(model,'imageData/test/00020/00231_00000.png', classes[6], False, False)# prediect(model,'imageData/test/00020/00231_00001.png', classes[6], False, False)# prediect(model,'imageData/test/00020/00231_00002.png', classes[6], False, False)## # 00021-7# prediect(model, 'imageData/test/00021/00375_00000.png', classes[7], False, False)# prediect(model, 'imageData/test/00021/00375_00001.png', classes[7], False, False)# prediect(model, 'imageData/test/00021/00375_00002.png', classes[7], False, False)# prediect(model, 'imageData/test/00021/00478_00000.png', classes[7], False, False)# prediect(model, 'imageData/test/00021/00478_00001.png', classes[7], False, False)# prediect(model, 'imageData/test/00021/00478_00002.png', classes[7], False, False)## # 00022-8# prediect(model, 'imageData/test/00022/00020_00000.png', classes[8], False, False)# prediect(model, 'imageData/test/00022/00020_00001.png', classes[8], False, False)# prediect(model, 'imageData/test/00022/00020_00002.png', classes[8], False, False)# prediect(model, 'imageData/test/00022/00048_00000.png', classes[8], False, False)# prediect(model, 'imageData/test/00022/00048_00001.png', classes[8], False, False)# prediect(model, 'imageData/test/00022/00048_00002.png', classes[8], False, False)## # 00023-9# prediect(model, 'imageData/test/00023/00465_00000.png', classes[9], False, False)# prediect(model, 'imageData/test/00023/00465_00001.png', classes[9], False, False)# prediect(model, 'imageData/test/00023/00465_00002.png', classes[9], False, False)# prediect(model, 'imageData/test/00023/00535_00000.png', classes[9], False, False)# prediect(model, 'imageData/test/00023/00535_00001.png', classes[9], False, False)# prediect(model, 'imageData/test/00023/00535_00002.png', classes[9], False, False)

 


http://chatgpt.dhexx.cn/article/159NPvXr.shtml

相关文章

深度学习图片分类实战学习

开始记录学习深度学习的点点滴滴 深度学习图片分类实战学习 前言一、深度学习二、使用步骤1. 自建网络模型2. 进行深度学习的学习迁移 注意事项 前言 随着人工智能的不断发展,这门技术也越来越重要,很多人都开启了学习人工智能,本人开始记录…

关于图片的多标签分类(1)

最近还在处理人脸附件(眼镜,刘海,口罩,帽子)的multi-label分类。给自己普及一下常识性问题: 1)什么是multi-label分类? multi-label分类,常见一张图片中可以存在多个目…

svm实现图片分类(python)

目录 前言 knn vs. svm svm & linear classifier bias trick loss function regularization optimization 代码主体 导入数据及预处理 svm计算loss_function和梯度 验证梯度公式是否正确 比较运行时间 svm训练及预测,结果可视化 通过corss-validat…

图片分类-python

目的:做一个简易的图片分类。 使用到的算法:hog、surfsvm 图片集:cifar-10、cifar-100、stl-10、自制图片集 分类完整代码链接 使用说明: 1.cifar-10、cifar-100和stl-10直接解压 2.自制图片集文件夹结构: ├…

CNN图片分类

最近在阅读一些AI项目,写入markdown,持续更新,算是之后也能回想起做法 项目 https://github.com/calssion/Fun_AI image classify(图片分类) CNN classify dogs and cats(猫狗二分类) Tutorial(教程):https://developers.google.com/mach…

深度学习之图像分类

第一篇CSDN文章,写的不好,还请各位大佬指正。万事开头难,千里之行始于足下! 1.什么是图像分类 图像分类,核心是从给定的分类集合中给图像分配一个标签的任务。实际上,这意味着我们的任务是分析一个输入图…

关于图像分类

https://www.zhihu.com/question/57075015/answer/194397802https://www.zhihu.com/question/57075015/answer/194397802 先定义一下图像分类,一般而言,图像分类分为通用类别分类以及细粒度图像分类 那什么是通用类别以及细粒度类别呢?这里…

(一)图像分类任务介绍 Image Classification

目录 一、什么是图像分类任务?它有哪些应用场景? 二、图像分类任务的难点? 三、基于规则的方法是否可行? 四、什么是数据驱动的图像分类范式? 数据集构建 分类器设计与学习 分类器决策 五、常用的分类任务评价指…

图像分类的数据集

图像分类的数据集 1. MNIST2. Fashion-MNIST3.CIFAR-10和CIFAR-1004. Caltech 1015. ImageNet5.1 ImageNet是什么?5.2 ILSVRC 6. 各个数据集上的最新进展其他参考资料 1. MNIST MNIST数据集的一个样例 一般机器学习框架都使用MNIST作为入门,就像"He…

机器学习——图像分类

1 图像分类的概念 1.1 什么是图像分类? 图像分类,根据图像信息中所反映出来的不同特征,把不同类别的目标区分开来的图像处理方法 1.2 图像分类的难度 ●任何拍摄情 况的改变都将提升分类的难度 1.3 CNN如何进行图像分类 ●数据驱动型方法通…

图像分类算法

图像分类 参考链接1.前言2.K近邻与KMeans算法比较KNN原理和实现过程(1) 计算已知类别数据集中的点与当前点之间的距离:(2) 按照距离递增次序排序(3) 选取与当前点距离最小的k个点(4) 确定前k个点所在类别的出现频率(5) 返回前k个点出现频率最高的类别作为当前点的预…

图像分类方法总结

1. 图像分类问题描述 图像分类问题是计算机视觉领域的基础问题,它的目的是根据图像的语义信息将不同类别图像区分开来,实现最小的分类误差。具体任务要求是从给定的分类集合中给图像分配一个标签的任务。总体来说,对于单标签的图像分类问题&…

9.图片分类数据集

1. 图像分类数据集 MNIST数据集 [LeCun et al., 1998] 是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。 我们将使用类似但更复杂的Fashion-MNIST数据集。 %matplotlib inline import torch import torchvision from torch.utils import data from t…

CNN实现花卉图片分类识别

CNN实现花卉图片分 前言 本文为一个利用卷积神经网络实现花卉分类的项目,因此不会过多介绍卷积神经网络的基本知识。此项目建立在了解卷积神经网络进行图像分类的原理上进行的。 项目简介 本项目为一个图像识别项目,基于tensorflow,利用C…

常用图像分类网络

想对图像分类网络写个简要的概括,如有介绍不当之处,还望指出。 一、VGG网络 更新于2018年10月20日 参考博客:深度学习经典卷积神经网络之VGGNet 论文地址:VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITIO…

干货——图像分类(上)

这是译自斯坦福CS231n课程笔记image classification notes,由课程教师Andrej Karpathy授权进行翻译。本篇教程由杜客翻译完成。非常感谢那些无偿奉献的大师,在此代表所有爱好学习者向您们致敬,谢谢! 这是斯坦福大学的课程&#xf…

图像分类

图像物体分类与检测算法综述 转自《计算机学报》 目录 图像物体分类与检测算法综述 目录图像物体分类与检测概述物体分类与检测的难点与挑战物体分类与检测数据库物体分类与检测发展历程 图像物体分类与检测是计算机视觉研究中的两个重要的基本问题,也是图像分割、…

【图像分类数据集】非常全面实用的垃圾分类图片数据集共享

【图像分类数据集】非常全面实用的垃圾分类图片数据集共享 数据集介绍: 训练集 文件夹结构如下(部分: 第0类文件夹下数据展示如下(部分: 测试集 大致如下: 数据集获取方式: 总结&#xf…

python学习(18)--图片分类

图片分类 学习动机. 在这一节中我们会引入图片分类为题。这也是从一个合适的集合中分配给图片一个标记的任务。这是计算机视觉的核心问题之一。鉴于它的简单性,有一大批实用应用。更多的是,我们可以在以后的章节中看到,一些看似分离的计算机…

【OpenMMLab】图片分类发展简史

一、发展简述 图片分类是CV领域的基础任务,也是检测、分割、追踪等任务的基石。简而言之,图片分类就是给定一张图片,判断其类别,一般而言所有的候选类别是预设的。 从数学上描述,图片分类就是寻找一个函数&#xff0…