Tensorflow加载Vgg预训练模型

article/2025/8/25 0:21:12

    很多深度神经网络模型需要加载预训练过的Vgg参数,比如说:风格迁移、目标检测、图像标注等计算机视觉中常见的任务。那么到底如何加载Vgg模型呢?Vgg文件的参数到底有何意义呢?加载后的模型该如何使用呢?本文将以Vgg19为例子,详细说明Tensorflow如何加载Vgg预训练模型。

实验环境

  • GTX1050-ti,  cuda9.0
  • Window10,  Tensorflow 1.12

 

展示Vgg19构造

import tensorflow as tfimport numpy as np
import scipy.iodata_path = 'model/vgg19.mat'  # data_path指下载下来的Vgg19预训练模型的文件地址# 读取Vgg19文件
data = scipy.io.loadmat(data_path)
# 打印Vgg19的数据类型及其组成
print("type: ", type(data))
print("data.keys: ", data.keys())# 得到对应卷积核的矩阵
weights = data['layers'][0]
# 定义Vgg19的组成
layers = ('conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1','conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3','relu3_3', 'conv3_4', 'relu3_4', 'pool3','conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3','relu4_3', 'conv4_4', 'relu4_4', 'pool4','conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3','relu5_3', 'conv5_4', 'relu5_4'
)# 打印Vgg19不同卷积层所对应的维度
for i, name in enumerate(layers):kind = name[:4]if kind == 'conv':print("%s: %s" % (name, weights[i][0][0][2][0][0].shape))elif kind == 'relu':print(name)elif kind == 'pool':print(name)代码输出结果如下:
type:  <class 'dict'>
data.keys:  dict_keys(['__header__', '__version__', '__globals__', 'layers', 'meta'])conv1_1: (3, 3, 3, 64)
relu1_1
conv1_2: (3, 3, 64, 64)
relu1_2
pool1
conv2_1: (3, 3, 64, 128)
relu2_1
conv2_2: (3, 3, 128, 128)
relu2_2
pool2
conv3_1: (3, 3, 128, 256)
relu3_1
conv3_2: (3, 3, 256, 256)
relu3_2
conv3_3: (3, 3, 256, 256)
relu3_3
conv3_4: (3, 3, 256, 256)
relu3_4
pool3
conv4_1: (3, 3, 256, 512)
relu4_1
conv4_2: (3, 3, 512, 512)
relu4_2
conv4_3: (3, 3, 512, 512)
relu4_3
conv4_4: (3, 3, 512, 512)
relu4_4
pool4
conv5_1: (3, 3, 512, 512)
relu5_1
conv5_2: (3, 3, 512, 512)
relu5_2
conv5_3: (3, 3, 512, 512)
relu5_3
conv5_4: (3, 3, 512, 512)
relu5_4

    那么Vgg19真实的网络结构是怎么样子的呢,如下图所示:

Vgg19结构图
Vgg19 结构图

    在本文,主要讨论卷积模块,大家通过对比可以发现,我们打印出来的Vgg19结构及其卷积核的构造的确如论文中给出的Vgg19结构一致。

 

构建Vgg19模型

def _conv_layer(input, weights, bias):conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1),padding='SAME')return tf.nn.bias_add(conv, bias)def _pool_layer(input):return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),padding='SAME')class VGG19:layers = ('conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1','conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3','relu3_3', 'conv3_4', 'relu3_4', 'pool3','conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3','relu4_3', 'conv4_4', 'relu4_4', 'pool4','conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3','relu5_3', 'conv5_4', 'relu5_4')def __init__(self, data_path):data = scipy.io.loadmat(data_path)self.weights = data['layers'][0]def feed_forward(self, input_image, scope=None):# 定义net用来保存模型每一步输出的特征图net = {}current = input_imagewith tf.variable_scope(scope):for i, name in enumerate(self.layers):kind = name[:4]if kind == 'conv':kernels = self.weights[i][0][0][2][0][0]bias = self.weights[i][0][0][2][0][1]kernels = np.transpose(kernels, (1, 0, 2, 3))bias = bias.reshape(-1)current = _conv_layer(current, kernels, bias)elif kind == 'relu':current = tf.nn.relu(current)elif kind == 'pool':current = _pool_layer(current)# 在每一步都保存当前输出的特征图net[name] = currentreturn net

    在上面的代码中,我们定义了一个Vgg19的类别专门用来加载Vgg19模型,并且将每一层卷积得到的特征图保存到net中,最后返回这个net,用于代码后续的处理。

 

测试Vgg19模型

    在给出Vgg19的构造模型后,我们下一步就是如何用它,我们的思路如下:

  • 加载本地图片
  • 定义Vgg19模型,传入本地图片
  • 得到返回每一层的特征图
