pytorch 计算混淆矩阵

article/2025/8/25 0:01:10

混淆矩阵是评估模型结果的一种指标 用来判断分类模型的好坏

 预测对了 为对角线 

还可以通过矩阵的上下角发现哪些容易出错

从这个 矩阵出发 可以得到 acc != precision recall  特异度?

 

 目标检测01笔记AP mAP recall precision是什么 查全率是什么 查准率是什么 什么是准确率 什么是召回率_:)�东东要拼命的博客-CSDN博客

 acc  是对所有类别来说的

其他三个都是 对于类别来说的

下面给出源码 

import json
import osimport matplotlib.pyplot as plt
import numpy as np
import torch
from prettytable import PrettyTable
from torchvision import datasets
from torchvision.models import MobileNetV2
from torchvision.transforms import transformsclass ConfusionMatrix(object):"""注意版本问题,使用numpy来进行数值计算的"""def __init__(self, num_classes: int, labels: list):self.matrix = np.zeros((num_classes, num_classes))self.num_classes = num_classesself.labels = labelsdef update(self, preds, labels):for p, t in zip(preds, labels):self.matrix[t, p] += 1# 行代表预测标签 列表示真实标签def summary(self):# calculate accuracysum_TP = 0for i in range(self.num_classes):sum_TP += self.matrix[i, i]acc = sum_TP / np.sum(self.matrix)print("acc is", acc)# precision, recall, specificitytable = PrettyTable()table.fields_names = ["", "pre", "recall", "spec"]for i in range(self.num_classes):TP = self.matrix[i, i]FP = np.sum(self.matrix[i, :]) - TPFN = np.sum(self.matrix[:, i]) - TPTN = np.sum(self.matrix) - TP - FP - FNpre = round(TP / (TP + FP), 3)    # round 保留三位小数recall = round(TP / (TP + FN), 3)spec = round(TN / (FP + FN), 3)table.add_row([self.labels[i], pre, recall, spec])print(table)def plot(self):matrix = self.matrixprint(matrix)plt.imshow(matrix, cmap=plt.cm.Blues)  # 颜色变化从白色到蓝色# 设置 x  轴坐标 labelplt.xticks(range(self.num_classes), self.labels, rotation=45)# 将原来的 x 轴的数字替换成我们想要的信息 self.num_classes  x 轴旋转45度# 设置 y  轴坐标 labelplt.yticks(range(self.num_classes), self.labels)# 显示 color bar  可以通过颜色的密度看出数值的分布plt.colorbar()plt.xlabel("true_label")plt.ylabel("Predicted_label")plt.title("ConfusionMatrix")# 在图中标注数量 概率信息thresh = matrix.max() / 2# 设定阈值来设定数值文本的颜色 开始遍历图像的时候一般是图像的左上角for x in range(self.num_classes):for y in range(self.num_classes):# 这里矩阵的行列交换,因为遍历的方向 第y行 第x列info = int(matrix[y, x])plt.text(x, y, info,verticalalignment='center',horizontalalignment='center',color="white" if info > thresh else "black")plt.tight_layout()# 图形显示更加的紧凑plt.show()if __name__ ==' __main__':device = torch.device("cuda:0" if torch.cuda.is_available()else "cpu")print(device)# 使用验证集的预处理方式data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor()transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])data_loot = os.path.abspath(os.path.join(os.getcwd(), "../.."))# get data root pathimage_path = data_loot + "/data_set/flower_data/"# flower data set pathvalidate_dataset = datasets.ImageFolder(root=image_path +"val",transform=data_transform)batch_size = 16validate_loader = torch.utils.data.DataLoder(validate_dataset,batch_size=batch_size,shuffle=False,num_workers=2)net = MobileNetV2(num_classes=5)#加载预训练的权重model_weight_path = "./MobileNetV2.pth"net.load_state_dict(torch.load(model_weight_path, map_location=device))net.to(device)#read class_indicttry:json_file = open('./class_indicts.json', 'r')class_indict = json.load(json_file)except Exception as e:print(e)exit(-1)labels = [label for _, label in class_indict.item()]# 通过json文件读出来的labelconfusion = ConfusionMatrix(num_classes=5, labels=labels)net.eval()# 启动验证模式# 通过上下文管理器  no_grad  来停止pytorch的变量对梯度的跟踪with torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))outputs = torch.softmax(outputs, dim=1)outputs = torch.argmax(outputs, dim=1)# 获取概率最大的元素confusion.update(outputs.numpy(), val_labels.numpy())# 预测值和标签值confusion.plot()# 绘制混淆矩阵confusion.summary()# 来打印各个指标信息

