BagNet特征heatmap可视化

article/2025/10/30 21:06:39

BagNet地址:https://github.com/wielandbrendel/bag-of-local-features-models
BagNet是ResNet的变体,显著的区别是将3x3卷积变为1x1卷积来达到构造整体网络具有某个最终的感受野(receptive field)目的。在这里主要讲解对于一张来源于ImageNet的尺寸为224x224的原始图像,如何判断其局部的image patch的重要性大小,并可视化heatmap。

获取heatmap张量

1. 读取预训练的BagNet,并读取原始图像并转化为tensor。将图像tensor输入BagNet得到维度为224x224的2D heatmap。

import bagnets.pytorchnet
from bagnets.utils import plot_heatmap, generate_heatmap_pytorch
import torchvision.transforms as transforms
import numpy as np
import cv2
import torchpytorch_model = bagnets.pytorchnet.bagnet33(pretrained=True).cuda()
pytorch_model.eval()image_path = 'val.JPEG'
raw_image = cv2.imread(image_path)
raw_image = cv2.resize(raw_image, (224,) * 2)
image = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])(raw_image[..., ::-1].copy()) # cv2库读取的为BGR通道,需将其变为RGB
image = torch.unsqueeze(image, 0) # 将单张图像维度由(3,224,224)变为(1,3,224,224)heatmap = generate_heatmap_pytorch(pytorch_model, image, 2, 33)
np.save('heatmap.npy', heatmap)  # 将heatmap张量保存用于之后可视化

generate_heatmap_pytorch函数的内容是根据(3,224,224)原始图像生成对应的2D (224,224)的heatmap,过程如下:

def generate_heatmap_pytorch(model, image, target, patchsize):"""Generates high-resolution heatmap for a BagNet by decomposing theimage into all possible patches and by computing the logits foreach patch.Parameters----------model : Pytorch ModelThis should be one of the BagNets.image : Numpy array of shape [1, 3, X, X]The image for which we want to compute the heatmap.target : intClass for which the heatmap is computed.patchsize : intThe size of the receptive field of the given BagNet."""import torchwith torch.no_grad():#  这里采用9x9的滑动框来生成image patches,为了保证输出尺寸为224x224#  需要pad 0_, c, x, y = image.shapepadded_image = np.zeros((c, x + patchsize - 1, y + patchsize - 1))padded_image[:, (patchsize-1)//2:(patchsize-1)//2 + x, (patchsize-1)//2:(patchsize-1)//2 + y] = image[0]image = padded_image[None].astype(np.float32)# turn to torch tensorinput = torch.from_numpy(image).cuda()# extract patchespatches = input.permute(0, 2, 3, 1)# 这个语句负责生成patches# patches:(1,224,224,3)# 设num_H==num_W=(224+2*paddings)/patchsize# patches.unfold(1, patchsize, 1):(1,num_H,224,3,patchsize)# patches.unfold(1, patchsize, 1).unfold(2, patchsize, 1):# (1,num_H,num_W,3,patchsize,patchsize)patches = patches.unfold(1, patchsize, 1).unfold(2, patchsize, 1)num_rows = patches.shape[1]num_cols = patches.shape[2]patches = patches.contiguous().view((-1, 3, patchsize, patchsize))# compute logits for each patchlogits_list = []for batch_patches in torch.split(patches, 1000):logits = model(batch_patches)logits = logits[:, target]logits_list.append(logits.data.cpu().numpy().copy())logits = np.hstack(logits_list)return logits.reshape((224, 224))

可视化heatmap

方法一: 这里采用bagnet的方法,将原图padding之后裁剪成224* 224个小片,然后依次进入网络得到 logits值,于是得到224*224个数,直接reshape就可以得到heatmap无需插值