image_path = "data/test.jpg" # 本地的测试图片image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)代码结果如下所示:
{'conv1_1': <tf.Tensor 'vgg19_1/BiasAdd:0' shape=(1, ?, ?, 64) dtype=float32>,'relu1_1': <tf.Tensor 'vgg19_1/Relu:0' shape=(1, ?, ?, 64) dtype=float32>,'conv1_2': <tf.Tensor 'vgg19_1/BiasAdd_1:0' shape=(1, ?, ?, 64) dtype=float32>,'relu1_2': <tf.Tensor 'vgg19_1/Relu_1:0' shape=(1, ?, ?, 64) dtype=float32>,'pool1': <tf.Tensor 'vgg19_1/MaxPool:0' shape=(1, ?, ?, 64) dtype=float32>,'conv2_1': <tf.Tensor 'vgg19_1/BiasAdd_2:0' shape=(1, ?, ?, 128) dtype=float32>,'relu2_1': <tf.Tensor 'vgg19_1/Relu_2:0' shape=(1, ?, ?, 128) dtype=float32>,'conv2_2': <tf.Tensor 'vgg19_1/BiasAdd_3:0' shape=(1, ?, ?, 128) dtype=float32>,'relu2_2': <tf.Tensor 'vgg19_1/Relu_3:0' shape=(1, ?, ?, 128) dtype=float32>,'pool2': <tf.Tensor 'vgg19_1/MaxPool_1:0' shape=(1, ?, ?, 128) dtype=float32>,'conv3_1': <tf.Tensor 'vgg19_1/BiasAdd_4:0' shape=(1, ?, ?, 256) dtype=float32>,'relu3_1': <tf.Tensor 'vgg19_1/Relu_4:0' shape=(1, ?, ?, 256) dtype=float32>,'conv3_2': <tf.Tensor 'vgg19_1/BiasAdd_5:0' shape=(1, ?, ?, 256) dtype=float32>,'relu3_2': <tf.Tensor 'vgg19_1/Relu_5:0' shape=(1, ?, ?, 256) dtype=float32>,'conv3_3': <tf.Tensor 'vgg19_1/BiasAdd_6:0' shape=(1, ?, ?, 256) dtype=float32>,'relu3_3': <tf.Tensor 'vgg19_1/Relu_6:0' shape=(1, ?, ?, 256) dtype=float32>,'conv3_4': <tf.Tensor 'vgg19_1/BiasAdd_7:0' shape=(1, ?, ?, 256) dtype=float32>,'relu3_4': <tf.Tensor 'vgg19_1/Relu_7:0' shape=(1, ?, ?, 256) dtype=float32>,'pool3': <tf.Tensor 'vgg19_1/MaxPool_2:0' shape=(1, ?, ?, 256) dtype=float32>,'conv4_1': <tf.Tensor 'vgg19_1/BiasAdd_8:0' shape=(1, ?, ?, 512) dtype=float32>,'relu4_1': <tf.Tensor 'vgg19_1/Relu_8:0' shape=(1, ?, ?, 512) dtype=float32>,'conv4_2': <tf.Tensor 'vgg19_1/BiasAdd_9:0' shape=(1, ?, ?, 512) dtype=float32>,'relu4_2': <tf.Tensor 'vgg19_1/Relu_9:0' shape=(1, ?, ?, 512) dtype=float32>,'conv4_3': <tf.Tensor 'vgg19_1/BiasAdd_10:0' shape=(1, ?, ?, 512) dtype=float32>,'relu4_3': <tf.Tensor 'vgg19_1/Relu_10:0' shape=(1, ?, ?, 512) dtype=float32>,'conv4_4': <tf.Tensor 'vgg19_1/BiasAdd_11:0' shape=(1, ?, ?, 512) dtype=float32>,'relu4_4': <tf.Tensor 'vgg19_1/Relu_11:0' shape=(1, ?, ?, 512) dtype=float32>,'pool4': <tf.Tensor 'vgg19_1/MaxPool_3:0' shape=(1, ?, ?, 512) dtype=float32>,'conv5_1': <tf.Tensor 'vgg19_1/BiasAdd_12:0' shape=(1, ?, ?, 512) dtype=float32>,'relu5_1': <tf.Tensor 'vgg19_1/Relu_12:0' shape=(1, ?, ?, 512) dtype=float32>,'conv5_2': <tf.Tensor 'vgg19_1/BiasAdd_13:0' shape=(1, ?, ?, 512) dtype=float32>,'relu5_2': <tf.Tensor 'vgg19_1/Relu_13:0' shape=(1, ?, ?, 512) dtype=float32>,'conv5_3': <tf.Tensor 'vgg19_1/BiasAdd_14:0' shape=(1, ?, ?, 512) dtype=float32>,'relu5_3': <tf.Tensor 'vgg19_1/Relu_14:0' shape=(1, ?, ?, 512) dtype=float32>,'conv5_4': <tf.Tensor 'vgg19_1/BiasAdd_15:0' shape=(1, ?, ?, 512) dtype=float32>,'relu5_4': <tf.Tensor 'vgg19_1/Relu_15:0' shape=(1, ?, ?, 512) dtype=float32>}

    本文提供的测试代码是完成正确的,已经避免了很多使用Vgg19预训练模型的坑操作,比如:给图片添加维度,转换读取图片的的格式等,为什么这么做的详细原因可参考我的另一篇博客:Tensorflow加载Vgg预训练模型的几个注意事项。

    到这里,如何使用tensorflow读取Vgg19模型结束了,若是大家有其他疑惑,可在评论区留言,会定时回答。

 