是这样的 这篇算是一个学习笔记,其中的基础图都源于我的导师

 霹雳吧啦Wz的个人空间_哔哩哔哩_bilibili

欢迎无依无靠的CV同学加入 

讲的非常好 代码其实也是导师给的 

我能做的就是读懂每一行加点注释

给不想看视频的同学留点时间


http://chatgpt.dhexx.cn/article/vLZstRiu.shtml

相关文章

Code::Blocks 相关

文库上的使用教程 http://blog.csdn.net/JGood/article/details/5252119 使用手册 http://blog.csdn.net/liquanhai/article/details/6618300 一.Code::blocks Code::blocks集成开发环境是一个支持编译、链接、调试许多种语言的IDE,支持VS6.0到VS200…

mysql8 sql_mode去掉only_full_group_by

1.查询版本与sqlmode: select version(), sql_mode; 2.修改sqlmode,执行下面两句代码: set global sql_modeSTRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION;set session sql_modeSTRICT_TRANS_TABLES,NO_…

- MySQL - 存储过程 Stored Procedure

文章目录 存储过程是什么存储过程的优点存储过程的缺点存储过程分类存储过程的查询语句存储过程的创建和调用语法实例 - IN、OUT、INOUT条件语句循环语句迭代 存储过程是什么 存储过程是一组为了完成特定功能的SQL语句集,存储在数据库中,一次编译多次使…

mysql的delete语句_mysql删除语句

展开全部 mysql删除语句如下: 1、delete删除一行:delete from student where id1。 2、delete删除多行:delete from student where in (1,2,3)3。 3、删62616964757a686964616fe78988e69d8331333433623162除表的所有数据:delete f…

mysql any all some

假设现在有两张表 a , b 如下 SELECT * FROM a WHERE id > ANY(SELECT id FROM b )# any 等价于 some 这里有两个id sql的意思大概是查询a表的所有,在where中a的id > 子表 b的id 这里用到了any(some) 他们的意思是一样的,就是说a表…

mysql 的 sql_mode.only_full_group_by属性解析

文章目录 1. 初始条件2. 现象3. 解决①:关闭sql_mode 的 only_full_group_by模式②:使用 ANY_VALUE() 抑制 ONLY_FULL_GROUP_BY 的影响 mysql8.0官网:处理 group by 1. 初始条件 现在有这样一张表,其中有两条 name 裤子 的数据…

MYSQL 删除语句

删除数据(DELETE) 如果你失忆了,希望你能想起曾经为了追求梦想的你。 数据库存储数据,总会有一些垃圾数据,也会有一些不需要用的数据了,这些情况下,我们就可以删除这些数据,释放出一定的空间,给…

MySql递归RECURSIVE的详解

背景: 在实际开发的过程中,我们会遇到一些数据是层级关系的、要展示数据子父级关系的时候, 第一个解决方案:将数据库中的所有数据都查询出来用Java代码进行处理。 第二个解决方案:可以考虑MySql中的RECURSIVE递归进行…

【MYSQL WITH recursive使用】

