Siamese Network(孪生网络)

article/2025/9/30 21:05:22

 

模型结构

在这里插入图片描述

上图是孪生网络的简单模型结构,思路很简单,就是输入两个样本到同样的网络(参数结构相同),最后计算两个网络输出的距离,如果距离较近就认为是同一类,较远就认为是不同的类别,在这里,我们可以使用两个同样参数的CNN,利用CNN从图像中提取特征。注意这里必须是同样的CNN,不然两个不同的CNN,即使输入相同,输出也可能认为两者不同。

损失函数

  • Constrastive loss

  • Triplet loss

  • Softmax loss

  • 其他损失函数:比如cosine loss,欧式距离等。

模型伪代码

模型代码

class SiameseNetwork(nn.Module):def __init__(self):super(SiameseNetwork, self).__init__()self.cnn1 = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(1, 4, kernel_size=3),nn.ReLU(inplace=True),nn.BatchNorm2d(4),nn.Dropout2d(p=.2),nn.ReflectionPad2d(1),nn.Conv2d(4, 8, kernel_size=3),nn.ReLU(inplace=True),nn.BatchNorm2d(8),nn.Dropout2d(p=.2),nn.ReflectionPad2d(1),nn.Conv2d(8, 8, kernel_size=3),nn.ReLU(inplace=True),nn.BatchNorm2d(8),nn.Dropout2d(p=.2),)self.fc1 = nn.Sequential(nn.Linear(8*100*100, 500),nn.ReLU(inplace=True),nn.Linear(500, 500),nn.ReLU(inplace=True),nn.Linear(500, 5))def forward_once(self, x):output = self.cnn1(x)output = output.view(output.size()[0], -1)output = self.fc1(output)return outputdef forward(self, input1, input2):output1 = self.forward_once(input1)output2 = self.forward_once(input2)return output1, output2
​

损失函数

class ContrastiveLoss(torch.nn.Module):"""Contrastive loss function.Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf"""def __init__(self, margin=2.0):super(ContrastiveLoss, self).__init__()self.margin = margindef forward(self, output1, output2, label):euclidean_distance = F.pairwise_distance(output1, output2)loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2)  (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))return loss_contrastive
​

真实案例

基于ORL人脸数据集,利用孪生网络来进行人脸验证。

数据介绍

ORL人脸数据集共包含40个不同人的400张图像,此数据集下包含40个目录,每个目录下有10张图像,每个目录表示一个不同的人。所有的图像是以PGM格式存储,灰度图,图像大小宽度92,高度为112。对每一个目录下,这些图像实在不同的时间、不同的光照、不同的面部表情和面部细节环境下采集的。

可以从http://www.cl.cam.ac.uk/research/dtg/attarchive/facedatabase.html下载此人脸数据集。

程序包导入

import torch
import torchvision
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import PIL.ImageOps 
print(torch.__version__)  #1.1.0
print(torchvision.__version__)  #0.3.0
​
​
#定义一些超参
train_batch_size = 32        #训练时batch_size
train_number_epochs = 50     #训练的epoch
​
def imshow(img,text=None,should_save=False): #展示一幅tensor图像,输入是(C,H,W)npimg = img.numpy() #将tensor转为ndarrayplt.axis("off")if text:plt.text(75, 8, text, style='italic',fontweight='bold',bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})plt.imshow(np.transpose(npimg, (1, 2, 0))) #转换为(H,W,C)plt.show()    
​
def show_plot(iteration,loss):#绘制损失变化图plt.plot(iteration,loss)plt.show()

自定义Dataset和Dataloader

自定义的Dataset需要实现 __ getitem __ 和 __ len __ 函数。每次读取一对图像,标签表示差异度,0表示同一个人,1表示不是同一人。