备注:本文为作者原创,转载需注明出处!


http://chatgpt.dhexx.cn/article/9QxDwb51.shtml

相关文章

混淆矩阵、准确率、F1和召回率的具体实现及混淆矩阵的可视化

utils专栏不会细讲概念性的内容&#xff0c;偏向实际使用&#xff0c;如有问题&#xff0c;欢迎留言。如果对你有帮助就点个赞哈&#xff0c;也不搞什么粉丝可见有的没的&#xff0c;有帮助点个赞就ok 1、混淆矩阵、准确率、F1和召回率的计算 混淆矩阵 对于混淆矩阵的计算…

预编码技术

预编码的基本原理 TD-LTE下行传输采用了MIMO-OFDM的物理层构架&#xff0c;通过最多4个发射天线并行传输多个&#xff08;最多4个&#xff09;数据流&#xff0c;能够有效地提高峰值传输速率。LTE的物理层处理过程中&#xff0c;预编码是其核心功能模块&#xff0c;物理下行共…

pytorch 计算混淆矩阵

混淆矩阵是评估模型结果的一种指标 用来判断分类模型的好坏 预测对了 为对角线 还可以通过矩阵的上下角发现哪些容易出错 从这个 矩阵出发 可以得到 acc &#xff01; precision recall 特异度&#xff1f; 目标检测01笔记AP mAP recall precision是什么 查全率是什么 查准率…

Code::Blocks 相关

文库上的使用教程 http://blog.csdn.net/JGood/article/details/5252119 使用手册 http://blog.csdn.net/liquanhai/article/details/6618300 一&#xff0e;Code::blocks Code::blocks集成开发环境是一个支持编译、链接、调试许多种语言的IDE&#xff0c;支持VS6.0到VS200…

mysql8 sql_mode去掉only_full_group_by

1.查询版本与sqlmode: select version(), sql_mode; 2.修改sqlmode,执行下面两句代码&#xff1a; set global sql_modeSTRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION;set session sql_modeSTRICT_TRANS_TABLES,NO_…

- MySQL - 存储过程 Stored Procedure

文章目录 存储过程是什么存储过程的优点存储过程的缺点存储过程分类存储过程的查询语句存储过程的创建和调用语法实例 - IN、OUT、INOUT条件语句循环语句迭代 存储过程是什么 存储过程是一组为了完成特定功能的SQL语句集&#xff0c;存储在数据库中&#xff0c;一次编译多次使…

mysql的delete语句_mysql删除语句

展开全部 mysql删除语句如下&#xff1a; 1、delete删除一行&#xff1a;delete from student where id1。 2、delete删除多行&#xff1a;delete from student where in (1,2,3)3。 3、删62616964757a686964616fe78988e69d8331333433623162除表的所有数据&#xff1a;delete f…

mysql any all some

假设现在有两张表 a &#xff0c; b 如下 SELECT * FROM a WHERE id > ANY(SELECT id FROM b )# any 等价于 some 这里有两个id sql的意思大概是查询a表的所有&#xff0c;在where中a的id > 子表 b的id 这里用到了any(some) 他们的意思是一样的&#xff0c;就是说a表…

mysql 的 sql_mode.only_full_group_by属性解析

文章目录 1. 初始条件2. 现象3. 解决①&#xff1a;关闭sql_mode 的 only_full_group_by模式②&#xff1a;使用 ANY_VALUE() 抑制 ONLY_FULL_GROUP_BY 的影响 mysql8.0官网&#xff1a;处理 group by 1. 初始条件 现在有这样一张表&#xff0c;其中有两条 name 裤子 的数据…