import numpy as np
import matplotlib.pyplot as plt
from skimage import feature, transform
import cv2def plot_heatmap(heatmap, original, ax1, ax2, ax3, cmap='RdBu_r',percentile=99, dilation=0.5, alpha=0.25):"""Plots the heatmap on top of the original image(which is shown by most important edges).Parameters----------heatmap : Numpy Array of shape [X, X]Heatmap to visualise.original : Numpy array of shape [X, X, 3]Original image for which the heatmap was computed.ax : Matplotlib axisAxis onto which the heatmap should be plotted.cmap : Matplotlib color mapColor map for the visualisation of the heatmaps (default: RdBu_r)percentile : float between 0 and 100 (default: 99)Extreme values outside of the percentile range are clipped.This avoids that a single outlier dominates the whole heatmap.dilation : floatResizing of the original image. Influences the edge detector andthus the image overlay.alpha : float in [0, 1]Opacity of the overlay image."""dx, dy = 0.05, 0.05xx = np.arange(0.0, heatmap.shape[1], dx)yy = np.arange(0.0, heatmap.shape[0], dy)xmin, xmax, ymin, ymax = np.amin(xx), np.amax(xx), np.amin(yy), np.amax(yy)extent = xmin, xmax, ymin, ymaxcmap_original = plt.get_cmap('Greys_r')cmap_original.set_bad(alpha=0)# Compute edges (to overlay to heatmaps later)original_greyscale = original if len(original.shape) == 2 else np.mean(original, axis=-1)# dilation=0.5,图像由(224,224)缩放为(112,112),这样做的目的是找出更粗略的边缘纹理in_image_upscaled = transform.rescale(original_greyscale, dilation, mode='constant',multichannel=False, anti_aliasing=True)# 找到图像的边缘纹理特征edges = feature.canny(in_image_upscaled).astype(float)edges[edges < 0.5] = np.nanedges[:5, :] = np.nanedges[-5:, :] = np.nanedges[:, :5] = np.nanedges[:, -5:] = np.nanoverlay = edges  # 找出图像的边缘特征显示在heatmap上,便于对照原图特征# 最大值设为99%处,若设为真正的最大值,heatmap的重要处颜色不是特别深abs_max = np.percentile(np.abs(heatmap), percentile)abs_min = abs_maxa1 = ax1.imshow(heatmap, extent=extent, interpolation='nearest', cmap=cmap, vmin=-abs_min, vmax=abs_max)a2 = ax2.imshow(overlay,  extent=extent, interpolation='nearest', cmap=cmap_original, alpha=alpha)cb = fig.colorbar(a1, ax=ax1, ticks=[1, 2, 3])cb.set_ticks([-abs_min, abs_max])cb.set_ticklabels(['Low', 'High'])a3 = ax3.imshow(heatmap, extent=extent, interpolation='nearest', cmap=cmap, vmin=-abs_min, vmax=abs_max)ax3.imshow(overlay, extent=extent, interpolation='nearest', cmap=cmap_original, alpha=alpha)heatmap = np.load('heatmap.npy')
heatmap = cv2.resize(heatmap, (224, 224), interpolation=cv2.INTER_NEAREST)image_path = 'val.JPEG'
raw_image = cv2.imread(image_path)
original_image = cv2.resize(raw_image, (224, 224))fig, _axs = plt.subplots(nrows=2, ncols=2)
axs = _axs.flatten()axs[0].set_title('original')
# matplotlib的imshow的RGB 3通道表示与cv2库(BGR)的顺序不同
axs[0].imshow(original_image[..., ::-1] / 255.)
axs[0].axis('off')  # 不显示坐标尺寸axs[1].set_title('heatmap')
axs[1].axis('off')  # 不显示坐标尺寸axs[2].set_title('feature canny')
axs[2].axis('off')  # 不显示坐标尺寸axs[3].set_title('heatmap+feature canny')
axs[3].axis('off')  # 不显示坐标尺寸plot_heatmap(heatmap, original_image, axs[1], axs[2], axs[3], dilation=0.5, percentile=99, alpha=.25)fig.tight_layout()
plt.show()

在这里插入图片描述
方法二: 将global average pooling前的3D 特征图根据FC层的权值进行加权(参加CAM方法),得到2D的特征图。由于此时的分辨率是小于224* 224的,此时一般需要进行插值来resize。以下给出一个2D特征图的npy文件,将其进行可视化。
python

import numpy as np
import matplotlib.pyplot as plt
import cv2
import matplotlib.cm as cm# 升采样map
map = cv2.resize(map, (224, 224))
image_path = 'val.JPEG'
raw_image = cv2.imread(image_path)
original_image = cv2.resize(raw_image, (224, 224))# 标准化到[0,1]
map = (map- map.min()) / (map.max()-map.min())
# 使用jet_r映射为RGB的heatmap
heatmap3 = cm.jet_r(map)[..., :3] * 255.0
# 与原图进行结合显示
gcam = (heatmap3.astype(np.float) + original_image.astype(np.float)) / 2
cv2.imwrite('heatmap.jpg', np.uint8(gcam))

在这里插入图片描述
也可以通过Near插值来得到类似马赛克的heatmap:

map = cv2.resize(map, (224, 224), interpolation=cv2.INTER_NEAREST)
map = (map- map.min()) / (map.max()-map.min())
heatmap = cm.jet_r(map)[..., :3] * 255.0
cv2.imwrite('heatmap.jpg', np.uint8(heatmap ))

在这里插入图片描述

标注重点的image patch

根据生成的heatmap对应到原始图像的image patch,并使用矩形框标注,这里使用的是33x33的框规模:

import numpy as np
import cv2image_path = 'val.JPEG'
image = cv2.imread(image_path)
image = cv2.resize(image, (224,) * 2)heatmap = np.load('heatmap.npy')
maximum = 0
pos_list = []
# 选取>99.95位置的数才标注出对应的image patch
threshold = np.percentile(heatmap, 99.95)
for i in range(heatmap.shape[0]):for j in range(heatmap.shape[1]):if heatmap[i, j] > threshold:pos_list.append((i, j))padding = 33//2
for pos in pos_list:# 注意cv2库中的图像坐标和numpy数组中的不同pt1 = (pos[1] - padding , pos[0] - padding)pt2 = (pt1[0] + 33-1, pt1[1] + 33-1)# (0, 255, 0)表示RGB中的绿色,1表示框的宽度cv2.rectangle(image, pt1, pt2, (0, 255, 0), 1)
cv2.imshow('label', image)
cv2.waitKey() # 等待按键才退出

在这里插入图片描述

计算bbox的IOU

def IOU(bboxA, bboxB):x1 = bboxA[0]y1 = bboxA[1]width1 = bboxA[2] - bboxA[0]height1 = bboxA[3] - bboxA[1]x2 = bboxB[0]y2 = bboxB[1]width2 = bboxB[2] - bboxB[0]height2 = bboxB[3] - bboxB[1]endx = max(x1 + width1, x2 + width2)startx = min(x1, x2)width = width1 + width2 - (endx - startx)endy = max(y1 + height1, y2 + height2)starty = min(y1, y2)height = height1 + height2 - (endy - starty)if width <= 0 or height <= 0:ratio = 0  # 重叠率为 0else:Area = width * height  # 两矩形相交面积Area1 = width1 * height1Area2 = width2 * height2ratio = Area * 1. / (Area1 + Area2 - Area)return ratioimage_path = 'val.JPEG'
image = cv2.imread(image_path)
image = cv2.resize(image, (224,) * 2)
pt1 = (0, 10)
pt2 = (pt1[0] + 33, pt1[1] +33)pt3 = (20, 15)
pt4 = (pt3[0] + 33, pt3[1] + 33)
print(IOU(pt1+pt2, pt3+pt4))cv2.rectangle(image, pt1, pt2, (0, 255, 0), 2)
cv2.rectangle(image, pt3, pt4, (0, 255, 0), 2)
cv2.imshow('label', image)
cv2.waitKey() # 等待按键才退出

在这里插入图片描述


http://chatgpt.dhexx.cn/article/Xt1tiZVH.shtml

相关文章

关键点检测——heatmap热力图法

一、数据集格式 二、解析xml文件&#xff0c;生成data_center.txt from PIL import Image import math,os from xml.etree import ElementTree as ETdef keep_image_size_open(path, size(256, 256)):img Image.open(path)temp max(img.size)mask Image.new(RGB, (temp, te…

Learn OpenCV之Heatmap

本文是利用热图&#xff08;Heatmap&#xff09;分析视频序列的标定。 注意&#xff0c;这里目的不是标定而是分析标定好的数据&#xff0c;或者也可以是检测的结果数据 文章结构是这样的&#xff0c;先详细的解释一下热图分析有什么用&#xff0c;根据一些具体的应用实例给出…

python heatmap画法

任务描述 将一个归一化的分数以热图的形式显示出来&#xff0c;分数高的地方颜色深&#xff0c;分数小的地方颜色浅 注意&#xff1a;使用单一颜色无法实现这种渐变过程 原理 将单通道的0-1之间的score值映射到三通道的颜色空间 原料 一个单通道的score矩阵颜色空间列表&a…

python heatmap总结

基础使用 import seaborn as sns; sns.set_theme(color_codesTrue) iris sns.load_dataset("iris") species iris.pop("species") g sns.clustermap(iris)取消行列分类树 import seaborn as sns; sns.set_theme(color_codesTrue) import matplotlib.p…

seaborn绘制heatmap

【seaborn.heatmap整理】 用处&#xff1a;将数据绘制为颜色方格&#xff08;编码矩阵&#xff09;。 引用形式&#xff1a; seaborn.heatmap(data, vminNone, vmaxNone, cmapNone, centerNone, robustFalse, annotNone, fmt’.2g’, annot_kwsNone, linewidths0, linecolor‘…

Heatmap

前言 目前所说的模型可视化或者模型可解释说到是对某一类别具有可解释性&#xff0c;直接画出来特征图并不能说明模型学到了某种特征&#xff0c;对一个深层的卷积神经网络而言&#xff0c;通过多次卷积和池化以后&#xff0c;它的最后一层卷积层包含了最丰富的空间和语义信息…

R | 可视化 | 热图(Heatmap)

1 基础绘制 R绘制热图时&#xff0c;数据需要输入一个矩阵&#xff0c;可以用as.matrix()把它转换成矩阵。这里利用R自带的数据集绘制热图。 > # 数据 > data <- as.matrix(mtcars) > > # 绘制热图 > heatmap(data) OUTPUT: 热图的每一列是一个变量&…

