关键点检测——heatmap热力图法

article/2025/10/30 21:28:44

一、数据集格式

 二、解析xml文件,生成data_center.txt

from PIL import Image
import math,os
from xml.etree import ElementTree as ETdef keep_image_size_open(path, size=(256, 256)):img = Image.open(path)temp = max(img.size)mask = Image.new('RGB', (temp, temp), (0, 0, 0))mask.paste(img, (0, 0))mask = mask.resize(size)return maskdef make_data_center_txt(xml_dir):with open('data_center.txt', 'a') as f:f.truncate(0)path=r'data/images'xml_names = os.listdir(xml_dir)for xml in xml_names:xml_path = os.path.join(xml_dir, xml)in_file = open(xml_path)tree = ET.parse(in_file)root = tree.getroot()image_path = root.find('path')polygon = root.find('outputs/object/item/polygon')data = []c_data = []data_str = ''print(xml)for i in polygon:data.append(int(i.text))data_str = data_str + ' ' + str(i.text)for i in range(0, len(data), 2):c_data.append((data[i], data[i + 1]))data_str = os.path.join(path,image_path.text.split('\\')[-1]) +data_strf.write(data_str + '\n')if __name__ == '__main__':make_data_center_txt('data/xml')

 三、加载数据集

import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Imagefrom heatmap_label import CenterLabelHeatMaptf = transforms.Compose([  #标准化处理transforms.ToTensor()
])class MyDataset(Dataset):def __init__(self,root): #传入路径f=open(root,'r')self.dataset=f.readlines() #读所有行def __len__(self):return len(self.dataset) #返回数据集长度def __getitem__(self, index):data=self.dataset[index] #取当前数据img_path=data.split(' ')[0] #以空格划分,并取出文件名,即data/images\0.pngimg_data=Image.open(img_path).resize((256, 256)) #打开图片# points = data.split(' ')[1:-2]  # 取出后面5个点的x,y坐标,-2是取不到的points=data.split(' ')[1:] #取出后面5个点的x,y坐标# print(img_data, points)#将坐标映射到256*256大小的图片上points = [int(points[0])*256/774, int(points[1])*256/434, int(points[2])*256/774, int(points[3])*256/434, int(points[4])*256/774, int(points[5])*256/434]# points=[int(i)/100 for i in points] #图像宽高为100,int(i)/100进行归一化# print(img_data, points)label = []for i in range(0, len(points), 2):heatmap = CenterLabelHeatMap(256, 256, points[i], points[i+1], 5)label.append(heatmap)#一个关键点会生成一个通道,3个关键点生成3个通道label = np.stack(label) #将列表转成数组的形式return tf(img_data), torch.Tensor(label) #将img_data标准化,将points转化为tensor格式if __name__ == '__main__':data=MyDataset('data_center.txt')for i in data:print(i[0].shape)print(i[1].shape)

四、构建网络