MYSQL 删除语句

删除数据(DELETE) 如果你失忆了&#xff0c;希望你能想起曾经为了追求梦想的你。 数据库存储数据&#xff0c;总会有一些垃圾数据&#xff0c;也会有一些不需要用的数据了&#xff0c;这些情况下&#xff0c;我们就可以删除这些数据&#xff0c;释放出一定的空间&#xff0c;给…

MySql递归RECURSIVE的详解

背景&#xff1a; 在实际开发的过程中&#xff0c;我们会遇到一些数据是层级关系的、要展示数据子父级关系的时候&#xff0c; 第一个解决方案&#xff1a;将数据库中的所有数据都查询出来用Java代码进行处理。 第二个解决方案&#xff1a;可以考虑MySql中的RECURSIVE递归进行…

【MYSQL WITH recursive使用】

MYSQL WITH recursive使用 MYSQL WITH recursive使用语法用法1&#xff1a;输出1~n或者求 1~n的和用法2 父求子创建table&#xff1a;user求张三后代 MYSQL WITH recursive使用 由于在项目中有使用到recursive&#xff0c;因此在此做记录 语法 WITH recursive 表名 AS ( 初始…

U 盘安装 CentOS7 时盘符和安装源不可用问题

记录一次在物理机上安装 CentOS7 遇到的问题及其解决办法&#xff0c;主要有两个问题卡住了很久&#xff0c;一个是盘符问题&#xff0c;一个就是安装源不可用的问题 。 1 No such device 1.1 问题描述 用 U 盘在物理机上安装 CentOS7 的时候&#xff0c;出现 could not ins…

Vmware Workstation17 安装centos7(详细教程)

1、为什么安装Vmware Vmware Workstation可以帮我们他们创建虚拟机&#xff0c;模拟生产环境(linux)&#xff0c;搭建集群等。作为一个开发人员特别是后端开发人员是需要懂一些运维的&#xff0c;不需要精通&#xff0c;遇到问题能定位排查。我自己想做一个自己的项目&#xf…

用大白菜装centos7_大白菜安装centos7 踩坑记

1.准备一个U盘,安装大白菜。这个去大白菜官网下载安装就可以了 安装大白菜的时候最好选择FAT32(2021.1.7记录) 2.U盘装完大白菜后U盘会被分为两个主分区 一个盘是大白菜系统的,另外一个盘放一些工具的。 DBC里面就是放的一些工具 比如磁盘管理工具 3.把Centos7的镜像放入到DB…

Windows10安装Centos7双系统

Windows10安装Centos7双系统 1.摘要2.制作Centos 7系统盘3.Windows磁盘管理为Centos系统留出空间4.bios配置使得计算机系统选择从U盘进入5.安装Centos 7系统6.如何在Windows和Centos系统间切换7.一些问题与补救方法8.参考文章 1.摘要 本篇博客主要整理记录了在Win10 OS下安装使…

VM16上安装CentOS7详细安装教程【附图】

在VM16上安装Centos7 下载Centos安装Centos快照拍摄 下载Centos 进入Centos官网 Centos官网的下载地址 点击Download 点击x86_64 自己选择镜像下载 上面选择好自己的镜像后选择后缀名为 iso 的CentOS下载 安装Centos 点击创建新的虚拟机 选择好自定义后点击下一步…

VMware 16安装centos 7详细教程

VMware 16安装centos 7详细教程 前言&#xff1a;之前在VMware15 pro上安装centos7&#xff0c;但是启动虚拟机出现蓝屏&#xff0c;身边有好多小伙伴也遇到了这个问题&#xff0c;经过一番排查&#xff0c;找到了最简单的办法就是升级到VMware16pro&#xff0c;在启动就没有出…

在Vmware虚拟机中安装CentOS 7

前言&#xff1a;材料和工具 1. 安装好的VMWare虚拟机软件&#xff1a; ​VMWare16下载地址&#xff0c;获取码&#xff1a;ye1a 2.CentOS7下载地址&#xff1a;官方镜像下载 (centos.org) &#xff08;官方镜像站下载比较慢&#xff09; 清华大学镜像站&#xff1a;清华…

物理服务器安装CentOS 7操作系统

目录 1、下载系统镜像 2、制作安装盘 2.1 方法一&#xff1a;光盘制作 2.2 方法二&#xff1a;U盘制作 3、更改bios启动顺序 4、安装CentOS 7操作系统 4.1 安装命令选择&#xff0c;及常见错误解决 4.2 语言选择 4.3 时区选择 4.4 软件选择 4.5 安装位置选择 4.6 手…