#自定义Dataset类,__getitem__(self,index)每次返回(img1, img2, 0/1)
class SiameseNetworkDataset(Dataset):def __init__(self,imageFolderDataset,transform=None,should_invert=True):self.imageFolderDataset = imageFolderDataset    self.transform = transformself.should_invert = should_invertdef __getitem__(self,index):img0_tuple = random.choice(self.imageFolderDataset.imgs) #37个类别中任选一个should_get_same_class = random.randint(0,1) #保证同类样本约占一半if should_get_same_class:while True:#直到找到同一类别img1_tuple = random.choice(self.imageFolderDataset.imgs) if img0_tuple[1]==img1_tuple[1]:breakelse:while True:#直到找到非同一类别img1_tuple = random.choice(self.imageFolderDataset.imgs) if img0_tuple[1] !=img1_tuple[1]:break
​img0 = Image.open(img0_tuple[0])img1 = Image.open(img1_tuple[0])img0 = img0.convert("L")img1 = img1.convert("L")if self.should_invert:img0 = PIL.ImageOps.invert(img0)img1 = PIL.ImageOps.invert(img1)
​if self.transform is not None:img0 = self.transform(img0)img1 = self.transform(img1)return img0, img1, torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))def __len__(self):return len(self.imageFolderDataset.imgs)#定义文件dataset
training_dir = "./data/faces/training/"  #训练集地址
folder_dataset = torchvision.datasets.ImageFolder(root=training_dir)
​
#定义图像dataset
transform = transforms.Compose([transforms.Resize((100,100)), #有坑,传入int和tuple有区别transforms.ToTensor()])
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,transform=transform,should_invert=False)
​
#定义图像dataloader
train_dataloader = DataLoader(siamese_dataset,shuffle=True,batch_size=train_batch_size)

可视化数据集

vis_dataloader = DataLoader(siamese_dataset,shuffle=True,batch_size=8)
example_batch = next(iter(vis_dataloader)) #生成一批图像
#其中example_batch[0] 维度为torch.Size([8, 1, 100, 100])
concatenated = torch.cat((example_batch[0],example_batch[1]),0) 
imshow(torchvision.utils.make_grid(concatenated, nrow=8))
print(example_batch[2].numpy())

img

注意torchvision.utils.make_grid用法:将若干幅图像拼成一幅图像。内部机制是铺成网格状的tensor,其中输入tensor必须是四维(B,C,H,W)。后续还需要调用numpy()和transpose(),再用plt显示。

# https://pytorch.org/docs/stable/_modules/torchvision/utils.html#make_grid
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
​
#示例
t = torchvision.utils.make_grid(concatenated, nrow=8)
concatenated.size()  #torch.Size([16, 1, 100, 100])
t.size() #torch.Size([3, 206, 818]) 对于(batch,1,H,W)的tensor,重复三个channel,详见官网文档源码

准备模型

自定义模型和损失函数

#搭建模型
class SiameseNetwork(nn.Module):def __init__(self):super().__init__()self.cnn1 = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(1, 4, kernel_size=3),nn.ReLU(inplace=True),nn.BatchNorm2d(4),nn.ReflectionPad2d(1),nn.Conv2d(4, 8, kernel_size=3),nn.ReLU(inplace=True),nn.BatchNorm2d(8),
​nn.ReflectionPad2d(1),nn.Conv2d(8, 8, kernel_size=3),nn.ReLU(inplace=True),nn.BatchNorm2d(8),)
​self.fc1 = nn.Sequential(nn.Linear(8*100*100, 500),nn.ReLU(inplace=True),
​nn.Linear(500, 500),nn.ReLU(inplace=True),
​nn.Linear(500, 5))
​def forward_once(self, x):output = self.cnn1(x)output = output.view(output.size()[0], -1)output = self.fc1(output)return output
​def forward(self, input1, input2):output1 = self.forward_once(input1)output2 = self.forward_once(input2)return output1, output2#自定义ContrastiveLoss
class ContrastiveLoss(torch.nn.Module):"""Contrastive loss function.Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf"""
​def __init__(self, margin=2.0):super(ContrastiveLoss, self).__init__()self.margin = margin
​def forward(self, output1, output2, label):euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
​return loss_contrastive

训练

net = SiameseNetwork().cuda() #定义模型且移至GPU
criterion = ContrastiveLoss() #定义损失函数
optimizer = optim.Adam(net.parameters(), lr = 0.0005) #定义优化器
​
counter = []
loss_history = [] 
iteration_number = 0
​
​
#开始训练
for epoch in range(0, train_number_epochs):for i, data in enumerate(train_dataloader, 0):img0, img1 , label = data#img0维度为torch.Size([32, 1, 100, 100]),32是batch,label为torch.Size([32, 1])img0, img1 , label = img0.cuda(), img1.cuda(), label.cuda() #数据移至GPUoptimizer.zero_grad()output1,output2 = net(img0, img1)loss_contrastive = criterion(output1, output2, label)loss_contrastive.backward()optimizer.step()if i % 10 == 0 :iteration_number +=10counter.append(iteration_number)loss_history.append(loss_contrastive.item())print("Epoch number: {} , Current loss: {:.4f}\n".format(epoch,loss_contrastive.item()))show_plot(counter, loss_history)