MYSQL WITH recursive使用 MYSQL WITH recursive使用语法用法1:输出1~n或者求 1~n的和用法2 父求子创建table:user求张三后代 MYSQL WITH recursive使用 由于在项目中有使用到recursive,因此在此做记录 语法 WITH recursive 表名 AS ( 初始…

U 盘安装 CentOS7 时盘符和安装源不可用问题

记录一次在物理机上安装 CentOS7 遇到的问题及其解决办法,主要有两个问题卡住了很久,一个是盘符问题,一个就是安装源不可用的问题 。 1 No such device 1.1 问题描述 用 U 盘在物理机上安装 CentOS7 的时候,出现 could not ins…

Vmware Workstation17 安装centos7(详细教程)

1、为什么安装Vmware Vmware Workstation可以帮我们他们创建虚拟机,模拟生产环境(linux),搭建集群等。作为一个开发人员特别是后端开发人员是需要懂一些运维的,不需要精通,遇到问题能定位排查。我自己想做一个自己的项目&#xf…

用大白菜装centos7_大白菜安装centos7 踩坑记

1.准备一个U盘,安装大白菜。这个去大白菜官网下载安装就可以了 安装大白菜的时候最好选择FAT32(2021.1.7记录) 2.U盘装完大白菜后U盘会被分为两个主分区 一个盘是大白菜系统的,另外一个盘放一些工具的。 DBC里面就是放的一些工具 比如磁盘管理工具 3.把Centos7的镜像放入到DB…

Windows10安装Centos7双系统

Windows10安装Centos7双系统 1.摘要2.制作Centos 7系统盘3.Windows磁盘管理为Centos系统留出空间4.bios配置使得计算机系统选择从U盘进入5.安装Centos 7系统6.如何在Windows和Centos系统间切换7.一些问题与补救方法8.参考文章 1.摘要 本篇博客主要整理记录了在Win10 OS下安装使…

VM16上安装CentOS7详细安装教程【附图】

在VM16上安装Centos7 下载Centos安装Centos快照拍摄 下载Centos 进入Centos官网 Centos官网的下载地址 点击Download 点击x86_64 自己选择镜像下载 上面选择好自己的镜像后选择后缀名为 iso 的CentOS下载 安装Centos 点击创建新的虚拟机 选择好自定义后点击下一步…

VMware 16安装centos 7详细教程

VMware 16安装centos 7详细教程 前言:之前在VMware15 pro上安装centos7,但是启动虚拟机出现蓝屏,身边有好多小伙伴也遇到了这个问题,经过一番排查,找到了最简单的办法就是升级到VMware16pro,在启动就没有出…

在Vmware虚拟机中安装CentOS 7

前言:材料和工具 1. 安装好的VMWare虚拟机软件: ​VMWare16下载地址,获取码:ye1a 2.CentOS7下载地址:官方镜像下载 (centos.org) (官方镜像站下载比较慢) 清华大学镜像站:清华…

物理服务器安装CentOS 7操作系统

目录 1、下载系统镜像 2、制作安装盘 2.1 方法一:光盘制作 2.2 方法二:U盘制作 3、更改bios启动顺序 4、安装CentOS 7操作系统 4.1 安装命令选择,及常见错误解决 4.2 语言选择 4.3 时区选择 4.4 软件选择 4.5 安装位置选择 4.6 手…

M1芯片Macbook虚拟机安装centos7

目录 一、安装parallels Desktop、centos7 二、安装Parallels Tools 三、安装VNC server服务 四、进程占用问题 一、安装parallels Desktop、centos7 由于centon7内核版本问题用PD18等版本安装centos7进入默认是命令行安装。 命令界面安装: 1、选数字5 回车 再…

虚拟机安装centos7

1、简介 这里虚拟机采用VMware15.1.0,镜像采用CentOS7版本,官网或国内镜像可直接下载。 https://mirrors.aliyun.com/centos/7/isos/x86_64/CentOS-7-x86_64-DVD-2009.iso 2、安装流程 1、打开vmware软件,点击 创建新的虚拟机。 2、选择 典…

Mac(2) Parallels Desktop 安装 CentOS7

文章目录 一、前言二、准备三、Parallels Desktop安装CentOS7四、CentOS7配置1、网络配置 -- 设置固定ip2、关闭防火墙3、关闭SELinux4、更新yum源5、安装ifconfig6、其它 一、前言 本文将通过Parallels Desktop安装CentOS7 二、准备 Parallels Desktop下载安装 https://www…