import torch
from torch import nn
from torch.nn import functional as Fclass Conv_Block(nn.Module):def __init__(self,in_channel,out_channel):super(Conv_Block, self).__init__()self.layer=nn.Sequential(nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),nn.BatchNorm2d(out_channel),nn.Dropout2d(0.3),nn.LeakyReLU(),nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),nn.BatchNorm2d(out_channel),nn.Dropout2d(0.3),nn.LeakyReLU())def forward(self,x):return self.layer(x)class DownSample(nn.Module):def __init__(self,channel):super(DownSample, self).__init__()self.layer=nn.Sequential(nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),nn.BatchNorm2d(channel),nn.LeakyReLU())def forward(self,x):return self.layer(x)class UpSample(nn.Module):def __init__(self,channel):super(UpSample, self).__init__()self.layer=nn.Conv2d(channel,channel//2,1,1)def forward(self,x,feature_map):up=F.interpolate(x,scale_factor=2,mode='nearest')out=self.layer(up)return torch.cat((out,feature_map),dim=1)class UNet(nn.Module):def __init__(self,num_classes):super(UNet, self).__init__()self.c1=Conv_Block(3,64)self.d1=DownSample(64)self.c2=Conv_Block(64,128)self.d2=DownSample(128)self.c3=Conv_Block(128,256)self.d3=DownSample(256)self.c4=Conv_Block(256,512)self.d4=DownSample(512)self.c5=Conv_Block(512,1024)self.u1=UpSample(1024)self.c6=Conv_Block(1024,512)self.u2 = UpSample(512)self.c7 = Conv_Block(512, 256)self.u3 = UpSample(256)self.c8 = Conv_Block(256, 128)self.u4 = UpSample(128)self.c9 = Conv_Block(128, 64)self.out=nn.Conv2d(64,3, 3, 1, 1)def forward(self,x):R1=self.c1(x)R2=self.c2(self.d1(R1))R3 = self.c3(self.d2(R2))R4 = self.c4(self.d3(R3))R5 = self.c5(self.d4(R4))O1=self.c6(self.u1(R5,R4))O2 = self.c7(self.u2(O1, R3))O3 = self.c8(self.u3(O2, R2))O4 = self.c9(self.u4(O3, R1))return self.out(O4)if __name__ == '__main__':x=torch.randn(2,3,256,256)net=UNet(num_classes=3)print(net(x).shape)

五、开始训练

import osfrom torch import nn,optim
import torch
from dataset import *
from net import *
from torch.utils.data import DataLoaderif __name__ == '__main__':device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')net=UNet(num_classes=3).to(device) #实例化网络并指认到设备上weights='params/unet.pth'if os.path.exists(weights): #如果有初始权值就加载net.load_state_dict(torch.load(weights)) #加载权重print('loading successfully')opt=optim.Adam(net.parameters()) #指定优化器并传入参数# loss_fun=nn.BCELoss() #定义损失函数loss_fun=nn.BCEWithLogitsLoss()dataset=MyDataset('data_center.txt') #实例化数据集data_loader=DataLoader(dataset,batch_size=2,shuffle=True) #加载数据集epoch = 1while True:for i,(image,label) in enumerate(data_loader): #用枚举的方式遍历数据集image,label=image.to(device),label.to(device) #将图片和标签指认到设备上# print(image.shape, label.shape)out=net(image) #将图片输入网络train_loss=loss_fun(out,label) #预测值和真是标签做损失print(f'{epoch}-{i}-train_loss:{train_loss.item()}') #打印当前轮次当前批次的训练损失opt.zero_grad() #梯度清零train_loss.backward() #反向传播opt.step() #更新梯度if epoch % 10 == 0: #每10轮保存一次权重torch.save(net.state_dict(),f'params/unet.pth') #保存参数print('save successfully')epoch += 1

 六、利用训练好的权重进行预测

import osimport torch
from PIL import Image,ImageDraw
from dataset import *
from net import *    #import * 代表导入所有path='test_image'
net=UNet(num_classes=3) #实例化网络
net.load_state_dict(torch.load('params/unet.pth')) #加载训练好的权重
net.eval() #测试模式
for j in os.listdir(path):img=Image.open(os.path.join(path,j)).resize((256, 256))draw=ImageDraw.Draw(img) #创建画板img_data=tf(img) #标准化img_data=torch.unsqueeze(img_data,dim=0) #设置批次维度out=net(img_data)out=out.squeeze()d=torch.max_pool2d(out, 256).squeeze()print(d)rst = []for i in range(3): #有3个关键点,故有3个通道h,w=np.where(out[i]==out[i].max()) #当前通道恒等于当前通道的最大值,就取其索引# rst.append((w[0], h[0]))draw.ellipse((w[0]*774/256-2, h[0]*434/256-2, w[0]*774/256+2, h[0]*434/256+2),(255,0,0)) #画半径为2的圆img.show()img.save(f'test_result/{j}')

reference

>>>>>来自B站大佬

【深度学习关键点回归(直接回归法&heatmap热力图法)】 https://www.bilibili.com/video/BV1sS4y197J1/?p=2&share_source=copy_web&vd_source=95705b32f23f70b32dfa1721628d5874


http://chatgpt.dhexx.cn/article/45ZvSwsL.shtml

相关文章

Learn OpenCV之Heatmap

本文是利用热图(Heatmap)分析视频序列的标定。 注意,这里目的不是标定而是分析标定好的数据,或者也可以是检测的结果数据 文章结构是这样的,先详细的解释一下热图分析有什么用,根据一些具体的应用实例给出…

python heatmap画法

任务描述 将一个归一化的分数以热图的形式显示出来,分数高的地方颜色深,分数小的地方颜色浅 注意:使用单一颜色无法实现这种渐变过程 原理 将单通道的0-1之间的score值映射到三通道的颜色空间 原料 一个单通道的score矩阵颜色空间列表&a…

python heatmap总结

基础使用 import seaborn as sns; sns.set_theme(color_codesTrue) iris sns.load_dataset("iris") species iris.pop("species") g sns.clustermap(iris)取消行列分类树 import seaborn as sns; sns.set_theme(color_codesTrue) import matplotlib.p…

seaborn绘制heatmap

【seaborn.heatmap整理】 用处:将数据绘制为颜色方格(编码矩阵)。 引用形式: seaborn.heatmap(data, vminNone, vmaxNone, cmapNone, centerNone, robustFalse, annotNone, fmt’.2g’, annot_kwsNone, linewidths0, linecolor‘…

Heatmap

前言 目前所说的模型可视化或者模型可解释说到是对某一类别具有可解释性,直接画出来特征图并不能说明模型学到了某种特征,对一个深层的卷积神经网络而言,通过多次卷积和池化以后,它的最后一层卷积层包含了最丰富的空间和语义信息…

R | 可视化 | 热图(Heatmap)

1 基础绘制 R绘制热图时&#xff0c;数据需要输入一个矩阵&#xff0c;可以用as.matrix()把它转换成矩阵。这里利用R自带的数据集绘制热图。 > # 数据 > data <- as.matrix(mtcars) > > # 绘制热图 > heatmap(data) OUTPUT: 热图的每一列是一个变量&…

科研作图-heatmap(一)

1.简介 在科研中有很多地方为了可解释给审稿人提供了热图,便于知道深度学习中到底是哪部分在起作用,或者是在机器学习中分析不同的特征之间是否存在相关性?存在多大的相关性;或者是直观的展示场景热力图…总之,用处很多,我正好现在也需要用,就先总结下:绘制HeatMap的库有很多,…

「C#」生成HeatMap(热度图)的实现

1、什么是Heatmap 其实不用多言&#xff0c;需要这个的人自然知道这是什么。基于一系列点生成的热度图&#xff0c;放张图感受一下&#xff1a; ma...大概就是这种样子。 2、生成&#xff08;计算&#xff09;原理 实现方式实际上是在每个点上叠加高斯矩阵。高斯矩阵就是在二…

关键点检测的heatmap介绍

开始学关键点检测的时候&#xff0c;到处找找不到heatmap的解释。现在大概有些懂了&#xff0c;干脆自己写一个。部分转载。 关键点定位任务两种做法&#xff1a;heatmap和fully connected回归&#xff08;Heapmap-based和Regression-Based&#xff09; heatmap得到一张类似热…

python绘制热度图(heatmap)

1、简单的代码 from matplotlib import pyplot as plt import seaborn as sns import numpy as np import pandas as pd#练习的数据&#xff1a; datanp.arange(25).reshape(5,5) datapd.DataFrame(data)#绘制热度图&#xff1a; plotsns.heatmap(data)plt.show() 查看效果&a…

热图(Heatmap)绘制(matplotlib与seaborn)

热图是数据统计中经常使用的一种数据表示方法&#xff0c;它能够直观地反映数据特征&#xff0c;查看数据总体情况&#xff0c;在诸多领域具有广泛应用。 一&#xff1a;matplotlib绘制方法 1.基础绘制 热图用以表示的是矩阵数据&#xff0c;例如相关阵、协差阵等方阵&#…

‘0’ 和 '\0'

48是0对应的ascii值。

KEIL/MDK编译优化optimization选项注意事项

KEIL编译器C语言编译选项优化等级说明 -Onum Specifies the level of optimization to be used when compiling source files. Syntax -Onum Where num is one of the following: 0 Minimum optimization. Turns off most optimizations. When debugging is enabled, this opt…

0,'\0','0'

#include <iostream> using namespace std; int main(void) { cout<<__FILE__<<\t<<__LINE__<<endl;cout<<"内 容:\t"<<"0"<<\t<<"\\\0\"<<\t<<"\0\"<<…

Odoo

狭路相逢 勇者胜 Odoo 是用于经营公司的最好的管理软件。 数百万用户使用我们的集成应用可以更好地开展工作 现在开始。免费的。 重新定义可扩展性 一个需求&#xff0c;一个应用程式。整合从来没有那么顺畅 促进销售量 客户关系管理POS销售 整合您的服务 项目工时表帮助…

0 、 '0' 、 0 、 ’\0’ 区别

转载自&#xff1a;https://blog.csdn.net/qnavy123/article/details/93901631 ① ‘0’ 代表 字符0 &#xff0c;对应ASCII码值为 0x30 (也就是十进制 48) ② ‘\0’ 代表 空字符(转义字符)【输出为空】 &#xff0c;对应ASCII码值为 0x00(也就是十进制 0)&#xff0c; …

Linux的内核编译用O0是编译不过的

最近在ATF的升级过程中遇到了一个编译问题&#xff0c;最后是通过编译优化解决的&#xff0c;然后一百度这个优化全是在Linux中的。于是就借着Linux编译优化来学学。 内容来自 宋宝华老师&#xff1a; 关于Linux编译优化几个必须掌握的姿势 1、编译选项和内核编译 首先我们都…

alert uuid does not exits. Dropping to a shell!

ALERT&#xff01;UUID does not exit. Dropping to a shell&#xff01; 服务器系统ubuntu16.04server&#xff0c;非自然断电后开机进入initramfs模式&#xff0c;服务器磁盘阵列是raid1和raid5。初步分析是硬盘坏道或掉盘&#xff0c;进入raid卡里看到硬盘一切正常&#xf…

跟着团子学SAP PS:如何查询PS模块中的user exits以及相关BAdIs SE80/SMOD/CNEX006/CNEX007/CNEX008

在PS很多标准字段或功能无法满足客户需求的时候往往需要通过SAP标准的user exits或者BAdI进行开发以满足业务需要&#xff0c;所以今天介绍下如何查询PS模块中的用户出口以及BAdIs&#xff1a; &#xff08;1&#xff09;查询PS模块中的user exits: 执行SE80&#xff0c;在菜…

EXT

ext的核心是store&#xff0c;存储数据用的。调试时可以先把store这块先屏蔽掉&#xff0c;先看页面的&#xff0c;页面出来了再调试store。这样会调试起来很快。 init: function () { var view this.getView(), // var store Global.getStore(app.store.L…