系统:Windows 11
pytorch版本:1.11.0
torchvision版本:0.12.0
使用图片地址:vision/person1.jpg at main · pytorch/vision · GitHub
使用代码:
import torch
import torchvision.transforms
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.io import read_image# person_int = read_image(str(Path("assets") / "person1.jpg"))
person_int = read_image(r"E:\git_rep\vision-0.12.0\gallery\assets\person1.jpg")transforms1 = torchvision.transforms.ToPILImage()
transforms2 = torchvision.transforms.ToTensor()
person_float = transforms1(person_int)
person_float = transforms2(person_float)model = keypointrcnn_resnet50_fpn(True, progress=False)
model = model.eval()outputs = model([person_float])
print(outputs)kpts = outputs[0]['keypoints']
scores = outputs[0]['scores']print(kpts)
print(scores)detect_threshold = 0.75
idx = torch.where(scores > detect_threshold)
keypoints = kpts[idx]print(keypoints)import torch
import numpy as np
import matplotlib.pyplot as pltimport torchvision.transforms.functional as Fplt.rcParams["savefig.bbox"] = 'tight'def show(imgs):if not isinstance(imgs, list):imgs = [imgs]fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)for i, img in enumerate(imgs):img = img.detach()img = F.to_pil_image(img)axs[0, i].imshow(np.asarray(img))axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])from torchvision.utils import draw_keypointsres = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
show(res)
结果:
参考文献:
1.keypointrcnn_resnet50_fpn — Torchvision 0.13 documentation
2.Visualization utilities — Torchvision 0.13 documentation
3.Models and pre-trained weights — Torchvision 0.13 documentation
4.人体关键点检测(Keypoints Detection)