科研作图-heatmap(一)

1.简介 在科研中有很多地方为了可解释给审稿人提供了热图,便于知道深度学习中到底是哪部分在起作用,或者是在机器学习中分析不同的特征之间是否存在相关性?存在多大的相关性;或者是直观的展示场景热力图…总之,用处很多,我正好现在也需要用,就先总结下:绘制HeatMap的库有很多,…

「C#」生成HeatMap(热度图)的实现

1、什么是Heatmap 其实不用多言&#xff0c;需要这个的人自然知道这是什么。基于一系列点生成的热度图&#xff0c;放张图感受一下&#xff1a; ma...大概就是这种样子。 2、生成&#xff08;计算&#xff09;原理 实现方式实际上是在每个点上叠加高斯矩阵。高斯矩阵就是在二…

关键点检测的heatmap介绍

开始学关键点检测的时候&#xff0c;到处找找不到heatmap的解释。现在大概有些懂了&#xff0c;干脆自己写一个。部分转载。 关键点定位任务两种做法&#xff1a;heatmap和fully connected回归&#xff08;Heapmap-based和Regression-Based&#xff09; heatmap得到一张类似热…

python绘制热度图(heatmap)

1、简单的代码 from matplotlib import pyplot as plt import seaborn as sns import numpy as np import pandas as pd#练习的数据&#xff1a; datanp.arange(25).reshape(5,5) datapd.DataFrame(data)#绘制热度图&#xff1a; plotsns.heatmap(data)plt.show() 查看效果&a…

热图(Heatmap)绘制(matplotlib与seaborn)

热图是数据统计中经常使用的一种数据表示方法&#xff0c;它能够直观地反映数据特征&#xff0c;查看数据总体情况&#xff0c;在诸多领域具有广泛应用。 一&#xff1a;matplotlib绘制方法 1.基础绘制 热图用以表示的是矩阵数据&#xff0c;例如相关阵、协差阵等方阵&#…

‘0’ 和 '\0'

48是0对应的ascii值。

KEIL/MDK编译优化optimization选项注意事项

KEIL编译器C语言编译选项优化等级说明 -Onum Specifies the level of optimization to be used when compiling source files. Syntax -Onum Where num is one of the following: 0 Minimum optimization. Turns off most optimizations. When debugging is enabled, this opt…

0,'\0','0'

#include <iostream> using namespace std; int main(void) { cout<<__FILE__<<\t<<__LINE__<<endl;cout<<"内 容:\t"<<"0"<<\t<<"\\\0\"<<\t<<"\0\"<<…

Odoo

狭路相逢 勇者胜 Odoo 是用于经营公司的最好的管理软件。 数百万用户使用我们的集成应用可以更好地开展工作 现在开始。免费的。 重新定义可扩展性 一个需求&#xff0c;一个应用程式。整合从来没有那么顺畅 促进销售量 客户关系管理POS销售 整合您的服务 项目工时表帮助…

0 、 '0' 、 0 、 ’\0’ 区别

转载自&#xff1a;https://blog.csdn.net/qnavy123/article/details/93901631 ① ‘0’ 代表 字符0 &#xff0c;对应ASCII码值为 0x30 (也就是十进制 48) ② ‘\0’ 代表 空字符(转义字符)【输出为空】 &#xff0c;对应ASCII码值为 0x00(也就是十进制 0)&#xff0c; …

Linux的内核编译用O0是编译不过的

最近在ATF的升级过程中遇到了一个编译问题&#xff0c;最后是通过编译优化解决的&#xff0c;然后一百度这个优化全是在Linux中的。于是就借着Linux编译优化来学学。 内容来自 宋宝华老师&#xff1a; 关于Linux编译优化几个必须掌握的姿势 1、编译选项和内核编译 首先我们都…

alert uuid does not exits. Dropping to a shell!

ALERT&#xff01;UUID does not exit. Dropping to a shell&#xff01; 服务器系统ubuntu16.04server&#xff0c;非自然断电后开机进入initramfs模式&#xff0c;服务器磁盘阵列是raid1和raid5。初步分析是硬盘坏道或掉盘&#xff0c;进入raid卡里看到硬盘一切正常&#xf…

跟着团子学SAP PS:如何查询PS模块中的user exits以及相关BAdIs SE80/SMOD/CNEX006/CNEX007/CNEX008

在PS很多标准字段或功能无法满足客户需求的时候往往需要通过SAP标准的user exits或者BAdI进行开发以满足业务需要&#xff0c;所以今天介绍下如何查询PS模块中的用户出口以及BAdIs&#xff1a; &#xff08;1&#xff09;查询PS模块中的user exits: 执行SE80&#xff0c;在菜…