img

测试

现在用testing文件夹中3个任务的图像进行测试,注意:模型从未见过这3个人的图像。

#定义测试的dataset和dataloader
​
#定义文件dataset
testing_dir = "./data/faces/testing/"  #测试集地址
folder_dataset_test = torchvision.datasets.ImageFolder(root=testing_dir)
​
#定义图像dataset
transform_test = transforms.Compose([transforms.Resize((100,100)), transforms.ToTensor()])
siamese_dataset_test = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,transform=transform_test,should_invert=False)
​
#定义图像dataloader
test_dataloader = DataLoader(siamese_dataset_test,shuffle=True,batch_size=1)
​
​
#生成对比图像
dataiter = iter(test_dataloader)
x0,_,_ = next(dataiter)
​
for i in range(10):_,x1,label2 = next(dataiter)concatenated = torch.cat((x0,x1),0)output1,output2 = net(x0.cuda(),x1.cuda())euclidean_distance = F.pairwise_distance(output1, output2)imshow(torchvision.utils.make_grid(concatenated),'Dissimilarity: {:.2f}'.format(euclidean_distance.item()))

img

参考

  • https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch

  • ORL Faces Database介绍_网络资源是无限的-CSDN博客_orl数据集

  • https://github.com/adambielski/siamese-triplet


http://chatgpt.dhexx.cn/article/dM8J3Gym.shtml

相关文章

51、目标的相似度检测模型Siamese部署rk3399pro、ncnn、mnn进行推理加速

基本思想:需要一个判断目标相似度的模型,来比对被检测目标和既定目标的相似度,测试图片仅有的几张图片,感觉一般,量化图片尽量多点对于rknn 链接: https://pan.baidu.com/s/1NFjnCBh5RqJXDxEjl9TzHg?pwdxev4 提取码:…

SPCL:Siamese Prototypical Contrastive Learning

论文链接:https://arxiv.org/abs/2208.08819 BMVC 2021 abstract CSL(Contrastive Self-supervised Learning)的一个缺点是,对比损失函数需要大量的负样本,以提供更好的理想互信息边界。 通过变大batch size来增加负样本数理,同…

Exploring Simple Siamese Representation Learning论文笔记

写在前面 大三狗随手记录,不喜勿喷。 主要思想 Siamese network常常被用来计算图像的两个增强之间的相似性,但可能会造成模型坍塌(即输出恒定)。作者在本文提出了一个非常简单的Simple Siamese network,简称Simsiam…

SiamCAR:Siamese Fully Convolutional Classification and Regression for Visual Tracking

文章目录 AbstractIntroductionProposed MethodFeature ExtractionBounding Box PredictionThe Tracking Phase 值得关注的几个问题Q1:输入的图片大小不一?Q2:在两者做相关性之前,如何得到特征图?Q3:两者的相关性计算是如何实现的&#xff1f…

TensorFlow搭建VGG-Siamese网络

TensorFlow搭建VGG-Siamese网络 Siamese原理 Siamese网络,中文称为孪生网络。大致结构如下图所示: Siamese网络有两个输入,一个输出。其中,两个输入经过相同的网络层知道成为一个n维向量,再对这个n维向量进行求距离&…

mesa 概述

技术关键词:mesa、OpenGL、dri、gpu、kmd、xsever 目录 一、mesa概述 二、mesa架构 1. 架构设计 2. 模块划分 三、mesa与linux图形系统中的其他模块的关系 四、mesa的编译 五、链接资源 总结 一、mesa概述 mesa是OpenGL、OpenGL ES、Vulkan、OpenCL的一个开…

Siamese 网络(Siamese network)

来源:Coursera吴恩达深度学习课程 上个文章One-Shot学习/一次学习(One-shot learning)中函数d的作用就是输入两张人脸图片,然后输出相似度。实现这个功能的一个方式就是用Siamese网络。 上图是常见的卷积网络,输入图片…

MISF:Multi-level Interactive Siamese Filtering for High-Fidelity Image Inpainting 论文解读与感想

