TensorFlow搭建VGG-Siamese网络

article/2025/9/30 22:04:36

TensorFlow搭建VGG-Siamese网络


  • Siamese原理

Siamese网络,中文称为孪生网络。大致结构如下图所示:在这里插入图片描述

Siamese网络有两个输入,一个输出。其中,两个输入经过相同的网络层知道成为一个n维向量,再对这个n维向量进行求距离,对此距离应用softmax函数,得到输出的结果。

例如,使用Siamese做一个人脸识别,那么输入就是两个人脸图像,若是同一个人输出1,若是不同的人则输出0。

首先,我们制作一个输入为(h, w, c),输出为(1, 128)的VGG模型,这里不使用完整的模型,我称为VGG-lite版。

import tensorflow as tf
from tensorflow import keras
from keras import backend as K
from tensorflow.keras import layers, Sequential
from tensorflow.keras.layers import Conv2D, ZeroPadding2D, Activation, MaxPooling2D, Dropout, Flatten, Dense, Lambda, Input
from tensorflow.keras.models import Model# 这里实现一个VGG网络,返回的是一个128维向量,用于siamese的输入
def VGG(X_input):X = X_inputX = Conv2D(64, (3,3), padding = 'same',activation='relu')(X)X = Conv2D(64, (3,3), padding = 'same',activation='relu')(X)X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Conv2D(128, (3,3), padding = 'same',activation='relu')(X)X = Conv2D(128, (3,3), padding = 'same',activation='relu')(X)X = Dropout(0.4)(X)X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Conv2D(256, (3,3), padding = 'same',activation='relu')(X)# X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Conv2D(256, (3,3), padding = 'same',activation='relu')(X)X = Conv2D(256, (3,3), padding = 'same',activation='relu')(X)X = Dropout(0.4)(X)X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Conv2D(512, (3,3), padding = 'same',activation='relu')(X)X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Conv2D(512, (3,3), padding = 'same',activation='relu')(X)X = Conv2D(512, (3,3), padding = 'same',activation='relu')(X)X = Dropout(0.4)(X)X = MaxPooling2D(pool_size=(2,2), strides=2)(X)X = Flatten()(X)X = Dense(1024, activation="relu")(X)X = Dense(128, activation="relu")(X)X = Lambda(lambda  x: K.l2_normalize(x,axis=1))(X)return X

这里对模型不再详细解释,只解释下对X的最后一步操作:通过keras.layers中的Lambda将128维的X进行L2正则化再输出。

对于模型构建其他部分的疑问,可以参考我的前两份文章。

接下来,我们要制作一个可以接受两个输入的模型。在TensorFlow中,只需要在定义模型的函数中,使用多次Input()即可获得多个输入。

def VGG_Siamese(input_shape):# 接收两个输入,X1_input和X2_input.X1_input = Input(input_shape)X2_input = Input(input_shape)X1 = ZeroPadding2D((3, 3), name='layer1')(X1_input)X2 = ZeroPadding2D((3, 3), name='layer2')(X2_input)X1 = VGG(X1)X2 = VGG(X2)print(X1)print(X2)l1_distance_layer = Lambda(lambda tensors: K.abs(tensors[0] - tensors[1]))l1_distance = l1_distance_layer([X1, X2])    X = Dense(512, activation='relu')(l1_distance)X = Dense(2, activation='softmax')(X)model = Model(inputs = [X1_input, X2_input], outputs = X)return model

在使用Input()获得两个输入后,将两个输入一同经过了VGG()函数,这说明两个输入会经历相同的卷积网络成为两个128维向量。而

	l1_distance_layer = Lambda(lambda tensors: K.abs(tensors[0] - tensors[1]))l1_distance = l1_distance_layer([X1, X2]) 

这两句是将得到的两个128维向量进行距离求和,使用差值绝对值求得,得到的结果也是一个128维向量。

再之后,将得到的128维向量经过全连接层与512维、2维(即classes维)连接,得到一个二维向量,这个二维向量使用"softmax"激活函数,得到预测结果。


通过上面的两个函数,我们已经完成了模型的构建,接下来,我们从处理数据集开始,讲解如何对此模型进行训练。

笔者选用的数据集是LFW数据集,各位可以自行选择数据集,下面介绍一种简单的数据集处理方法(LFW数据集有pairs.txt文件,处理方式与下面介绍的不一致,这并不影响,因为得到的数据集形式是相同的):

  • 因为不同数据集可能有不同的初步获取方式,因此这里假设我们获得了dataset_x(图像)、dataset_y(标签).

对于数据处理的思想是:首先取数据集中的任一图片,然后再随机取另一张图片(不要与第一张图片相同),将第一张图片加入X_L(这是一个list),将第二张图片加入X_R,如果两张图片是同一个人,将1加入labels(这是标签集),如果两张图片不是同一个人,将0加入labels。具体操作如下:

X_L = []
X_R = []
labels = []
for i in range(dataset_x.shape[0]):for j in range(4):  # 每个数据与四个其他数据对比a = random.randint(0,dataset_x.shape[0]-1)while a == i:a = random.randint(0,dataset_x.shape[0]-1)X_L.append(dataset_x[i])X_R.append(dataset_x[a])if dataset_y[i] == dataset_y[a]:labels.append(1)else:labels.append(0)

这样,我们得到了一个具有两个图片并且已经标志其是否为同一人的数据集。但是我们对于数据集的处理还没有完成,如果使用以上的数据集去进行训练,会有多个错误产生。

  • TensorFlow的模型训练应接收带有shape方法的数据集,而我们上面的数据集是list类型,不具有shape方法,要使其得到此方法,可按如下处理:
import numpy as np
X_L = np.array(X_L)
X_R = np.array(X_R)
labels = np.array(labels)

numpy.array()方法将list转化为array,具有shape方法。到这里,数据处理仍没有结束。还记得我们模型最后的输出吗?应该是(?, 2)维的向量,而我们的labels是(?, 1)维向量,这是怎么回事?
这里我们的labels向量使用0和1代表两种结果,因此对于每对图片都只有一个标签。要处理这个问题,有两种解决方案。

  • 第一种解决方案是,将模型最后的输出激活函数换为’sigmoid’并改为1维。这样便与标签集维数相同。
  • 第二种解决方案是,将标签转为2维,并且要与softmax输出匹配,即转化为独热编码。(0->(1,0), 1->(0,1)).

这里我们采用第二种解决方案

labels = to_categorical(labels, num_classes=2)

现在,我们可以获取我们的模型了:

model = VGG_Siamese(input_shape=x_train[0].shape)

设置模型参数:

model.compile(optimizer='adam', loss="categorical_crossentropy", metrics=['accuracy'])
# 如果刚才采用第一种解决方案,将loss改为'binary_crossentropy'

参数设置完毕后,可以开始训练模型了:

model.fit([X_L, X_R], labels, validation_split=0.2, batch_size=32, epochs=30, verbose=1)

这里只为了演示如何构建Siamese模型,因此选用的模型较简单,训练效果并不优秀,但是便于理解Siamese的工作原理和创建方式,为了优化训练效果,可以自己动手尝试更换模型进行训练。

  • 训练完成后,可以将模型保存:
save_path = "./weights/my_weight" # 填文件地址和名称
model.save_weights(save_path) # 保存权重
model.save(save_path+'h5') # 保存模型和权重

http://chatgpt.dhexx.cn/article/aaxLAKB7.shtml

相关文章

mesa 概述

技术关键词:mesa、OpenGL、dri、gpu、kmd、xsever 目录 一、mesa概述 二、mesa架构 1. 架构设计 2. 模块划分 三、mesa与linux图形系统中的其他模块的关系 四、mesa的编译 五、链接资源 总结 一、mesa概述 mesa是OpenGL、OpenGL ES、Vulkan、OpenCL的一个开…

Siamese 网络(Siamese network)

来源:Coursera吴恩达深度学习课程 上个文章One-Shot学习/一次学习(One-shot learning)中函数d的作用就是输入两张人脸图片,然后输出相似度。实现这个功能的一个方式就是用Siamese网络。 上图是常见的卷积网络,输入图片…

MISF:Multi-level Interactive Siamese Filtering for High-Fidelity Image Inpainting 论文解读与感想

深度学习模型被广泛应用于各种视觉任务的同时,似乎传统的图像处理方式已经被人们渐渐遗忘,然而很多时候传统图像处理方式的稳定性和可解释性依然是深度学习模型所不能达到的。本文是CVPR2022的一篇将传统与深度相结合进行inpainting的文章。 在图像inpa…

Siamese系列文章

说明 在学习目标追踪方面,慢慢读懂论文,记录论文的笔记,同时贴上一些别人写的非常优秀的帖子。 文章目录 说明综述类型笔记SiamFC笔记 SiamRPN笔记 DaSiamRPN笔记 SiamRPN笔记复现 SiamDW笔记 SiamFC笔记 UpdateNet笔记 SiamBAN笔记 SiamMa…

SiamRPN阅读笔记:High Performance Visual Tracking with Siamese Region Proposal Network

这是来自商汤的一篇文章 发表在CVPR2018上 论文地址 目录: 文章目录 摘要1.引言2.相关工作2.2 RPN2.3 One-shot learning 3.Siamese-RPN framework3.1 孪生特征提取子网络3.2 候选区域提取子网络3.3 训练阶段:端到端训练孪生RPN 4. Tracking as one-sho…

【度量学习】Siamese Network