深度学习模型被广泛应用于各种视觉任务的同时,似乎传统的图像处理方式已经被人们渐渐遗忘,然而很多时候传统图像处理方式的稳定性和可解释性依然是深度学习模型所不能达到的。本文是CVPR2022的一篇将传统与深度相结合进行inpainting的文章。 在图像inpa…

Siamese系列文章

说明 在学习目标追踪方面,慢慢读懂论文,记录论文的笔记,同时贴上一些别人写的非常优秀的帖子。 文章目录 说明综述类型笔记SiamFC笔记 SiamRPN笔记 DaSiamRPN笔记 SiamRPN笔记复现 SiamDW笔记 SiamFC笔记 UpdateNet笔记 SiamBAN笔记 SiamMa…

SiamRPN阅读笔记:High Performance Visual Tracking with Siamese Region Proposal Network

这是来自商汤的一篇文章 发表在CVPR2018上 论文地址 目录: 文章目录 摘要1.引言2.相关工作2.2 RPN2.3 One-shot learning 3.Siamese-RPN framework3.1 孪生特征提取子网络3.2 候选区域提取子网络3.3 训练阶段:端到端训练孪生RPN 4. Tracking as one-sho…

【度量学习】Siamese Network

基于2-channel network的图片相似度判别 一、相关理论 本篇博文主要讲解2015年CVPR的一篇关于图像相似度计算的文章:《Learning to Compare Image Patches via Convolutional Neural Networks》,本篇文章对经典的算法Siamese Networks 做了改进。学习这…

【论文阅读】Learning to Rank Proposals for Siamese Visual Tracking

Learning to Rank Proposals for Siamese Visual Tracking:2021 TIP 引入 There are two main challenges for visual tracking: 首先,待跟踪目标具有类不可知性和任意性,关于目标的先验信息很少。 其次,仅仅向跟踪器…

深度学习笔记-----多输入网络 (Siamese网络,Triplet网络)

目录 1,什么时候需要多个输入 2,常见的多输入网络 2.1 Siamese网络(孪生网络) 2.1 Triplet网络 1,什么时候需要多个输入 深度学习网络一般是输入都是一个,或者是一段视频切片,因为大部分的内容是对一张图像或者一段…

Siamese networks

Siamese Network 是一种神经网络的架构,而不是具体的某种网络,就像Seq2Seq一样,具体实现上可以使用RNN也可以使用CNN。Siamese Network 就像“连体的神经网络”,神经网络的“连体”是通过共享权值来实现的(共享权值即左…

Siamese Network理解(附代码)

author:DivinerShi 文章地址:http://blog.csdn.net/sxf1061926959/article/details/54836696 提起siamese network一般都会引用这两篇文章: 《Learning a similarity metric discriminatively, with application to face verification》和《 Hamming D…

详解Siamese网络

摘要 Siamese网络用途,原理,如何训练? 背景 在人脸识别中,存在所谓的one-shot问题。举例来说,就是对公司员工进行人脸识别,每个员工只给你一张照片(训练集样本少),并且…

Siamese网络(孪生网络)

1. Why Siamese 在人脸识别中,存在所谓的one-shot问题。举例来说,就是对公司员工进行人脸识别,每个员工只有一张照片(因为每个类别训练样本少),并且员工会离职、入职(每次变动都要重新训练模型…

Siamese网络(孪生神经网络)详解

SiameseFC Siamese网络(孪生神经网络)本文参考文章:Siamese背景 Siamese网络解决的问题要解决什么问题?用了什么方法解决?应用的场景: Siamese的创新Siamese的理论Siamese的损失函数——Contrastive Loss损…

8.HttpEntity,ResponseEntity

RequestBody请求体,获取一个请求的请求体内容就不用RequestParam RequestMapping("/testRequestBody")public String testRequestBody(RequestBody String body){System.out.println("请求体: "body);return "success";}只有表单才有…

使用restTemplate进行feign调用new HttpEntity<>报错解决方案

使用restTemplate进行feign调用new HttpEntity<>报错解决方案 问题背景HttpEntity<>标红解决方案心得Lyric&#xff1a; 沙漠之中怎么会有泥鳅 问题背景 今天才知道restTemplate可以直接调用feign&#xff0c;高级用法呀&#xff0c;但使用restTemplate进行feign调…