基于2-channel network的图片相似度判别 一、相关理论 本篇博文主要讲解2015年CVPR的一篇关于图像相似度计算的文章:《Learning to Compare Image Patches via Convolutional Neural Networks》,本篇文章对经典的算法Siamese Networks 做了改进。学习这…

【论文阅读】Learning to Rank Proposals for Siamese Visual Tracking

Learning to Rank Proposals for Siamese Visual Tracking:2021 TIP 引入 There are two main challenges for visual tracking: 首先,待跟踪目标具有类不可知性和任意性,关于目标的先验信息很少。 其次,仅仅向跟踪器…

深度学习笔记-----多输入网络 (Siamese网络,Triplet网络)

目录 1,什么时候需要多个输入 2,常见的多输入网络 2.1 Siamese网络(孪生网络) 2.1 Triplet网络 1,什么时候需要多个输入 深度学习网络一般是输入都是一个,或者是一段视频切片,因为大部分的内容是对一张图像或者一段…

Siamese networks

Siamese Network 是一种神经网络的架构,而不是具体的某种网络,就像Seq2Seq一样,具体实现上可以使用RNN也可以使用CNN。Siamese Network 就像“连体的神经网络”,神经网络的“连体”是通过共享权值来实现的(共享权值即左…

Siamese Network理解(附代码)

author:DivinerShi 文章地址:http://blog.csdn.net/sxf1061926959/article/details/54836696 提起siamese network一般都会引用这两篇文章: 《Learning a similarity metric discriminatively, with application to face verification》和《 Hamming D…

详解Siamese网络

摘要 Siamese网络用途,原理,如何训练? 背景 在人脸识别中,存在所谓的one-shot问题。举例来说,就是对公司员工进行人脸识别,每个员工只给你一张照片(训练集样本少),并且…

Siamese网络(孪生网络)

1. Why Siamese 在人脸识别中,存在所谓的one-shot问题。举例来说,就是对公司员工进行人脸识别,每个员工只有一张照片(因为每个类别训练样本少),并且员工会离职、入职(每次变动都要重新训练模型…

Siamese网络(孪生神经网络)详解

SiameseFC Siamese网络(孪生神经网络)本文参考文章:Siamese背景 Siamese网络解决的问题要解决什么问题?用了什么方法解决?应用的场景: Siamese的创新Siamese的理论Siamese的损失函数——Contrastive Loss损…

8.HttpEntity,ResponseEntity

RequestBody请求体,获取一个请求的请求体内容就不用RequestParam RequestMapping("/testRequestBody")public String testRequestBody(RequestBody String body){System.out.println("请求体: "body);return "success";}只有表单才有…

使用restTemplate进行feign调用new HttpEntity<>报错解决方案

使用restTemplate进行feign调用new HttpEntity<>报错解决方案 问题背景HttpEntity<>标红解决方案心得Lyric&#xff1a; 沙漠之中怎么会有泥鳅 问题背景 今天才知道restTemplate可以直接调用feign&#xff0c;高级用法呀&#xff0c;但使用restTemplate进行feign调…

HttpClient 源码详解之HttpEntity

HttpClient 源码详解 之HttpEntity 1. 类释义 An entity that can be sent or received with an HTTP message. Entities can be found in some requests and in responses, where they are optional. There are three distinct types of entities in HttpCore, depending on …

System.Net.Http.HttpClient

本文主要是介绍如何用HttpClient请求带参数的服务&#xff0c;请求服务为某翻译API 直接上源码 1.添加using System.Net;的引用 using System.Net; 2.使用HttpClient发送请求 public static async void Fanyin_HttpClient(string fromString) {Console.WriteLine($"F…

ResponseEntity类和HttpEntity及跨平台路径问题

1. 简介 使用spring时&#xff0c;达到同一目的通常有很多方法&#xff0c;对处理http响应也是一样。本文我们学习如何通过ResponseEntity设置http相应内容、状态以及头信息。 ResponseEntity是HttpEntity的扩展&#xff0c;添加一个HttpStatus状态代码。在RestTemplate和Con…

RestTemplate发送HTTP、HTTPS请求

前面我们介绍了如何使用Apache的HttpClient发送HTTP请求,这里我们介绍Spring的Rest客户端(即:RestTemplate) 如何发送HTTP、HTTPS请求。注:HttpClient如何发送HTTPS请求,有机会的话也会再给出示例。 声明:本人一些内容摘录自其他朋友的博客&#xff0c;链接在本文末给出&#…

HttpEntity的用法

关于HttpEntity的用法 HttpEntity表示http的request和resposne实体&#xff0c;它由消息头和消息体组成。 从HttpEntity中可以获取http请求头和回应头&#xff0c;也可以获取http请求体和回应体信息。HttpEntity的使用&#xff0c;与RequestBody 、ResponseBody类似。 HttpEnti…